feat(backends): add system backend, refactor (#6059)

- Add a system backend path
- Refactor and consolidate system information in system state
- Use system state in all the components to figure out the system paths
  to used whenever needed
- Refactor BackendConfig -> ModelConfig. This was otherway misleading as
  now we do have a backend configuration which is not the model config.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-08-14 19:38:26 +02:00
committed by GitHub
parent 253b7537dc
commit 089efe05fd
85 changed files with 999 additions and 652 deletions

View File

@@ -7,7 +7,7 @@ import (
)
type Application struct {
backendLoader *config.BackendConfigLoader
backendLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
templatesEvaluator *templates.Evaluator
@@ -15,14 +15,14 @@ type Application struct {
func newApplication(appConfig *config.ApplicationConfig) *Application {
return &Application{
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend),
backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
modelLoader: model.NewModelLoader(appConfig.SystemState, appConfig.SingleBackend),
applicationConfig: appConfig,
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath),
}
}
func (a *Application) BackendLoader() *config.BackendConfigLoader {
func (a *Application) BackendLoader() *config.ModelConfigLoader {
return a.backendLoader
}

View File

@@ -20,7 +20,7 @@ func New(opts ...config.AppOption) (*Application, error) {
options := config.NewApplicationConfig(opts...)
application := newApplication(options)
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
caps, err := xsysinfo.CPUCapabilities()
if err == nil {
@@ -35,10 +35,11 @@ func New(opts ...config.AppOption) (*Application, error) {
}
// Make sure directories exists
if options.ModelPath == "" {
return nil, fmt.Errorf("options.ModelPath cannot be empty")
if options.SystemState.Model.ModelsPath == "" {
return nil, fmt.Errorf("models path cannot be empty")
}
err = os.MkdirAll(options.ModelPath, 0750)
err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750)
if err != nil {
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
}
@@ -55,50 +56,50 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
configLoaderOpts := options.ToConfigLoaderOptions()
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
if err := application.BackendLoader().LoadModelConfigsFromPath(options.SystemState.Model.ModelsPath, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config files")
}
if err := gallery.RegisterBackends(options.BackendsPath, application.ModelLoader()); err != nil {
if err := gallery.RegisterBackends(options.SystemState, application.ModelLoader()); err != nil {
log.Error().Err(err).Msg("error registering external backends")
}
if options.ConfigFile != "" {
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
if err := application.BackendLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config file")
}
}
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
if err := application.BackendLoader().Preload(options.SystemState.Model.ModelsPath); err != nil {
log.Error().Err(err).Msg("error downloading models")
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
return nil, err
}
}
if options.Debug {
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
for _, v := range application.BackendLoader().GetAllModelsConfigs() {
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
}
}
@@ -131,7 +132,7 @@ func New(opts ...config.AppOption) (*Application, error) {
if options.LoadToMemory != nil && !options.SingleBackend {
for _, m := range options.LoadToMemory {
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
cfg, err := application.BackendLoader().LoadModelConfigFileByNameDefaultOptions(m, options)
if err != nil {
return nil, err
}

View File

@@ -13,9 +13,9 @@ func Detection(
sourceFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
modelConfig config.ModelConfig,
) (*proto.DetectResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
detectionModel, err := loader.Load(opts...)
if err != nil {
return nil, err

View File

@@ -9,9 +9,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(opts...)
if err != nil {
@@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}

View File

@@ -7,9 +7,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(
opts...,
)
@@ -27,12 +27,12 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
CLIPSkip: int32(modelConfig.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: backendConfig.Diffusers.EnableParameters,
EnableParameters: modelConfig.Diffusers.EnableParameters,
RefImages: refImages,
})
return err

View File

@@ -35,7 +35,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, cl *config.BackendConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, c.Name, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err
@@ -201,7 +201,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.BackendConfig, input, prediction string) string {
func Finetune(config config.ModelConfig, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}

View File

@@ -12,14 +12,14 @@ import (
var _ = Describe("LLM tests", func() {
Context("Finetune LLM output", func() {
var (
testConfig config.BackendConfig
testConfig config.ModelConfig
input string
prediction string
result string
)
BeforeEach(func() {
testConfig = config.BackendConfig{
testConfig = config.ModelConfig{
PredictionOptions: schema.PredictionOptions{
Echo: false,
},

View File

@@ -11,7 +11,7 @@ import (
"github.com/rs/zerolog/log"
)
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
name := c.Name
if name == "" {
name = c.Model
@@ -58,7 +58,7 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...
return append(defOpts, opts...)
}
func getSeed(c config.BackendConfig) int32 {
func getSeed(c config.ModelConfig) int32 {
var seed int32 = config.RAND_SEED
if c.Seed != nil {
@@ -72,7 +72,7 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}
func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
@@ -195,7 +195,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
}
}
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)

View File

@@ -9,8 +9,8 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig)
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
opts := ModelOptions(modelConfig, appConfig)
rerankModel, err := loader.Load(opts...)
if err != nil {
return nil, err

View File

@@ -21,10 +21,10 @@ func SoundGeneration(
sourceDivisor *int32,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
modelConfig config.ModelConfig,
) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
soundGenModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
@@ -49,7 +49,7 @@ func SoundGeneration(
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: backendConfig.Model,
Model: modelConfig.Model,
Dst: filePath,
Sample: doSample,
Duration: duration,

View File

@@ -13,9 +13,9 @@ func TokenMetrics(
modelFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
modelConfig config.ModelConfig) (*proto.MetricsResponse, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile))
model, err := loader.Load(opts...)
if err != nil {
return nil, err

View File

@@ -7,19 +7,19 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err = loader.Load(opts...)
if err != nil {
return schema.TokenizeResponse{}, err
}
defer loader.Close()
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
predictOptions.Prompt = s
// tokenize the string

View File

@@ -12,13 +12,13 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if backendConfig.Backend == "" {
backendConfig.Backend = model.WhisperBackend
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
}
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
transcriptionModel, err := ml.Load(opts...)
if err != nil {
@@ -34,7 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
Dst: audio,
Language: language,
Translate: translate,
Threads: uint32(*backendConfig.Threads),
Threads: uint32(*modelConfig.Threads),
})
if err != nil {
return nil, err

View File

@@ -19,9 +19,9 @@ func ModelTTS(
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
modelConfig config.ModelConfig,
) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
ttsModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
@@ -29,7 +29,7 @@ func ModelTTS(
defer loader.Close()
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
return "", nil, fmt.Errorf("could not load tts model %q", modelConfig.Model)
}
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
@@ -47,14 +47,14 @@ func ModelTTS(
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
mp := filepath.Join(loader.ModelPath, modelConfig.Model)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = backendConfig.Model // skip this step if it fails?????
modelPath = modelConfig.Model // skip this step if it fails?????
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

View File

@@ -13,8 +13,8 @@ func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
opts := ModelOptions(modelConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
return nil, err

View File

@@ -7,9 +7,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(
opts...,
)

View File

@@ -6,6 +6,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
@@ -14,8 +15,9 @@ import (
)
type BackendsCMDFlags struct {
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
}
type BackendsList struct {
@@ -48,7 +50,15 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error {
log.Error().Err(err).Msg("unable to load galleries")
}
backends, err := gallery.AvailableBackends(galleries, bl.BackendsPath)
systemState, err := system.GetSystemState(
system.WithBackendSystemPath(bl.BackendsSystemPath),
system.WithBackendPath(bl.BackendsPath),
)
if err != nil {
return err
}
backends, err := gallery.AvailableBackends(galleries, systemState)
if err != nil {
return err
}
@@ -68,6 +78,14 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
log.Error().Err(err).Msg("unable to load galleries")
}
systemState, err := system.GetSystemState(
system.WithBackendSystemPath(bi.BackendsSystemPath),
system.WithBackendPath(bi.BackendsPath),
)
if err != nil {
return err
}
progressBar := progressbar.NewOptions(
1000,
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)),
@@ -82,7 +100,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
}
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}
@@ -94,7 +112,15 @@ func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
for _, backendName := range bu.BackendArgs {
log.Info().Str("backend", backendName).Msg("uninstalling backend")
err := gallery.DeleteBackendFromSystem(bu.BackendsPath, backendName)
systemState, err := system.GetSystemState(
system.WithBackendSystemPath(bu.BackendsSystemPath),
system.WithBackendPath(bu.BackendsPath),
)
if err != nil {
return err
}
err = gallery.DeleteBackendFromSystem(systemState, backendName)
if err != nil {
return err
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
"github.com/schollz/progressbar/v3"
)
@@ -45,7 +46,14 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
log.Error().Err(err).Msg("unable to load galleries")
}
models, err := gallery.AvailableGalleryModels(galleries, ml.ModelsPath)
systemState, err := system.GetSystemState(
system.WithModelPath(ml.ModelsPath),
system.WithBackendPath(ml.BackendsPath),
)
if err != nil {
return err
}
models, err := gallery.AvailableGalleryModels(galleries, systemState)
if err != nil {
return err
}
@@ -60,6 +68,15 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
}
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithModelPath(mi.ModelsPath),
system.WithBackendPath(mi.BackendsPath),
)
if err != nil {
return err
}
var galleries []config.Gallery
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
log.Error().Err(err).Msg("unable to load galleries")
@@ -86,7 +103,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}
}
//startup.InstallModels()
models, err := gallery.AvailableGalleryModels(galleries, mi.ModelsPath)
models, err := gallery.AvailableGalleryModels(galleries, systemState)
if err != nil {
return err
}
@@ -94,7 +111,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
modelURI := downloader.URI(modelName)
if !modelURI.LooksLikeOCI() {
model := gallery.FindGalleryElement(models, modelName, mi.ModelsPath)
model := gallery.FindGalleryElement(models, modelName)
if model == nil {
log.Error().Str("model", modelName).Msg("model not found")
return err
@@ -108,7 +125,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
}
err = startup.InstallModels(galleries, backendGalleries, mi.ModelsPath, mi.BackendsPath, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
@@ -22,6 +23,7 @@ type RunCMD struct {
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
@@ -77,12 +79,20 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
os.MkdirAll(r.BackendsPath, 0750)
os.MkdirAll(r.ModelsPath, 0750)
systemState, err := system.GetSystemState(
system.WithBackendSystemPath(r.BackendsSystemPath),
system.WithModelPath(r.ModelsPath),
system.WithBackendPath(r.BackendsPath),
)
if err != nil {
return err
}
opts := []config.AppOption{
config.WithConfigFile(r.ModelsConfigFile),
config.WithJSONStringPreload(r.PreloadModels),
config.WithYAMLConfigPreload(r.PreloadModelsConfig),
config.WithModelPath(r.ModelsPath),
config.WithBackendsPath(r.BackendsPath),
config.WithSystemState(systemState),
config.WithContextSize(r.ContextSize),
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
config.WithGeneratedContentDir(r.GeneratedContentPath),

View File

@@ -12,6 +12,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -56,6 +57,13 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
}
text := strings.Join(t.Text, " ")
systemState, err := system.GetSystemState(
system.WithModelPath(t.ModelsPath),
)
if err != nil {
return err
}
externalBackends := make(map[string]string)
// split ":" to get backend name and the uri
for _, v := range t.ExternalGRPCBackends {
@@ -66,12 +74,12 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
}
opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath,
SystemState: systemState,
Context: context.Background(),
GeneratedContentDir: outputDir,
ExternalGRPCBackends: externalBackends,
}
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
ml := model.NewModelLoader(systemState, opts.SingleBackend)
defer func() {
err := ml.StopAllGRPC()
@@ -80,7 +88,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
}
}()
options := config.BackendConfig{}
options := config.ModelConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model

View File

@@ -9,6 +9,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -24,18 +25,24 @@ type TranscriptCMD struct {
}
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithModelPath(t.ModelsPath),
)
if err != nil {
return err
}
opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath,
Context: context.Background(),
SystemState: systemState,
Context: context.Background(),
}
cl := config.NewBackendConfigLoader(t.ModelsPath)
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil {
cl := config.NewModelConfigLoader(t.ModelsPath)
ml := model.NewModelLoader(systemState, opts.SingleBackend)
if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
return err
}
c, exists := cl.GetBackendConfig(t.Model)
c, exists := cl.GetModelConfig(t.Model)
if !exists {
return errors.New("model not found")
}

View File

@@ -11,6 +11,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -34,12 +35,20 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
text := strings.Join(t.Text, " ")
systemState, err := system.GetSystemState(
system.WithModelPath(t.ModelsPath),
)
if err != nil {
return err
}
opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath,
SystemState: systemState,
Context: context.Background(),
GeneratedContentDir: outputDir,
}
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
ml := model.NewModelLoader(systemState, opts.SingleBackend)
defer func() {
err := ml.StopAllGRPC()
@@ -48,7 +57,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
}
}()
options := config.BackendConfig{}
options := config.ModelConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model

View File

@@ -17,6 +17,7 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/oci"
"github.com/mudler/LocalAI/pkg/system"
)
type UtilCMD struct {
@@ -108,6 +109,14 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
}
func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithModelPath(hfscmd.ModelsPath),
)
if err != nil {
return err
}
log.Info().Msg("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!")
if len(hfscmd.ToScan) == 0 {
log.Info().Msg("Checking all installed models against galleries")
@@ -116,7 +125,7 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
log.Error().Err(err).Msg("unable to load galleries")
}
err := gallery.SafetyScanGalleryModels(galleries, hfscmd.ModelsPath)
err := gallery.SafetyScanGalleryModels(galleries, systemState)
if err == nil {
log.Info().Msg("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.")
} else {
@@ -150,17 +159,17 @@ func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error {
log.Error().Msg("ModelsPath is a required parameter")
return fmt.Errorf("model path is a required parameter")
}
bcl := config.NewBackendConfigLoader(uhcmd.ModelsPath)
err := bcl.LoadBackendConfig(uhcmd.ConfigName)
bcl := config.NewModelConfigLoader(uhcmd.ModelsPath)
err := bcl.ReadModelConfig(uhcmd.ConfigName)
if err != nil {
log.Error().Err(err).Str("ConfigName", uhcmd.ConfigName).Msg("error while loading backend")
return err
}
bc, exists := bcl.GetBackendConfig(uhcmd.ConfigName)
bc, exists := bcl.GetModelConfig(uhcmd.ConfigName)
if !exists {
log.Error().Str("ConfigName", uhcmd.ConfigName).Msg("ConfigName not found")
}
for name, uc := range config.GetAllBackendConfigUsecases() {
for name, uc := range config.GetAllModelConfigUsecases() {
if bc.HasUsecases(uc) {
log.Info().Str("Usecase", name)
}

View File

@@ -1,8 +1,9 @@
package worker
type WorkerFlags struct {
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
}
type Worker struct {

View File

@@ -10,6 +10,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -21,20 +22,19 @@ const (
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
)
func findLLamaCPPBackend(backendSystemPath string) (string, error) {
backends, err := gallery.ListSystemBackends(backendSystemPath)
func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
log.Warn().Msgf("Failed listing system backends: %s", err)
return "", err
}
log.Debug().Msgf("System backends: %v", backends)
backendPath := ""
backend, ok := backends.Get("llama-cpp")
if !ok {
return "", errors.New("llama-cpp backend not found, install it first")
}
backendPath = filepath.Dir(backend.RunFile)
backendPath := filepath.Dir(backend.RunFile)
if backendPath == "" {
return "", errors.New("llama-cpp backend not found, install it first")
@@ -54,7 +54,14 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
return fmt.Errorf("usage: local-ai worker llama-cpp-rpc -- <llama-rpc-server-args>")
}
grpcProcess, err := findLLamaCPPBackend(r.BackendsPath)
systemState, err := system.GetSystemState(
system.WithBackendPath(r.BackendsPath),
system.WithBackendSystemPath(r.BackendsSystemPath),
)
if err != nil {
return err
}
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
return err
}

View File

@@ -10,6 +10,7 @@ import (
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/pkg/system"
"github.com/phayes/freeport"
"github.com/rs/zerolog/log"
)
@@ -25,6 +26,14 @@ type P2P struct {
func (r *P2P) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithBackendPath(r.BackendsPath),
system.WithBackendSystemPath(r.BackendsSystemPath),
)
if err != nil {
return err
}
// Check if the token is set
// as we always need it.
if r.Token == "" {
@@ -60,7 +69,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
for {
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
grpcProcess, err := findLLamaCPPBackend(r.BackendsPath)
grpcProcess, err := findLLamaCPPBackend(systemState)
if err != nil {
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
return

View File

@@ -6,6 +6,7 @@ import (
"regexp"
"time"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/rs/zerolog/log"
)
@@ -13,8 +14,7 @@ import (
type ApplicationConfig struct {
Context context.Context
ConfigFile string
ModelPath string
BackendsPath string
SystemState *system.SystemState
ExternalBackends []string
UploadLimitMB, Threads, ContextSize int
F16 bool
@@ -86,15 +86,9 @@ func WithModelsURL(urls ...string) AppOption {
}
}
func WithModelPath(path string) AppOption {
func WithSystemState(state *system.SystemState) AppOption {
return func(o *ApplicationConfig) {
o.ModelPath = path
}
}
func WithBackendsPath(path string) AppOption {
return func(o *ApplicationConfig) {
o.BackendsPath = path
o.SystemState = state
}
}
@@ -379,7 +373,7 @@ func (o *ApplicationConfig) ToConfigLoaderOptions() []ConfigLoaderOption {
LoadOptionDebug(o.Debug),
LoadOptionF16(o.F16),
LoadOptionThreads(o.Threads),
ModelPath(o.ModelPath),
ModelPath(o.SystemState.Model.ModelsPath),
}
}

View File

@@ -24,20 +24,20 @@ type TTSConfig struct {
AudioPath string `yaml:"audio_path"`
}
type BackendConfig struct {
type ModelConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *BackendConfigUsecases `yaml:"-"`
Pipeline Pipeline `yaml:"pipeline"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *ModelConfigUsecases `yaml:"-"`
Pipeline Pipeline `yaml:"pipeline"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
@@ -217,18 +217,18 @@ type TemplateConfig struct {
ReplyPrefix string `yaml:"reply_prefix"`
}
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias BackendConfig
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias ModelConfig
var aux BCAlias
if err := value.Decode(&aux); err != nil {
return err
}
*c = BackendConfig(aux)
*c = ModelConfig(aux)
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllBackendConfigUsecases() {
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
@@ -236,25 +236,25 @@ func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
return nil
}
func (c *BackendConfig) SetFunctionCallString(s string) {
func (c *ModelConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *BackendConfig) SetFunctionCallNameString(s string) {
func (c *ModelConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *BackendConfig) ShouldUseFunctions() bool {
func (c *ModelConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
func (c *ModelConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) MMProjFileName() string {
func (c *ModelConfig) MMProjFileName() string {
uri := downloader.URI(c.MMProj)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
@@ -264,19 +264,19 @@ func (c *BackendConfig) MMProjFileName() string {
return c.MMProj
}
func (c *BackendConfig) IsMMProjURL() bool {
func (c *ModelConfig) IsMMProjURL() bool {
uri := downloader.URI(c.MMProj)
return uri.LooksLikeURL()
}
func (c *BackendConfig) IsModelURL() bool {
func (c *ModelConfig) IsModelURL() bool {
uri := downloader.URI(c.Model)
return uri.LooksLikeURL()
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) ModelFileName() string {
func (c *ModelConfig) ModelFileName() string {
uri := downloader.URI(c.Model)
if uri.LooksLikeURL() {
f, _ := uri.FilenameFromUrl()
@@ -286,7 +286,7 @@ func (c *BackendConfig) ModelFileName() string {
return c.Model
}
func (c *BackendConfig) FunctionToCall() string {
func (c *ModelConfig) FunctionToCall() string {
if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
return c.functionCallNameString
@@ -295,7 +295,7 @@ func (c *BackendConfig) FunctionToCall() string {
return c.functionCallString
}
func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
lo := &LoadOptions{}
lo.Apply(opts...)
@@ -411,7 +411,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
guessDefaultsFromFile(cfg, lo.modelPath, ctx)
}
func (c *BackendConfig) Validate() bool {
func (c *ModelConfig) Validate() bool {
downloadedFileNames := []string{}
for _, f := range c.DownloadFiles {
downloadedFileNames = append(downloadedFileNames, f.Filename)
@@ -438,34 +438,34 @@ func (c *BackendConfig) Validate() bool {
return true
}
func (c *BackendConfig) HasTemplate() bool {
func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
}
type BackendConfigUsecases int
type ModelConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b000000000000
FLAG_CHAT BackendConfigUsecases = 0b000000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000000010
FLAG_EDIT BackendConfigUsecases = 0b000000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000
FLAG_RERANK BackendConfigUsecases = 0b000000010000
FLAG_IMAGE BackendConfigUsecases = 0b000000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000
FLAG_TTS BackendConfigUsecases = 0b000010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
FLAG_VAD BackendConfigUsecases = 0b010000000000
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
FLAG_ANY ModelConfigUsecases = 0b000000000000
FLAG_CHAT ModelConfigUsecases = 0b000000000001
FLAG_COMPLETION ModelConfigUsecases = 0b000000000010
FLAG_EDIT ModelConfigUsecases = 0b000000000100
FLAG_EMBEDDINGS ModelConfigUsecases = 0b000000001000
FLAG_RERANK ModelConfigUsecases = 0b000000010000
FLAG_IMAGE ModelConfigUsecases = 0b000000100000
FLAG_TRANSCRIPT ModelConfigUsecases = 0b000001000000
FLAG_TTS ModelConfigUsecases = 0b000010000000
FLAG_SOUND_GENERATION ModelConfigUsecases = 0b000100000000
FLAG_TOKENIZE ModelConfigUsecases = 0b001000000000
FLAG_VAD ModelConfigUsecases = 0b010000000000
FLAG_VIDEO ModelConfigUsecases = 0b100000000000
FLAG_DETECTION ModelConfigUsecases = 0b1000000000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
FLAG_LLM ModelConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
return map[string]BackendConfigUsecases{
func GetAllModelConfigUsecases() map[string]ModelConfigUsecases {
return map[string]ModelConfigUsecases{
"FLAG_ANY": FLAG_ANY,
"FLAG_CHAT": FLAG_CHAT,
"FLAG_COMPLETION": FLAG_COMPLETION,
@@ -488,12 +488,12 @@ func stringToFlag(s string) string {
return "FLAG_" + strings.ToUpper(s)
}
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
func GetUsecasesFromYAML(input []string) *ModelConfigUsecases {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllBackendConfigUsecases()
flags := GetAllModelConfigUsecases()
for _, str := range input {
flag, exists := flags[stringToFlag(str)]
if exists {
@@ -503,8 +503,8 @@ func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
return &result
}
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
return true
}
@@ -514,7 +514,7 @@ func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
return false

View File

@@ -2,11 +2,11 @@ package config
import "regexp"
type BackendConfigFilterFn func(string, *BackendConfig) bool
type ModelConfigFilterFn func(string, *ModelConfig) bool
func NoFilterFn(_ string, _ *BackendConfig) bool { return true }
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
if filter == "" {
return NoFilterFn, nil
}
@@ -14,7 +14,7 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
if err != nil {
return nil, err
}
return func(name string, config *BackendConfig) bool {
return func(name string, config *ModelConfig) bool {
if config != nil {
return rxp.MatchString(config.Name)
}
@@ -22,11 +22,11 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
}, nil
}
func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn {
func BuildUsecaseFilterFn(usecases ModelConfigUsecases) ModelConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
return func(name string, config *BackendConfig) bool {
return func(name string, config *ModelConfig) bool {
if config == nil {
return false // TODO: Potentially make this a param, for now, no known usecase to include
}

View File

@@ -18,15 +18,15 @@ import (
"gopkg.in/yaml.v3"
)
type BackendConfigLoader struct {
configs map[string]BackendConfig
type ModelConfigLoader struct {
configs map[string]ModelConfig
modelPath string
sync.Mutex
}
func NewBackendConfigLoader(modelPath string) *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
func NewModelConfigLoader(modelPath string) *ModelConfigLoader {
return &ModelConfigLoader{
configs: make(map[string]ModelConfig),
modelPath: modelPath,
}
}
@@ -77,14 +77,14 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
}
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) {
c := &[]*ModelConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot unmarshal config file %q: %w", file, err)
}
for _, cc := range *c {
@@ -94,17 +94,17 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
return *c, nil
}
func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
c := &BackendConfig{}
c := &ModelConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
return nil, fmt.Errorf("readModelConfigFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
return nil, fmt.Errorf("readModelConfigFromFile cannot unmarshal config file %q: %w", file, err)
}
c.SetDefaults(opts...)
@@ -112,10 +112,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
}
// Load a config file for a model
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
func (bcl *ModelConfigLoader) LoadModelConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
cfg := &ModelConfig{
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
@@ -123,19 +123,19 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
},
}
cfgExisting, exists := bcl.GetBackendConfig(modelName)
cfgExisting, exists := bcl.GetModelConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := bcl.LoadBackendConfig(
if err := bcl.ReadModelConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = bcl.GetBackendConfig(modelName)
cfgExisting, exists = bcl.GetModelConfig(modelName)
if exists {
cfg = &cfgExisting
}
@@ -147,20 +147,20 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*ModelConfig, error) {
return bcl.LoadModelConfigFileByName(modelName, appConfig.SystemState.Model.ModelsPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
ModelPath(appConfig.SystemState.Model.ModelsPath))
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readMultipleBackendConfigsFromFile(file, opts...)
c, err := readMultipleModelConfigsFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
@@ -173,12 +173,12 @@ func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string
return nil
}
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
c, err := readModelConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
}
if c.Validate() {
@@ -190,17 +190,17 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
return nil
}
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
func (bcl *ModelConfigLoader) GetModelConfig(m string) (ModelConfig, bool) {
bcl.Lock()
defer bcl.Unlock()
v, exists := bcl.configs[m]
return v, exists
}
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
var res []ModelConfig
for _, v := range bcl.configs {
res = append(res, v)
}
@@ -212,10 +212,10 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
return res
}
func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig {
func (bcl *ModelConfigLoader) GetModelConfigsByFilter(filter ModelConfigFilterFn) []ModelConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
var res []ModelConfig
if filter == nil {
filter = NoFilterFn
@@ -232,14 +232,14 @@ func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFi
return res
}
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
bcl.Lock()
defer bcl.Unlock()
delete(bcl.configs, m)
}
// Preload prepare models if they are not local but url or huggingface repositories
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *ModelConfigLoader) Preload(modelPath string) error {
bcl.Lock()
defer bcl.Unlock()
@@ -330,15 +330,15 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
return nil
}
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// LoadModelConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
return fmt.Errorf("LoadModelConfigsFromPath cannot read directory '%s': %w", path, err)
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
@@ -354,9 +354,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
c, err := readModelConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil {
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file")
continue
}
if c.Validate() {

View File

@@ -25,7 +25,7 @@ known_usecases:
- COMPLETION
`)
Expect(err).ToNot(HaveOccurred())
config, err := readBackendConfigFromFile(tmp.Name())
config, err := readModelConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
Expect(config.Validate()).To(BeFalse())
@@ -41,7 +41,7 @@ backend: "foo-bar"
parameters:
model: "foo-bar"`)
Expect(err).ToNot(HaveOccurred())
config, err := readBackendConfigFromFile(tmp.Name())
config, err := readModelConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@@ -58,7 +58,7 @@ parameters:
defer os.Remove(tmp.Name())
_, err = io.Copy(tmp, resp.Body)
Expect(err).To(BeNil())
config, err = readBackendConfigFromFile(tmp.Name())
config, err = readModelConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@@ -68,12 +68,12 @@ parameters:
})
It("Properly handles backend usecase matching", func() {
a := BackendConfig{
a := ModelConfig{
Name: "a",
}
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
b := BackendConfig{
b := ModelConfig{
Name: "b",
Backend: "stablediffusion",
}
@@ -81,7 +81,7 @@ parameters:
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
c := BackendConfig{
c := ModelConfig{
Name: "c",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
@@ -93,7 +93,7 @@ parameters:
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
d := BackendConfig{
d := ModelConfig{
Name: "d",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
@@ -107,7 +107,7 @@ parameters:
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
trueValue := true
e := BackendConfig{
e := ModelConfig{
Name: "e",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
@@ -122,7 +122,7 @@ parameters:
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
f := BackendConfig{
f := ModelConfig{
Name: "f",
Backend: "piper",
}
@@ -130,7 +130,7 @@ parameters:
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
g := BackendConfig{
g := ModelConfig{
Name: "g",
Backend: "whisper",
}
@@ -138,7 +138,7 @@ parameters:
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
h := BackendConfig{
h := ModelConfig{
Name: "h",
Backend: "transformers-musicgen",
}
@@ -148,7 +148,7 @@ parameters:
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
i := BackendConfig{
i := ModelConfig{
Name: "i",
Backend: "whisper",
// Earlier test checks parsing, this just needs to set final values

View File

@@ -16,7 +16,7 @@ var _ = Describe("Test cases for config related functions", func() {
Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test readConfigFile", func() {
config, err := readMultipleBackendConfigsFromFile(configFile)
config, err := readMultipleModelConfigsFromFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@@ -26,11 +26,11 @@ var _ = Describe("Test cases for config related functions", func() {
It("Test LoadConfigs", func() {
bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH"))
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
Expect(err).To(BeNil())
configs := bcl.GetAllBackendConfigs()
configs := bcl.GetAllModelsConfigs()
loadedModelNames := []string{}
for _, v := range configs {
loadedModelNames = append(loadedModelNames, v.Name)
@@ -51,10 +51,10 @@ var _ = Describe("Test cases for config related functions", func() {
It("Test new loadconfig", func() {
bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH"))
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
Expect(err).To(BeNil())
configs := bcl.GetAllBackendConfigs()
configs := bcl.GetAllModelsConfigs()
loadedModelNames := []string{}
for _, v := range configs {
loadedModelNames = append(loadedModelNames, v.Name)
@@ -90,14 +90,14 @@ options:
err = os.WriteFile(modelFile, []byte(model), 0644)
Expect(err).ToNot(HaveOccurred())
err = bcl.LoadBackendConfigsFromPath(tmpdir)
err = bcl.LoadModelConfigsFromPath(tmpdir)
Expect(err).ToNot(HaveOccurred())
configs = bcl.GetAllBackendConfigs()
configs = bcl.GetAllModelsConfigs()
Expect(len(configs)).ToNot(Equal(totalModels))
loadedModelNames = []string{}
var testModel BackendConfig
var testModel ModelConfig
for _, v := range configs {
loadedModelNames = append(loadedModelNames, v.Name)
if v.Name == "test-model" {

View File

@@ -146,7 +146,7 @@ var knownTemplates = map[string]familyType{
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
}
func guessGGUFFromFile(cfg *BackendConfig, f *gguf.GGUFFile, defaultCtx int) {
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
if defaultCtx == 0 && cfg.ContextSize == nil {
ctxSize := f.EstimateLLaMACppRun().ContextSize

View File

@@ -8,7 +8,7 @@ import (
"github.com/rs/zerolog/log"
)
func guessDefaultsFromFile(cfg *BackendConfig, modelPath string, defaultCtx int) {
func guessDefaultsFromFile(cfg *ModelConfig, modelPath string, defaultCtx int) {
if os.Getenv("LOCALAI_DISABLE_GUESSING") == "true" {
log.Debug().Msgf("guessDefaultsFromFile: %s", "guessing disabled with LOCALAI_DISABLE_GUESSING")
return

View File

@@ -59,10 +59,10 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}
// Installs a model from the gallery
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, basePath string, downloadStatus func(string, string, string, float64), force bool) error {
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(basePath)
backends, err := ListSystemBackends(systemState)
if err != nil {
return err
}
@@ -77,12 +77,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Interface("galleries", galleries).Str("name", name).Msg("Installing backend from gallery")
backends, err := AvailableBackends(galleries, basePath)
backends, err := AvailableBackends(galleries, systemState)
if err != nil {
return err
}
backend := FindGalleryElement(backends, name, basePath)
backend := FindGalleryElement(backends, name)
if backend == nil {
return fmt.Errorf("no backend found with name %q", name)
}
@@ -99,12 +99,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(basePath, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(systemState, bestBackend, downloadStatus); err != nil {
return err
}
// we need now to create a path for the meta backend, with the alias to the installed ones so it can be used to remove it
metaBackendPath := filepath.Join(basePath, name)
metaBackendPath := filepath.Join(systemState.Backend.BackendsPath, name)
if err := os.MkdirAll(metaBackendPath, 0750); err != nil {
return fmt.Errorf("failed to create meta backend path %q: %v", metaBackendPath, err)
}
@@ -124,12 +124,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
return InstallBackend(basePath, backend, downloadStatus)
return InstallBackend(systemState, backend, downloadStatus)
}
func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(systemState *system.SystemState, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
return fmt.Errorf("failed to create base path: %v", err)
}
@@ -139,7 +139,7 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
}
name := config.Name
backendPath := filepath.Join(basePath, name)
backendPath := filepath.Join(systemState.Backend.BackendsPath, name)
err = os.MkdirAll(backendPath, 0750)
if err != nil {
return fmt.Errorf("failed to create base path: %v", err)
@@ -188,14 +188,28 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
return nil
}
func DeleteBackendFromSystem(basePath string, name string) error {
backendDirectory := filepath.Join(basePath, name)
func DeleteBackendFromSystem(systemState *system.SystemState, name string) error {
backends, err := ListSystemBackends(systemState)
if err != nil {
return err
}
backend, ok := backends.Get(name)
if !ok {
return fmt.Errorf("backend %q not found", name)
}
if backend.IsSystem {
return fmt.Errorf("system backend %q cannot be deleted", name)
}
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
// check if the backend dir exists
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
// if doesn't exist, it might be an alias, so we need to check if we have a matching alias in
// all the backends in the basePath
backends, err := os.ReadDir(basePath)
backends, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return err
}
@@ -203,12 +217,12 @@ func DeleteBackendFromSystem(basePath string, name string) error {
for _, backend := range backends {
if backend.IsDir() {
metadata, err := readBackendMetadata(filepath.Join(basePath, backend.Name()))
metadata, err := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, backend.Name()))
if err != nil {
return err
}
if metadata != nil && metadata.Alias == name {
backendDirectory = filepath.Join(basePath, backend.Name())
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
foundBackend = true
break
}
@@ -228,7 +242,7 @@ func DeleteBackendFromSystem(basePath string, name string) error {
}
if metadata != nil && metadata.MetaBackendFor != "" {
metaBackendDirectory := filepath.Join(basePath, metadata.MetaBackendFor)
metaBackendDirectory := filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor)
log.Debug().Str("backendDirectory", metaBackendDirectory).Msg("Deleting meta backend")
if _, err := os.Stat(metaBackendDirectory); os.IsNotExist(err) {
return fmt.Errorf("meta backend %q not found", metadata.MetaBackendFor)
@@ -243,6 +257,7 @@ type SystemBackend struct {
Name string
RunFile string
IsMeta bool
IsSystem bool
Metadata *BackendMetadata
}
@@ -266,30 +281,51 @@ func (b SystemBackends) GetAll() []SystemBackend {
return backends
}
func ListSystemBackends(basePath string) (SystemBackends, error) {
potentialBackends, err := os.ReadDir(basePath)
func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) {
potentialBackends, err := os.ReadDir(systemState.Backend.BackendsPath)
if err != nil {
return nil, err
}
backends := make(SystemBackends)
systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath)
if err == nil {
// system backends are special, they are provided by the system and not managed by LocalAI
for _, systemBackend := range systemBackends {
if systemBackend.IsDir() {
systemBackendRunFile := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
if _, err := os.Stat(systemBackendRunFile); err == nil {
backends[systemBackend.Name()] = SystemBackend{
Name: systemBackend.Name(),
RunFile: filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile),
IsMeta: false,
IsSystem: true,
Metadata: nil,
}
}
}
}
} else {
log.Warn().Err(err).Msg("Failed to read system backends, but that's ok, we will just use the backends managed by LocalAI")
}
for _, potentialBackend := range potentialBackends {
if potentialBackend.IsDir() {
potentialBackendRunFile := filepath.Join(basePath, potentialBackend.Name(), runFile)
potentialBackendRunFile := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), runFile)
var metadata *BackendMetadata
// If metadata file does not exist, we just use the directory name
// and we do not fill the other metadata (such as potential backend Aliases)
metadataFilePath := filepath.Join(basePath, potentialBackend.Name(), metadataFile)
metadataFilePath := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), metadataFile)
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
metadata = &BackendMetadata{
Name: potentialBackend.Name(),
}
} else {
// Check for alias in metadata
metadata, err = readBackendMetadata(filepath.Join(basePath, potentialBackend.Name()))
metadata, err = readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name()))
if err != nil {
return nil, err
}
@@ -323,7 +359,7 @@ func ListSystemBackends(basePath string) (SystemBackends, error) {
if metadata.MetaBackendFor != "" {
backends[metadata.Name] = SystemBackend{
Name: metadata.Name,
RunFile: filepath.Join(basePath, metadata.MetaBackendFor, runFile),
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
IsMeta: true,
Metadata: metadata,
}
@@ -334,8 +370,8 @@ func ListSystemBackends(basePath string) (SystemBackends, error) {
return backends, nil
}
func RegisterBackends(basePath string, modelLoader *model.ModelLoader) error {
backends, err := ListSystemBackends(basePath)
func RegisterBackends(systemState *system.SystemState, modelLoader *model.ModelLoader) error {
backends, err := ListSystemBackends(systemState)
if err != nil {
return err
}

View File

@@ -43,13 +43,21 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, nil, "non-existent", tempDir, nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, nil, "test-backend", tempDir, nil, true)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackendFromGallery(galleries, systemState, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -220,26 +228,32 @@ var _ = Describe("Gallery Backends", func() {
Expect(err).NotTo(HaveOccurred())
// Test with NVIDIA system state
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
nvidiaSystemState := &system.SystemState{
GPUVendor: "nvidia",
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
Expect(metaBackendPath).To(BeADirectory())
metaBackendPath = filepath.Join(tempDir, "meta-backend", "metadata.json")
Expect(metaBackendPath).To(BeARegularFile())
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
Expect(concreteBackendPath).To(BeADirectory())
allBackends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
allBackends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends).To(HaveKey("meta-backend"))
Expect(allBackends).To(HaveKey("nvidia-backend"))
// Delete meta backend by name
err = DeleteBackendFromSystem(tempDir, "meta-backend")
err = DeleteBackendFromSystem(systemState, "meta-backend")
Expect(err).NotTo(HaveOccurred())
// Verify meta backend directory is deleted
@@ -294,8 +308,12 @@ var _ = Describe("Gallery Backends", func() {
Expect(err).NotTo(HaveOccurred())
// Test with NVIDIA system state
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
nvidiaSystemState := &system.SystemState{
GPUVendor: "nvidia",
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -304,19 +322,22 @@ var _ = Describe("Gallery Backends", func() {
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
Expect(concreteBackendPath).To(BeADirectory())
allBackends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
backend, ok := allBackends.Get("meta-backend")
Expect(ok).To(BeTrue())
Expect(backend.Metadata.MetaBackendFor).To(Equal("nvidia-backend"))
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
Expect(backend.IsMeta).To(BeTrue())
allBackends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends).To(HaveKey("meta-backend"))
Expect(allBackends).To(HaveKey("nvidia-backend"))
mback, exists := allBackends.Get("meta-backend")
Expect(exists).To(BeTrue())
Expect(mback.IsMeta).To(BeTrue())
Expect(mback.Metadata.MetaBackendFor).To(Equal("nvidia-backend"))
// Delete meta backend by name
err = DeleteBackendFromSystem(tempDir, "meta-backend")
err = DeleteBackendFromSystem(systemState, "meta-backend")
Expect(err).NotTo(HaveOccurred())
// Verify meta backend directory is deleted
@@ -371,8 +392,12 @@ var _ = Describe("Gallery Backends", func() {
Expect(err).NotTo(HaveOccurred())
// Test with NVIDIA system state
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
nvidiaSystemState := &system.SystemState{
GPUVendor: "nvidia",
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -381,16 +406,21 @@ var _ = Describe("Gallery Backends", func() {
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
Expect(concreteBackendPath).To(BeADirectory())
allBackends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
backend, ok := allBackends.Get("meta-backend")
Expect(ok).To(BeTrue())
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
allBackends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
Expect(allBackends).To(HaveKey("meta-backend"))
Expect(allBackends).To(HaveKey("nvidia-backend"))
mback, exists := allBackends.Get("meta-backend")
Expect(exists).To(BeTrue())
Expect(mback.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
// Delete meta backend by name
err = DeleteBackendFromSystem(tempDir, "meta-backend")
err = DeleteBackendFromSystem(systemState, "meta-backend")
Expect(err).NotTo(HaveOccurred())
// Verify meta backend directory is deleted
@@ -427,25 +457,28 @@ var _ = Describe("Gallery Backends", func() {
Expect(err).NotTo(HaveOccurred())
// List system backends
backends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
backends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
metaBackend, exists := backends.Get("meta-backend")
concreteBackendRunFile := filepath.Join(tempDir, "concrete-backend", "run.sh")
// Should include both the meta backend name and concrete backend name
Expect(backends.Exists("meta-backend")).To(BeTrue())
Expect(exists).To(BeTrue())
Expect(backends.Exists("concrete-backend")).To(BeTrue())
// meta-backend should point to concrete-backend
Expect(backends.Exists("meta-backend")).To(BeTrue())
backend, ok := backends.Get("meta-backend")
Expect(ok).To(BeTrue())
Expect(backend.Metadata.MetaBackendFor).To(Equal("concrete-backend"))
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh")))
Expect(backend.IsMeta).To(BeTrue())
// meta-backend should be empty
Expect(metaBackend.IsMeta).To(BeTrue())
Expect(metaBackend.RunFile).To(Equal(concreteBackendRunFile))
// concrete-backend should point to its own run.sh
backend, ok = backends.Get("concrete-backend")
Expect(ok).To(BeTrue())
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh")))
concreteBackend, exists := backends.Get("concrete-backend")
Expect(exists).To(BeTrue())
Expect(concreteBackend.RunFile).To(Equal(concreteBackendRunFile))
})
})
@@ -459,11 +492,80 @@ var _ = Describe("Gallery Backends", func() {
URI: "test-uri",
}
err := InstallBackend(newPath, &backend, nil)
systemState, err := system.GetSystemState(
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
It("should overwrite existing backend", func() {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
Skip("Skipping test on darwin/arm64")
}
newPath := filepath.Join(tempDir, "test-backend")
// Create a dummy backend directory
err := os.MkdirAll(newPath, 0750)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(newPath, "metadata.json"), []byte("foo"), 0644)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(newPath, "run.sh"), []byte(""), 0644)
Expect(err).NotTo(HaveOccurred())
backend := GalleryBackend{
Metadata: Metadata{
Name: "test-backend",
},
URI: "quay.io/mudler/tests:localai-backend-test",
Alias: "test-alias",
}
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
Expect(err).ToNot(HaveOccurred())
Expect(string(dat)).ToNot(Equal("foo"))
})
It("should overwrite existing backend", func() {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
Skip("Skipping test on darwin/arm64")
}
newPath := filepath.Join(tempDir, "test-backend")
// Create a dummy backend directory
err := os.MkdirAll(newPath, 0750)
Expect(err).NotTo(HaveOccurred())
backend := GalleryBackend{
Metadata: Metadata{
Name: "test-backend",
},
URI: "quay.io/mudler/tests:localai-backend-test",
Alias: "test-alias",
}
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
It("should create alias file when specified", func() {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
Skip("Skipping test on darwin/arm64")
@@ -476,7 +578,11 @@ var _ = Describe("Gallery Backends", func() {
Alias: "test-alias",
}
err := InstallBackend(tempDir, &backend, nil)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
@@ -492,16 +598,14 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
// Check that the alias was recognized
backends, err := ListSystemBackends(tempDir)
backends, err := ListSystemBackends(systemState)
Expect(err).ToNot(HaveOccurred())
Expect(backends.Exists("test-alias")).To(BeTrue())
Expect(backends.Exists("test-backend")).To(BeTrue())
b, ok := backends.Get("test-alias")
Expect(ok).To(BeTrue())
Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
b, ok = backends.Get("test-backend")
Expect(ok).To(BeTrue())
Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
aliasBackend, exists := backends.Get("test-alias")
Expect(exists).To(BeTrue())
Expect(aliasBackend.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
testB, exists := backends.Get("test-backend")
Expect(exists).To(BeTrue())
Expect(testB.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
})
})
@@ -514,13 +618,26 @@ var _ = Describe("Gallery Backends", func() {
err := os.MkdirAll(backendPath, 0750)
Expect(err).NotTo(HaveOccurred())
err = DeleteBackendFromSystem(tempDir, backendName)
err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), []byte("{}"), 0644)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0644)
Expect(err).NotTo(HaveOccurred())
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = DeleteBackendFromSystem(systemState, backendName)
Expect(err).NotTo(HaveOccurred())
Expect(backendPath).NotTo(BeADirectory())
})
It("should not error when backend doesn't exist", func() {
err := DeleteBackendFromSystem(tempDir, "non-existent")
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = DeleteBackendFromSystem(systemState, "non-existent")
Expect(err).To(HaveOccurred())
})
})
@@ -538,14 +655,17 @@ var _ = Describe("Gallery Backends", func() {
Expect(err).NotTo(HaveOccurred())
}
backends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(backends.GetAll()).To(HaveLen(len(backendNames)))
backends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
Expect(backends).To(HaveLen(len(backendNames)))
for _, name := range backendNames {
Expect(backends.Exists(name)).To(BeTrue())
backend, ok := backends.Get(name)
Expect(ok).To(BeTrue())
backend, exists := backends.Get(name)
Expect(exists).To(BeTrue())
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, name, "run.sh")))
}
})
@@ -572,16 +692,23 @@ var _ = Describe("Gallery Backends", func() {
err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0755)
Expect(err).NotTo(HaveOccurred())
backends, err := ListSystemBackends(tempDir)
systemState, err := system.GetSystemState(
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
Expect(backends.Exists(alias)).To(BeTrue())
backend, ok := backends.Get(alias)
Expect(ok).To(BeTrue())
backends, err := ListSystemBackends(systemState)
Expect(err).NotTo(HaveOccurred())
backend, exists := backends.Get(alias)
Expect(exists).To(BeTrue())
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, backendName, "run.sh")))
})
It("should return error when base path doesn't exist", func() {
_, err := ListSystemBackends(filepath.Join(tempDir, "non-existent"))
systemState, err := system.GetSystemState(
system.WithBackendPath("foobardir"),
)
Expect(err).NotTo(HaveOccurred())
_, err = ListSystemBackends(systemState)
Expect(err).To(HaveOccurred())
})
})

View File

@@ -8,6 +8,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
@@ -89,7 +90,7 @@ func (gm GalleryElements[T]) Paginate(pageNum int, itemsNum int) GalleryElements
return gm[start:end]
}
func FindGalleryElement[T GalleryElement](models []T, name string, basePath string) T {
func FindGalleryElement[T GalleryElement](models []T, name string) T {
var model T
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
@@ -116,13 +117,13 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
// List available models
// Models galleries are a list of yaml files that are hosted on a remote server (for example github).
// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryModel], error) {
func AvailableGalleryModels(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) {
var models []*GalleryModel
// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath, func(model *GalleryModel) bool {
if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
return true
}
return false
@@ -137,13 +138,13 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler
}
// List available backends
func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryBackend], error) {
var models []*GalleryBackend
func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
var backends []*GalleryBackend
// Get models from galleries
// Get backends from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath, func(backend *GalleryBackend) bool {
backends, err := ListSystemBackends(basePath)
galleryBackends, err := getGalleryElements[*GalleryBackend](gallery, systemState.Backend.BackendsPath, func(backend *GalleryBackend) bool {
backends, err := ListSystemBackends(systemState)
if err != nil {
return false
}
@@ -152,10 +153,10 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElem
if err != nil {
return nil, err
}
models = append(models, galleryModels...)
backends = append(backends, galleryBackends...)
}
return models, nil
return backends, nil
}
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {

View File

@@ -72,7 +72,8 @@ type PromptTemplate struct {
// Installs a model from the gallery
func InstallModelFromGallery(
modelGalleries, backendGalleries []config.Gallery,
name string, basePath, backendBasePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
systemState *system.SystemState,
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
applyModel := func(model *GalleryModel) error {
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
@@ -81,7 +82,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, basePath)
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -122,19 +123,15 @@ func InstallModelFromGallery(
return err
}
installedModel, err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan)
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
log.Debug().Msgf("Installed model %q", installedModel.Name)
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
systemState, err := system.GetSystemState()
if err != nil {
return err
}
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -142,12 +139,12 @@ func InstallModelFromGallery(
return nil
}
models, err := AvailableGalleryModels(modelGalleries, basePath)
models, err := AvailableGalleryModels(modelGalleries, systemState)
if err != nil {
return err
}
model := FindGalleryElement(models, name, basePath)
model := FindGalleryElement(models, name)
if model == nil {
return fmt.Errorf("no model found with name %q", name)
}
@@ -155,7 +152,8 @@ func InstallModelFromGallery(
return applyModel(model)
}
func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.BackendConfig, error) {
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
if err != nil {
@@ -221,7 +219,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
return nil, err
}
backendConfig := lconfig.BackendConfig{}
modelConfig := lconfig.ModelConfig{}
// write config file
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
@@ -246,12 +244,12 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err)
}
err = yaml.Unmarshal(updatedConfigYAML, &backendConfig)
err = yaml.Unmarshal(updatedConfigYAML, &modelConfig)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
}
if !backendConfig.Validate() {
if !modelConfig.Validate() {
return nil, fmt.Errorf("failed to validate updated config YAML")
}
@@ -272,7 +270,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
log.Debug().Msgf("Written gallery file %s", modelFile)
return &backendConfig, os.WriteFile(modelFile, data, 0600)
return &modelConfig, os.WriteFile(modelFile, data, 0600)
}
func galleryFileName(name string) string {
@@ -285,21 +283,39 @@ func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, err
return ReadConfigFile[ModelConfig](galleryFile)
}
func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
additionalFiles := []string{}
configFile := filepath.Join(basePath, fmt.Sprintf("%s.yaml", name))
configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil {
return fmt.Errorf("failed to verify path %s: %w", configFile, err)
}
// Galleryname is the name of the model in this case
dat, err := os.ReadFile(configFile)
if err == nil {
modelConfig := &config.ModelConfig{}
galleryFile := filepath.Join(basePath, galleryFileName(name))
err = yaml.Unmarshal(dat, &modelConfig)
if err != nil {
return err
}
if modelConfig.Model != "" {
additionalFiles = append(additionalFiles, modelConfig.ModelFileName())
}
for _, f := range []string{configFile, galleryFile} {
if err := utils.VerifyPath(f, basePath); err != nil {
return fmt.Errorf("failed to verify path %s: %w", f, err)
if modelConfig.MMProj != "" {
additionalFiles = append(additionalFiles, modelConfig.MMProjFileName())
}
}
var err error
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name))
if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil {
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
}
// Delete all the files associated to the model
// read the model config
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
@@ -312,13 +328,19 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin
// Remove additional files
if galleryconfig != nil {
for _, f := range galleryconfig.Files {
fullPath := filepath.Join(basePath, f.Filename)
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
}
filesToRemove = append(filesToRemove, fullPath)
}
}
for _, f := range additionalFiles {
fullPath := filepath.Join(filepath.Join(basePath, f))
fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f))
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
}
filesToRemove = append(filesToRemove, fullPath)
}
@@ -340,8 +362,8 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin
// This is ***NEVER*** going to be perfect or finished.
// This is a BEST EFFORT function to surface known-vulnerable models to users.
func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error {
galleryModels, err := AvailableGalleryModels(galleries, basePath)
func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error {
galleryModels, err := AvailableGalleryModels(galleries, systemState)
if err != nil {
return err
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
@@ -29,7 +30,11 @@ var _ = Describe("Model test", func() {
defer os.RemoveAll(tempdir)
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
systemState, err := system.GetSystemState(
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -71,15 +76,19 @@ var _ = Describe("Model test", func() {
URL: "file://" + galleryFilePath,
},
}
systemState, err := system.GetSystemState(
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
models, err := AvailableGalleryModels(galleries, tempdir)
models, err := AvailableGalleryModels(galleries, systemState)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, []config.Gallery{}, "test@bert", tempdir, "", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -90,16 +99,16 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
models, err = AvailableGalleryModels(galleries, tempdir)
models, err = AvailableGalleryModels(galleries, systemState)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeTrue())
// delete
err = DeleteModelFromSystem(tempdir, "bert", []string{})
err = DeleteModelFromSystem(systemState, "bert")
Expect(err).ToNot(HaveOccurred())
models, err = AvailableGalleryModels(galleries, tempdir)
models, err = AvailableGalleryModels(galleries, systemState)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeFalse())
@@ -116,7 +125,11 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
systemState, err := system.GetSystemState(
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -132,7 +145,11 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
systemState, err := system.GetSystemState(
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -158,7 +175,11 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
systemState, err := system.GetSystemState(
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})

View File

@@ -198,7 +198,7 @@ func API(application *application.Application) (*fiber.App, error) {
}
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader(), application.ApplicationConfig().SystemState)
if err != nil {
return nil, err
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
@@ -320,12 +321,17 @@ var _ = Describe("API test", func() {
},
}
systemState, err := system.GetSystemState(
system.WithBackendPath(backendPath),
system.WithModelPath(modelDir),
)
Expect(err).ToNot(HaveOccurred())
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithSystemState(systemState),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithBackendsPath(backendPath),
config.WithApiKeys([]string{apiKey}),
)...)
Expect(err).ToNot(HaveOccurred())
@@ -523,13 +529,18 @@ var _ = Describe("API test", func() {
},
}
systemState, err := system.GetSystemState(
system.WithBackendPath(backendPath),
system.WithModelPath(modelDir),
)
Expect(err).ToNot(HaveOccurred())
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithGeneratedContentDir(tmpdir),
config.WithBackendsPath(backendPath),
config.WithSystemState(systemState),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
)...,
)
Expect(err).ToNot(HaveOccurred())
@@ -729,12 +740,17 @@ var _ = Describe("API test", func() {
var err error
systemState, err := system.GetSystemState(
system.WithBackendPath(backendPath),
system.WithModelPath(modelPath),
)
Expect(err).ToNot(HaveOccurred())
application, err := application.New(
append(commonOpts,
config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")),
config.WithContext(c),
config.WithBackendsPath(backendPath),
config.WithModelPath(modelPath),
config.WithSystemState(systemState),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = API(application)
@@ -960,11 +976,17 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background())
var err error
systemState, err := system.GetSystemState(
system.WithBackendPath(backendPath),
system.WithModelPath(modelPath),
)
Expect(err).ToNot(HaveOccurred())
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithModelPath(modelPath),
config.WithBackendsPath(backendPath),
config.WithSystemState(systemState),
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())

View File

@@ -15,7 +15,7 @@ import (
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
@@ -23,7 +23,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -17,7 +17,7 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
voiceID := c.Params("voice-id")
@@ -27,7 +27,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -17,7 +17,7 @@ import (
// @Param request body schema.JINARerankRequest true "query params"
// @Success 200 {object} schema.JINARerankResponse "Response"
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
@@ -25,7 +25,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -11,24 +11,27 @@ import (
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type BackendEndpointService struct {
galleries []config.Gallery
backendPath string
backendApplier *services.GalleryService
galleries []config.Gallery
backendPath string
backendSystemPath string
backendApplier *services.GalleryService
}
type GalleryBackend struct {
ID string `json:"id"`
}
func CreateBackendEndpointService(galleries []config.Gallery, backendPath string, backendApplier *services.GalleryService) BackendEndpointService {
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
return BackendEndpointService{
galleries: galleries,
backendPath: backendPath,
backendApplier: backendApplier,
galleries: galleries,
backendPath: systemState.Backend.BackendsPath,
backendSystemPath: systemState.Backend.BackendsSystemPath,
backendApplier: backendApplier,
}
}
@@ -111,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
// @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint() func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backends, err := gallery.ListSystemBackends(mgs.backendPath)
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
return err
}
@@ -141,9 +144,9 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
// @Summary List all available Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends/available [get]
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint() func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backends, err := gallery.AvailableBackends(mgs.galleries, mgs.backendPath)
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ import (
// @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post]
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
@@ -24,7 +24,7 @@ func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, ap
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
@@ -26,11 +27,11 @@ type GalleryModel struct {
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: modelPath,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
@@ -115,10 +116,10 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err

View File

@@ -21,7 +21,7 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenMetricsRequest)
@@ -37,7 +37,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig)
if err != nil {
log.Err(err)

View File

@@ -14,14 +14,14 @@ import (
// @Param request body schema.TokenizeRequest true "Request"
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -22,14 +22,14 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -16,14 +16,14 @@ import (
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}

View File

@@ -64,7 +64,7 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" {
@@ -72,7 +72,7 @@ func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config")
return fiber.ErrBadRequest

View File

@@ -11,12 +11,12 @@ import (
)
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{}
for _, m := range backendConfigs {
for _, m := range modelConfigs {
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil {
continue
@@ -34,7 +34,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
"ApplicationConfig": appConfig,
"ProcessingModels": processingModels,

View File

@@ -27,11 +27,11 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
var id, textContentToReturn string
var created int
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
@@ -66,7 +66,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
@@ -183,7 +183,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
@@ -501,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
}
func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
if len(funcResults) == 0 && result != "" {
log.Debug().Msgf("nothing function results but we had a message from the LLM")

View File

@@ -27,10 +27,10 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
created := int(time.Now().Unix())
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@@ -73,7 +73,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}

View File

@@ -23,7 +23,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post]
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
@@ -34,7 +34,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}

View File

@@ -21,14 +21,14 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}

View File

@@ -65,7 +65,7 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
@@ -73,7 +73,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest

View File

@@ -11,8 +11,8 @@ import (
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.BackendConfig,
bcl *config.BackendConfigLoader,
config *config.ModelConfig,
bcl *config.ModelConfigLoader,
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),

View File

@@ -12,7 +12,7 @@ import (
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
// If blank, no filter is applied.
filter := c.Query("filter")

View File

@@ -559,7 +559,7 @@ func sendNotImplemented(c *websocket.Conn, message string) {
sendError(c, "not_implemented", message, "", "event_TODO")
}
func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
sessionLock.Lock()
defer sessionLock.Unlock()
@@ -589,7 +589,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
}
// Function to update session configurations
func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
sessionLock.Lock()
defer sessionLock.Unlock()
@@ -628,7 +628,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Bac
// handleVAD is a goroutine that listens for audio data from the client,
// runs VAD on the audio data, and commits utterances to the conversation
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
vadContext, cancel := context.WithCancel(context.Background())
go func() {
<-done
@@ -742,7 +742,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
}
}
func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
if len(utt) == 0 {
return
}
@@ -853,7 +853,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
// TODO: Below needed for normal mode instead of transcription only
// Function to generate a response based on the conversation
// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
//
// log.Debug().Msg("Generating realtime response...")
//
@@ -1067,7 +1067,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
// }
// Function to process text response and detect function calls
func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) {
func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) {
// Placeholder implementation
// Replace this with actual model inference logic using session.Model and prompt

View File

@@ -22,14 +22,14 @@ var (
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
type wrappedModel struct {
TTSConfig *config.BackendConfig
TranscriptionConfig *config.BackendConfig
LLMConfig *config.BackendConfig
TTSConfig *config.ModelConfig
TranscriptionConfig *config.ModelConfig
LLMConfig *config.ModelConfig
TTSClient grpcClient.Backend
TranscriptionClient grpcClient.Backend
LLMClient grpcClient.Backend
VADConfig *config.BackendConfig
VADConfig *config.ModelConfig
VADClient grpcClient.Backend
}
@@ -37,17 +37,17 @@ type wrappedModel struct {
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
// In the future there could be models that accept continous audio input only so this design will be useful for that
type anyToAnyModel struct {
LLMConfig *config.BackendConfig
LLMConfig *config.ModelConfig
LLMClient grpcClient.Backend
VADConfig *config.BackendConfig
VADConfig *config.ModelConfig
VADClient grpcClient.Backend
}
type transcriptOnlyModel struct {
TranscriptionConfig *config.BackendConfig
TranscriptionConfig *config.ModelConfig
TranscriptionClient grpcClient.Backend
VADConfig *config.BackendConfig
VADConfig *config.ModelConfig
VADClient grpcClient.Backend
}
@@ -105,8 +105,8 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
return m.LLMClient.PredictStream(ctx, in, f)
}
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) {
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -122,7 +122,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
return nil, nil, fmt.Errorf("failed to load tts model: %w", err)
}
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -139,17 +139,17 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
}
return &transcriptOnlyModel{
VADConfig: cfgVAD,
VADClient: VADClient,
VADConfig: cfgVAD,
VADClient: VADClient,
TranscriptionConfig: cfgSST,
TranscriptionClient: transcriptionClient,
}, cfgSST, nil
}
// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -166,7 +166,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
}
// TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -185,7 +185,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
// TODO: Decide when we have a real any-to-any model
if false {
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -212,7 +212,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
log.Debug().Msg("Loading a wrapped model")
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -222,7 +222,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
return nil, fmt.Errorf("failed to validate config: %w", err)
}
cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath)
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
@@ -232,7 +232,6 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
return nil, fmt.Errorf("failed to validate config: %w", err)
}
opts = backend.ModelOptions(*cfgTTS, appConfig)
ttsClient, err := ml.Load(opts...)
if err != nil {

View File

@@ -24,14 +24,14 @@ import (
// @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}

View File

@@ -26,16 +26,16 @@ type correlationIDKeyType string
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
backendConfigLoader: backendConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
@@ -59,7 +59,7 @@ func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
@@ -81,7 +81,7 @@ func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModel
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
@@ -89,7 +89,7 @@ func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn con
return ctx.Next()
}
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
@@ -129,7 +129,7 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
}
}
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
@@ -152,7 +152,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
@@ -168,7 +168,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
@@ -184,7 +184,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
return ctx.Next()
}
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}

View File

@@ -11,7 +11,7 @@ import (
func RegisterElevenLabsRoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {

View File

@@ -12,7 +12,7 @@ import (
func RegisterJINARoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {

View File

@@ -14,7 +14,7 @@ import (
func RegisterLocalAIRoutes(router *fiber.App,
requestExtractor *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService) {
@@ -23,20 +23,23 @@ func RegisterLocalAIRoutes(router *fiber.App,
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.ModelPath, galleryService)
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
backendGalleryEndpointService := localai.CreateBackendEndpointService(appConfig.BackendGalleries, appConfig.BackendsPath, galleryService)
backendGalleryEndpointService := localai.CreateBackendEndpointService(
appConfig.BackendGalleries,
appConfig.SystemState,
galleryService)
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint())
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint())
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
}

View File

@@ -15,7 +15,7 @@ import (
)
func RegisterUIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService) {
@@ -65,9 +65,9 @@ func RegisterUIRoutes(app *fiber.App,
}
app.Get("/talk/", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(backendConfigs) == 0 {
if len(modelConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
}
@@ -75,8 +75,8 @@ func RegisterUIRoutes(app *fiber.App,
summary := fiber.Map{
"Title": "LocalAI - Talk",
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0],
"ModelsConfig": modelConfigs,
"Model": modelConfigs[0],
"Version": internal.PrintableVersion(),
}
@@ -86,17 +86,17 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/chat/", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
}
modelThatCanBeUsed := ""
galleryConfigs := map[string]*gallery.ModelConfig{}
for _, m := range backendConfigs {
for _, m := range modelConfigs {
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil {
continue
@@ -106,7 +106,7 @@ func RegisterUIRoutes(app *fiber.App,
title := "LocalAI - Chat"
for _, b := range backendConfigs {
for _, b := range modelConfigs {
if b.HasUsecases(config.FLAG_CHAT) {
modelThatCanBeUsed = b.Name
title = "LocalAI - Chat with " + modelThatCanBeUsed
@@ -119,7 +119,7 @@ func RegisterUIRoutes(app *fiber.App,
"BaseURL": utils.BaseURL(c),
"ModelsWithoutConfig": modelsWithoutConfig,
"GalleryConfig": galleryConfigs,
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"Model": modelThatCanBeUsed,
"Version": internal.PrintableVersion(),
}
@@ -130,12 +130,12 @@ func RegisterUIRoutes(app *fiber.App,
// Show the Chat page
app.Get("/chat/:model", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
galleryConfigs := map[string]*gallery.ModelConfig{}
for _, m := range backendConfigs {
for _, m := range modelConfigs {
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil {
continue
@@ -146,7 +146,7 @@ func RegisterUIRoutes(app *fiber.App,
summary := fiber.Map{
"Title": "LocalAI - Chat with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
@@ -158,13 +158,13 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Version": internal.PrintableVersion(),
@@ -175,10 +175,10 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/text2image/", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
}
@@ -186,7 +186,7 @@ func RegisterUIRoutes(app *fiber.App,
modelThatCanBeUsed := ""
title := "LocalAI - Generate images"
for _, b := range backendConfigs {
for _, b := range modelConfigs {
if b.HasUsecases(config.FLAG_IMAGE) {
modelThatCanBeUsed = b.Name
title = "LocalAI - Generate images with " + modelThatCanBeUsed
@@ -197,7 +197,7 @@ func RegisterUIRoutes(app *fiber.App,
summary := fiber.Map{
"Title": title,
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
"Version": internal.PrintableVersion(),
@@ -208,13 +208,13 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/tts/:model", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Version": internal.PrintableVersion(),
@@ -225,10 +225,10 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/tts/", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
}
@@ -236,7 +236,7 @@ func RegisterUIRoutes(app *fiber.App,
modelThatCanBeUsed := ""
title := "LocalAI - Generate audio"
for _, b := range backendConfigs {
for _, b := range modelConfigs {
if b.HasUsecases(config.FLAG_TTS) {
modelThatCanBeUsed = b.Name
title = "LocalAI - Generate audio with " + modelThatCanBeUsed
@@ -246,7 +246,7 @@ func RegisterUIRoutes(app *fiber.App,
summary := fiber.Map{
"Title": title,
"BaseURL": utils.BaseURL(c),
"ModelsConfig": backendConfigs,
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
"Version": internal.PrintableVersion(),

View File

@@ -28,7 +28,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
page := c.Query("page")
items := c.Query("items")
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list backends from galleries")
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
@@ -129,7 +129,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
}
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if page != "" {
// return a subset of the backends

View File

@@ -20,7 +20,7 @@ import (
"github.com/rs/zerolog/log"
)
func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
// Show the Models page (all models)
app.Get("/browse", func(c *fiber.Ctx) error {
@@ -28,7 +28,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
page := c.Query("page")
items := c.Query("items")
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
@@ -131,7 +131,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
}
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if page != "" {
// return a subset of the models
@@ -224,7 +224,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
}
go func() {
galleryService.ModelGalleryChannel <- op
cl.RemoveBackendConfig(galleryName)
cl.RemoveModelConfig(galleryName)
}()
return c.SendString(elements.StartModelProgressBar(uid, "0", "Deletion"))

View File

@@ -16,21 +16,21 @@ import (
)
type BackendMonitorService struct {
backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.ModelConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
return &BackendMonitorService{
modelLoader: modelLoader,
backendConfigLoader: configLoader,
options: appConfig,
modelLoader: modelLoader,
modelConfigLoader: configLoader,
options: appConfig,
}
}
func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bms.backendConfigLoader.GetBackendConfig(modelName)
config, exists := bms.modelConfigLoader.GetModelConfig(modelName)
var backendId string
if exists {
backendId = config.Model
@@ -47,7 +47,7 @@ func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string)
}
func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bms.backendConfigLoader.GetBackendConfig(model)
config, exists := bms.modelConfigLoader.GetModelConfig(model)
var backend string
if exists {
backend = config.Model

View File

@@ -20,21 +20,21 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend], s
var err error
if op.Delete {
err = gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName)
err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName)
g.modelLoader.DeleteExternalBackend(op.GalleryElementName)
} else {
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback, true)
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, progressCallback, true)
if err == nil {
err = gallery.RegisterBackends(g.appConfig.BackendsPath, g.modelLoader)
err = gallery.RegisterBackends(systemState, g.modelLoader)
}
}
if err != nil {
log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName)
if !op.Delete {
// If we didn't install the backend, we need to make sure we don't have a leftover directory
gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName)
gallery.DeleteBackendFromSystem(systemState, op.GalleryElementName)
}
return err
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type GalleryService struct {
@@ -52,7 +51,7 @@ func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
return g.statuses
}
func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) error {
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
// updates the status with an error
var updateError func(id string, e error)
if !g.appConfig.OpaqueErrors {
@@ -65,11 +64,6 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
}
}
systemState, err := system.GetSystemState()
if err != nil {
log.Error().Err(err).Msg("failed to get system state")
}
go func() {
for {
select {
@@ -82,7 +76,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
}
case op := <-g.ModelGalleryChannel:
err := g.modelHandler(&op, cl)
err := g.modelHandler(&op, cl, systemState)
if err != nil {
updateError(op.ID, err)
}

View File

@@ -14,7 +14,7 @@ const (
ALWAYS_INCLUDE
)
func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter config.BackendConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
func ListModels(bcl *config.ModelConfigLoader, ml *model.ModelLoader, filter config.ModelConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
var skipMap map[string]interface{} = map[string]interface{}{}
@@ -22,7 +22,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
// Start with known configurations
for _, c := range bcl.GetBackendConfigsByFilter(filter) {
for _, c := range bcl.GetModelConfigsByFilter(filter) {
// Is this better than looseFilePolicy <= SKIP_IF_CONFIGURED ? less performant but more readable?
if (looseFilePolicy == SKIP_IF_CONFIGURED) || (looseFilePolicy == LOOSE_ONLY) {
skipMap[c.Model] = nil
@@ -50,7 +50,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
return dataModels, nil
}
func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
func CheckIfModelExists(bcl *config.ModelConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
filter, err := config.BuildNameFilterFn(modelName)
if err != nil {
return false, err

View File

@@ -3,7 +3,6 @@ package services
import (
"encoding/json"
"os"
"path/filepath"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
@@ -12,7 +11,7 @@ import (
"gopkg.in/yaml.v2"
)
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.BackendConfigLoader) error {
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.ModelConfigLoader, systemState *system.SystemState) error {
utils.ResetDownloadTimers()
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
@@ -23,18 +22,18 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
err := processModelOperation(op, g.appConfig.ModelPath, g.appConfig.BackendsPath, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
err := processModelOperation(op, systemState, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
if err != nil {
return err
}
// Reload models
err = cl.LoadBackendConfigsFromPath(g.appConfig.ModelPath)
err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath)
if err != nil {
return err
}
err = cl.Preload(g.appConfig.ModelPath)
err = cl.Preload(systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -50,26 +49,21 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c
return nil
}
func installModelFromRemoteConfig(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery, backendBasePath string) error {
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, modelPath)
func installModelFromRemoteConfig(systemState *system.SystemState, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
installedModel, err := gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
if automaticallyInstallBackend && installedModel.Backend != "" {
systemState, err := system.GetSystemState()
if err != nil {
return err
}
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil {
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -82,22 +76,22 @@ type galleryModel struct {
ID string `json:"id"`
}
func processRequests(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
func processRequests(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = installModelFromRemoteConfig(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries, backendBasePath)
err = installModelFromRemoteConfig(systemState, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
} else {
err = gallery.InstallModelFromGallery(
galleries, backendGalleries, r.ID, modelPath, backendBasePath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
galleries, backendGalleries, systemState, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
}
}
return err
}
func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
func ApplyGalleryFromFile(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
@@ -108,58 +102,35 @@ func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automa
return err
}
return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
}
func ApplyGalleryFromString(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
func ApplyGalleryFromString(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
}
// processModelOperation handles the installation or deletion of a model
func processModelOperation(
op *GalleryOp[gallery.GalleryModel],
modelPath string,
backendBasePath string,
systemState *system.SystemState,
enforcePredownloadScans bool,
automaticallyInstallBackend bool,
progressCallback func(string, string, string, float64),
) error {
// delete a model
if op.Delete {
modelConfig := &config.BackendConfig{}
// Galleryname is the name of the model in this case
dat, err := os.ReadFile(filepath.Join(modelPath, op.GalleryElementName+".yaml"))
if err != nil {
return err
}
err = yaml.Unmarshal(dat, modelConfig)
if err != nil {
return err
}
files := []string{}
// Remove the model from the config
if modelConfig.Model != "" {
files = append(files, modelConfig.ModelFileName())
}
if modelConfig.MMProj != "" {
files = append(files, modelConfig.MMProjFileName())
}
return gallery.DeleteModelFromSystem(modelPath, op.GalleryElementName, files)
return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName)
}
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryElementName != "" {
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, op.GalleryElementName, modelPath, backendBasePath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
// } else if op.ConfigURL != "" {
// err := startup.InstallModels(op.Galleries, modelPath, enforcePredownloadScans, progressCallback, op.ConfigURL)
// if err != nil {
@@ -167,6 +138,6 @@ func processModelOperation(
// }
// return cl.Preload(modelPath)
} else {
return installModelFromRemoteConfig(modelPath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries, backendBasePath)
return installModelFromRemoteConfig(systemState, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
}
}

View File

@@ -12,11 +12,7 @@ import (
"github.com/rs/zerolog/log"
)
func InstallExternalBackends(galleries []config.Gallery, backendPath string, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
systemState, err := system.GetSystemState()
if err != nil {
return fmt.Errorf("failed to get system state: %w", err)
}
func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
uri := downloader.URI(backend)
switch {
case uri.LooksLikeDir():
@@ -24,7 +20,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
name = filepath.Base(backend)
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -38,7 +34,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
return fmt.Errorf("specifying a name is required for OCI images")
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -56,7 +52,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
name = strings.TrimSuffix(name, filepath.Ext(name))
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -69,7 +65,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
if name != "" || alias != "" {
return fmt.Errorf("specifying a name or alias is not supported for this backend")
}
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true)
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, downloadStatus, true)
if err != nil {
return fmt.Errorf("error installing backend %s: %w", backend, err)
}

View File

@@ -24,7 +24,7 @@ const (
// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, backendBasePath string, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
func InstallModels(galleries, backendGalleries []config.Gallery, systemState *system.SystemState, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
// create an error that groups all errors
var err error
@@ -36,7 +36,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
return e
}
var model config.BackendConfig
var model config.ModelConfig
if e := yaml.Unmarshal(modelYAML, &model); e != nil {
log.Error().Err(e).Str("filepath", modelPath).Msg("error unmarshalling model definition")
return e
@@ -47,12 +47,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
return nil
}
systemState, err := system.GetSystemState()
if err != nil {
return err
}
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, backendBasePath, downloadStatus, false); err != nil {
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, downloadStatus, false); err != nil {
log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend")
return err
}
@@ -77,8 +72,8 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
ociName = strings.ReplaceAll(ociName, ":", "__")
// check if file exists
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
if _, e := os.Stat(filepath.Join(systemState.Model.ModelsPath, ociName)); errors.Is(e, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, ociName)
e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
@@ -100,7 +95,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
continue
}
modelPath := filepath.Join(modelPath, fileName)
modelPath := filepath.Join(systemState.Model.ModelsPath, fileName)
if e := utils.VerifyPath(fileName, modelPath); e != nil {
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
@@ -138,7 +133,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
continue
}
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + YAML_EXTENSION
modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, md5Name) + YAML_EXTENSION
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil {
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
err = errors.Join(err, e)
@@ -152,7 +147,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
}
} else {
// Check if it's a model gallery, or print a warning
e, found := installModel(galleries, backendGalleries, url, modelPath, backendBasePath, downloadStatus, enforceScan, autoloadBackendGalleries)
e, found := installModel(galleries, backendGalleries, url, systemState, downloadStatus, enforceScan, autoloadBackendGalleries)
if e != nil && found {
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
err = errors.Join(err, e)
@@ -166,13 +161,13 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
return err
}
func installModel(galleries, backendGalleries []config.Gallery, modelName, modelPath, backendBasePath string, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, modelPath)
func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, systemState)
if err != nil {
return err, false
}
model := gallery.FindGalleryElement(models, modelName, modelPath)
model := gallery.FindGalleryElement(models, modelName)
if model == nil {
return err, false
}
@@ -182,7 +177,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName, model
}
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
err = gallery.InstallModelFromGallery(galleries, backendGalleries, modelName, modelPath, backendBasePath, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
if err != nil {
return err, true
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -21,7 +22,10 @@ var _ = Describe("Preload test", func() {
url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", "phi-2")
InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", true, true, nil, url)
systemState, err := system.GetSystemState(system.WithModelPath(tmpdir))
Expect(err).ToNot(HaveOccurred())
InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url)
resultFile := filepath.Join(tmpdir, fileName)
@@ -36,7 +40,10 @@ var _ = Describe("Preload test", func() {
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
err = InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", false, true, nil, url)
systemState, err := system.GetSystemState(system.WithModelPath(tmpdir))
Expect(err).ToNot(HaveOccurred())
err = InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url)
Expect(err).ToNot(HaveOccurred())
resultFile := filepath.Join(tmpdir, fileName)

View File

@@ -55,7 +55,7 @@ func NewEvaluator(modelPath string) *Evaluator {
}
}
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) {
template := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
@@ -135,7 +135,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string {
if config.TemplateConfig.JinjaTemplate {
var messageData []ChatMessageTemplateData

View File

@@ -53,7 +53,7 @@ Function response:
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
@@ -69,7 +69,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
},
"assistant": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
@@ -86,7 +86,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
"function_call": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
@@ -102,7 +102,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
},
"function_response": {
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
@@ -121,7 +121,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
@@ -137,7 +137,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
},
"assistant": {
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
@@ -153,7 +153,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
},
"function_call": {
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
@@ -175,7 +175,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
},
"function_response": {
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
@@ -194,7 +194,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"config": &config.BackendConfig{
"config": &config.ModelConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: toolCallJinja,
JinjaTemplate: true,
@@ -219,7 +219,7 @@ var _ = Describe("Templates", func() {
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
@@ -232,7 +232,7 @@ var _ = Describe("Templates", func() {
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
@@ -245,7 +245,7 @@ var _ = Describe("Templates", func() {
for key := range jinjaTest {
foo := jinjaTest[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}