diff --git a/core/gallery/backends.go b/core/gallery/backends.go index 7515514f9..0da7a68ac 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -162,7 +162,7 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func return fmt.Errorf("failed to create backend path %q: %v", backendPath, err) } - if err := oci.ExtractOCIImage(img, backendPath, downloadStatus); err != nil { + if err := oci.ExtractOCIImage(img, config.URI, backendPath, downloadStatus); err != nil { return fmt.Errorf("failed to extract image %q: %v", config.URI, err) } @@ -246,6 +246,15 @@ func ListSystemBackends(basePath string) (map[string]string, error) { for _, backend := range backends { if backend.IsDir() { runFile := filepath.Join(basePath, backend.Name(), runFile) + // Skip if runfile and metadata file don't exist + if _, err := os.Stat(runFile); os.IsNotExist(err) { + continue + } + metadataFilePath := filepath.Join(basePath, backend.Name(), metadataFile) + if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) { + continue + } + backendsNames[backend.Name()] = runFile // Check for alias in metadata diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 2759f081f..dbd895932 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -121,7 +121,12 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath) + galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath, func(model *GalleryModel) bool { + if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { + return true + } + return false + }) if err != nil { return nil, err } @@ -137,7 +142,14 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElem // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath) + galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath, func(backend *GalleryBackend) bool { + backends, err := ListSystemBackends(basePath) + if err != nil { + return false + } + _, exists := backends[backend.GetName()] + return exists + }) if err != nil { return nil, err } @@ -162,7 +174,7 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) return refFile, err } -func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string) ([]T, error) { +func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) { var models []T = []T{} if strings.HasSuffix(gallery.URL, ".ref") { @@ -187,15 +199,7 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin // Add gallery to models for _, model := range models { model.SetGallery(gallery) - // we check if the model was already installed by checking if the config file exists - // TODO: (what to do if the model doesn't install a config file?) - // TODO: This is sub-optimal now that the gallery handles both backends and models - we need to abstract this away - if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { - model.SetInstalled(true) - } - if _, err := os.Stat(filepath.Join(basePath, model.GetName())); err == nil { - model.SetInstalled(true) - } + model.SetInstalled(isInstalledCallback(model)) } return models, nil } diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go index 8d69b5dad..6b6ba40e3 100644 --- a/core/http/routes/ui_backend_gallery.go +++ b/core/http/routes/ui_backend_gallery.go @@ -223,7 +223,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC return c.SendString(elements.ProgressBar("0")) } - if status.Progress == 100 { + if status.Progress == 100 && status.Processed && status.Message == "completed" { c.Set("HX-Trigger", "done") // this triggers /browse/backend/job/:uid return c.SendString(elements.ProgressBar("100")) } diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go index 1cc629cac..d9b0c43d6 100644 --- a/core/http/routes/ui_gallery.go +++ b/core/http/routes/ui_gallery.go @@ -243,7 +243,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo return c.SendString(elements.ProgressBar("0")) } - if status.Progress == 100 { + if status.Progress == 100 && status.Processed && status.Message == "completed" { c.Set("HX-Trigger", "done") // this triggers /browse/job/:uid (which is when the job is done) return c.SendString(elements.ProgressBar("100")) } diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 94c2e13af..a4da4f574 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -256,7 +256,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat return fmt.Errorf("failed to get image %q: %v", url, err) } - return oci.ExtractOCIImage(img, filepath.Dir(filePath), downloadStatus) + return oci.ExtractOCIImage(img, url, filepath.Dir(filePath), downloadStatus) } // Check if the file already exists diff --git a/pkg/oci/image.go b/pkg/oci/image.go index 3efbe189d..e06442a97 100644 --- a/pkg/oci/image.go +++ b/pkg/oci/image.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "runtime" "strconv" "strings" @@ -21,6 +22,7 @@ import ( "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/google/go-containerregistry/pkg/v1/tarball" ) // ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117 @@ -95,31 +97,30 @@ func (pw *progressWriter) Write(p []byte) (int, error) { } // ExtractOCIImage will extract a given targetImage into a given targetDestination -func ExtractOCIImage(img v1.Image, targetDestination string, downloadStatus func(string, string, string, float64)) error { - var reader io.Reader - reader = mutate.Extract(img) +func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error { + // Create a temporary tar file + tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar") + if err != nil { + return fmt.Errorf("failed to create temporary tar file: %v", err) + } + defer os.Remove(tmpTarFile.Name()) + defer tmpTarFile.Close() - if downloadStatus != nil { - var totalSize int64 - layers, err := img.Layers() - if err != nil { - return err - } - for _, layer := range layers { - size, err := layer.Size() - if err != nil { - return err - } - totalSize += size - } - reader = io.TeeReader(reader, &progressWriter{total: totalSize, downloadStatus: downloadStatus}) + // Download the image as tar with progress tracking + err = DownloadOCIImageTar(img, imageRef, tmpTarFile.Name(), downloadStatus) + if err != nil { + return fmt.Errorf("failed to download image tar: %v", err) } - _, err := archive.Apply(context.Background(), - targetDestination, reader, - archive.WithNoSameOwner()) + downloadStatus("Extracting", "", "", 0) - return err + // Extract the tar file to the target destination + err = ExtractOCIImageFromTar(tmpTarFile.Name(), imageRef, targetDestination, downloadStatus) + if err != nil { + return fmt.Errorf("failed to extract image tar: %v", err) + } + + return nil } func ParseImageParts(image string) (tag, repository, dstimage string) { @@ -205,3 +206,164 @@ func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.Aut return size, nil } + +// DownloadOCIImageTar downloads the compressed layers of an image and then creates an uncompressed tar +// This provides accurate size estimation and allows for later extraction +func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error { + // Get layers to calculate total compressed size for estimation + layers, err := img.Layers() + if err != nil { + return fmt.Errorf("failed to get layers: %v", err) + } + + // Calculate total compressed size for progress tracking + var totalCompressedSize int64 + for _, layer := range layers { + size, err := layer.Size() + if err != nil { + return fmt.Errorf("failed to get layer size: %v", err) + } + totalCompressedSize += size + } + + // Create a temporary directory to store the compressed layers + tmpDir, err := os.MkdirTemp("", "localai-oci-layers-*") + if err != nil { + return fmt.Errorf("failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Download all compressed layers with progress tracking + var downloadedLayers []v1.Layer + var downloadedSize int64 + + // Extract image name from the reference for display + imageName := imageRef + for i, layer := range layers { + layerSize, err := layer.Size() + if err != nil { + return fmt.Errorf("failed to get layer size: %v", err) + } + + // Create a temporary file for this layer + layerFile := fmt.Sprintf("%s/layer-%d.tar.gz", tmpDir, i) + file, err := os.Create(layerFile) + if err != nil { + return fmt.Errorf("failed to create layer file: %v", err) + } + + // Create progress writer for this layer + var writer io.Writer = file + if downloadStatus != nil { + writer = io.MultiWriter(file, &progressWriter{ + total: totalCompressedSize, + fileName: fmt.Sprintf("Downloading %d/%d %s", i+1, len(layers), imageName), + downloadStatus: downloadStatus, + }) + } + + // Download the compressed layer + layerReader, err := layer.Compressed() + if err != nil { + file.Close() + return fmt.Errorf("failed to get compressed layer: %v", err) + } + + _, err = io.Copy(writer, layerReader) + file.Close() + if err != nil { + return fmt.Errorf("failed to download layer %d: %v", i, err) + } + + // Load the downloaded layer + downloadedLayer, err := tarball.LayerFromFile(layerFile) + if err != nil { + return fmt.Errorf("failed to load downloaded layer: %v", err) + } + + downloadedLayers = append(downloadedLayers, downloadedLayer) + downloadedSize += layerSize + } + + // Create a local image from the downloaded layers + localImg, err := mutate.AppendLayers(img, downloadedLayers...) + if err != nil { + return fmt.Errorf("failed to create local image: %v", err) + } + + // Now extract the uncompressed tar from the local image + tarFile, err := os.Create(tarFilePath) + if err != nil { + return fmt.Errorf("failed to create tar file: %v", err) + } + defer tarFile.Close() + + // Extract uncompressed tar from local image + extractReader := mutate.Extract(localImg) + _, err = io.Copy(tarFile, extractReader) + if err != nil { + return fmt.Errorf("failed to extract uncompressed tar: %v", err) + } + + return nil +} + +// ExtractOCIImageFromTar extracts an image from a previously downloaded tar file +func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error { + // Open the tar file + tarFile, err := os.Open(tarFilePath) + if err != nil { + return fmt.Errorf("failed to open tar file: %v", err) + } + defer tarFile.Close() + + // Get file size for progress tracking + fileInfo, err := tarFile.Stat() + if err != nil { + return fmt.Errorf("failed to get file info: %v", err) + } + + var reader io.Reader = tarFile + if downloadStatus != nil { + reader = io.TeeReader(tarFile, &progressWriter{ + total: fileInfo.Size(), + fileName: fmt.Sprintf("Extracting %s", imageRef), + downloadStatus: downloadStatus, + }) + } + + // Extract the tar file + _, err = archive.Apply(context.Background(), + targetDestination, reader, + archive.WithNoSameOwner()) + + return err +} + +// GetOCIImageUncompressedSize returns the total uncompressed size of an image +func GetOCIImageUncompressedSize(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (int64, error) { + var totalSize int64 + var img v1.Image + var err error + + img, err = GetImage(targetImage, targetPlatform, auth, t) + if err != nil { + return totalSize, err + } + + layers, err := img.Layers() + if err != nil { + return totalSize, err + } + + for _, layer := range layers { + // Use compressed size as an approximation since uncompressed size is not directly available + size, err := layer.Size() + if err != nil { + return totalSize, err + } + totalSize += size + } + + return totalSize, nil +} diff --git a/pkg/oci/image_test.go b/pkg/oci/image_test.go index 3fc31c20a..1e59d762f 100644 --- a/pkg/oci/image_test.go +++ b/pkg/oci/image_test.go @@ -30,7 +30,7 @@ var _ = Describe("OCI", func() { Expect(err).NotTo(HaveOccurred()) defer os.RemoveAll(dir) - err = ExtractOCIImage(img, dir, nil) + err = ExtractOCIImage(img, imageName, dir, nil) Expect(err).NotTo(HaveOccurred()) }) })