From f895d066055728e2744044ce6390a222bc24d095 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 13 Mar 2024 10:05:30 +0100 Subject: [PATCH] fix(config): set better defaults for inferencing (#1822) * fix(defaults): set better defaults for inferencing This changeset aim to have better defaults and to properly detect when no inference settings are provided with the model. If not specified, we defaults to mirostat sampling, and offload all the GPU layers (if a GPU is detected). Related to https://github.com/mudler/LocalAI/issues/1373 and https://github.com/mudler/LocalAI/issues/1723 * Adapt tests * Also pre-initialize default seed --- core/backend/embeddings.go | 2 +- core/backend/image.go | 6 +- core/backend/llm.go | 6 +- core/backend/options.go | 45 ++--- core/backend/transcript.go | 4 +- core/config/backend_config.go | 240 +++++++++++++++++++------- core/http/api_test.go | 2 +- core/http/endpoints/localai/tts.go | 9 +- core/http/endpoints/openai/image.go | 2 +- core/http/endpoints/openai/request.go | 33 ++-- core/schema/prediction.go | 17 +- main.go | 2 +- 12 files changed, 235 insertions(+), 133 deletions(-) diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 0a74ea4ca..943108540 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(backendConfig.Threads)), + model.WithThreads(uint32(*backendConfig.Threads)), model.WithAssetDir(appConfig.AssetsDestination), model.WithModel(modelFile), model.WithContext(appConfig.Context), diff --git a/core/backend/image.go b/core/backend/image.go index 79b8d4ba1..b0cffb0b8 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -9,14 +9,14 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { threads := backendConfig.Threads - if threads == 0 && appConfig.Threads != 0 { - threads = appConfig.Threads + if *threads == 0 && appConfig.Threads != 0 { + threads = &appConfig.Threads } gRPCOpts := gRPCModelOpts(backendConfig) opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(backendConfig.Backend), model.WithAssetDir(appConfig.AssetsDestination), - model.WithThreads(uint32(threads)), + model.WithThreads(uint32(*threads)), model.WithContext(appConfig.Context), model.WithModel(backendConfig.Model), model.WithLoadGRPCLoadModelOpts(gRPCOpts), diff --git a/core/backend/llm.go b/core/backend/llm.go index 54e261889..d5e14df01 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -29,8 +29,8 @@ type TokenUsage struct { func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model threads := c.Threads - if threads == 0 && o.Threads != 0 { - threads = o.Threads + if *threads == 0 && o.Threads != 0 { + threads = &o.Threads } grpcOpts := gRPCModelOpts(c) @@ -39,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(threads)), // some models uses this to allocate threads during startup + model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), diff --git a/core/backend/options.go b/core/backend/options.go index 3af6f6797..bc7fa5a41 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -46,15 +46,15 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { CFGScale: c.Diffusers.CFGScale, LoraAdapter: c.LoraAdapter, LoraScale: c.LoraScale, - F16Memory: c.F16, + F16Memory: *c.F16, LoraBase: c.LoraBase, IMG2IMG: c.Diffusers.IMG2IMG, CLIPModel: c.Diffusers.ClipModel, CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSkip: int32(c.Diffusers.ClipSkip), ControlNet: c.Diffusers.ControlNet, - ContextSize: int32(c.ContextSize), - Seed: int32(c.Seed), + ContextSize: int32(*c.ContextSize), + Seed: int32(*c.Seed), NBatch: int32(b), NoMulMatQ: c.NoMulMatQ, DraftModel: c.DraftModel, @@ -72,18 +72,18 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { YarnBetaSlow: c.YarnBetaSlow, NGQA: c.NGQA, RMSNormEps: c.RMSNormEps, - MLock: c.MMlock, + MLock: *c.MMlock, RopeFreqBase: c.RopeFreqBase, RopeScaling: c.RopeScaling, Type: c.ModelType, RopeFreqScale: c.RopeFreqScale, NUMA: c.NUMA, Embeddings: c.Embeddings, - LowVRAM: c.LowVRAM, - NGPULayers: int32(c.NGPULayers), - MMap: c.MMap, + LowVRAM: *c.LowVRAM, + NGPULayers: int32(*c.NGPULayers), + MMap: *c.MMap, MainGPU: c.MainGPU, - Threads: int32(c.Threads), + Threads: int32(*c.Threads), TensorSplit: c.TensorSplit, // AutoGPTQ ModelBaseName: c.AutoGPTQ.ModelBaseName, @@ -102,36 +102,37 @@ func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOption os.MkdirAll(filepath.Dir(p), 0755) promptCachePath = p } + return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), + Temperature: float32(*c.Temperature), + TopP: float32(*c.TopP), NDraft: c.NDraft, - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), + TopK: int32(*c.TopK), + Tokens: int32(*c.Maxtokens), + Threads: int32(*c.Threads), PromptCacheAll: c.PromptCacheAll, PromptCacheRO: c.PromptCacheRO, PromptCachePath: promptCachePath, - F16KV: c.F16, - DebugMode: c.Debug, + F16KV: *c.F16, + DebugMode: *c.Debug, Grammar: c.Grammar, NegativePromptScale: c.NegativePromptScale, RopeFreqBase: c.RopeFreqBase, RopeFreqScale: c.RopeFreqScale, NegativePrompt: c.NegativePrompt, - Mirostat: int32(c.LLMConfig.Mirostat), - MirostatETA: float32(c.LLMConfig.MirostatETA), - MirostatTAU: float32(c.LLMConfig.MirostatTAU), - Debug: c.Debug, + Mirostat: int32(*c.LLMConfig.Mirostat), + MirostatETA: float32(*c.LLMConfig.MirostatETA), + MirostatTAU: float32(*c.LLMConfig.MirostatTAU), + Debug: *c.Debug, StopPrompts: c.StopWords, Repeat: int32(c.RepeatPenalty), NKeep: int32(c.Keep), Batch: int32(c.Batch), IgnoreEOS: c.IgnoreEOS, - Seed: int32(c.Seed), + Seed: int32(*c.Seed), FrequencyPenalty: float32(c.FrequencyPenalty), - MLock: c.MMlock, - MMap: c.MMap, + MLock: *c.MMlock, + MMap: *c.MMap, MainGPU: c.MainGPU, TensorSplit: c.TensorSplit, TailFreeSamplingZ: float32(c.TFZ), diff --git a/core/backend/transcript.go b/core/backend/transcript.go index bbb4f4b4c..4c3859dfe 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -17,7 +17,7 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo model.WithBackendString(model.WhisperBackend), model.WithModel(backendConfig.Model), model.WithContext(appConfig.Context), - model.WithThreads(uint32(backendConfig.Threads)), + model.WithThreads(uint32(*backendConfig.Threads)), model.WithAssetDir(appConfig.AssetsDestination), }) @@ -33,6 +33,6 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, - Threads: uint32(backendConfig.Threads), + Threads: uint32(*backendConfig.Threads), }) } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 63e5855cd..53326b3f1 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io/fs" + "math/rand" "os" "path/filepath" "strings" @@ -20,9 +21,9 @@ type BackendConfig struct { schema.PredictionOptions `yaml:"parameters"` Name string `yaml:"name"` - F16 bool `yaml:"f16"` - Threads int `yaml:"threads"` - Debug bool `yaml:"debug"` + F16 *bool `yaml:"f16"` + Threads *int `yaml:"threads"` + Debug *bool `yaml:"debug"` Roles map[string]string `yaml:"roles"` Embeddings bool `yaml:"embeddings"` Backend string `yaml:"backend"` @@ -105,20 +106,20 @@ type LLMConfig struct { PromptCachePath string `yaml:"prompt_cache_path"` PromptCacheAll bool `yaml:"prompt_cache_all"` PromptCacheRO bool `yaml:"prompt_cache_ro"` - MirostatETA float64 `yaml:"mirostat_eta"` - MirostatTAU float64 `yaml:"mirostat_tau"` - Mirostat int `yaml:"mirostat"` - NGPULayers int `yaml:"gpu_layers"` - MMap bool `yaml:"mmap"` - MMlock bool `yaml:"mmlock"` - LowVRAM bool `yaml:"low_vram"` + MirostatETA *float64 `yaml:"mirostat_eta"` + MirostatTAU *float64 `yaml:"mirostat_tau"` + Mirostat *int `yaml:"mirostat"` + NGPULayers *int `yaml:"gpu_layers"` + MMap *bool `yaml:"mmap"` + MMlock *bool `yaml:"mmlock"` + LowVRAM *bool `yaml:"low_vram"` Grammar string `yaml:"grammar"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` TrimSpace []string `yaml:"trimspace"` TrimSuffix []string `yaml:"trimsuffix"` - ContextSize int `yaml:"context_size"` + ContextSize *int `yaml:"context_size"` NUMA bool `yaml:"numa"` LoraAdapter string `yaml:"lora_adapter"` LoraBase string `yaml:"lora_base"` @@ -185,19 +186,96 @@ func (c *BackendConfig) FunctionToCall() string { return c.functionCallNameString } -func defaultPredictOptions(modelFile string) schema.PredictionOptions { - return schema.PredictionOptions{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, - } -} +func (cfg *BackendConfig) SetDefaults(debug bool, threads, ctx int, f16 bool) { + defaultTopP := 0.7 + defaultTopK := 80 + defaultTemp := 0.9 + defaultMaxTokens := 2048 + defaultMirostat := 2 + defaultMirostatTAU := 5.0 + defaultMirostatETA := 0.1 -func DefaultConfig(modelFile string) *BackendConfig { - return &BackendConfig{ - PredictionOptions: defaultPredictOptions(modelFile), + // Try to offload all GPU layers (if GPU is found) + defaultNGPULayers := 99999999 + + trueV := true + falseV := false + + if cfg.Seed == nil { + // random number generator seed + defaultSeed := int(rand.Int31()) + cfg.Seed = &defaultSeed + } + + if cfg.TopK == nil { + cfg.TopK = &defaultTopK + } + + if cfg.MMap == nil { + // MMap is enabled by default + cfg.MMap = &trueV + } + + if cfg.MMlock == nil { + // MMlock is disabled by default + cfg.MMlock = &falseV + } + + if cfg.TopP == nil { + cfg.TopP = &defaultTopP + } + if cfg.Temperature == nil { + cfg.Temperature = &defaultTemp + } + + if cfg.Maxtokens == nil { + cfg.Maxtokens = &defaultMaxTokens + } + + if cfg.Mirostat == nil { + cfg.Mirostat = &defaultMirostat + } + + if cfg.MirostatETA == nil { + cfg.MirostatETA = &defaultMirostatETA + } + + if cfg.MirostatTAU == nil { + cfg.MirostatTAU = &defaultMirostatTAU + } + if cfg.NGPULayers == nil { + cfg.NGPULayers = &defaultNGPULayers + } + + if cfg.LowVRAM == nil { + cfg.LowVRAM = &falseV + } + + // Value passed by the top level are treated as default (no implicit defaults) + // defaults are set by the user + if ctx == 0 { + ctx = 1024 + } + + if cfg.ContextSize == nil { + cfg.ContextSize = &ctx + } + + if threads == 0 { + // Threads can't be 0 + threads = 4 + } + + if cfg.Threads == nil { + cfg.Threads = &threads + } + + if cfg.F16 == nil { + cfg.F16 = &f16 + } + + if debug { + cfg.Debug = &debug } } @@ -208,23 +286,63 @@ type BackendConfigLoader struct { sync.Mutex } +type LoadOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*LoadOptions) + +func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + // Load a config file for a model -func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigLoader, debug bool, threads, ctx int, f16 bool) (*BackendConfig, error) { +func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + lo := &LoadOptions{} + lo.Apply(opts...) + // Load a config file if present after the model name - modelConfig := filepath.Join(modelPath, modelName+".yaml") - - var cfg *BackendConfig - - defaults := func() { - cfg = DefaultConfig(modelName) - cfg.ContextSize = ctx - cfg.Threads = threads - cfg.F16 = f16 - cfg.Debug = debug + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, } cfgExisting, exists := cl.GetBackendConfig(modelName) - if !exists { + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") if _, err := os.Stat(modelConfig); err == nil { if err := cl.LoadBackendConfig(modelConfig); err != nil { return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) @@ -232,32 +350,11 @@ func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigL cfgExisting, exists = cl.GetBackendConfig(modelName) if exists { cfg = &cfgExisting - } else { - defaults() } - } else { - defaults() - } - } else { - cfg = &cfgExisting - } - - // Set the parameters for the language model prediction - //updateConfig(cfg, input) - - // Don't allow 0 as setting - if cfg.Threads == 0 { - if threads != 0 { - cfg.Threads = threads - } else { - cfg.Threads = 4 } } - // Enforce debug flag if passed from CLI - if debug { - cfg.Debug = true - } + cfg.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) return cfg, nil } @@ -267,7 +364,10 @@ func NewBackendConfigLoader() *BackendConfigLoader { configs: make(map[string]BackendConfig), } } -func ReadBackendConfigFile(file string) ([]*BackendConfig, error) { +func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + c := &[]*BackendConfig{} f, err := os.ReadFile(file) if err != nil { @@ -277,10 +377,17 @@ func ReadBackendConfigFile(file string) ([]*BackendConfig, error) { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } + for _, cc := range *c { + cc.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) + } + return *c, nil } -func ReadBackendConfig(file string) (*BackendConfig, error) { +func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + c := &BackendConfig{} f, err := os.ReadFile(file) if err != nil { @@ -290,13 +397,14 @@ func ReadBackendConfig(file string) (*BackendConfig, error) { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } + c.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16) return c, nil } -func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error { +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { cm.Lock() defer cm.Unlock() - c, err := ReadBackendConfigFile(file) + c, err := ReadBackendConfigFile(file, opts...) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } @@ -307,10 +415,10 @@ func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error { return nil } -func (cl *BackendConfigLoader) LoadBackendConfig(file string) error { +func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { cl.Lock() defer cl.Unlock() - c, err := ReadBackendConfig(file) + c, err := ReadBackendConfig(file, opts...) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } @@ -407,7 +515,9 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { return nil } -func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error { +// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// (non-recursive) +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { cm.Lock() defer cm.Unlock() entries, err := os.ReadDir(path) @@ -427,7 +537,7 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error { if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { continue } - c, err := ReadBackendConfig(filepath.Join(path, file.Name())) + c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) if err == nil { cm.configs[c.Name] = *c } diff --git a/core/http/api_test.go b/core/http/api_test.go index 8f3cfc91c..b0579a19d 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -386,7 +386,7 @@ var _ = Describe("API test", func() { var res map[string]string err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) - Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res)) + Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 84fb7a555..9c3f890de 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -26,7 +26,14 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false) + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 8f535801f..d59b1051a 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -196,7 +196,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon baseURL := c.BaseURL() - fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) + fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) if err != nil { return err } diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 46ff2438c..505244c45 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -74,10 +74,10 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque if input.Echo { config.Echo = input.Echo } - if input.TopK != 0 { + if input.TopK != nil { config.TopK = input.TopK } - if input.TopP != 0 { + if input.TopP != nil { config.TopP = input.TopP } @@ -117,11 +117,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Grammar = input.Grammar } - if input.Temperature != 0 { + if input.Temperature != nil { config.Temperature = input.Temperature } - if input.Maxtokens != 0 { + if input.Maxtokens != nil { config.Maxtokens = input.Maxtokens } @@ -193,30 +193,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Batch = input.Batch } - if input.F16 { - config.F16 = input.F16 - } - if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } - if input.Seed != 0 { + if input.Seed != nil { config.Seed = input.Seed } - if input.Mirostat != 0 { - config.LLMConfig.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.LLMConfig.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.LLMConfig.MirostatTAU = input.MirostatTAU - } - if input.TypicalP != 0 { config.TypicalP = input.TypicalP } @@ -272,7 +256,12 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque } func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16) + cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, + config.LoadOptionDebug(debug), + config.LoadOptionThreads(threads), + config.LoadOptionContextSize(ctx), + config.LoadOptionF16(f16), + ) // Set the parameters for the language model prediction updateRequestConfig(cfg, input) diff --git a/core/schema/prediction.go b/core/schema/prediction.go index efd085a4a..d75e5eb85 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -12,28 +12,23 @@ type PredictionOptions struct { N int `json:"n"` // Common options between all the API calls, part of the OpenAI spec - TopP float64 `json:"top_p" yaml:"top_p"` - TopK int `json:"top_k" yaml:"top_k"` - Temperature float64 `json:"temperature" yaml:"temperature"` - Maxtokens int `json:"max_tokens" yaml:"max_tokens"` - Echo bool `json:"echo"` + TopP *float64 `json:"top_p" yaml:"top_p"` + TopK *int `json:"top_k" yaml:"top_k"` + Temperature *float64 `json:"temperature" yaml:"temperature"` + Maxtokens *int `json:"max_tokens" yaml:"max_tokens"` + Echo bool `json:"echo"` // Custom parameters - not present in the OpenAI API Batch int `json:"batch" yaml:"batch"` - F16 bool `json:"f16" yaml:"f16"` IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` Keep int `json:"n_keep" yaml:"n_keep"` - MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` - MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` - Mirostat int `json:"mirostat" yaml:"mirostat"` - FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` TFZ float64 `json:"tfz" yaml:"tfz"` TypicalP float64 `json:"typical_p" yaml:"typical_p"` - Seed int `json:"seed" yaml:"seed"` + Seed *int `json:"seed" yaml:"seed"` NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"` diff --git a/main.go b/main.go index 237191cfa..21560e5a7 100644 --- a/main.go +++ b/main.go @@ -497,7 +497,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit return errors.New("model not found") } - c.Threads = threads + c.Threads = &threads defer ml.StopAllGRPC()