diff --git a/core/application/startup.go b/core/application/startup.go index 0789f2a65..1fdd1ad50 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -57,7 +57,7 @@ func New(opts ...config.AppOption) (*Application, error) { } } - if err := pkgStartup.InstallModels(options.Galleries, options.ModelPath, options.EnforcePredownloadScans, nil, options.ModelsURL...); err != nil { + if err := pkgStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { log.Error().Err(err).Msg("error installing models") } @@ -86,13 +86,13 @@ func New(opts ...config.AppOption) (*Application, error) { } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil { + if err := services.ApplyGalleryFromString(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { return nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil { + if err := services.ApplyGalleryFromFile(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { return nil, err } } diff --git a/core/backend/llm.go b/core/backend/llm.go index 9d6f771f0..cf8a4f861 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -42,9 +42,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if _, err := os.Stat(modelFile); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGallery(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans) + err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, modelFile, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) if err != nil { - return nil, err + log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile) + //return nil, err } } } diff --git a/core/cli/models.go b/core/cli/models.go index ae0bedca0..94a838961 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -16,8 +16,10 @@ import ( ) type ModelsCMDFlags struct { - Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` - ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` } type ModelsList struct { @@ -25,8 +27,9 @@ type ModelsList struct { } type ModelsInstall struct { - DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` - ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` + DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` + AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"` + ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` ModelsCMDFlags `embed:""` } @@ -62,6 +65,11 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { log.Error().Err(err).Msg("unable to load galleries") } + var backendGalleries []config.Gallery + if err := json.Unmarshal([]byte(mi.BackendGalleries), &backendGalleries); err != nil { + log.Error().Err(err).Msg("unable to load backend galleries") + } + for _, modelName := range mi.ModelArgs { progressBar := progressbar.NewOptions( @@ -100,7 +108,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") } - err = startup.InstallModels(galleries, mi.ModelsPath, !mi.DisablePredownloadScan, progressCallback, modelName) + err = startup.InstallModels(galleries, backendGalleries, mi.ModelsPath, mi.BackendsPath, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName) if err != nil { return err } diff --git a/core/cli/run.go b/core/cli/run.go index 57b3c4a9c..481d89448 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -30,13 +30,14 @@ type RunCMD struct { LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"` // The alias on this option is there to preserve functionality with the old `--config-file` parameter - ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"` - BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` - Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` - AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models"` - PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` - Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"` - PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"` + ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` + Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` + AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"` + AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"` + PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` + Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"` + PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"` F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"` Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"` @@ -192,6 +193,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.EnableGalleriesAutoload) } + if r.AutoloadBackendGalleries { + opts = append(opts, config.EnableBackendGalleriesAutoload) + } + if r.PreloadBackendOnly { _, err := application.New(opts...) return err diff --git a/core/config/application_config.go b/core/config/application_config.go index 45e2f25c7..662bddc6a 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -55,7 +55,7 @@ type ApplicationConfig struct { ExternalGRPCBackends map[string]string - AutoloadGalleries bool + AutoloadGalleries, AutoloadBackendGalleries bool SingleBackend bool ParallelBackendRequests bool @@ -192,6 +192,10 @@ var EnableGalleriesAutoload = func(o *ApplicationConfig) { o.AutoloadGalleries = true } +var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) { + o.AutoloadBackendGalleries = true +} + func WithExternalBackend(name string, uri string) AppOption { return func(o *ApplicationConfig) { if o.ExternalGRPCBackends == nil { diff --git a/core/gallery/backends.go b/core/gallery/backends.go index 1b703e700..7515514f9 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -71,7 +71,22 @@ func findBestBackendFromMeta(backend *GalleryBackend, systemState *system.System } // Installs a model from the gallery -func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, basePath string, downloadStatus func(string, string, string, float64)) error { +func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, basePath string, downloadStatus func(string, string, string, float64), force bool) error { + if !force { + // check if we already have the backend installed + backends, err := ListSystemBackends(basePath) + if err != nil { + return err + } + if _, ok := backends[name]; ok { + return nil + } + } + + if name == "" { + return fmt.Errorf("backend name is empty") + } + log.Debug().Interface("galleries", galleries).Str("name", name).Msg("Installing backend from gallery") backends, err := AvailableBackends(galleries, basePath) diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go index 864ed3b5f..be79b701f 100644 --- a/core/gallery/backends_test.go +++ b/core/gallery/backends_test.go @@ -42,13 +42,13 @@ var _ = Describe("Gallery Backends", func() { Describe("InstallBackendFromGallery", func() { It("should return error when backend is not found", func() { - err := InstallBackendFromGallery(galleries, nil, "non-existent", tempDir, nil) + err := InstallBackendFromGallery(galleries, nil, "non-existent", tempDir, nil, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\"")) }) It("should install backend from gallery", func() { - err := InstallBackendFromGallery(galleries, nil, "test-backend", tempDir, nil) + err := InstallBackendFromGallery(galleries, nil, "test-backend", tempDir, nil, true) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) }) @@ -181,7 +181,7 @@ var _ = Describe("Gallery Backends", func() { // Test with NVIDIA system state nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia"} - err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil) + err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") diff --git a/core/gallery/models.go b/core/gallery/models.go index 428f51153..a1c8a4b75 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -10,6 +10,7 @@ import ( "dario.cat/mergo" "github.com/mudler/LocalAI/core/config" lconfig "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/system" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/utils" @@ -69,7 +70,9 @@ type PromptTemplate struct { } // Installs a model from the gallery -func InstallModelFromGallery(galleries []config.Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error { +func InstallModelFromGallery( + modelGalleries, backendGalleries []config.Gallery, + name string, basePath, backendBasePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") @@ -119,14 +122,26 @@ func InstallModelFromGallery(galleries []config.Gallery, name string, basePath s return err } - if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan); err != nil { + installedModel, err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan) + if err != nil { return err } + if automaticallyInstallBackend && installedModel.Backend != "" { + systemState, err := system.GetSystemState() + if err != nil { + return err + } + + if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil { + return err + } + } + return nil } - models, err := AvailableGalleryModels(galleries, basePath) + models, err := AvailableGalleryModels(modelGalleries, basePath) if err != nil { return err } @@ -139,11 +154,11 @@ func InstallModelFromGallery(galleries []config.Gallery, name string, basePath s return applyModel(model) } -func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) error { +func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.BackendConfig, error) { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0750) if err != nil { - return fmt.Errorf("failed to create base path: %v", err) + return nil, fmt.Errorf("failed to create base path: %v", err) } if len(configOverrides) > 0 { @@ -155,7 +170,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) if err := utils.VerifyPath(file.Filename, basePath); err != nil { - return err + return nil, err } // Create file path @@ -165,19 +180,19 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") - return err + return nil, err } } uri := downloader.URI(file.URI) if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { - return err + return nil, err } } // Write prompt template contents to separate files for _, template := range config.PromptTemplates { if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil { - return err + return nil, err } // Create file path filePath := filepath.Join(basePath, template.Name+".tmpl") @@ -185,12 +200,12 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver // Create parent directory err := os.MkdirAll(filepath.Dir(filePath), 0750) if err != nil { - return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) + return nil, fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) } // Create and write file content err = os.WriteFile(filePath, []byte(template.Content), 0600) if err != nil { - return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) + return nil, fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) } log.Debug().Msgf("Prompt template %q written", template.Name) @@ -202,9 +217,11 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver } if err := utils.VerifyPath(name+".yaml", basePath); err != nil { - return err + return nil, err } + backendConfig := lconfig.BackendConfig{} + // write config file if len(configOverrides) != 0 || len(config.ConfigFile) != 0 { configFilePath := filepath.Join(basePath, name+".yaml") @@ -213,33 +230,33 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver configMap := make(map[string]interface{}) err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) if err != nil { - return fmt.Errorf("failed to unmarshal config YAML: %v", err) + return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err) } configMap["name"] = name if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { - return err + return nil, err } // Write updated config file updatedConfigYAML, err := yaml.Marshal(configMap) if err != nil { - return fmt.Errorf("failed to marshal updated config YAML: %v", err) + return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err) } - backendConfig := lconfig.BackendConfig{} err = yaml.Unmarshal(updatedConfigYAML, &backendConfig) if err != nil { - return fmt.Errorf("failed to unmarshal updated config YAML: %v", err) + return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } + if !backendConfig.Validate() { - return fmt.Errorf("failed to validate updated config YAML") + return nil, fmt.Errorf("failed to validate updated config YAML") } err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) if err != nil { - return fmt.Errorf("failed to write updated config file: %v", err) + return nil, fmt.Errorf("failed to write updated config file: %v", err) } log.Debug().Msgf("Written config file %s", configFilePath) @@ -249,14 +266,12 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver modelFile := filepath.Join(basePath, galleryFileName(name)) data, err := yaml.Marshal(config) if err != nil { - return err + return nil, err } log.Debug().Msgf("Written gallery file %s", modelFile) - return os.WriteFile(modelFile, data, 0600) - - //return nil + return &backendConfig, os.WriteFile(modelFile, data, 0600) } func galleryFileName(name string) string { diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go index 26151aa83..5ffd675d1 100644 --- a/core/gallery/models_test.go +++ b/core/gallery/models_test.go @@ -29,7 +29,7 @@ var _ = Describe("Model test", func() { defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -79,7 +79,7 @@ var _ = Describe("Model test", func() { Expect(models[0].URL).To(Equal(bertEmbeddingsURL)) Expect(models[0].Installed).To(BeFalse()) - err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true) + err = InstallModelFromGallery(galleries, []config.Gallery{}, "test@bert", tempdir, "", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true) Expect(err).ToNot(HaveOccurred()) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) @@ -116,7 +116,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -132,7 +132,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) + _, err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -158,7 +158,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) + _, err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).To(HaveOccurred()) }) }) diff --git a/core/services/backends.go b/core/services/backends.go index ae9ecbda8..855d1c0ad 100644 --- a/core/services/backends.go +++ b/core/services/backends.go @@ -24,7 +24,7 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend], s g.modelLoader.DeleteExternalBackend(op.GalleryElementName) } else { log.Warn().Msgf("installing backend %s", op.GalleryElementName) - err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback) + err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback, true) if err == nil { err = gallery.RegisterBackends(g.appConfig.BackendsPath, g.modelLoader) } diff --git a/core/services/models.go b/core/services/models.go index b0b6ede3d..beb981c8e 100644 --- a/core/services/models.go +++ b/core/services/models.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/system" "github.com/mudler/LocalAI/pkg/utils" "gopkg.in/yaml.v2" ) @@ -22,7 +23,7 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c utils.DisplayDownloadFunction(fileName, current, total, percentage) } - err := processModelOperation(op, g.appConfig.ModelPath, g.appConfig.EnforcePredownloadScans, progressCallback) + err := processModelOperation(op, g.appConfig.ModelPath, g.appConfig.BackendsPath, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback) if err != nil { return err } @@ -49,7 +50,7 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c return nil } -func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan bool) error { +func installModelFromRemoteConfig(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery, backendBasePath string) error { config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, modelPath) if err != nil { return err @@ -57,7 +58,23 @@ func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus fun config.Files = append(config.Files, req.AdditionalFiles...) - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan) + installedModel, err := gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan) + if err != nil { + return err + } + + if automaticallyInstallBackend && installedModel.Backend != "" { + systemState, err := system.GetSystemState() + if err != nil { + return err + } + + if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil { + return err + } + } + + return nil } type galleryModel struct { @@ -65,22 +82,22 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath string, enforceScan bool, galleries []config.Gallery, requests []galleryModel) error { +func processRequests(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan) + err = installModelFromRemoteConfig(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries, backendBasePath) } else { err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan) + galleries, backendGalleries, r.ID, modelPath, backendBasePath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend) } } return err } -func ApplyGalleryFromFile(modelPath, s string, enforceScan bool, galleries []config.Gallery) error { +func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -91,24 +108,26 @@ func ApplyGalleryFromFile(modelPath, s string, enforceScan bool, galleries []con return err } - return processRequests(modelPath, enforceScan, galleries, requests) + return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) } -func ApplyGalleryFromString(modelPath, s string, enforceScan bool, galleries []config.Gallery) error { +func ApplyGalleryFromString(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { return err } - return processRequests(modelPath, enforceScan, galleries, requests) + return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests) } // processModelOperation handles the installation or deletion of a model func processModelOperation( op *GalleryOp[gallery.GalleryModel], modelPath string, + backendBasePath string, enforcePredownloadScans bool, + automaticallyInstallBackend bool, progressCallback func(string, string, string, float64), ) error { // delete a model @@ -140,7 +159,7 @@ func processModelOperation( // if the request contains a gallery name, we apply the gallery from the gallery list if op.GalleryElementName != "" { - return gallery.InstallModelFromGallery(op.Galleries, op.GalleryElementName, modelPath, op.Req, progressCallback, enforcePredownloadScans) + return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, op.GalleryElementName, modelPath, backendBasePath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend) // } else if op.ConfigURL != "" { // err := startup.InstallModels(op.Galleries, modelPath, enforcePredownloadScans, progressCallback, op.ConfigURL) // if err != nil { @@ -148,6 +167,6 @@ func processModelOperation( // } // return cl.Preload(modelPath) } else { - return prepareModel(modelPath, op.Req, progressCallback, enforcePredownloadScans) + return installModelFromRemoteConfig(modelPath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries, backendBasePath) } } diff --git a/core/services/operation.go b/core/services/operation.go index e8d29f5de..962921807 100644 --- a/core/services/operation.go +++ b/core/services/operation.go @@ -10,8 +10,9 @@ type GalleryOp[T any] struct { GalleryElementName string Delete bool - Req T - Galleries []config.Gallery + Req T + Galleries []config.Gallery + BackendGalleries []config.Gallery } type GalleryOpStatus struct { diff --git a/pkg/startup/backend_preload.go b/pkg/startup/backend_preload.go index cbc37ca05..b68a811dd 100644 --- a/pkg/startup/backend_preload.go +++ b/pkg/startup/backend_preload.go @@ -27,7 +27,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow errs = errors.Join(err, fmt.Errorf("error installing backend %s", backend)) } default: - err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus) + err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true) if err != nil { errs = errors.Join(err, fmt.Errorf("error installing backend %s", backend)) } diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index 93da36284..ffad50d29 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -17,7 +17,7 @@ import ( // InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func InstallModels(galleries []config.Gallery, modelPath string, enforceScan bool, downloadStatus func(string, string, string, float64), models ...string) error { +func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, backendBasePath string, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error { // create an error that groups all errors var err error @@ -99,7 +99,7 @@ func InstallModels(galleries []config.Gallery, modelPath string, enforceScan boo } } else { // Check if it's a model gallery, or print a warning - e, found := installModel(galleries, url, modelPath, downloadStatus, enforceScan) + e, found := installModel(galleries, backendGalleries, url, modelPath, backendBasePath, downloadStatus, enforceScan, autoloadBackendGalleries) if e != nil && found { log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) err = errors.Join(err, e) @@ -113,7 +113,7 @@ func InstallModels(galleries []config.Gallery, modelPath string, enforceScan boo return err } -func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64), enforceScan bool) (error, bool) { +func installModel(galleries, backendGalleries []config.Gallery, modelName, modelPath, backendBasePath string, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) { models, err := gallery.AvailableGalleryModels(galleries, modelPath) if err != nil { return err, false @@ -129,7 +129,7 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl } log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") - err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus, enforceScan) + err = gallery.InstallModelFromGallery(galleries, backendGalleries, modelName, modelPath, backendBasePath, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries) if err != nil { return err, true } diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index 51e6d7026..d8e257d7a 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -21,7 +21,7 @@ var _ = Describe("Preload test", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", "phi-2") - InstallModels([]config.Gallery{}, tmpdir, true, nil, url) + InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", true, true, nil, url) resultFile := filepath.Join(tmpdir, fileName) @@ -36,7 +36,7 @@ var _ = Describe("Preload test", func() { url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K") - err = InstallModels([]config.Gallery{}, tmpdir, false, nil, url) + err = InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", false, true, nil, url) Expect(err).ToNot(HaveOccurred()) resultFile := filepath.Join(tmpdir, fileName)