feat(puterai): add groq

This commit is contained in:
KernelDeimos
2024-08-23 00:16:29 -04:00
parent 937528f767
commit 53e7a91f18
5 changed files with 119 additions and 0 deletions

View File

@@ -36,6 +36,7 @@
"express": "^4.18.2",
"file-type": "^18.5.0",
"form-data": "^4.0.0",
"groq-sdk": "^0.5.0",
"handlebars": "^4.7.8",
"helmet": "^7.0.0",
"hi-base32": "^0.5.1",

View File

@@ -0,0 +1,69 @@
const { PassThrough } = require("stream");
const BaseService = require("../../services/BaseService");
const { TypedValue } = require("../../services/drivers/meta/Runtime");
const { nou } = require("../../util/langutil");
class GroqAIService extends BaseService {
static MODULES = {
Groq: require('groq-sdk'),
}
async _init () {
const Groq = require('groq-sdk');
this.client = new Groq({
apiKey: this.config.apiKey,
});
}
static IMPLEMENTS = {
'puter-chat-completion': {
async list () {
// They send: { "object": "list", data }
const funny_wrapper = await this.client.models.list();
return funny_wrapper.data;
},
async complete ({ messages, model, stream }) {
for ( let i = 0; i < messages.length; i++ ) {
const message = messages[i];
if ( ! message.role ) message.role = 'user';
}
const completion = await this.client.chat.completions.create({
messages,
model,
stream,
});
if ( stream ) {
const stream = new PassThrough();
const retval = new TypedValue({
$: 'stream',
content_type: 'application/x-ndjson',
chunked: true,
}, stream);
(async () => {
for await ( const chunk of completion ) {
if ( chunk.choices.length < 1 ) continue;
if ( chunk.choices[0].finish_reason ) {
stream.end();
break;
}
if ( nou(chunk.choices[0].delta.content) ) continue;
const str = JSON.stringify({
text: chunk.choices[0].delta.content
});
stream.write(str + '\n');
}
})();
return retval;
}
return completion.choices[0];
}
}
};
}
module.exports = {
GroqAIService,
};

View File

@@ -43,6 +43,11 @@ class PuterAIModule extends AdvancedBase {
const { MistralAIService } = require('./MistralAIService');
services.registerService('mistral', MistralAIService);
}
if ( !! config?.services?.['groq'] ) {
const { GroqAIService } = require('./GroqAIService');
services.registerService('groq', GroqAIService);
}
}
}

View File

@@ -229,6 +229,9 @@ class AI{
if ( options.model === 'mistral' ) {
options.model = 'mistral-large-latest';
}
if ( options.model === 'groq' ) {
options.model = 'llama3-8b-8192';
}
// map model to the appropriate driver
if (!options.model || options.model === 'gpt-4o' || options.model === 'gpt-4o-mini') {
@@ -239,6 +242,21 @@ class AI{
driver = 'together-ai';
}else if(options.model === 'mistral-large-latest' || options.model === 'codestral-latest'){
driver = 'mistral';
}else if([
"distil-whisper-large-v3-en",
"gemma2-9b-it",
"gemma-7b-it",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama3-70b-8192",
"llama3-8b-8192",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llama-guard-3-8b",
"mixtral-8x7b-32768",
"whisper-large-v3"
].includes(options.model)) {
driver = 'groq';
}
// stream flag from settings