AI Integration, first steps

This commit is contained in:
Ralf Wisser
2026-04-30 15:15:21 +02:00
parent 0da151b839
commit 33ae4781c0
6 changed files with 179 additions and 17 deletions
Binary file not shown.
+1 -1
View File
@@ -863,7 +863,7 @@ public class UIUtil {
if (!(t instanceof CancellationException)) {
t.printStackTrace();
}
if (!(t instanceof ClassNotFoundException)) {
if (!(t instanceof ClassNotFoundException) && !(t instanceof IOException)) {
while (t.getCause() != null && t != t.getCause() && !(t instanceof SqlException)) {
t = t.getCause();
}
@@ -54,11 +54,13 @@ public class AIProviderConfig {
public final String apiUrl;
public final String apiKey;
public final String model;
public final int maxTokens;
public AIProviderConfig(ProviderType providerType, String apiUrl, String apiKey, String model) {
public AIProviderConfig(ProviderType providerType, String apiUrl, String apiKey, String model, int maxTokens) {
this.providerType = providerType;
this.apiUrl = (apiUrl != null && !apiUrl.isEmpty()) ? apiUrl : providerType.defaultApiUrl;
this.apiKey = apiKey != null ? apiKey : "";
this.model = (model != null && !model.isEmpty()) ? model : providerType.defaultModel;
this.maxTokens = maxTokens > 0 ? maxTokens : 1024;
}
}
@@ -65,9 +65,30 @@ public class AIQueryAssistant {
ObjectNode body = buildRequestBody(question, history, schema, dbmsName, config, isAnthropic);
JsonNode response = post(config.apiUrl, config.apiKey, body, isAnthropic);
if (isAnthropic) {
return response.path("content").get(0).path("text").asText("").trim();
JsonNode contentNode = response.path("content");
if (contentNode.isArray() && contentNode.size() > 0) {
return contentNode.get(0).path("text").asText("").trim();
}
throw new IOException("Unexpected response format: missing 'content' array. Response: " + response.toString());
} else {
return response.path("choices").get(0).path("message").path("content").asText("").trim();
// OpenAI-compatible: choices[0].message.content
JsonNode choicesNode = response.path("choices");
if (choicesNode.isArray() && choicesNode.size() > 0) {
JsonNode messageNode = choicesNode.get(0).path("message");
String content = messageNode.path("content").asText("");
if (!content.isEmpty()) {
return content.trim();
}
}
// Ollama-compatible: message.content (streaming response, single object)
JsonNode messageNode = response.path("message");
if (!messageNode.isMissingNode() && !messageNode.isNull()) {
String content = messageNode.path("content").asText("");
if (!content.isEmpty()) {
return content.trim();
}
}
throw new IOException("Unexpected response format: missing 'choices' or 'message'. Response: " + response.toString());
}
}
@@ -79,7 +100,7 @@ public class AIQueryAssistant {
String schema, String dbmsName, AIProviderConfig config, boolean isAnthropic) {
ObjectNode body = MAPPER.createObjectNode();
body.put("model", config.model);
body.put("max_tokens", 1024);
body.put("max_tokens", config.maxTokens);
// Schema lives in the system prompt so it is sent once, not repeated per user message.
String systemPrompt = buildSystemPrompt(schema, dbmsName);
@@ -165,16 +186,67 @@ public class AIQueryAssistant {
byte[] responseBytes;
if (status >= 400) {
InputStream es = conn.getErrorStream();
responseBytes = (es != null) ? readAllBytes(es) : new byte[0];
if (es != null) {
responseBytes = readAllBytes(es);
} else {
try (InputStream is = conn.getInputStream()) {
responseBytes = readAllBytes(is);
} catch (IOException ignored) {
responseBytes = new byte[0];
}
}
} else {
try (InputStream is = conn.getInputStream()) {
responseBytes = readAllBytes(is);
}
}
_log.debug("RESPONSE {}\n Body: {}", status, new String(responseBytes, StandardCharsets.UTF_8).trim());
String responseBody = new String(responseBytes, StandardCharsets.UTF_8).trim();
_log.debug("RESPONSE {}\n Body: {}", status, responseBody);
if (status >= 400) {
throw new IOException("API error " + status + ": " + parseErrorMessage(responseBytes, status));
}
// Check if response is streamed (multiple JSON objects, one per line)
String[] lines = responseBody.split("\\r?\\n");
if (lines.length > 1 && responseBody.contains("\"done\":")) {
// Streaming response - concatenate all message contents until done
StringBuilder fullContent = new StringBuilder();
for (String line : lines) {
line = line.trim();
if (line.isEmpty()) continue;
try {
JsonNode lineNode = MAPPER.readTree(line);
JsonNode doneNode = lineNode.path("done");
if (doneNode.asBoolean()) {
// Last chunk, stop here
break;
}
JsonNode messageNode = lineNode.path("message");
String content = messageNode.path("content").asText("");
if (!content.isEmpty()) {
fullContent.append(content);
}
} catch (IOException e) {
// skip invalid line
}
}
// Build synthetic response matching expected format
ObjectNode synthResponse = MAPPER.createObjectNode();
if (isAnthropic) {
// Anthropic: content is an array of text blocks
ArrayNode contentArray = synthResponse.putArray("content");
ObjectNode textBlock = contentArray.addObject();
textBlock.put("type", "text");
textBlock.put("text", fullContent.toString());
} else {
// OpenAI-compatible: choices[0].message.content
ArrayNode choices = synthResponse.putArray("choices");
ObjectNode choice = choices.addObject();
ObjectNode message = choice.putObject("message");
message.put("role", "assistant");
message.put("content", fullContent.toString());
}
return synthResponse;
}
return MAPPER.readTree(responseBytes);
} finally {
conn.disconnect();
@@ -188,6 +260,7 @@ public class AIQueryAssistant {
List<String> cmd = new ArrayList<>();
cmd.add("curl");
cmd.add("-s");
cmd.add("-f");
cmd.add("-X"); cmd.add("POST");
cmd.add("-H"); cmd.add("Content-Type: application/json");
if (isAnthropic) {
@@ -213,6 +286,16 @@ public class AIQueryAssistant {
process.destroy();
throw new IOException("curl timed out");
}
int exitCode = process.exitValue();
if (exitCode != 0) {
byte[] errBytes = readAllBytes(process.getErrorStream());
String errStr = new String(errBytes, StandardCharsets.UTF_8).trim();
if (errStr.length() > 0) {
_log.debug("RESPONSE (curl) exitCode={} Body: {}", exitCode, errStr);
throw new IOException("API error " + exitCode + ": " + errStr);
}
throw new IOException("curl failed with exit code " + exitCode);
}
if (responseBytes.length == 0) {
byte[] errBytes = readAllBytes(process.getErrorStream());
String curlErr = new String(errBytes, StandardCharsets.UTF_8).trim();
@@ -247,19 +330,25 @@ public class AIQueryAssistant {
if (responseBytes.length == 0) {
return "HTTP " + status;
}
String responseJson = new String(responseBytes, StandardCharsets.UTF_8);
try {
JsonNode node = MAPPER.readTree(responseBytes);
String msg = node.path("error").path("message").asText(null);
if (msg == null) {
msg = node.path("error").asText(null);
}
if (msg == null) {
msg = node.path("message").asText(null);
}
if (msg != null && !msg.isEmpty()) {
return msg;
return msg + " (" + status + ")";
}
// Include full response body if no specific message found
return responseJson.trim() + " (" + status + ")";
} catch (IOException ignored) {
// not JSON — fall through
}
String raw = new String(responseBytes, StandardCharsets.UTF_8).trim();
String raw = responseJson.trim();
if (raw.startsWith("<") || raw.toLowerCase(Locale.ROOT).contains("<html")) {
File htmlFile = new File(System.getProperty("java.io.tmpdir"), "jailer-ai-error.html");
try (FileOutputStream fos = new FileOutputStream(htmlFile)) {
@@ -268,7 +357,8 @@ public class AIQueryAssistant {
}
return "HTTP " + status + " (HTML response saved to: " + htmlFile.getAbsolutePath() + ")";
}
return raw.length() > 300 ? raw.substring(0, 300) + "..." : raw;
// Include full response body in error message
return "HTTP " + status + " - Response: " + raw;
}
private static String buildSystemPrompt(String schema, String dbmsName) {
@@ -374,4 +464,4 @@ public class AIQueryAssistant {
}
// TODO
// TODO put comments into context
// TODO session management: if the provider supports it, we could keep a session ID and reuse it for subsequent calls to maintain context without resending the full schema each time.
@@ -49,6 +49,7 @@ import org.fife.ui.rtextarea.RTextScrollPane;
import net.sf.jailer.ui.syntaxtextarea.RSyntaxTextAreaWithSQLSyntaxStyle;
import net.sf.jailer.datamodel.DataModel;
import net.sf.jailer.ui.UIUtil;
import net.sf.jailer.ui.ai.AIProviderConfig;
import net.sf.jailer.ui.ai.AIProviderConfig.ProviderType;
import net.sf.jailer.ui.ai.AIQueryAssistant;
@@ -67,6 +68,7 @@ public class AIQueryDialog extends JDialog {
private static final String SETTING_PROVIDER = "aiProviderType";
private static final String SETTING_API_URL = "aiApiUrl";
private static final String SETTING_MODEL = "aiModel";
private static final String SETTING_MAX_TOKENS = "aiMaxTokens";
private static final String SETTING_API_KEY_PREFIX = "aiApiKey_";
private final DataModel dataModel;
@@ -86,6 +88,7 @@ public class AIQueryDialog extends JDialog {
private JComboBox<ProviderType> providerCombo;
private JTextField urlField;
private JTextField modelField;
private JTextField maxTokensField;
private JPasswordField apiKeyField;
private JCheckBox saveBox;
@@ -121,6 +124,7 @@ public class AIQueryDialog extends JDialog {
questionPanel.add(new JScrollPane(questionArea), BorderLayout.CENTER);
generateButton = new JButton("Generate SQL");
generateButton.setEnabled(false);
statusLabel = new JLabel(" ");
generateButton.addActionListener(e -> onGenerate());
JPanel genRow = new JPanel(new FlowLayout(FlowLayout.LEFT, 6, 0));
@@ -128,6 +132,24 @@ public class AIQueryDialog extends JDialog {
genRow.add(statusLabel);
questionPanel.add(genRow, BorderLayout.SOUTH);
questionArea.getDocument().addDocumentListener(new javax.swing.event.DocumentListener() {
@Override
public void insertUpdate(javax.swing.event.DocumentEvent e) {
updateGenerateButton();
}
@Override
public void removeUpdate(javax.swing.event.DocumentEvent e) {
updateGenerateButton();
}
@Override
public void changedUpdate(javax.swing.event.DocumentEvent e) {
updateGenerateButton();
}
private void updateGenerateButton() {
generateButton.setEnabled(!questionArea.getText().trim().isEmpty());
}
});
// SQL result area
JPanel resultPanel = new JPanel(new BorderLayout(4, 4));
resultPanel.add(new JLabel("Generated SQL:"), BorderLayout.NORTH);
@@ -162,7 +184,9 @@ public class AIQueryDialog extends JDialog {
insertButton.addActionListener(e -> {
String sql = sqlArea.getText().trim();
if (!sql.isEmpty()) {
sqlConsumer.accept(sql);
String comment = buildCommentForHistory();
String combined = comment + "\n" + sql;
sqlConsumer.accept(combined);
dispose();
}
});
@@ -201,6 +225,10 @@ public class AIQueryDialog extends JDialog {
urlField = new JTextField(savedUrl != null ? savedUrl : savedProvider.defaultApiUrl, 36);
modelField = new JTextField(savedModel != null ? savedModel : savedProvider.defaultModel, 18);
maxTokensField = new JTextField((String) UISettings.restore(SETTING_MAX_TOKENS), 6);
if (maxTokensField.getText().isEmpty()) {
maxTokensField.setText("1024");
}
apiKeyField = new JPasswordField(36);
if (savedKey != null) {
apiKeyField.setText(savedKey);
@@ -219,6 +247,9 @@ public class AIQueryDialog extends JDialog {
fc.gridwidth = 1; fc.weightx = 0;
fc.gridx = 5; fc.gridy = 1; panel.add(saveBox, fc);
lc.gridx = 0; lc.gridy = 2; panel.add(new JLabel("Max Tokens:"), lc);
fc.gridx = 1; fc.gridy = 2; panel.add(maxTokensField, fc);
ProviderType[] prev = { savedProvider };
providerCombo.addItemListener(e -> {
if (e.getStateChange() != ItemEvent.SELECTED) {
@@ -252,17 +283,29 @@ public class AIQueryDialog extends JDialog {
return;
}
int maxTokens = 1024;
try {
maxTokens = Integer.parseInt(maxTokensField.getText().trim());
if (maxTokens <= 0) {
maxTokens = 1024;
}
} catch (NumberFormatException e) {
// ignore, use default
}
AIProviderConfig config = new AIProviderConfig(
(ProviderType) providerCombo.getSelectedItem(),
urlField.getText().trim(),
apiKey,
modelField.getText().trim()
modelField.getText().trim(),
maxTokens
);
if (saveBox.isSelected()) {
UISettings.store(SETTING_PROVIDER, config.providerType.name());
UISettings.store(SETTING_API_URL, config.apiUrl);
UISettings.store(SETTING_MODEL, config.model);
UISettings.store(SETTING_MAX_TOKENS, String.valueOf(config.maxTokens));
UISettings.store(SETTING_API_KEY_PREFIX + config.providerType.name(), config.apiKey);
}
@@ -295,11 +338,9 @@ public class AIQueryDialog extends JDialog {
updateHistoryDisplay();
}
} catch (ExecutionException ex) {
String msg = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage();
sqlArea.setText("Error: " + msg);
UIUtil.showException(AIQueryDialog.this, "SQL Generation Error", ex);
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
sqlArea.setText("Request interrupted.");
}
}
}.execute();
@@ -354,4 +395,33 @@ public class AIQueryDialog extends JDialog {
Object legacy = UISettings.restore("aiApiKey");
return legacy instanceof String ? (String) legacy : null;
}
private String buildCommentForHistory() {
// Collect only user messages
List<String> userMessages = new ArrayList<>();
for (ConversationMessage msg : conversationHistory) {
if ("user".equals(msg.role)) {
// Replace newlines with spaces to keep each message on one line in the comment
String cleanedContent = msg.content.replaceAll("[\\r\\n]+", " ");
userMessages.add(cleanedContent);
}
}
if (userMessages.isEmpty()) {
return "";
}
StringBuilder sb = new StringBuilder();
sb.append("/* Ask AI:\n");
if (userMessages.size() == 1) {
sb.append(userMessages.get(0));
} else {
for (String msg : userMessages) {
sb.append("- ").append(msg).append("\n");
}
// Remove trailing newline
sb.setLength(sb.length() - 1);
}
sb.append("\n */");
return sb.toString();
}
}
@@ -525,7 +525,7 @@ public abstract class SQLConsole extends javax.swing.JPanel {
new AIQueryDialog(SwingUtilities.getWindowAncestor(SQLConsole.this), dm, dbmsName,
sql -> editorPane.setText(sql)).setVisible(true);
});
// jToolBar1.add(aiButton, 4); TODO
jToolBar1.add(aiButton, 4);
jToolBar1.add(new JToolBar.Separator(), 5);
limitComboBox.setModel(new DefaultComboBoxModel(DataBrowser.ROW_LIMITS));