mirror of
https://github.com/HeyPuter/puter.git
synced 2026-01-07 21:50:25 -06:00
dev: normalize Mistral, get tool calls working
Mistral has a lot of unnecessary quirks because the SDK coerces to snake_case. We should see if the OpenAI client works with Mistral, and it it does we'll use that instead so we don't have to maintain all these exceptional cases.
This commit is contained in:
@@ -24,6 +24,7 @@ const { TypedValue } = require("../../services/drivers/meta/Runtime");
|
||||
const { nou } = require("../../util/langutil");
|
||||
|
||||
const axios = require('axios');
|
||||
const OpenAIUtil = require("./lib/OpenAIUtil");
|
||||
const { TeePromise } = require('@heyputer/putility').libs.promise;
|
||||
|
||||
|
||||
@@ -215,71 +216,51 @@ class MistralAIService extends BaseService {
|
||||
* AI Chat completion method.
|
||||
* See AIChatService for more details.
|
||||
*/
|
||||
async complete ({ messages, stream, model }) {
|
||||
async complete ({ messages, stream, model, tools }) {
|
||||
|
||||
for ( let i = 0; i < messages.length; i++ ) {
|
||||
const message = messages[i];
|
||||
if ( ! message.role ) message.role = 'user';
|
||||
messages = await OpenAIUtil.process_input_messages(messages);
|
||||
for ( const message of messages ) {
|
||||
if ( message.tool_calls ) {
|
||||
message.toolCalls = message.tool_calls;
|
||||
delete message.tool_calls;
|
||||
}
|
||||
if ( message.tool_call_id ) {
|
||||
message.toolCallId = message.tool_call_id;
|
||||
delete message.tool_call_id;
|
||||
}
|
||||
}
|
||||
|
||||
if ( stream ) {
|
||||
let usage_promise = new TeePromise();
|
||||
console.log('MESSAGES TO MISTRAL', messages);
|
||||
|
||||
const stream = new PassThrough();
|
||||
const retval = new TypedValue({
|
||||
$: 'stream',
|
||||
content_type: 'application/x-ndjson',
|
||||
chunked: true,
|
||||
}, stream);
|
||||
const completion = await this.client.chat.stream({
|
||||
model: model ?? this.get_default_model(),
|
||||
messages,
|
||||
});
|
||||
(async () => {
|
||||
for await ( let chunk of completion ) {
|
||||
// just because Mistral wants to be different
|
||||
chunk = chunk.data;
|
||||
|
||||
if ( chunk.usage ) {
|
||||
usage_promise.resolve({
|
||||
input_tokens: chunk.usage.promptTokens,
|
||||
output_tokens: chunk.usage.completionTokens,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if ( chunk.choices.length < 1 ) continue;
|
||||
if ( chunk.choices[0].finish_reason ) {
|
||||
stream.end();
|
||||
break;
|
||||
}
|
||||
if ( nou(chunk.choices[0].delta.content) ) continue;
|
||||
const str = JSON.stringify({
|
||||
text: chunk.choices[0].delta.content
|
||||
});
|
||||
stream.write(str + '\n');
|
||||
}
|
||||
stream.end();
|
||||
})();
|
||||
|
||||
return new TypedValue({ $: 'ai-chat-intermediate' }, {
|
||||
stream: true,
|
||||
response: retval,
|
||||
usage_promise: usage_promise,
|
||||
});
|
||||
}
|
||||
|
||||
const completion = await this.client.chat.complete({
|
||||
const completion = await this.client.chat[
|
||||
stream ? 'stream' : 'complete'
|
||||
]({
|
||||
model: model ?? this.get_default_model(),
|
||||
...(tools ? { tools } : {}),
|
||||
messages,
|
||||
});
|
||||
// Expected case when mistralai/client-ts#23 is fixed
|
||||
const ret = completion.choices[0];
|
||||
ret.usage = {
|
||||
input_tokens: completion.usage.promptTokens,
|
||||
output_tokens: completion.usage.completionTokens,
|
||||
};
|
||||
return ret;
|
||||
|
||||
return await OpenAIUtil.handle_completion_output({
|
||||
deviations: {
|
||||
index_usage_from_stream_chunk: chunk => {
|
||||
if ( ! chunk.usage ) return;
|
||||
|
||||
const snake_usage = {};
|
||||
for ( const key in chunk.usage ) {
|
||||
const snakeKey = key.replace(/([A-Z])/g, "_$1").toLowerCase();
|
||||
snake_usage[snakeKey] = chunk.usage[key];
|
||||
}
|
||||
|
||||
return snake_usage;
|
||||
},
|
||||
chunk_but_like_actually: chunk => chunk.data,
|
||||
index_tool_calls_from_stream_choice: choice => choice.delta.toolCalls,
|
||||
},
|
||||
completion, stream,
|
||||
usage_calculator: OpenAIUtil.create_usage_calculator({
|
||||
model_details: this.models_array_.find(m => m.id === model),
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,11 @@ module.exports = class OpenAIUtil {
|
||||
completion, usage_promise,
|
||||
}) => async ({ chatStream }) => {
|
||||
deviations = Object.assign({
|
||||
// affected by: Groq
|
||||
index_usage_from_stream_chunk: chunk => chunk.usage,
|
||||
// affected by: Mistral
|
||||
chunk_but_like_actually: chunk => chunk,
|
||||
index_tool_calls_from_stream_choice: choice => choice.delta.tool_calls,
|
||||
}, deviations);
|
||||
|
||||
const message = chatStream.message();
|
||||
@@ -94,7 +98,8 @@ module.exports = class OpenAIUtil {
|
||||
const tool_call_blocks = [];
|
||||
|
||||
let last_usage = null;
|
||||
for await ( const chunk of completion ) {
|
||||
for await ( let chunk of completion ) {
|
||||
chunk = deviations.chunk_but_like_actually(chunk);
|
||||
if ( process.env.DEBUG ) {
|
||||
const delta = chunk?.choices?.[0]?.delta;
|
||||
console.log(
|
||||
@@ -119,12 +124,13 @@ module.exports = class OpenAIUtil {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ( ! nou(choice.delta.tool_calls) ) {
|
||||
const tool_calls = deviations.index_tool_calls_from_stream_choice(choice);
|
||||
if ( ! nou(tool_calls) ) {
|
||||
if ( mode === 'text' ) {
|
||||
mode = 'tool';
|
||||
textblock.end();
|
||||
}
|
||||
for ( const tool_call of choice.delta.tool_calls ) {
|
||||
for ( const tool_call of tool_calls ) {
|
||||
if ( ! tool_call_blocks[tool_call.index] ) {
|
||||
toolblock = message.contentBlock({
|
||||
type: 'tool_use',
|
||||
|
||||
Reference in New Issue
Block a user