mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat: add grammar and functions call support
This commit is contained in:
@@ -46,12 +46,16 @@ type Config struct {
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||
|
||||
PromptStrings, InputStrings []string
|
||||
InputToken [][]int
|
||||
Grammar string `yaml:"grammar"`
|
||||
|
||||
PromptStrings, InputStrings []string
|
||||
InputToken [][]int
|
||||
functionCallString, functionCallNameString string
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Completion string `yaml:"completion"`
|
||||
Functions string `yaml:"function"`
|
||||
Chat string `yaml:"chat"`
|
||||
Edit string `yaml:"edit"`
|
||||
}
|
||||
@@ -181,6 +185,10 @@ func updateConfig(config *Config, input *OpenAIRequest) {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
@@ -262,6 +270,24 @@ func updateConfig(config *Config, input *OpenAIRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.functionCallString = fnc
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if !e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.functionCallNameString = name
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
|
||||
150
api/openai.go
150
api/openai.go
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
@@ -73,8 +74,12 @@ type Choice struct {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty" yaml:"role"`
|
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"`
|
||||
// The message content
|
||||
Content string `json:"content,omitempty" yaml:"content"`
|
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIModel struct {
|
||||
@@ -104,6 +109,10 @@ type OpenAIRequest struct {
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"`
|
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
Echo bool `json:"echo"`
|
||||
// Common options between all the API calls
|
||||
@@ -134,6 +143,9 @@ type OpenAIRequest struct {
|
||||
Mode int `json:"mode"`
|
||||
Step int `json:"step"`
|
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"`
|
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||
}
|
||||
|
||||
@@ -345,6 +357,23 @@ func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
// TODO: replace this with config settings
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
const noActionName = "answer"
|
||||
const noActionDescription = "use this action to answer without performing any action"
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
initialMessage := OpenAIResponse{
|
||||
@@ -368,6 +397,8 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
close(responses)
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := []grammar.Function{}
|
||||
model, input, err := readInput(c, o.loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
@@ -377,8 +408,33 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 &&
|
||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
var funcs grammar.Functions = input.Functions
|
||||
processFunctions = true
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.functionCallNameString != "" {
|
||||
funcs = funcs.Select(config.functionCallNameString)
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("")
|
||||
}
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions
|
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
@@ -397,7 +453,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
|
||||
if input.Stream {
|
||||
if toStream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
@@ -409,20 +465,35 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Chat != "" {
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{Input: predInput})
|
||||
Input string
|
||||
Functions []grammar.Function
|
||||
}{
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Prompt: %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
|
||||
if toStream {
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.loader, responses)
|
||||
@@ -459,6 +530,71 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
|
||||
if processFunctions {
|
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"]
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args)
|
||||
|
||||
ss["arguments"] = string(d)
|
||||
ss["name"] = func_name
|
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(d), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = Finetune(*config, predInput, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: message}})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = ""
|
||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction = Finetune(*config, predInput, prediction)
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: prediction}})
|
||||
} else {
|
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &Message{Role: "function", FunctionCall: ss},
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -189,6 +189,8 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar))
|
||||
|
||||
if c.PromptCachePath != "" {
|
||||
// Create parent directory
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
|
||||
Reference in New Issue
Block a user