feat(ui): allow to cancel ops (#7264)

* feat(ui): allow to cancel ops

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve progress text

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Cancel queued ops, don't show up message cancellation always

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: fixup displaying of total progress over multiple files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-13 18:41:47 +01:00
committed by GitHub
parent b1d1f2a37d
commit 735ca757fa
29 changed files with 648 additions and 199 deletions

View File

@@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}

View File

@@ -45,7 +45,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err

View File

@@ -1,6 +1,7 @@
package cli
import (
"context"
"encoding/json"
"fmt"
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err

View File

@@ -3,6 +3,7 @@
package gallery
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}
// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err
}
@@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
}
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
}
} else {
uri := downloader.URI(config.URI)
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
// Check for cancellation before trying next mirror
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
break
}

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"encoding/json"
"os"
"path/filepath"
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil
}
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
return config, err
}
return config, nil
}
func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)

View File

@@ -1,6 +1,7 @@
package gallery
import (
"context"
"errors"
"fmt"
"os"
@@ -72,6 +73,7 @@ type PromptTemplate struct {
// Installs a model from the gallery
func InstallModelFromGallery(
ctx context.Context,
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
return err
}
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
return applyModel(model)
}
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
// Download files and verify their SHA
for i, file := range config.Files {
// Check for cancellation before each file
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
}
}
uri := downloader.URI(file.URI)
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err
}
}

View File

@@ -1,6 +1,7 @@
package gallery_test
import (
"context"
"errors"
"os"
"path/filepath"
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})

View File

@@ -85,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})

View File

@@ -1,6 +1,7 @@
package routes
import (
"context"
"fmt"
"math"
"net/url"
@@ -35,23 +36,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
progress := 0
isDeletion := false
isQueued := false
isCancelled := false
isCancellable := false
message := ""
if status != nil {
// Skip completed operations
if status.Processed {
// Skip completed operations (unless cancelled and not yet cleaned up)
if status.Processed && !status.Cancelled {
continue
}
// Skip cancelled operations that are processed (they're done, no need to show)
if status.Processed && status.Cancelled {
continue
}
progress = int(status.Progress)
isDeletion = status.Deletion
isCancelled = status.Cancelled
isCancellable = status.Cancellable
message = status.Message
if isDeletion {
taskType = "deletion"
}
if isCancelled {
taskType = "cancelled"
}
} else {
// Job is queued but hasn't started
isQueued = true
isCancellable = true
message = "Operation queued"
}
@@ -76,16 +89,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
operations = append(operations, fiber.Map{
"id": galleryID,
"name": displayName,
"fullName": galleryID,
"jobID": jobID,
"progress": progress,
"taskType": taskType,
"isDeletion": isDeletion,
"isBackend": isBackend,
"isQueued": isQueued,
"message": message,
"id": galleryID,
"name": displayName,
"fullName": galleryID,
"jobID": jobID,
"progress": progress,
"taskType": taskType,
"isDeletion": isDeletion,
"isBackend": isBackend,
"isQueued": isQueued,
"isCancelled": isCancelled,
"cancellable": isCancellable,
"message": message,
})
}
@@ -108,6 +123,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
})
// Cancel operation endpoint
app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error {
jobID := strings.Clone(c.Params("jobID"))
log.Debug().Msgf("API request to cancel operation: %s", jobID)
err := galleryService.CancelOperation(jobID)
if err != nil {
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": err.Error(),
})
}
// Clean up opcache for cancelled operation
opcache.DeleteUUID(jobID)
return c.JSON(fiber.Map{
"success": true,
"message": "Operation cancelled",
})
})
// Model Gallery APIs
app.Get("/api/models", func(c *fiber.Ctx) error {
term := c.Query("term")
@@ -248,12 +285,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(galleryID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
GalleryElementName: galleryID,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
}()
@@ -291,13 +333,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(galleryID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
Delete: true,
GalleryElementName: galleryName,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
cl.RemoveModelConfig(galleryName)
@@ -341,7 +388,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
_, err = gallery.InstallModel(appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": err.Error(),
@@ -526,11 +573,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(backendID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
GalleryElementName: backendID,
Galleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()
@@ -568,12 +620,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(backendID, uid)
ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
Delete: true,
GalleryElementName: backendName,
Galleries: appConfig.BackendGalleries,
Context: ctx,
CancelFunc: cancelFunc,
}
// Store cancellation function immediately so queued operations can be cancelled
galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()

View File

@@ -71,15 +71,34 @@
Queued
</span>
</template>
<template x-if="!operation.isQueued && operation.message">
<template x-if="operation.isCancelled">
<span class="text-xs text-red-400 flex items-center">
<i class="fas fa-ban mr-1"></i>
Cancelling...
</span>
</template>
<template x-if="!operation.isQueued && !operation.isCancelled && operation.message">
<span class="text-xs text-[#94A3B8] truncate" x-text="operation.message"></span>
</template>
</div>
</div>
<!-- Progress percentage -->
<div class="flex-shrink-0 text-right">
<!-- Progress percentage and cancel button -->
<div class="flex-shrink-0 text-right flex items-center space-x-2">
<span class="text-[#E5E7EB] font-bold text-lg" x-text="operation.progress + '%'"></span>
<template x-if="operation.cancellable && !operation.isCancelled">
<button @click="cancelOperation(operation.jobID, operation.id)"
class="text-red-400 hover:text-red-300 transition-colors p-1 rounded hover:bg-red-500/20"
title="Cancel operation">
<i class="fas fa-times"></i>
</button>
</template>
<template x-if="operation.isCancelled">
<span class="text-red-400 text-xs flex items-center">
<i class="fas fa-ban mr-1"></i>
Cancelled
</span>
</template>
</div>
</div>
</div>
@@ -88,8 +107,8 @@
<div class="w-full bg-[#101827] rounded-full h-2 overflow-hidden border border-[#1E293B]">
<div class="h-full rounded-full transition-all duration-300 ease-out"
:class="{
'bg-gradient-to-r from-[#38BDF8] to-[#8B5CF6]': !operation.isDeletion,
'bg-gradient-to-r from-red-500 to-red-600': operation.isDeletion
'bg-gradient-to-r from-[#38BDF8] to-[#8B5CF6]': !operation.isDeletion && !operation.isCancelled,
'bg-gradient-to-r from-red-500 to-red-600': operation.isDeletion || operation.isCancelled
}"
:style="'width: ' + operation.progress + '%'">
</div>
@@ -141,6 +160,57 @@ function operationsStatus() {
}
},
async cancelOperation(jobID, operationID) {
// Check if operation is already cancelled
const operation = this.operations.find(op => op.jobID === jobID);
if (operation && operation.isCancelled) {
// Already cancelled, no need to do anything
return;
}
try {
const response = await fetch(`/api/operations/${jobID}/cancel`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
});
if (!response.ok) {
const error = await response.json();
const errorMessage = error.error || 'Failed to cancel operation';
// Don't show alert for "already cancelled" - just update UI silently
if (errorMessage.includes('already cancelled')) {
if (operation) {
operation.isCancelled = true;
operation.cancellable = false;
}
this.fetchOperations();
return;
}
throw new Error(errorMessage);
}
// Update the operation status immediately
if (operation) {
operation.isCancelled = true;
operation.cancellable = false;
operation.message = 'Cancelling...';
}
// Refresh operations to get updated status
this.fetchOperations();
} catch (error) {
console.error('Error cancelling operation:', error);
// Only show alert if it's not an "already cancelled" error
if (!error.message.includes('already cancelled')) {
alert('Failed to cancel operation: ' + error.message);
}
}
},
destroy() {
if (this.pollInterval) {
clearInterval(this.pollInterval);

View File

@@ -1,6 +1,10 @@
package services
import (
"context"
"errors"
"fmt"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/system"
@@ -10,14 +14,43 @@ import (
func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, any], systemState *system.SystemState) error {
utils.ResetDownloadTimers()
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
// Check if already cancelled
if op.Context != nil {
select {
case <-op.Context.Done():
g.UpdateStatus(op.ID, &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
GalleryElementName: op.GalleryElementName,
})
return op.Context.Err()
default:
}
}
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing backend: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
// Check for cancellation during progress updates
if op.Context != nil {
select {
case <-op.Context.Done():
return
default:
}
}
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
ctx := op.Context
if ctx == nil {
ctx = context.Background()
}
var err error
if op.Delete {
err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName)
@@ -25,9 +58,19 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
} else {
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
err = gallery.InstallBackendFromGallery(ctx, g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
}
if err != nil {
// Check if error is due to cancellation
if op.Context != nil && errors.Is(err, op.Context.Err()) {
g.UpdateStatus(op.ID, &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
GalleryElementName: op.GalleryElementName,
})
return err
}
log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName)
if !op.Delete {
// If we didn't install the backend, we need to make sure we don't have a leftover directory
@@ -42,6 +85,7 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
Processed: true,
GalleryElementName: op.GalleryElementName,
Message: "completed",
Progress: 100})
Progress: 100,
Cancellable: false})
return nil
}

View File

@@ -1,88 +1,166 @@
package services
import (
"context"
"fmt"
"sync"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
)
type GalleryService struct {
appConfig *config.ApplicationConfig
sync.Mutex
ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
modelLoader *model.ModelLoader
statuses map[string]*GalleryOpStatus
}
func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
return &GalleryService{
appConfig: appConfig,
ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
modelLoader: ml,
statuses: make(map[string]*GalleryOpStatus),
}
}
func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
// updates the status with an error
var updateError func(id string, e error)
if !g.appConfig.OpaqueErrors {
updateError = func(id string, e error) {
g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
} else {
updateError = func(id string, _ error) {
g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
}
}
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.BackendGalleryChannel:
err := g.backendHandler(&op, systemState)
if err != nil {
updateError(op.ID, err)
}
case op := <-g.ModelGalleryChannel:
err := g.modelHandler(&op, cl, systemState)
if err != nil {
updateError(op.ID, err)
}
}
}
}()
return nil
}
package services
import (
"context"
"fmt"
"sync"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
)
type GalleryService struct {
appConfig *config.ApplicationConfig
sync.Mutex
ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
modelLoader *model.ModelLoader
statuses map[string]*GalleryOpStatus
cancellations map[string]context.CancelFunc
}
func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
return &GalleryService{
appConfig: appConfig,
ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
modelLoader: ml,
statuses: make(map[string]*GalleryOpStatus),
cancellations: make(map[string]context.CancelFunc),
}
}
func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
// CancelOperation cancels an in-progress operation by its ID
func (g *GalleryService) CancelOperation(id string) error {
g.Lock()
defer g.Unlock()
// Check if operation is already cancelled
if status, ok := g.statuses[id]; ok && status.Cancelled {
return fmt.Errorf("operation %q is already cancelled", id)
}
cancelFunc, exists := g.cancellations[id]
if !exists {
return fmt.Errorf("operation %q not found or already completed", id)
}
// Cancel the operation
cancelFunc()
// Update status to reflect cancellation
if status, ok := g.statuses[id]; ok {
status.Cancelled = true
status.Processed = true
status.Message = "cancelled"
} else {
// Create status for queued operations that haven't started yet
g.statuses[id] = &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
Cancellable: false,
}
}
// Clean up cancellation function
delete(g.cancellations, id)
return nil
}
// storeCancellation stores a cancellation function for an operation
func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) {
g.Lock()
defer g.Unlock()
g.cancellations[id] = cancelFunc
}
// StoreCancellation is a public method to store a cancellation function for an operation
// This allows cancellation functions to be stored immediately when operations are created,
// enabling cancellation of queued operations that haven't started processing yet.
func (g *GalleryService) StoreCancellation(id string, cancelFunc context.CancelFunc) {
g.storeCancellation(id, cancelFunc)
}
// removeCancellation removes a cancellation function when operation completes
func (g *GalleryService) removeCancellation(id string) {
g.Lock()
defer g.Unlock()
delete(g.cancellations, id)
}
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
// updates the status with an error
var updateError func(id string, e error)
if !g.appConfig.OpaqueErrors {
updateError = func(id string, e error) {
g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
} else {
updateError = func(id string, _ error) {
g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
}
}
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.BackendGalleryChannel:
// Create context if not provided
if op.Context == nil {
op.Context, op.CancelFunc = context.WithCancel(c)
g.storeCancellation(op.ID, op.CancelFunc)
} else if op.CancelFunc != nil {
g.storeCancellation(op.ID, op.CancelFunc)
}
err := g.backendHandler(&op, systemState)
if err != nil {
updateError(op.ID, err)
}
g.removeCancellation(op.ID)
case op := <-g.ModelGalleryChannel:
// Create context if not provided
if op.Context == nil {
op.Context, op.CancelFunc = context.WithCancel(c)
g.storeCancellation(op.ID, op.CancelFunc)
} else if op.CancelFunc != nil {
g.storeCancellation(op.ID, op.CancelFunc)
}
err := g.modelHandler(&op, cl, systemState)
if err != nil {
updateError(op.ID, err)
}
g.removeCancellation(op.ID)
}
}
}()
return nil
}

View File

@@ -1,7 +1,10 @@
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"github.com/mudler/LocalAI/core/config"
@@ -13,22 +16,74 @@ import (
"gopkg.in/yaml.v2"
)
const (
processingMessage = "processing file: %s. Total: %s. Current: %s"
)
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, gallery.ModelConfig], cl *config.ModelConfigLoader, systemState *system.SystemState) error {
utils.ResetDownloadTimers()
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
// Check if already cancelled
if op.Context != nil {
select {
case <-op.Context.Done():
g.UpdateStatus(op.ID, &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
GalleryElementName: op.GalleryElementName,
})
return op.Context.Err()
default:
}
}
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing model: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
// Check for cancellation during progress updates
if op.Context != nil {
select {
case <-op.Context.Done():
return
default:
}
}
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
err := processModelOperation(op, systemState, g.modelLoader, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
if err != nil {
// Check if error is due to cancellation
if op.Context != nil && errors.Is(err, op.Context.Err()) {
g.UpdateStatus(op.ID, &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
GalleryElementName: op.GalleryElementName,
})
return err
}
return err
}
// Check for cancellation before final steps
if op.Context != nil {
select {
case <-op.Context.Done():
g.UpdateStatus(op.ID, &GalleryOpStatus{
Cancelled: true,
Processed: true,
Message: "cancelled",
GalleryElementName: op.GalleryElementName,
})
return op.Context.Err()
default:
}
}
// Reload models
err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath)
if err != nil {
@@ -46,26 +101,27 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, galler
Processed: true,
GalleryElementName: op.GalleryElementName,
Message: "completed",
Progress: 100})
Progress: 100,
Cancellable: false})
return nil
}
func installModelFromRemoteConfig(systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath)
func installModelFromRemoteConfig(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
config, err := gallery.GetGalleryConfigFromURLWithContext[gallery.ModelConfig](ctx, req.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
installedModel, err := gallery.InstallModel(ctx, systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
if automaticallyInstallBackend && installedModel.Backend != "" {
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -79,15 +135,16 @@ type galleryModel struct {
}
func processRequests(systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
ctx := context.Background()
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = installModelFromRemoteConfig(systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
err = installModelFromRemoteConfig(ctx, systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
} else {
err = gallery.InstallModelFromGallery(
galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
ctx, galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
}
}
return err
@@ -126,25 +183,40 @@ func processModelOperation(
automaticallyInstallBackend bool,
progressCallback func(string, string, string, float64),
) error {
ctx := op.Context
if ctx == nil {
ctx = context.Background()
}
// Check for cancellation before starting
select {
case <-ctx.Done():
return ctx.Err()
default:
}
switch {
case op.Delete:
return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName)
case op.GalleryElement != nil:
installedModel, err := gallery.InstallModel(
systemState, op.GalleryElement.Name,
ctx, systemState, op.GalleryElement.Name,
op.GalleryElement,
op.Req.Overrides,
progressCallback, enforcePredownloadScans)
if err != nil {
return err
}
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := gallery.InstallBackendFromGallery(op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
if err := gallery.InstallBackendFromGallery(ctx, op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
return err
}
}
return err
return nil
case op.GalleryElementName != "":
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
return gallery.InstallModelFromGallery(ctx, op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
default:
return installModelFromRemoteConfig(systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
return installModelFromRemoteConfig(ctx, systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
}
}

View File

@@ -1,6 +1,8 @@
package services
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/xsync"
)
@@ -17,6 +19,10 @@ type GalleryOp[T any, E any] struct {
Galleries []config.Gallery
BackendGalleries []config.Gallery
// Context for cancellation support
Context context.Context
CancelFunc context.CancelFunc
}
type GalleryOpStatus struct {
@@ -29,6 +35,8 @@ type GalleryOpStatus struct {
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
GalleryElementName string `json:"gallery_element_name"`
Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled
Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled
}
type OpCache struct {

View File

@@ -1,6 +1,7 @@
package startup
import (
"context"
"fmt"
"path/filepath"
"strings"
@@ -13,7 +14,7 @@ import (
"github.com/rs/zerolog/log"
)
func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
func InstallExternalBackends(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
uri := downloader.URI(backend)
switch {
case uri.LooksLikeDir():
@@ -21,7 +22,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
name = filepath.Base(backend)
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -35,7 +36,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
return fmt.Errorf("specifying a name is required for OCI images")
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -53,7 +54,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
name = strings.TrimSuffix(name, filepath.Ext(name))
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -66,7 +67,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
if name != "" || alias != "" {
return fmt.Errorf("specifying a name or alias is not supported for this backend")
}
err := gallery.InstallBackendFromGallery(galleries, systemState, modelLoader, backend, downloadStatus, true)
err := gallery.InstallBackendFromGallery(ctx, galleries, systemState, modelLoader, backend, downloadStatus, true)
if err != nil {
return fmt.Errorf("error installing backend %s: %w", backend, err)
}

View File

@@ -1,6 +1,7 @@
package startup
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -30,7 +31,7 @@ const (
// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func InstallModels(galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
func InstallModels(ctx context.Context, galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
// create an error that groups all errors
var err error
@@ -53,7 +54,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
return nil
}
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend")
return err
}
@@ -153,7 +154,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
}
} else {
// Check if it's a model gallery, or print a warning
e, found := installModel(galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
if e != nil && found {
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
err = errors.Join(err, e)
@@ -210,7 +211,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
return err
}
func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
func installModel(ctx context.Context, galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, systemState)
if err != nil {
return err, false
@@ -226,7 +227,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName string
}
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
err = gallery.InstallModelFromGallery(ctx, galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
if err != nil {
return err, true
}

View File

@@ -1,6 +1,7 @@
package startup_test
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -33,7 +34,7 @@ var _ = Describe("Preload test", func() {
url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", "phi-2")
InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
resultFile := filepath.Join(tmpdir, fileName)
@@ -46,7 +47,7 @@ var _ = Describe("Preload test", func() {
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
err := InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
err := InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
Expect(err).ToNot(HaveOccurred())
resultFile := filepath.Join(tmpdir, fileName)

View File

@@ -1,6 +1,9 @@
package downloader
import "hash"
import (
"context"
"hash"
)
type progressWriter struct {
fileName string
@@ -10,23 +13,45 @@ type progressWriter struct {
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
ctx context.Context
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
// Check for cancellation before writing
if pw.ctx != nil {
select {
case <-pw.ctx.Done():
return 0, pw.ctx.Err()
default:
}
}
n, err = pw.hash.Write(p)
if err != nil {
return n, err
}
pw.written += int64(n)
// Check for cancellation after writing chunk
if pw.ctx != nil {
select {
case <-pw.ctx.Done():
return n, pw.ctx.Err()
default:
}
}
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
if pw.totalFiles > 1 {
// This is a multi-file download
// so we need to adjust the percentage
// to reflect the progress of the whole download
// This is the file pw.fileNo of pw.totalFiles files. We assume that
// This is the file pw.fileNo (0-indexed) of pw.totalFiles files. We assume that
// the files before successfully downloaded.
percentage = percentage / float64(pw.totalFiles)
if pw.fileNo > 1 {
percentage += float64(pw.fileNo-1) * 100 / float64(pw.totalFiles)
if pw.fileNo > 0 {
percentage += float64(pw.fileNo) * 100 / float64(pw.totalFiles)
}
}
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)

View File

@@ -1,6 +1,7 @@
package downloader
import (
"context"
"crypto/sha256"
"errors"
"fmt"
@@ -49,10 +50,10 @@ func loadConfig() string {
}
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
return uri.DownloadWithAuthorizationAndCallback(context.Background(), basePath, "", f)
}
func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
func (uri URI) DownloadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()
if strings.HasPrefix(url, LocalPrefix) {
@@ -83,8 +84,7 @@ func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorizati
}
// Send a GET request to the URL
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
@@ -264,6 +264,10 @@ func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
}
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus)
}
func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
url := uri.ResolveURL()
if uri.LooksLikeOCI() {
@@ -285,7 +289,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
}
if url, ok := strings.CutPrefix(url, OllamaPrefix); ok {
return oci.OllamaFetchModel(url, filePath, progressStatus)
return oci.OllamaFetchModel(ctx, url, filePath, progressStatus)
}
if url, ok := strings.CutPrefix(url, OCIFilePrefix); ok {
@@ -295,7 +299,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("failed to open tarball: %s", err.Error())
}
return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
url = strings.TrimPrefix(url, OCIPrefix)
@@ -304,7 +308,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("failed to get image %q: %v", url, err)
}
return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
// We need to check if url looks like an URL or bail out
@@ -312,6 +316,13 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("url %q does not look like an HTTP URL", url)
}
// Check for cancellation before starting
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
@@ -346,7 +357,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
log.Info().Msgf("Downloading %q", url)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
}
@@ -375,6 +386,12 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
// Start the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
// Check if error is due to context cancellation
if errors.Is(err, context.Canceled) {
// Clean up partial file on cancellation
removePartialFile(tmpFilePath)
return err
}
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
defer resp.Body.Close()
@@ -406,12 +423,27 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
ctx: ctx,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
// Check if error is due to context cancellation
if errors.Is(err, context.Canceled) {
// Clean up partial file on cancellation
removePartialFile(tmpFilePath)
return err
}
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
// Check for cancellation before finalizing
select {
case <-ctx.Done():
removePartialFile(tmpFilePath)
return ctx.Err()
default:
}
err = os.Rename(tmpFilePath, filePath)
if err != nil {
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)

View File

@@ -6,13 +6,14 @@ import (
"io"
"os"
"github.com/mudler/LocalAI/pkg/xio"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
oras "oras.land/oras-go/v2"
"oras.land/oras-go/v2/registry/remote"
)
func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
func FetchImageBlob(ctx context.Context, r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
// 0. Create a file store for the output
fs, err := os.Create(dst)
if err != nil {
@@ -21,7 +22,6 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
defer fs.Close()
// 1. Connect to a remote repository
ctx := context.Background()
repo, err := remote.NewRepository(r)
if err != nil {
return fmt.Errorf("failed to create repository: %v", err)
@@ -37,12 +37,12 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
if statusReader != nil {
// 3. Write the file to the file store
_, err = io.Copy(io.MultiWriter(fs, statusReader(desc)), reader)
_, err = xio.Copy(ctx, io.MultiWriter(fs, statusReader(desc)), reader)
if err != nil {
return err
}
} else {
_, err = io.Copy(fs, reader)
_, err = xio.Copy(ctx, fs, reader)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package oci_test
import (
"context"
"os"
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
err = FetchImageBlob("registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
err = FetchImageBlob(context.TODO(), "registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})

View File

@@ -23,6 +23,7 @@ import (
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
"github.com/google/go-containerregistry/pkg/v1/tarball"
"github.com/mudler/LocalAI/pkg/xio"
)
// ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117
@@ -97,7 +98,7 @@ func (pw *progressWriter) Write(p []byte) (int, error) {
}
// ExtractOCIImage will extract a given targetImage into a given targetDestination
func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
func ExtractOCIImage(ctx context.Context, img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
// Create a temporary tar file
tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar")
if err != nil {
@@ -107,13 +108,13 @@ func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, do
defer tmpTarFile.Close()
// Download the image as tar with progress tracking
err = DownloadOCIImageTar(img, imageRef, tmpTarFile.Name(), downloadStatus)
err = DownloadOCIImageTar(ctx, img, imageRef, tmpTarFile.Name(), downloadStatus)
if err != nil {
return fmt.Errorf("failed to download image tar: %v", err)
}
// Extract the tar file to the target destination
err = ExtractOCIImageFromTar(tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
err = ExtractOCIImageFromTar(ctx, tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
if err != nil {
return fmt.Errorf("failed to extract image tar: %v", err)
}
@@ -207,7 +208,7 @@ func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.Aut
// DownloadOCIImageTar downloads the compressed layers of an image and then creates an uncompressed tar
// This provides accurate size estimation and allows for later extraction
func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
func DownloadOCIImageTar(ctx context.Context, img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
// Get layers to calculate total compressed size for estimation
layers, err := img.Layers()
if err != nil {
@@ -267,7 +268,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
return fmt.Errorf("failed to get compressed layer: %v", err)
}
_, err = io.Copy(writer, layerReader)
_, err = xio.Copy(ctx, writer, layerReader)
file.Close()
if err != nil {
return fmt.Errorf("failed to download layer %d: %v", i, err)
@@ -298,7 +299,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
// Extract uncompressed tar from local image
extractReader := mutate.Extract(localImg)
_, err = io.Copy(tarFile, extractReader)
_, err = xio.Copy(ctx, tarFile, extractReader)
if err != nil {
return fmt.Errorf("failed to extract uncompressed tar: %v", err)
}
@@ -307,7 +308,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
}
// ExtractOCIImageFromTar extracts an image from a previously downloaded tar file
func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
func ExtractOCIImageFromTar(ctx context.Context, tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
// Open the tar file
tarFile, err := os.Open(tarFilePath)
if err != nil {
@@ -331,7 +332,7 @@ func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, dow
}
// Extract the tar file
_, err = archive.Apply(context.Background(),
_, err = archive.Apply(ctx,
targetDestination, reader,
archive.WithNoSameOwner())

View File

@@ -1,6 +1,7 @@
package oci_test
import (
"context"
"os"
"runtime"
@@ -30,7 +31,7 @@ var _ = Describe("OCI", func() {
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(dir)
err = ExtractOCIImage(img, imageName, dir, nil)
err = ExtractOCIImage(context.TODO(), img, imageName, dir, nil)
Expect(err).NotTo(HaveOccurred())
})
})

View File

@@ -1,6 +1,7 @@
package oci
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -76,7 +77,7 @@ func OllamaModelBlob(image string) (string, error) {
return "", nil
}
func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
func OllamaFetchModel(ctx context.Context, image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
_, repository, imageNoTag := ParseImageParts(image)
blobID, err := OllamaModelBlob(image)
@@ -84,5 +85,5 @@ func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Des
return err
}
return FetchImageBlob(fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
return FetchImageBlob(ctx, fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
}

View File

@@ -1,6 +1,7 @@
package oci_test
import (
"context"
"os"
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
err = OllamaFetchModel("gemma:2b", f.Name(), nil)
err = OllamaFetchModel(context.TODO(), "gemma:2b", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})

21
pkg/xio/copy.go Normal file
View File

@@ -0,0 +1,21 @@
package xio
import (
"context"
"io"
)
type readerFunc func(p []byte) (n int, err error)
func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return src.Read(p)
}
}))
}