mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-13 05:29:11 -06:00
feat(ui): allow to cancel ops (#7264)
* feat(ui): allow to cancel ops Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Improve progress text Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Cancel queued ops, don't show up message cancellation always Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: fixup displaying of total progress over multiple files Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b1d1f2a37d
commit
735ca757fa
@@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
||||
log.Error().Err(err).Msg("failed loading galleries")
|
||||
return "", err
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
|
||||
return "", err
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
||||
}
|
||||
|
||||
// InstallBackendFromGallery installs a backend from the gallery.
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
if !force {
|
||||
// check if we already have the backend installed
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
@@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
|
||||
|
||||
// Then, let's install the best backend
|
||||
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
return nil
|
||||
}
|
||||
|
||||
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
|
||||
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
|
||||
}
|
||||
|
||||
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
|
||||
if err != nil {
|
||||
@@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
|
||||
}
|
||||
} else {
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
// Check for cancellation before trying next mirror
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
)
|
||||
must(err)
|
||||
sysDefault.GPUVendor = "" // force default selection
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
must(err)
|
||||
aliasBack, ok := backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
must(err)
|
||||
sysNvidia.GPUVendor = "nvidia"
|
||||
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
must(err)
|
||||
aliasBack, ok = backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Describe("InstallBackendFromGallery", func() {
|
||||
It("should return error when backend is not found", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||
})
|
||||
|
||||
It("should install backend from gallery", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
})
|
||||
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(newPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
Expect(newPath).To(BeADirectory())
|
||||
})
|
||||
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
})
|
||||
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
|
||||
var config T
|
||||
uri := downloader.URI(url)
|
||||
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile[T any](filePath string) (*T, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -72,6 +73,7 @@ type PromptTemplate struct {
|
||||
|
||||
// Installs a model from the gallery
|
||||
func InstallModelFromGallery(
|
||||
ctx context.Context,
|
||||
modelGalleries, backendGalleries []config.Gallery,
|
||||
systemState *system.SystemState,
|
||||
modelLoader *model.ModelLoader,
|
||||
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
|
||||
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
|
||||
return err
|
||||
}
|
||||
|
||||
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
|
||||
// Download files and verify their SHA
|
||||
for i, file := range config.Files {
|
||||
// Check for cancellation before each file
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
|
||||
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
}
|
||||
}
|
||||
uri := downloader.URI(file.URI)
|
||||
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -85,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||
response := []gallery.GalleryModel{}
|
||||
uri := downloader.URI(url)
|
||||
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
||||
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
|
||||
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
|
||||
// Unmarshal YAML data into a struct
|
||||
return json.Unmarshal(i, &response)
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
@@ -35,23 +36,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
progress := 0
|
||||
isDeletion := false
|
||||
isQueued := false
|
||||
isCancelled := false
|
||||
isCancellable := false
|
||||
message := ""
|
||||
|
||||
if status != nil {
|
||||
// Skip completed operations
|
||||
if status.Processed {
|
||||
// Skip completed operations (unless cancelled and not yet cleaned up)
|
||||
if status.Processed && !status.Cancelled {
|
||||
continue
|
||||
}
|
||||
// Skip cancelled operations that are processed (they're done, no need to show)
|
||||
if status.Processed && status.Cancelled {
|
||||
continue
|
||||
}
|
||||
|
||||
progress = int(status.Progress)
|
||||
isDeletion = status.Deletion
|
||||
isCancelled = status.Cancelled
|
||||
isCancellable = status.Cancellable
|
||||
message = status.Message
|
||||
if isDeletion {
|
||||
taskType = "deletion"
|
||||
}
|
||||
if isCancelled {
|
||||
taskType = "cancelled"
|
||||
}
|
||||
} else {
|
||||
// Job is queued but hasn't started
|
||||
isQueued = true
|
||||
isCancellable = true
|
||||
message = "Operation queued"
|
||||
}
|
||||
|
||||
@@ -76,16 +89,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
|
||||
operations = append(operations, fiber.Map{
|
||||
"id": galleryID,
|
||||
"name": displayName,
|
||||
"fullName": galleryID,
|
||||
"jobID": jobID,
|
||||
"progress": progress,
|
||||
"taskType": taskType,
|
||||
"isDeletion": isDeletion,
|
||||
"isBackend": isBackend,
|
||||
"isQueued": isQueued,
|
||||
"message": message,
|
||||
"id": galleryID,
|
||||
"name": displayName,
|
||||
"fullName": galleryID,
|
||||
"jobID": jobID,
|
||||
"progress": progress,
|
||||
"taskType": taskType,
|
||||
"isDeletion": isDeletion,
|
||||
"isBackend": isBackend,
|
||||
"isQueued": isQueued,
|
||||
"isCancelled": isCancelled,
|
||||
"cancellable": isCancellable,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -108,6 +123,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
})
|
||||
|
||||
// Cancel operation endpoint
|
||||
app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error {
|
||||
jobID := strings.Clone(c.Params("jobID"))
|
||||
log.Debug().Msgf("API request to cancel operation: %s", jobID)
|
||||
|
||||
err := galleryService.CancelOperation(jobID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Clean up opcache for cancelled operation
|
||||
opcache.DeleteUUID(jobID)
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"success": true,
|
||||
"message": "Operation cancelled",
|
||||
})
|
||||
})
|
||||
|
||||
// Model Gallery APIs
|
||||
app.Get("/api/models", func(c *fiber.Ctx) error {
|
||||
term := c.Query("term")
|
||||
@@ -248,12 +285,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
uid := id.String()
|
||||
opcache.Set(galleryID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uid,
|
||||
GalleryElementName: galleryID,
|
||||
Galleries: appConfig.Galleries,
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
}()
|
||||
@@ -291,13 +333,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
opcache.Set(galleryID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uid,
|
||||
Delete: true,
|
||||
GalleryElementName: galleryName,
|
||||
Galleries: appConfig.Galleries,
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
cl.RemoveModelConfig(galleryName)
|
||||
@@ -341,7 +388,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
_, err = gallery.InstallModel(appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
|
||||
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
@@ -526,11 +573,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
uid := id.String()
|
||||
opcache.Set(backendID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uid,
|
||||
GalleryElementName: backendID,
|
||||
Galleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
@@ -568,12 +620,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
opcache.Set(backendID, uid)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
op := services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uid,
|
||||
Delete: true,
|
||||
GalleryElementName: backendName,
|
||||
Galleries: appConfig.BackendGalleries,
|
||||
Context: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
// Store cancellation function immediately so queued operations can be cancelled
|
||||
galleryService.StoreCancellation(uid, cancelFunc)
|
||||
go func() {
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
@@ -71,15 +71,34 @@
|
||||
Queued
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!operation.isQueued && operation.message">
|
||||
<template x-if="operation.isCancelled">
|
||||
<span class="text-xs text-red-400 flex items-center">
|
||||
<i class="fas fa-ban mr-1"></i>
|
||||
Cancelling...
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!operation.isQueued && !operation.isCancelled && operation.message">
|
||||
<span class="text-xs text-[#94A3B8] truncate" x-text="operation.message"></span>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Progress percentage -->
|
||||
<div class="flex-shrink-0 text-right">
|
||||
<!-- Progress percentage and cancel button -->
|
||||
<div class="flex-shrink-0 text-right flex items-center space-x-2">
|
||||
<span class="text-[#E5E7EB] font-bold text-lg" x-text="operation.progress + '%'"></span>
|
||||
<template x-if="operation.cancellable && !operation.isCancelled">
|
||||
<button @click="cancelOperation(operation.jobID, operation.id)"
|
||||
class="text-red-400 hover:text-red-300 transition-colors p-1 rounded hover:bg-red-500/20"
|
||||
title="Cancel operation">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</template>
|
||||
<template x-if="operation.isCancelled">
|
||||
<span class="text-red-400 text-xs flex items-center">
|
||||
<i class="fas fa-ban mr-1"></i>
|
||||
Cancelled
|
||||
</span>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -88,8 +107,8 @@
|
||||
<div class="w-full bg-[#101827] rounded-full h-2 overflow-hidden border border-[#1E293B]">
|
||||
<div class="h-full rounded-full transition-all duration-300 ease-out"
|
||||
:class="{
|
||||
'bg-gradient-to-r from-[#38BDF8] to-[#8B5CF6]': !operation.isDeletion,
|
||||
'bg-gradient-to-r from-red-500 to-red-600': operation.isDeletion
|
||||
'bg-gradient-to-r from-[#38BDF8] to-[#8B5CF6]': !operation.isDeletion && !operation.isCancelled,
|
||||
'bg-gradient-to-r from-red-500 to-red-600': operation.isDeletion || operation.isCancelled
|
||||
}"
|
||||
:style="'width: ' + operation.progress + '%'">
|
||||
</div>
|
||||
@@ -141,6 +160,57 @@ function operationsStatus() {
|
||||
}
|
||||
},
|
||||
|
||||
async cancelOperation(jobID, operationID) {
|
||||
// Check if operation is already cancelled
|
||||
const operation = this.operations.find(op => op.jobID === jobID);
|
||||
if (operation && operation.isCancelled) {
|
||||
// Already cancelled, no need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/operations/${jobID}/cancel`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
const errorMessage = error.error || 'Failed to cancel operation';
|
||||
|
||||
// Don't show alert for "already cancelled" - just update UI silently
|
||||
if (errorMessage.includes('already cancelled')) {
|
||||
if (operation) {
|
||||
operation.isCancelled = true;
|
||||
operation.cancellable = false;
|
||||
}
|
||||
this.fetchOperations();
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
// Update the operation status immediately
|
||||
if (operation) {
|
||||
operation.isCancelled = true;
|
||||
operation.cancellable = false;
|
||||
operation.message = 'Cancelling...';
|
||||
}
|
||||
|
||||
// Refresh operations to get updated status
|
||||
this.fetchOperations();
|
||||
} catch (error) {
|
||||
console.error('Error cancelling operation:', error);
|
||||
// Only show alert if it's not an "already cancelled" error
|
||||
if (!error.message.includes('already cancelled')) {
|
||||
alert('Failed to cancel operation: ' + error.message);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
destroy() {
|
||||
if (this.pollInterval) {
|
||||
clearInterval(this.pollInterval);
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
@@ -10,14 +14,43 @@ import (
|
||||
|
||||
func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, any], systemState *system.SystemState) error {
|
||||
utils.ResetDownloadTimers()
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// Check if already cancelled
|
||||
if op.Context != nil {
|
||||
select {
|
||||
case <-op.Context.Done():
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
})
|
||||
return op.Context.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing backend: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
|
||||
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
// Check for cancellation during progress updates
|
||||
if op.Context != nil {
|
||||
select {
|
||||
case <-op.Context.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
ctx := op.Context
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var err error
|
||||
if op.Delete {
|
||||
err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName)
|
||||
@@ -25,9 +58,19 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
|
||||
} else {
|
||||
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
|
||||
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
|
||||
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
|
||||
err = gallery.InstallBackendFromGallery(ctx, g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
|
||||
}
|
||||
if err != nil {
|
||||
// Check if error is due to cancellation
|
||||
if op.Context != nil && errors.Is(err, op.Context.Err()) {
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
})
|
||||
return err
|
||||
}
|
||||
log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName)
|
||||
if !op.Delete {
|
||||
// If we didn't install the backend, we need to make sure we don't have a leftover directory
|
||||
@@ -42,6 +85,7 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
|
||||
Processed: true,
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
Message: "completed",
|
||||
Progress: 100})
|
||||
Progress: 100,
|
||||
Cancellable: false})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,88 +1,166 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
type GalleryService struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
sync.Mutex
|
||||
ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
|
||||
BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
|
||||
|
||||
modelLoader *model.ModelLoader
|
||||
statuses map[string]*GalleryOpStatus
|
||||
}
|
||||
|
||||
func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
|
||||
return &GalleryService{
|
||||
appConfig: appConfig,
|
||||
ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
|
||||
BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
|
||||
modelLoader: ml,
|
||||
statuses: make(map[string]*GalleryOpStatus),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
g.statuses[s] = op
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses[s]
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
|
||||
// updates the status with an error
|
||||
var updateError func(id string, e error)
|
||||
if !g.appConfig.OpaqueErrors {
|
||||
updateError = func(id string, e error) {
|
||||
g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||
}
|
||||
} else {
|
||||
updateError = func(id string, _ error) {
|
||||
g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.BackendGalleryChannel:
|
||||
err := g.backendHandler(&op, systemState)
|
||||
if err != nil {
|
||||
updateError(op.ID, err)
|
||||
}
|
||||
|
||||
case op := <-g.ModelGalleryChannel:
|
||||
err := g.modelHandler(&op, cl, systemState)
|
||||
if err != nil {
|
||||
updateError(op.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
type GalleryService struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
sync.Mutex
|
||||
ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
|
||||
BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
|
||||
|
||||
modelLoader *model.ModelLoader
|
||||
statuses map[string]*GalleryOpStatus
|
||||
cancellations map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
|
||||
return &GalleryService{
|
||||
appConfig: appConfig,
|
||||
ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
|
||||
BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
|
||||
modelLoader: ml,
|
||||
statuses: make(map[string]*GalleryOpStatus),
|
||||
cancellations: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
g.statuses[s] = op
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses[s]
|
||||
}
|
||||
|
||||
func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
// CancelOperation cancels an in-progress operation by its ID
|
||||
func (g *GalleryService) CancelOperation(id string) error {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
// Check if operation is already cancelled
|
||||
if status, ok := g.statuses[id]; ok && status.Cancelled {
|
||||
return fmt.Errorf("operation %q is already cancelled", id)
|
||||
}
|
||||
|
||||
cancelFunc, exists := g.cancellations[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("operation %q not found or already completed", id)
|
||||
}
|
||||
|
||||
// Cancel the operation
|
||||
cancelFunc()
|
||||
|
||||
// Update status to reflect cancellation
|
||||
if status, ok := g.statuses[id]; ok {
|
||||
status.Cancelled = true
|
||||
status.Processed = true
|
||||
status.Message = "cancelled"
|
||||
} else {
|
||||
// Create status for queued operations that haven't started yet
|
||||
g.statuses[id] = &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
Cancellable: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up cancellation function
|
||||
delete(g.cancellations, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeCancellation stores a cancellation function for an operation
|
||||
func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
g.cancellations[id] = cancelFunc
|
||||
}
|
||||
|
||||
// StoreCancellation is a public method to store a cancellation function for an operation
|
||||
// This allows cancellation functions to be stored immediately when operations are created,
|
||||
// enabling cancellation of queued operations that haven't started processing yet.
|
||||
func (g *GalleryService) StoreCancellation(id string, cancelFunc context.CancelFunc) {
|
||||
g.storeCancellation(id, cancelFunc)
|
||||
}
|
||||
|
||||
// removeCancellation removes a cancellation function when operation completes
|
||||
func (g *GalleryService) removeCancellation(id string) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
delete(g.cancellations, id)
|
||||
}
|
||||
|
||||
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
|
||||
// updates the status with an error
|
||||
var updateError func(id string, e error)
|
||||
if !g.appConfig.OpaqueErrors {
|
||||
updateError = func(id string, e error) {
|
||||
g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||
}
|
||||
} else {
|
||||
updateError = func(id string, _ error) {
|
||||
g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.BackendGalleryChannel:
|
||||
// Create context if not provided
|
||||
if op.Context == nil {
|
||||
op.Context, op.CancelFunc = context.WithCancel(c)
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
} else if op.CancelFunc != nil {
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
}
|
||||
err := g.backendHandler(&op, systemState)
|
||||
if err != nil {
|
||||
updateError(op.ID, err)
|
||||
}
|
||||
g.removeCancellation(op.ID)
|
||||
|
||||
case op := <-g.ModelGalleryChannel:
|
||||
// Create context if not provided
|
||||
if op.Context == nil {
|
||||
op.Context, op.CancelFunc = context.WithCancel(c)
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
} else if op.CancelFunc != nil {
|
||||
g.storeCancellation(op.ID, op.CancelFunc)
|
||||
}
|
||||
err := g.modelHandler(&op, cl, systemState)
|
||||
if err != nil {
|
||||
updateError(op.ID, err)
|
||||
}
|
||||
g.removeCancellation(op.ID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -13,22 +16,74 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
processingMessage = "processing file: %s. Total: %s. Current: %s"
|
||||
)
|
||||
|
||||
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, gallery.ModelConfig], cl *config.ModelConfigLoader, systemState *system.SystemState) error {
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
|
||||
// Check if already cancelled
|
||||
if op.Context != nil {
|
||||
select {
|
||||
case <-op.Context.Done():
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
})
|
||||
return op.Context.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing model: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
|
||||
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
// Check for cancellation during progress updates
|
||||
if op.Context != nil {
|
||||
select {
|
||||
case <-op.Context.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
err := processModelOperation(op, systemState, g.modelLoader, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
|
||||
if err != nil {
|
||||
// Check if error is due to cancellation
|
||||
if op.Context != nil && errors.Is(err, op.Context.Err()) {
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
})
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for cancellation before final steps
|
||||
if op.Context != nil {
|
||||
select {
|
||||
case <-op.Context.Done():
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{
|
||||
Cancelled: true,
|
||||
Processed: true,
|
||||
Message: "cancelled",
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
})
|
||||
return op.Context.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Reload models
|
||||
err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
@@ -46,26 +101,27 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, galler
|
||||
Processed: true,
|
||||
GalleryElementName: op.GalleryElementName,
|
||||
Message: "completed",
|
||||
Progress: 100})
|
||||
Progress: 100,
|
||||
Cancellable: false})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func installModelFromRemoteConfig(systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
|
||||
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath)
|
||||
func installModelFromRemoteConfig(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
|
||||
config, err := gallery.GetGalleryConfigFromURLWithContext[gallery.ModelConfig](ctx, req.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := gallery.InstallModel(ctx, systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -79,15 +135,16 @@ type galleryModel struct {
|
||||
}
|
||||
|
||||
func processRequests(systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = installModelFromRemoteConfig(systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
|
||||
err = installModelFromRemoteConfig(ctx, systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
|
||||
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
|
||||
ctx, galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -126,25 +183,40 @@ func processModelOperation(
|
||||
automaticallyInstallBackend bool,
|
||||
progressCallback func(string, string, string, float64),
|
||||
) error {
|
||||
ctx := op.Context
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Check for cancellation before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
switch {
|
||||
case op.Delete:
|
||||
return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName)
|
||||
case op.GalleryElement != nil:
|
||||
installedModel, err := gallery.InstallModel(
|
||||
systemState, op.GalleryElement.Name,
|
||||
ctx, systemState, op.GalleryElement.Name,
|
||||
op.GalleryElement,
|
||||
op.Req.Overrides,
|
||||
progressCallback, enforcePredownloadScans)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
if err := gallery.InstallBackendFromGallery(op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(ctx, op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
case op.GalleryElementName != "":
|
||||
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
|
||||
return gallery.InstallModelFromGallery(ctx, op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
|
||||
default:
|
||||
return installModelFromRemoteConfig(systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
|
||||
return installModelFromRemoteConfig(ctx, systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
)
|
||||
@@ -17,6 +19,10 @@ type GalleryOp[T any, E any] struct {
|
||||
|
||||
Galleries []config.Gallery
|
||||
BackendGalleries []config.Gallery
|
||||
|
||||
// Context for cancellation support
|
||||
Context context.Context
|
||||
CancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
type GalleryOpStatus struct {
|
||||
@@ -29,6 +35,8 @@ type GalleryOpStatus struct {
|
||||
TotalFileSize string `json:"file_size"`
|
||||
DownloadedFileSize string `json:"downloaded_size"`
|
||||
GalleryElementName string `json:"gallery_element_name"`
|
||||
Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled
|
||||
Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled
|
||||
}
|
||||
|
||||
type OpCache struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package startup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
|
||||
func InstallExternalBackends(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
|
||||
uri := downloader.URI(backend)
|
||||
switch {
|
||||
case uri.LooksLikeDir():
|
||||
@@ -21,7 +22,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
|
||||
name = filepath.Base(backend)
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
|
||||
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -35,7 +36,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
|
||||
return fmt.Errorf("specifying a name is required for OCI images")
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -53,7 +54,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
|
||||
name = strings.TrimSuffix(name, filepath.Ext(name))
|
||||
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -66,7 +67,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
|
||||
if name != "" || alias != "" {
|
||||
return fmt.Errorf("specifying a name or alias is not supported for this backend")
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(galleries, systemState, modelLoader, backend, downloadStatus, true)
|
||||
err := gallery.InstallBackendFromGallery(ctx, galleries, systemState, modelLoader, backend, downloadStatus, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package startup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -30,7 +31,7 @@ const (
|
||||
// InstallModels will preload models from the given list of URLs and galleries
|
||||
// It will download the model if it is not already present in the model path
|
||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||
func InstallModels(galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
||||
func InstallModels(ctx context.Context, galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
||||
// create an error that groups all errors
|
||||
var err error
|
||||
|
||||
@@ -53,7 +54,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
|
||||
log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend")
|
||||
return err
|
||||
}
|
||||
@@ -153,7 +154,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
|
||||
}
|
||||
} else {
|
||||
// Check if it's a model gallery, or print a warning
|
||||
e, found := installModel(galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
if e != nil && found {
|
||||
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
|
||||
err = errors.Join(err, e)
|
||||
@@ -210,7 +211,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
|
||||
return err
|
||||
}
|
||||
|
||||
func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
|
||||
func installModel(ctx context.Context, galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
|
||||
models, err := gallery.AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err, false
|
||||
@@ -226,7 +227,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName string
|
||||
}
|
||||
|
||||
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
||||
err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
err = gallery.InstallModelFromGallery(ctx, galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
if err != nil {
|
||||
return err, true
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package startup_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -33,7 +34,7 @@ var _ = Describe("Preload test", func() {
|
||||
url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", "phi-2")
|
||||
|
||||
InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
|
||||
InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
@@ -46,7 +47,7 @@ var _ = Describe("Preload test", func() {
|
||||
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
|
||||
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
|
||||
|
||||
err := InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
|
||||
err := InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package downloader
|
||||
|
||||
import "hash"
|
||||
import (
|
||||
"context"
|
||||
"hash"
|
||||
)
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
@@ -10,23 +13,45 @@ type progressWriter struct {
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
// Check for cancellation before writing
|
||||
if pw.ctx != nil {
|
||||
select {
|
||||
case <-pw.ctx.Done():
|
||||
return 0, pw.ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
n, err = pw.hash.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
pw.written += int64(n)
|
||||
|
||||
// Check for cancellation after writing chunk
|
||||
if pw.ctx != nil {
|
||||
select {
|
||||
case <-pw.ctx.Done():
|
||||
return n, pw.ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
if pw.totalFiles > 1 {
|
||||
// This is a multi-file download
|
||||
// so we need to adjust the percentage
|
||||
// to reflect the progress of the whole download
|
||||
// This is the file pw.fileNo of pw.totalFiles files. We assume that
|
||||
// This is the file pw.fileNo (0-indexed) of pw.totalFiles files. We assume that
|
||||
// the files before successfully downloaded.
|
||||
percentage = percentage / float64(pw.totalFiles)
|
||||
if pw.fileNo > 1 {
|
||||
percentage += float64(pw.fileNo-1) * 100 / float64(pw.totalFiles)
|
||||
if pw.fileNo > 0 {
|
||||
percentage += float64(pw.fileNo) * 100 / float64(pw.totalFiles)
|
||||
}
|
||||
}
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -49,10 +50,10 @@ func loadConfig() string {
|
||||
}
|
||||
|
||||
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
|
||||
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
|
||||
return uri.DownloadWithAuthorizationAndCallback(context.Background(), basePath, "", f)
|
||||
}
|
||||
|
||||
func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
|
||||
func (uri URI) DownloadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error {
|
||||
url := uri.ResolveURL()
|
||||
|
||||
if strings.HasPrefix(url, LocalPrefix) {
|
||||
@@ -83,8 +84,7 @@ func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorizati
|
||||
}
|
||||
|
||||
// Send a GET request to the URL
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -264,6 +264,10 @@ func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
|
||||
}
|
||||
|
||||
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus)
|
||||
}
|
||||
|
||||
func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
url := uri.ResolveURL()
|
||||
if uri.LooksLikeOCI() {
|
||||
|
||||
@@ -285,7 +289,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
}
|
||||
|
||||
if url, ok := strings.CutPrefix(url, OllamaPrefix); ok {
|
||||
return oci.OllamaFetchModel(url, filePath, progressStatus)
|
||||
return oci.OllamaFetchModel(ctx, url, filePath, progressStatus)
|
||||
}
|
||||
|
||||
if url, ok := strings.CutPrefix(url, OCIFilePrefix); ok {
|
||||
@@ -295,7 +299,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
return fmt.Errorf("failed to open tarball: %s", err.Error())
|
||||
}
|
||||
|
||||
return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
|
||||
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
|
||||
}
|
||||
|
||||
url = strings.TrimPrefix(url, OCIPrefix)
|
||||
@@ -304,7 +308,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
return fmt.Errorf("failed to get image %q: %v", url, err)
|
||||
}
|
||||
|
||||
return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
|
||||
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
|
||||
}
|
||||
|
||||
// We need to check if url looks like an URL or bail out
|
||||
@@ -312,6 +316,13 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
return fmt.Errorf("url %q does not look like an HTTP URL", url)
|
||||
}
|
||||
|
||||
// Check for cancellation before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Check if the file already exists
|
||||
_, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
@@ -346,7 +357,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
|
||||
log.Info().Msgf("Downloading %q", url)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
|
||||
}
|
||||
@@ -375,6 +386,12 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
// Start the request
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if errors.Is(err, context.Canceled) {
|
||||
// Clean up partial file on cancellation
|
||||
removePartialFile(tmpFilePath)
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to download file %q: %v", filePath, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -406,12 +423,27 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
||||
fileNo: fileN,
|
||||
totalFiles: total,
|
||||
downloadStatus: downloadStatus,
|
||||
ctx: ctx,
|
||||
}
|
||||
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if errors.Is(err, context.Canceled) {
|
||||
// Clean up partial file on cancellation
|
||||
removePartialFile(tmpFilePath)
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to write file %q: %v", filePath, err)
|
||||
}
|
||||
|
||||
// Check for cancellation before finalizing
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removePartialFile(tmpFilePath)
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFilePath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
|
||||
|
||||
@@ -6,13 +6,14 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/xio"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
|
||||
oras "oras.land/oras-go/v2"
|
||||
"oras.land/oras-go/v2/registry/remote"
|
||||
)
|
||||
|
||||
func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
|
||||
func FetchImageBlob(ctx context.Context, r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
|
||||
// 0. Create a file store for the output
|
||||
fs, err := os.Create(dst)
|
||||
if err != nil {
|
||||
@@ -21,7 +22,6 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
|
||||
defer fs.Close()
|
||||
|
||||
// 1. Connect to a remote repository
|
||||
ctx := context.Background()
|
||||
repo, err := remote.NewRepository(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create repository: %v", err)
|
||||
@@ -37,12 +37,12 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
|
||||
|
||||
if statusReader != nil {
|
||||
// 3. Write the file to the file store
|
||||
_, err = io.Copy(io.MultiWriter(fs, statusReader(desc)), reader)
|
||||
_, err = xio.Copy(ctx, io.MultiWriter(fs, statusReader(desc)), reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err = io.Copy(fs, reader)
|
||||
_, err = xio.Copy(ctx, fs, reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oci_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
|
||||
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
|
||||
f, err := os.CreateTemp("", "ollama")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(f.Name())
|
||||
err = FetchImageBlob("registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
|
||||
err = FetchImageBlob(context.TODO(), "registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
|
||||
"github.com/google/go-containerregistry/pkg/v1/tarball"
|
||||
"github.com/mudler/LocalAI/pkg/xio"
|
||||
)
|
||||
|
||||
// ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117
|
||||
@@ -97,7 +98,7 @@ func (pw *progressWriter) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
// ExtractOCIImage will extract a given targetImage into a given targetDestination
|
||||
func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
|
||||
func ExtractOCIImage(ctx context.Context, img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create a temporary tar file
|
||||
tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar")
|
||||
if err != nil {
|
||||
@@ -107,13 +108,13 @@ func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, do
|
||||
defer tmpTarFile.Close()
|
||||
|
||||
// Download the image as tar with progress tracking
|
||||
err = DownloadOCIImageTar(img, imageRef, tmpTarFile.Name(), downloadStatus)
|
||||
err = DownloadOCIImageTar(ctx, img, imageRef, tmpTarFile.Name(), downloadStatus)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download image tar: %v", err)
|
||||
}
|
||||
|
||||
// Extract the tar file to the target destination
|
||||
err = ExtractOCIImageFromTar(tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
|
||||
err = ExtractOCIImageFromTar(ctx, tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract image tar: %v", err)
|
||||
}
|
||||
@@ -207,7 +208,7 @@ func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.Aut
|
||||
|
||||
// DownloadOCIImageTar downloads the compressed layers of an image and then creates an uncompressed tar
|
||||
// This provides accurate size estimation and allows for later extraction
|
||||
func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
|
||||
func DownloadOCIImageTar(ctx context.Context, img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
|
||||
// Get layers to calculate total compressed size for estimation
|
||||
layers, err := img.Layers()
|
||||
if err != nil {
|
||||
@@ -267,7 +268,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
|
||||
return fmt.Errorf("failed to get compressed layer: %v", err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(writer, layerReader)
|
||||
_, err = xio.Copy(ctx, writer, layerReader)
|
||||
file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download layer %d: %v", i, err)
|
||||
@@ -298,7 +299,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
|
||||
|
||||
// Extract uncompressed tar from local image
|
||||
extractReader := mutate.Extract(localImg)
|
||||
_, err = io.Copy(tarFile, extractReader)
|
||||
_, err = xio.Copy(ctx, tarFile, extractReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract uncompressed tar: %v", err)
|
||||
}
|
||||
@@ -307,7 +308,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
|
||||
}
|
||||
|
||||
// ExtractOCIImageFromTar extracts an image from a previously downloaded tar file
|
||||
func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
|
||||
func ExtractOCIImageFromTar(ctx context.Context, tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
|
||||
// Open the tar file
|
||||
tarFile, err := os.Open(tarFilePath)
|
||||
if err != nil {
|
||||
@@ -331,7 +332,7 @@ func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, dow
|
||||
}
|
||||
|
||||
// Extract the tar file
|
||||
_, err = archive.Apply(context.Background(),
|
||||
_, err = archive.Apply(ctx,
|
||||
targetDestination, reader,
|
||||
archive.WithNoSameOwner())
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oci_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
@@ -30,7 +31,7 @@ var _ = Describe("OCI", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
err = ExtractOCIImage(img, imageName, dir, nil)
|
||||
err = ExtractOCIImage(context.TODO(), img, imageName, dir, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oci
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -76,7 +77,7 @@ func OllamaModelBlob(image string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
|
||||
func OllamaFetchModel(ctx context.Context, image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
|
||||
_, repository, imageNoTag := ParseImageParts(image)
|
||||
|
||||
blobID, err := OllamaModelBlob(image)
|
||||
@@ -84,5 +85,5 @@ func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Des
|
||||
return err
|
||||
}
|
||||
|
||||
return FetchImageBlob(fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
|
||||
return FetchImageBlob(ctx, fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oci_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
|
||||
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
|
||||
f, err := os.CreateTemp("", "ollama")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(f.Name())
|
||||
err = OllamaFetchModel("gemma:2b", f.Name(), nil)
|
||||
err = OllamaFetchModel(context.TODO(), "gemma:2b", f.Name(), nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
21
pkg/xio/copy.go
Normal file
21
pkg/xio/copy.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package xio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type readerFunc func(p []byte) (n int, err error)
|
||||
|
||||
func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
|
||||
|
||||
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
default:
|
||||
return src.Read(p)
|
||||
}
|
||||
}))
|
||||
}
|
||||
Reference in New Issue
Block a user