mirror of
https://github.com/Wisser/Jailer.git
synced 2026-05-06 11:20:14 -05:00
AI Integration, first steps
This commit is contained in:
Binary file not shown.
@@ -863,7 +863,7 @@ public class UIUtil {
|
||||
if (!(t instanceof CancellationException)) {
|
||||
t.printStackTrace();
|
||||
}
|
||||
if (!(t instanceof ClassNotFoundException)) {
|
||||
if (!(t instanceof ClassNotFoundException) && !(t instanceof IOException)) {
|
||||
while (t.getCause() != null && t != t.getCause() && !(t instanceof SqlException)) {
|
||||
t = t.getCause();
|
||||
}
|
||||
|
||||
@@ -54,11 +54,13 @@ public class AIProviderConfig {
|
||||
public final String apiUrl;
|
||||
public final String apiKey;
|
||||
public final String model;
|
||||
public final int maxTokens;
|
||||
|
||||
public AIProviderConfig(ProviderType providerType, String apiUrl, String apiKey, String model) {
|
||||
public AIProviderConfig(ProviderType providerType, String apiUrl, String apiKey, String model, int maxTokens) {
|
||||
this.providerType = providerType;
|
||||
this.apiUrl = (apiUrl != null && !apiUrl.isEmpty()) ? apiUrl : providerType.defaultApiUrl;
|
||||
this.apiKey = apiKey != null ? apiKey : "";
|
||||
this.model = (model != null && !model.isEmpty()) ? model : providerType.defaultModel;
|
||||
this.maxTokens = maxTokens > 0 ? maxTokens : 1024;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,9 +65,30 @@ public class AIQueryAssistant {
|
||||
ObjectNode body = buildRequestBody(question, history, schema, dbmsName, config, isAnthropic);
|
||||
JsonNode response = post(config.apiUrl, config.apiKey, body, isAnthropic);
|
||||
if (isAnthropic) {
|
||||
return response.path("content").get(0).path("text").asText("").trim();
|
||||
JsonNode contentNode = response.path("content");
|
||||
if (contentNode.isArray() && contentNode.size() > 0) {
|
||||
return contentNode.get(0).path("text").asText("").trim();
|
||||
}
|
||||
throw new IOException("Unexpected response format: missing 'content' array. Response: " + response.toString());
|
||||
} else {
|
||||
return response.path("choices").get(0).path("message").path("content").asText("").trim();
|
||||
// OpenAI-compatible: choices[0].message.content
|
||||
JsonNode choicesNode = response.path("choices");
|
||||
if (choicesNode.isArray() && choicesNode.size() > 0) {
|
||||
JsonNode messageNode = choicesNode.get(0).path("message");
|
||||
String content = messageNode.path("content").asText("");
|
||||
if (!content.isEmpty()) {
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
// Ollama-compatible: message.content (streaming response, single object)
|
||||
JsonNode messageNode = response.path("message");
|
||||
if (!messageNode.isMissingNode() && !messageNode.isNull()) {
|
||||
String content = messageNode.path("content").asText("");
|
||||
if (!content.isEmpty()) {
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
throw new IOException("Unexpected response format: missing 'choices' or 'message'. Response: " + response.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +100,7 @@ public class AIQueryAssistant {
|
||||
String schema, String dbmsName, AIProviderConfig config, boolean isAnthropic) {
|
||||
ObjectNode body = MAPPER.createObjectNode();
|
||||
body.put("model", config.model);
|
||||
body.put("max_tokens", 1024);
|
||||
body.put("max_tokens", config.maxTokens);
|
||||
// Schema lives in the system prompt so it is sent once, not repeated per user message.
|
||||
String systemPrompt = buildSystemPrompt(schema, dbmsName);
|
||||
|
||||
@@ -165,16 +186,67 @@ public class AIQueryAssistant {
|
||||
byte[] responseBytes;
|
||||
if (status >= 400) {
|
||||
InputStream es = conn.getErrorStream();
|
||||
responseBytes = (es != null) ? readAllBytes(es) : new byte[0];
|
||||
if (es != null) {
|
||||
responseBytes = readAllBytes(es);
|
||||
} else {
|
||||
try (InputStream is = conn.getInputStream()) {
|
||||
responseBytes = readAllBytes(is);
|
||||
} catch (IOException ignored) {
|
||||
responseBytes = new byte[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
try (InputStream is = conn.getInputStream()) {
|
||||
responseBytes = readAllBytes(is);
|
||||
}
|
||||
}
|
||||
_log.debug("RESPONSE {}\n Body: {}", status, new String(responseBytes, StandardCharsets.UTF_8).trim());
|
||||
String responseBody = new String(responseBytes, StandardCharsets.UTF_8).trim();
|
||||
_log.debug("RESPONSE {}\n Body: {}", status, responseBody);
|
||||
if (status >= 400) {
|
||||
throw new IOException("API error " + status + ": " + parseErrorMessage(responseBytes, status));
|
||||
}
|
||||
// Check if response is streamed (multiple JSON objects, one per line)
|
||||
String[] lines = responseBody.split("\\r?\\n");
|
||||
if (lines.length > 1 && responseBody.contains("\"done\":")) {
|
||||
// Streaming response - concatenate all message contents until done
|
||||
StringBuilder fullContent = new StringBuilder();
|
||||
for (String line : lines) {
|
||||
line = line.trim();
|
||||
if (line.isEmpty()) continue;
|
||||
try {
|
||||
JsonNode lineNode = MAPPER.readTree(line);
|
||||
JsonNode doneNode = lineNode.path("done");
|
||||
if (doneNode.asBoolean()) {
|
||||
// Last chunk, stop here
|
||||
break;
|
||||
}
|
||||
JsonNode messageNode = lineNode.path("message");
|
||||
String content = messageNode.path("content").asText("");
|
||||
if (!content.isEmpty()) {
|
||||
fullContent.append(content);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// skip invalid line
|
||||
}
|
||||
}
|
||||
// Build synthetic response matching expected format
|
||||
ObjectNode synthResponse = MAPPER.createObjectNode();
|
||||
if (isAnthropic) {
|
||||
// Anthropic: content is an array of text blocks
|
||||
ArrayNode contentArray = synthResponse.putArray("content");
|
||||
ObjectNode textBlock = contentArray.addObject();
|
||||
textBlock.put("type", "text");
|
||||
textBlock.put("text", fullContent.toString());
|
||||
} else {
|
||||
// OpenAI-compatible: choices[0].message.content
|
||||
ArrayNode choices = synthResponse.putArray("choices");
|
||||
ObjectNode choice = choices.addObject();
|
||||
ObjectNode message = choice.putObject("message");
|
||||
message.put("role", "assistant");
|
||||
message.put("content", fullContent.toString());
|
||||
}
|
||||
return synthResponse;
|
||||
}
|
||||
return MAPPER.readTree(responseBytes);
|
||||
} finally {
|
||||
conn.disconnect();
|
||||
@@ -188,6 +260,7 @@ public class AIQueryAssistant {
|
||||
List<String> cmd = new ArrayList<>();
|
||||
cmd.add("curl");
|
||||
cmd.add("-s");
|
||||
cmd.add("-f");
|
||||
cmd.add("-X"); cmd.add("POST");
|
||||
cmd.add("-H"); cmd.add("Content-Type: application/json");
|
||||
if (isAnthropic) {
|
||||
@@ -213,6 +286,16 @@ public class AIQueryAssistant {
|
||||
process.destroy();
|
||||
throw new IOException("curl timed out");
|
||||
}
|
||||
int exitCode = process.exitValue();
|
||||
if (exitCode != 0) {
|
||||
byte[] errBytes = readAllBytes(process.getErrorStream());
|
||||
String errStr = new String(errBytes, StandardCharsets.UTF_8).trim();
|
||||
if (errStr.length() > 0) {
|
||||
_log.debug("RESPONSE (curl) exitCode={} Body: {}", exitCode, errStr);
|
||||
throw new IOException("API error " + exitCode + ": " + errStr);
|
||||
}
|
||||
throw new IOException("curl failed with exit code " + exitCode);
|
||||
}
|
||||
if (responseBytes.length == 0) {
|
||||
byte[] errBytes = readAllBytes(process.getErrorStream());
|
||||
String curlErr = new String(errBytes, StandardCharsets.UTF_8).trim();
|
||||
@@ -247,19 +330,25 @@ public class AIQueryAssistant {
|
||||
if (responseBytes.length == 0) {
|
||||
return "HTTP " + status;
|
||||
}
|
||||
String responseJson = new String(responseBytes, StandardCharsets.UTF_8);
|
||||
try {
|
||||
JsonNode node = MAPPER.readTree(responseBytes);
|
||||
String msg = node.path("error").path("message").asText(null);
|
||||
if (msg == null) {
|
||||
msg = node.path("error").asText(null);
|
||||
}
|
||||
if (msg == null) {
|
||||
msg = node.path("message").asText(null);
|
||||
}
|
||||
if (msg != null && !msg.isEmpty()) {
|
||||
return msg;
|
||||
return msg + " (" + status + ")";
|
||||
}
|
||||
// Include full response body if no specific message found
|
||||
return responseJson.trim() + " (" + status + ")";
|
||||
} catch (IOException ignored) {
|
||||
// not JSON — fall through
|
||||
}
|
||||
String raw = new String(responseBytes, StandardCharsets.UTF_8).trim();
|
||||
String raw = responseJson.trim();
|
||||
if (raw.startsWith("<") || raw.toLowerCase(Locale.ROOT).contains("<html")) {
|
||||
File htmlFile = new File(System.getProperty("java.io.tmpdir"), "jailer-ai-error.html");
|
||||
try (FileOutputStream fos = new FileOutputStream(htmlFile)) {
|
||||
@@ -268,7 +357,8 @@ public class AIQueryAssistant {
|
||||
}
|
||||
return "HTTP " + status + " (HTML response saved to: " + htmlFile.getAbsolutePath() + ")";
|
||||
}
|
||||
return raw.length() > 300 ? raw.substring(0, 300) + "..." : raw;
|
||||
// Include full response body in error message
|
||||
return "HTTP " + status + " - Response: " + raw;
|
||||
}
|
||||
|
||||
private static String buildSystemPrompt(String schema, String dbmsName) {
|
||||
@@ -374,4 +464,4 @@ public class AIQueryAssistant {
|
||||
}
|
||||
|
||||
// TODO
|
||||
// TODO put comments into context
|
||||
// TODO session management: if the provider supports it, we could keep a session ID and reuse it for subsequent calls to maintain context without resending the full schema each time.
|
||||
|
||||
@@ -49,6 +49,7 @@ import org.fife.ui.rtextarea.RTextScrollPane;
|
||||
import net.sf.jailer.ui.syntaxtextarea.RSyntaxTextAreaWithSQLSyntaxStyle;
|
||||
|
||||
import net.sf.jailer.datamodel.DataModel;
|
||||
import net.sf.jailer.ui.UIUtil;
|
||||
import net.sf.jailer.ui.ai.AIProviderConfig;
|
||||
import net.sf.jailer.ui.ai.AIProviderConfig.ProviderType;
|
||||
import net.sf.jailer.ui.ai.AIQueryAssistant;
|
||||
@@ -67,6 +68,7 @@ public class AIQueryDialog extends JDialog {
|
||||
private static final String SETTING_PROVIDER = "aiProviderType";
|
||||
private static final String SETTING_API_URL = "aiApiUrl";
|
||||
private static final String SETTING_MODEL = "aiModel";
|
||||
private static final String SETTING_MAX_TOKENS = "aiMaxTokens";
|
||||
private static final String SETTING_API_KEY_PREFIX = "aiApiKey_";
|
||||
|
||||
private final DataModel dataModel;
|
||||
@@ -86,6 +88,7 @@ public class AIQueryDialog extends JDialog {
|
||||
private JComboBox<ProviderType> providerCombo;
|
||||
private JTextField urlField;
|
||||
private JTextField modelField;
|
||||
private JTextField maxTokensField;
|
||||
private JPasswordField apiKeyField;
|
||||
private JCheckBox saveBox;
|
||||
|
||||
@@ -121,6 +124,7 @@ public class AIQueryDialog extends JDialog {
|
||||
questionPanel.add(new JScrollPane(questionArea), BorderLayout.CENTER);
|
||||
|
||||
generateButton = new JButton("Generate SQL");
|
||||
generateButton.setEnabled(false);
|
||||
statusLabel = new JLabel(" ");
|
||||
generateButton.addActionListener(e -> onGenerate());
|
||||
JPanel genRow = new JPanel(new FlowLayout(FlowLayout.LEFT, 6, 0));
|
||||
@@ -128,6 +132,24 @@ public class AIQueryDialog extends JDialog {
|
||||
genRow.add(statusLabel);
|
||||
questionPanel.add(genRow, BorderLayout.SOUTH);
|
||||
|
||||
questionArea.getDocument().addDocumentListener(new javax.swing.event.DocumentListener() {
|
||||
@Override
|
||||
public void insertUpdate(javax.swing.event.DocumentEvent e) {
|
||||
updateGenerateButton();
|
||||
}
|
||||
@Override
|
||||
public void removeUpdate(javax.swing.event.DocumentEvent e) {
|
||||
updateGenerateButton();
|
||||
}
|
||||
@Override
|
||||
public void changedUpdate(javax.swing.event.DocumentEvent e) {
|
||||
updateGenerateButton();
|
||||
}
|
||||
private void updateGenerateButton() {
|
||||
generateButton.setEnabled(!questionArea.getText().trim().isEmpty());
|
||||
}
|
||||
});
|
||||
|
||||
// SQL result area
|
||||
JPanel resultPanel = new JPanel(new BorderLayout(4, 4));
|
||||
resultPanel.add(new JLabel("Generated SQL:"), BorderLayout.NORTH);
|
||||
@@ -162,7 +184,9 @@ public class AIQueryDialog extends JDialog {
|
||||
insertButton.addActionListener(e -> {
|
||||
String sql = sqlArea.getText().trim();
|
||||
if (!sql.isEmpty()) {
|
||||
sqlConsumer.accept(sql);
|
||||
String comment = buildCommentForHistory();
|
||||
String combined = comment + "\n" + sql;
|
||||
sqlConsumer.accept(combined);
|
||||
dispose();
|
||||
}
|
||||
});
|
||||
@@ -201,6 +225,10 @@ public class AIQueryDialog extends JDialog {
|
||||
|
||||
urlField = new JTextField(savedUrl != null ? savedUrl : savedProvider.defaultApiUrl, 36);
|
||||
modelField = new JTextField(savedModel != null ? savedModel : savedProvider.defaultModel, 18);
|
||||
maxTokensField = new JTextField((String) UISettings.restore(SETTING_MAX_TOKENS), 6);
|
||||
if (maxTokensField.getText().isEmpty()) {
|
||||
maxTokensField.setText("1024");
|
||||
}
|
||||
apiKeyField = new JPasswordField(36);
|
||||
if (savedKey != null) {
|
||||
apiKeyField.setText(savedKey);
|
||||
@@ -219,6 +247,9 @@ public class AIQueryDialog extends JDialog {
|
||||
fc.gridwidth = 1; fc.weightx = 0;
|
||||
fc.gridx = 5; fc.gridy = 1; panel.add(saveBox, fc);
|
||||
|
||||
lc.gridx = 0; lc.gridy = 2; panel.add(new JLabel("Max Tokens:"), lc);
|
||||
fc.gridx = 1; fc.gridy = 2; panel.add(maxTokensField, fc);
|
||||
|
||||
ProviderType[] prev = { savedProvider };
|
||||
providerCombo.addItemListener(e -> {
|
||||
if (e.getStateChange() != ItemEvent.SELECTED) {
|
||||
@@ -252,17 +283,29 @@ public class AIQueryDialog extends JDialog {
|
||||
return;
|
||||
}
|
||||
|
||||
int maxTokens = 1024;
|
||||
try {
|
||||
maxTokens = Integer.parseInt(maxTokensField.getText().trim());
|
||||
if (maxTokens <= 0) {
|
||||
maxTokens = 1024;
|
||||
}
|
||||
} catch (NumberFormatException e) {
|
||||
// ignore, use default
|
||||
}
|
||||
|
||||
AIProviderConfig config = new AIProviderConfig(
|
||||
(ProviderType) providerCombo.getSelectedItem(),
|
||||
urlField.getText().trim(),
|
||||
apiKey,
|
||||
modelField.getText().trim()
|
||||
modelField.getText().trim(),
|
||||
maxTokens
|
||||
);
|
||||
|
||||
if (saveBox.isSelected()) {
|
||||
UISettings.store(SETTING_PROVIDER, config.providerType.name());
|
||||
UISettings.store(SETTING_API_URL, config.apiUrl);
|
||||
UISettings.store(SETTING_MODEL, config.model);
|
||||
UISettings.store(SETTING_MAX_TOKENS, String.valueOf(config.maxTokens));
|
||||
UISettings.store(SETTING_API_KEY_PREFIX + config.providerType.name(), config.apiKey);
|
||||
}
|
||||
|
||||
@@ -295,11 +338,9 @@ public class AIQueryDialog extends JDialog {
|
||||
updateHistoryDisplay();
|
||||
}
|
||||
} catch (ExecutionException ex) {
|
||||
String msg = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage();
|
||||
sqlArea.setText("Error: " + msg);
|
||||
UIUtil.showException(AIQueryDialog.this, "SQL Generation Error", ex);
|
||||
} catch (InterruptedException ex) {
|
||||
Thread.currentThread().interrupt();
|
||||
sqlArea.setText("Request interrupted.");
|
||||
}
|
||||
}
|
||||
}.execute();
|
||||
@@ -354,4 +395,33 @@ public class AIQueryDialog extends JDialog {
|
||||
Object legacy = UISettings.restore("aiApiKey");
|
||||
return legacy instanceof String ? (String) legacy : null;
|
||||
}
|
||||
|
||||
private String buildCommentForHistory() {
|
||||
// Collect only user messages
|
||||
List<String> userMessages = new ArrayList<>();
|
||||
for (ConversationMessage msg : conversationHistory) {
|
||||
if ("user".equals(msg.role)) {
|
||||
// Replace newlines with spaces to keep each message on one line in the comment
|
||||
String cleanedContent = msg.content.replaceAll("[\\r\\n]+", " ");
|
||||
userMessages.add(cleanedContent);
|
||||
}
|
||||
}
|
||||
if (userMessages.isEmpty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("/* Ask AI:\n");
|
||||
if (userMessages.size() == 1) {
|
||||
sb.append(userMessages.get(0));
|
||||
} else {
|
||||
for (String msg : userMessages) {
|
||||
sb.append("- ").append(msg).append("\n");
|
||||
}
|
||||
// Remove trailing newline
|
||||
sb.setLength(sb.length() - 1);
|
||||
}
|
||||
sb.append("\n */");
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,7 +525,7 @@ public abstract class SQLConsole extends javax.swing.JPanel {
|
||||
new AIQueryDialog(SwingUtilities.getWindowAncestor(SQLConsole.this), dm, dbmsName,
|
||||
sql -> editorPane.setText(sql)).setVisible(true);
|
||||
});
|
||||
// jToolBar1.add(aiButton, 4); TODO
|
||||
jToolBar1.add(aiButton, 4);
|
||||
jToolBar1.add(new JToolBar.Separator(), 5);
|
||||
|
||||
limitComboBox.setModel(new DefaultComboBoxModel(DataBrowser.ROW_LIMITS));
|
||||
|
||||
Reference in New Issue
Block a user