diff --git a/core/gallery/backends.go b/core/gallery/backends.go index 0da7a68ac..7515514f9 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -162,7 +162,7 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func return fmt.Errorf("failed to create backend path %q: %v", backendPath, err) } - if err := oci.ExtractOCIImage(img, config.URI, backendPath, downloadStatus); err != nil { + if err := oci.ExtractOCIImage(img, backendPath, downloadStatus); err != nil { return fmt.Errorf("failed to extract image %q: %v", config.URI, err) } @@ -246,15 +246,6 @@ func ListSystemBackends(basePath string) (map[string]string, error) { for _, backend := range backends { if backend.IsDir() { runFile := filepath.Join(basePath, backend.Name(), runFile) - // Skip if runfile and metadata file don't exist - if _, err := os.Stat(runFile); os.IsNotExist(err) { - continue - } - metadataFilePath := filepath.Join(basePath, backend.Name(), metadataFile) - if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) { - continue - } - backendsNames[backend.Name()] = runFile // Check for alias in metadata diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index dbd895932..2759f081f 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -121,12 +121,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath, func(model *GalleryModel) bool { - if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { - return true - } - return false - }) + galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath) if err != nil { return nil, err } @@ -142,14 +137,7 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElem // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath, func(backend *GalleryBackend) bool { - backends, err := ListSystemBackends(basePath) - if err != nil { - return false - } - _, exists := backends[backend.GetName()] - return exists - }) + galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath) if err != nil { return nil, err } @@ -174,7 +162,7 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) return refFile, err } -func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) { +func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string) ([]T, error) { var models []T = []T{} if strings.HasSuffix(gallery.URL, ".ref") { @@ -199,7 +187,15 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin // Add gallery to models for _, model := range models { model.SetGallery(gallery) - model.SetInstalled(isInstalledCallback(model)) + // we check if the model was already installed by checking if the config file exists + // TODO: (what to do if the model doesn't install a config file?) + // TODO: This is sub-optimal now that the gallery handles both backends and models - we need to abstract this away + if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { + model.SetInstalled(true) + } + if _, err := os.Stat(filepath.Join(basePath, model.GetName())); err == nil { + model.SetInstalled(true) + } } return models, nil } diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go index 6b6ba40e3..8d69b5dad 100644 --- a/core/http/routes/ui_backend_gallery.go +++ b/core/http/routes/ui_backend_gallery.go @@ -223,7 +223,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC return c.SendString(elements.ProgressBar("0")) } - if status.Progress == 100 && status.Processed && status.Message == "completed" { + if status.Progress == 100 { c.Set("HX-Trigger", "done") // this triggers /browse/backend/job/:uid return c.SendString(elements.ProgressBar("100")) } diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go index d9b0c43d6..1cc629cac 100644 --- a/core/http/routes/ui_gallery.go +++ b/core/http/routes/ui_gallery.go @@ -243,7 +243,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo return c.SendString(elements.ProgressBar("0")) } - if status.Progress == 100 && status.Processed && status.Message == "completed" { + if status.Progress == 100 { c.Set("HX-Trigger", "done") // this triggers /browse/job/:uid (which is when the job is done) return c.SendString(elements.ProgressBar("100")) } diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index a4da4f574..94c2e13af 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -256,7 +256,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat return fmt.Errorf("failed to get image %q: %v", url, err) } - return oci.ExtractOCIImage(img, url, filepath.Dir(filePath), downloadStatus) + return oci.ExtractOCIImage(img, filepath.Dir(filePath), downloadStatus) } // Check if the file already exists diff --git a/pkg/oci/image.go b/pkg/oci/image.go index e06442a97..3efbe189d 100644 --- a/pkg/oci/image.go +++ b/pkg/oci/image.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "os" "runtime" "strconv" "strings" @@ -22,7 +21,6 @@ import ( "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote/transport" - "github.com/google/go-containerregistry/pkg/v1/tarball" ) // ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117 @@ -97,30 +95,31 @@ func (pw *progressWriter) Write(p []byte) (int, error) { } // ExtractOCIImage will extract a given targetImage into a given targetDestination -func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error { - // Create a temporary tar file - tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar") - if err != nil { - return fmt.Errorf("failed to create temporary tar file: %v", err) - } - defer os.Remove(tmpTarFile.Name()) - defer tmpTarFile.Close() +func ExtractOCIImage(img v1.Image, targetDestination string, downloadStatus func(string, string, string, float64)) error { + var reader io.Reader + reader = mutate.Extract(img) - // Download the image as tar with progress tracking - err = DownloadOCIImageTar(img, imageRef, tmpTarFile.Name(), downloadStatus) - if err != nil { - return fmt.Errorf("failed to download image tar: %v", err) + if downloadStatus != nil { + var totalSize int64 + layers, err := img.Layers() + if err != nil { + return err + } + for _, layer := range layers { + size, err := layer.Size() + if err != nil { + return err + } + totalSize += size + } + reader = io.TeeReader(reader, &progressWriter{total: totalSize, downloadStatus: downloadStatus}) } - downloadStatus("Extracting", "", "", 0) + _, err := archive.Apply(context.Background(), + targetDestination, reader, + archive.WithNoSameOwner()) - // Extract the tar file to the target destination - err = ExtractOCIImageFromTar(tmpTarFile.Name(), imageRef, targetDestination, downloadStatus) - if err != nil { - return fmt.Errorf("failed to extract image tar: %v", err) - } - - return nil + return err } func ParseImageParts(image string) (tag, repository, dstimage string) { @@ -206,164 +205,3 @@ func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.Aut return size, nil } - -// DownloadOCIImageTar downloads the compressed layers of an image and then creates an uncompressed tar -// This provides accurate size estimation and allows for later extraction -func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error { - // Get layers to calculate total compressed size for estimation - layers, err := img.Layers() - if err != nil { - return fmt.Errorf("failed to get layers: %v", err) - } - - // Calculate total compressed size for progress tracking - var totalCompressedSize int64 - for _, layer := range layers { - size, err := layer.Size() - if err != nil { - return fmt.Errorf("failed to get layer size: %v", err) - } - totalCompressedSize += size - } - - // Create a temporary directory to store the compressed layers - tmpDir, err := os.MkdirTemp("", "localai-oci-layers-*") - if err != nil { - return fmt.Errorf("failed to create temporary directory: %v", err) - } - defer os.RemoveAll(tmpDir) - - // Download all compressed layers with progress tracking - var downloadedLayers []v1.Layer - var downloadedSize int64 - - // Extract image name from the reference for display - imageName := imageRef - for i, layer := range layers { - layerSize, err := layer.Size() - if err != nil { - return fmt.Errorf("failed to get layer size: %v", err) - } - - // Create a temporary file for this layer - layerFile := fmt.Sprintf("%s/layer-%d.tar.gz", tmpDir, i) - file, err := os.Create(layerFile) - if err != nil { - return fmt.Errorf("failed to create layer file: %v", err) - } - - // Create progress writer for this layer - var writer io.Writer = file - if downloadStatus != nil { - writer = io.MultiWriter(file, &progressWriter{ - total: totalCompressedSize, - fileName: fmt.Sprintf("Downloading %d/%d %s", i+1, len(layers), imageName), - downloadStatus: downloadStatus, - }) - } - - // Download the compressed layer - layerReader, err := layer.Compressed() - if err != nil { - file.Close() - return fmt.Errorf("failed to get compressed layer: %v", err) - } - - _, err = io.Copy(writer, layerReader) - file.Close() - if err != nil { - return fmt.Errorf("failed to download layer %d: %v", i, err) - } - - // Load the downloaded layer - downloadedLayer, err := tarball.LayerFromFile(layerFile) - if err != nil { - return fmt.Errorf("failed to load downloaded layer: %v", err) - } - - downloadedLayers = append(downloadedLayers, downloadedLayer) - downloadedSize += layerSize - } - - // Create a local image from the downloaded layers - localImg, err := mutate.AppendLayers(img, downloadedLayers...) - if err != nil { - return fmt.Errorf("failed to create local image: %v", err) - } - - // Now extract the uncompressed tar from the local image - tarFile, err := os.Create(tarFilePath) - if err != nil { - return fmt.Errorf("failed to create tar file: %v", err) - } - defer tarFile.Close() - - // Extract uncompressed tar from local image - extractReader := mutate.Extract(localImg) - _, err = io.Copy(tarFile, extractReader) - if err != nil { - return fmt.Errorf("failed to extract uncompressed tar: %v", err) - } - - return nil -} - -// ExtractOCIImageFromTar extracts an image from a previously downloaded tar file -func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error { - // Open the tar file - tarFile, err := os.Open(tarFilePath) - if err != nil { - return fmt.Errorf("failed to open tar file: %v", err) - } - defer tarFile.Close() - - // Get file size for progress tracking - fileInfo, err := tarFile.Stat() - if err != nil { - return fmt.Errorf("failed to get file info: %v", err) - } - - var reader io.Reader = tarFile - if downloadStatus != nil { - reader = io.TeeReader(tarFile, &progressWriter{ - total: fileInfo.Size(), - fileName: fmt.Sprintf("Extracting %s", imageRef), - downloadStatus: downloadStatus, - }) - } - - // Extract the tar file - _, err = archive.Apply(context.Background(), - targetDestination, reader, - archive.WithNoSameOwner()) - - return err -} - -// GetOCIImageUncompressedSize returns the total uncompressed size of an image -func GetOCIImageUncompressedSize(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (int64, error) { - var totalSize int64 - var img v1.Image - var err error - - img, err = GetImage(targetImage, targetPlatform, auth, t) - if err != nil { - return totalSize, err - } - - layers, err := img.Layers() - if err != nil { - return totalSize, err - } - - for _, layer := range layers { - // Use compressed size as an approximation since uncompressed size is not directly available - size, err := layer.Size() - if err != nil { - return totalSize, err - } - totalSize += size - } - - return totalSize, nil -} diff --git a/pkg/oci/image_test.go b/pkg/oci/image_test.go index 1e59d762f..3fc31c20a 100644 --- a/pkg/oci/image_test.go +++ b/pkg/oci/image_test.go @@ -30,7 +30,7 @@ var _ = Describe("OCI", func() { Expect(err).NotTo(HaveOccurred()) defer os.RemoveAll(dir) - err = ExtractOCIImage(img, imageName, dir, nil) + err = ExtractOCIImage(img, dir, nil) Expect(err).NotTo(HaveOccurred()) }) })