diff --git a/src/backend/src/modules/puterai/AIChatService.js b/src/backend/src/modules/puterai/AIChatService.js index 5cbc2589..c8abfa5c 100644 --- a/src/backend/src/modules/puterai/AIChatService.js +++ b/src/backend/src/modules/puterai/AIChatService.js @@ -354,7 +354,8 @@ class AIChatService extends BaseService { } if ( parameters.messages ) { - Messages.normalize_messages(parameters.messages); + parameters.messages = + Messages.normalize_messages(parameters.messages); } if ( ! test_mode && ! await this.moderate(parameters) ) { diff --git a/src/backend/src/modules/puterai/ClaudeService.js b/src/backend/src/modules/puterai/ClaudeService.js index 462fec19..7ca6e82b 100644 --- a/src/backend/src/modules/puterai/ClaudeService.js +++ b/src/backend/src/modules/puterai/ClaudeService.js @@ -24,6 +24,7 @@ const { whatis } = require("../../util/langutil"); const { PassThrough } = require("stream"); const { TypedValue } = require("../../services/drivers/meta/Runtime"); const FunctionCalling = require("./lib/FunctionCalling"); +const Messages = require("./lib/Messages"); const { TeePromise } = require('@heyputer/putility').libs.promise; const PUTER_PROMPT = ` @@ -116,41 +117,10 @@ class ClaudeService extends BaseService { * @returns {TypedValue|Object} Returns either a TypedValue with streaming response or a completion object */ async complete ({ messages, stream, model, tools }) { - const adapted_messages = []; - tools = FunctionCalling.make_claude_tools(tools); - const system_prompts = []; - let previous_was_user = false; - for ( const message of messages ) { - if ( typeof message.content === 'string' ) { - message.content = { - type: 'text', - text: message.content, - }; - } - if ( whatis(message.content) !== 'array' ) { - message.content = [message.content]; - } - if ( ! message.role ) message.role = 'user'; - if ( message.role === 'user' && previous_was_user ) { - const last_msg = adapted_messages[adapted_messages.length-1]; - last_msg.content.push( - ...(Array.isArray ? message.content : [message.content]) - ); - continue; - } - if ( message.role === 'system' ) { - system_prompts.push(...message.content); - continue; - } - adapted_messages.push(message); - if ( message.role === 'user' ) { - previous_was_user = true; - } else { - previous_was_user = false; - } - } + let system_prompts; + [system_prompts, messages] = Messages.extract_and_remove_system_messages(messages); if ( stream ) { let usage_promise = new TeePromise(); @@ -167,7 +137,7 @@ class ClaudeService extends BaseService { max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096, temperature: 0, system: PUTER_PROMPT + JSON.stringify(system_prompts), - messages: adapted_messages, + messages, ...(tools ? { tools } : {}), }); const counts = { input_tokens: 0, output_tokens: 0 }; @@ -278,7 +248,7 @@ class ClaudeService extends BaseService { max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096, temperature: 0, system: PUTER_PROMPT + JSON.stringify(system_prompts), - messages: adapted_messages, + messages, ...(tools ? { tools } : {}), }); return { diff --git a/src/backend/src/modules/puterai/lib/Messages.js b/src/backend/src/modules/puterai/lib/Messages.js index e3ab6234..4a2eaf22 100644 --- a/src/backend/src/modules/puterai/lib/Messages.js +++ b/src/backend/src/modules/puterai/lib/Messages.js @@ -45,7 +45,35 @@ module.exports = class Messages { for ( let i=0 ; i < messages.length ; i++ ) { messages[i] = this.normalize_single_message(messages[i], params); } + + // If multiple messages are from the same role, merge them + let merged_messages = []; + let current_role = null; + for ( let i=0 ; i < messages.length ; i++ ) { + if ( current_role === messages[i].role ) { + merged_messages[merged_messages.length - 1].content.push(...messages[i].content); + } else { + merged_messages.push(messages[i]); + current_role = messages[i].role; + } + } + + return merged_messages; } + + static extract_and_remove_system_messages (messages) { + let system_messages = []; + let new_messages = []; + for ( let i=0 ; i < messages.length ; i++ ) { + if ( messages[i].role === 'system' ) { + system_messages.push(messages[i]); + } else { + new_messages.push(messages[i]); + } + } + return [system_messages, new_messages]; + } + static extract_text (messages) { return messages.map(m => { if ( whatis(m) === 'string' ) {