From ca17f389290682d960fbc15f95893da3c95f5fc8 Mon Sep 17 00:00:00 2001 From: KernelDeimos Date: Fri, 7 Feb 2025 15:27:28 -0500 Subject: [PATCH] dev: normalize Mistral, get tool calls working Mistral has a lot of unnecessary quirks because the SDK coerces to snake_case. We should see if the OpenAI client works with Mistral, and it it does we'll use that instead so we don't have to maintain all these exceptional cases. --- .../src/modules/puterai/MistralAIService.js | 97 ++++++++----------- .../src/modules/puterai/lib/OpenAIUtil.js | 12 ++- 2 files changed, 48 insertions(+), 61 deletions(-) diff --git a/src/backend/src/modules/puterai/MistralAIService.js b/src/backend/src/modules/puterai/MistralAIService.js index e03506e9..67a1a0cc 100644 --- a/src/backend/src/modules/puterai/MistralAIService.js +++ b/src/backend/src/modules/puterai/MistralAIService.js @@ -24,6 +24,7 @@ const { TypedValue } = require("../../services/drivers/meta/Runtime"); const { nou } = require("../../util/langutil"); const axios = require('axios'); +const OpenAIUtil = require("./lib/OpenAIUtil"); const { TeePromise } = require('@heyputer/putility').libs.promise; @@ -215,71 +216,51 @@ class MistralAIService extends BaseService { * AI Chat completion method. * See AIChatService for more details. */ - async complete ({ messages, stream, model }) { + async complete ({ messages, stream, model, tools }) { - for ( let i = 0; i < messages.length; i++ ) { - const message = messages[i]; - if ( ! message.role ) message.role = 'user'; + messages = await OpenAIUtil.process_input_messages(messages); + for ( const message of messages ) { + if ( message.tool_calls ) { + message.toolCalls = message.tool_calls; + delete message.tool_calls; + } + if ( message.tool_call_id ) { + message.toolCallId = message.tool_call_id; + delete message.tool_call_id; + } } - if ( stream ) { - let usage_promise = new TeePromise(); + console.log('MESSAGES TO MISTRAL', messages); - const stream = new PassThrough(); - const retval = new TypedValue({ - $: 'stream', - content_type: 'application/x-ndjson', - chunked: true, - }, stream); - const completion = await this.client.chat.stream({ - model: model ?? this.get_default_model(), - messages, - }); - (async () => { - for await ( let chunk of completion ) { - // just because Mistral wants to be different - chunk = chunk.data; - - if ( chunk.usage ) { - usage_promise.resolve({ - input_tokens: chunk.usage.promptTokens, - output_tokens: chunk.usage.completionTokens, - }); - continue; - } - - if ( chunk.choices.length < 1 ) continue; - if ( chunk.choices[0].finish_reason ) { - stream.end(); - break; - } - if ( nou(chunk.choices[0].delta.content) ) continue; - const str = JSON.stringify({ - text: chunk.choices[0].delta.content - }); - stream.write(str + '\n'); - } - stream.end(); - })(); - - return new TypedValue({ $: 'ai-chat-intermediate' }, { - stream: true, - response: retval, - usage_promise: usage_promise, - }); - } - - const completion = await this.client.chat.complete({ + const completion = await this.client.chat[ + stream ? 'stream' : 'complete' + ]({ model: model ?? this.get_default_model(), + ...(tools ? { tools } : {}), messages, }); - // Expected case when mistralai/client-ts#23 is fixed - const ret = completion.choices[0]; - ret.usage = { - input_tokens: completion.usage.promptTokens, - output_tokens: completion.usage.completionTokens, - }; - return ret; + + return await OpenAIUtil.handle_completion_output({ + deviations: { + index_usage_from_stream_chunk: chunk => { + if ( ! chunk.usage ) return; + + const snake_usage = {}; + for ( const key in chunk.usage ) { + const snakeKey = key.replace(/([A-Z])/g, "_$1").toLowerCase(); + snake_usage[snakeKey] = chunk.usage[key]; + } + + return snake_usage; + }, + chunk_but_like_actually: chunk => chunk.data, + index_tool_calls_from_stream_choice: choice => choice.delta.toolCalls, + }, + completion, stream, + usage_calculator: OpenAIUtil.create_usage_calculator({ + model_details: this.models_array_.find(m => m.id === model), + }), + }); } } } diff --git a/src/backend/src/modules/puterai/lib/OpenAIUtil.js b/src/backend/src/modules/puterai/lib/OpenAIUtil.js index bdf0fade..76e6b025 100644 --- a/src/backend/src/modules/puterai/lib/OpenAIUtil.js +++ b/src/backend/src/modules/puterai/lib/OpenAIUtil.js @@ -84,7 +84,11 @@ module.exports = class OpenAIUtil { completion, usage_promise, }) => async ({ chatStream }) => { deviations = Object.assign({ + // affected by: Groq index_usage_from_stream_chunk: chunk => chunk.usage, + // affected by: Mistral + chunk_but_like_actually: chunk => chunk, + index_tool_calls_from_stream_choice: choice => choice.delta.tool_calls, }, deviations); const message = chatStream.message(); @@ -94,7 +98,8 @@ module.exports = class OpenAIUtil { const tool_call_blocks = []; let last_usage = null; - for await ( const chunk of completion ) { + for await ( let chunk of completion ) { + chunk = deviations.chunk_but_like_actually(chunk); if ( process.env.DEBUG ) { const delta = chunk?.choices?.[0]?.delta; console.log( @@ -119,12 +124,13 @@ module.exports = class OpenAIUtil { continue; } - if ( ! nou(choice.delta.tool_calls) ) { + const tool_calls = deviations.index_tool_calls_from_stream_choice(choice); + if ( ! nou(tool_calls) ) { if ( mode === 'text' ) { mode = 'tool'; textblock.end(); } - for ( const tool_call of choice.delta.tool_calls ) { + for ( const tool_call of tool_calls ) { if ( ! tool_call_blocks[tool_call.index] ) { toolblock = message.contentBlock({ type: 'tool_use',