mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat(backends): add system backend, refactor (#6059)
- Add a system backend path - Refactor and consolidate system information in system state - Use system state in all the components to figure out the system paths to used whenever needed - Refactor BackendConfig -> ModelConfig. This was otherway misleading as now we do have a backend configuration which is not the model config. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
253b7537dc
commit
089efe05fd
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
backendLoader *config.BackendConfigLoader
|
||||
backendLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
templatesEvaluator *templates.Evaluator
|
||||
@@ -15,14 +15,14 @@ type Application struct {
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
return &Application{
|
||||
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||
modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend),
|
||||
backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
|
||||
modelLoader: model.NewModelLoader(appConfig.SystemState, appConfig.SingleBackend),
|
||||
applicationConfig: appConfig,
|
||||
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||
templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) BackendLoader() *config.BackendConfigLoader {
|
||||
func (a *Application) BackendLoader() *config.ModelConfigLoader {
|
||||
return a.backendLoader
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
options := config.NewApplicationConfig(opts...)
|
||||
application := newApplication(options)
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
caps, err := xsysinfo.CPUCapabilities()
|
||||
if err == nil {
|
||||
@@ -35,10 +35,11 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
// Make sure directories exists
|
||||
if options.ModelPath == "" {
|
||||
return nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||
if options.SystemState.Model.ModelsPath == "" {
|
||||
return nil, fmt.Errorf("models path cannot be empty")
|
||||
}
|
||||
err = os.MkdirAll(options.ModelPath, 0750)
|
||||
|
||||
err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
}
|
||||
@@ -55,50 +56,50 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.BackendsPath, nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
|
||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||
|
||||
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||
if err := application.BackendLoader().LoadModelConfigsFromPath(options.SystemState.Model.ModelsPath, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config files")
|
||||
}
|
||||
|
||||
if err := gallery.RegisterBackends(options.BackendsPath, application.ModelLoader()); err != nil {
|
||||
if err := gallery.RegisterBackends(options.SystemState, application.ModelLoader()); err != nil {
|
||||
log.Error().Err(err).Msg("error registering external backends")
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
if err := application.BackendLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config file")
|
||||
}
|
||||
}
|
||||
|
||||
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
|
||||
if err := application.BackendLoader().Preload(options.SystemState.Model.ModelsPath); err != nil {
|
||||
log.Error().Err(err).Msg("error downloading models")
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
if err := services.ApplyGalleryFromString(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.BackendsPath, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
if err := services.ApplyGalleryFromFile(options.SystemState, options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
|
||||
for _, v := range application.BackendLoader().GetAllModelsConfigs() {
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
||||
}
|
||||
}
|
||||
@@ -131,7 +132,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
|
||||
if options.LoadToMemory != nil && !options.SingleBackend {
|
||||
for _, m := range options.LoadToMemory {
|
||||
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
|
||||
cfg, err := application.BackendLoader().LoadModelConfigFileByNameDefaultOptions(m, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ func Detection(
|
||||
sourceFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.DetectResponse, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
detectionModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
||||
switch model := inferenceModel.(type) {
|
||||
case grpc.Backend:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
embeds := []int32{}
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
@@ -27,12 +27,12 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
Mode: int32(mode),
|
||||
Step: int32(step),
|
||||
Seed: int32(seed),
|
||||
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
|
||||
CLIPSkip: int32(modelConfig.Diffusers.ClipSkip),
|
||||
PositivePrompt: positive_prompt,
|
||||
NegativePrompt: negative_prompt,
|
||||
Dst: dst,
|
||||
Src: src,
|
||||
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||
EnableParameters: modelConfig.Diffusers.EnableParameters,
|
||||
RefImages: refImages,
|
||||
})
|
||||
return err
|
||||
|
||||
@@ -35,7 +35,7 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, cl *config.BackendConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, c.Name, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
@@ -201,7 +201,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
|
||||
func Finetune(config config.BackendConfig, input, prediction string) string {
|
||||
func Finetune(config config.ModelConfig, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
var _ = Describe("LLM tests", func() {
|
||||
Context("Finetune LLM output", func() {
|
||||
var (
|
||||
testConfig config.BackendConfig
|
||||
testConfig config.ModelConfig
|
||||
input string
|
||||
prediction string
|
||||
result string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
testConfig = config.BackendConfig{
|
||||
testConfig = config.ModelConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
Echo: false,
|
||||
},
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
@@ -58,7 +58,7 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...
|
||||
return append(defOpts, opts...)
|
||||
}
|
||||
|
||||
func getSeed(c config.BackendConfig) int32 {
|
||||
func getSeed(c config.ModelConfig) int32 {
|
||||
var seed int32 = config.RAND_SEED
|
||||
|
||||
if c.Seed != nil {
|
||||
@@ -72,7 +72,7 @@ func getSeed(c config.BackendConfig) int32 {
|
||||
return seed
|
||||
}
|
||||
|
||||
func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||
func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
b = c.Batch
|
||||
@@ -195,7 +195,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
|
||||
func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath := ""
|
||||
if c.PromptCachePath != "" {
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -21,10 +21,10 @@ func SoundGeneration(
|
||||
sourceDivisor *int32,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (string, *proto.Result, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
soundGenModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
@@ -49,7 +49,7 @@ func SoundGeneration(
|
||||
|
||||
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
|
||||
Text: text,
|
||||
Model: backendConfig.Model,
|
||||
Model: modelConfig.Model,
|
||||
Dst: filePath,
|
||||
Sample: doSample,
|
||||
Duration: duration,
|
||||
|
||||
@@ -13,9 +13,9 @@ func TokenMetrics(
|
||||
modelFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
|
||||
modelConfig config.ModelConfig) (*proto.MetricsResponse, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile))
|
||||
model, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -7,19 +7,19 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
|
||||
func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
|
||||
|
||||
var inferenceModel grpc.Backend
|
||||
var err error
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err = loader.Load(opts...)
|
||||
if err != nil {
|
||||
return schema.TokenizeResponse{}, err
|
||||
}
|
||||
defer loader.Close()
|
||||
|
||||
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
predictOptions.Prompt = s
|
||||
|
||||
// tokenize the string
|
||||
|
||||
@@ -12,13 +12,13 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
|
||||
if backendConfig.Backend == "" {
|
||||
backendConfig.Backend = model.WhisperBackend
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
}
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
|
||||
transcriptionModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -34,7 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
Threads: uint32(*backendConfig.Threads),
|
||||
Threads: uint32(*modelConfig.Threads),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -19,9 +19,9 @@ func ModelTTS(
|
||||
language string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (string, *proto.Result, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
@@ -29,7 +29,7 @@ func ModelTTS(
|
||||
defer loader.Close()
|
||||
|
||||
if ttsModel == nil {
|
||||
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
|
||||
return "", nil, fmt.Errorf("could not load tts model %q", modelConfig.Model)
|
||||
}
|
||||
|
||||
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
|
||||
@@ -47,14 +47,14 @@ func ModelTTS(
|
||||
// Checking first that it exists and is not outside ModelPath
|
||||
// TODO: we should actually first check if the modelFile is looking like
|
||||
// a FS path
|
||||
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
|
||||
mp := filepath.Join(loader.ModelPath, modelConfig.Model)
|
||||
if _, err := os.Stat(mp); err == nil {
|
||||
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
|
||||
if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
modelPath = mp
|
||||
} else {
|
||||
modelPath = backendConfig.Model // skip this step if it fails?????
|
||||
modelPath = modelConfig.Model // skip this step if it fails?????
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
|
||||
@@ -13,8 +13,8 @@ func VAD(request *schema.VADRequest,
|
||||
ctx context.Context,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
opts...,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
@@ -14,8 +15,9 @@ import (
|
||||
)
|
||||
|
||||
type BackendsCMDFlags struct {
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
}
|
||||
|
||||
type BackendsList struct {
|
||||
@@ -48,7 +50,15 @@ func (bl *BackendsList) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
}
|
||||
|
||||
backends, err := gallery.AvailableBackends(galleries, bl.BackendsPath)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(bl.BackendsSystemPath),
|
||||
system.WithBackendPath(bl.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backends, err := gallery.AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -68,6 +78,14 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(bi.BackendsSystemPath),
|
||||
system.WithBackendPath(bi.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
progressBar := progressbar.NewOptions(
|
||||
1000,
|
||||
progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)),
|
||||
@@ -82,7 +100,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := startup.InstallExternalBackends(galleries, bi.BackendsPath, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,7 +112,15 @@ func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
|
||||
for _, backendName := range bu.BackendArgs {
|
||||
log.Info().Str("backend", backendName).Msg("uninstalling backend")
|
||||
|
||||
err := gallery.DeleteBackendFromSystem(bu.BackendsPath, backendName)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(bu.BackendsSystemPath),
|
||||
system.WithBackendPath(bu.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = gallery.DeleteBackendFromSystem(systemState, backendName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
@@ -45,7 +46,14 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
}
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(galleries, ml.ModelsPath)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(ml.ModelsPath),
|
||||
system.WithBackendPath(ml.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
models, err := gallery.AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -60,6 +68,15 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(mi.ModelsPath),
|
||||
system.WithBackendPath(mi.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
@@ -86,7 +103,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}
|
||||
//startup.InstallModels()
|
||||
models, err := gallery.AvailableGalleryModels(galleries, mi.ModelsPath)
|
||||
models, err := gallery.AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,7 +111,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
modelURI := downloader.URI(modelName)
|
||||
|
||||
if !modelURI.LooksLikeOCI() {
|
||||
model := gallery.FindGalleryElement(models, modelName, mi.ModelsPath)
|
||||
model := gallery.FindGalleryElement(models, modelName)
|
||||
if model == nil {
|
||||
log.Error().Str("model", modelName).Msg("model not found")
|
||||
return err
|
||||
@@ -108,7 +125,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
||||
}
|
||||
|
||||
err = startup.InstallModels(galleries, backendGalleries, mi.ModelsPath, mi.BackendsPath, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(galleries, backendGalleries, systemState, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -22,6 +23,7 @@ type RunCMD struct {
|
||||
|
||||
ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
|
||||
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
||||
@@ -77,12 +79,20 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
os.MkdirAll(r.BackendsPath, 0750)
|
||||
os.MkdirAll(r.ModelsPath, 0750)
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
system.WithModelPath(r.ModelsPath),
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts := []config.AppOption{
|
||||
config.WithConfigFile(r.ModelsConfigFile),
|
||||
config.WithJSONStringPreload(r.PreloadModels),
|
||||
config.WithYAMLConfigPreload(r.PreloadModelsConfig),
|
||||
config.WithModelPath(r.ModelsPath),
|
||||
config.WithBackendsPath(r.BackendsPath),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithContextSize(r.ContextSize),
|
||||
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
||||
config.WithGeneratedContentDir(r.GeneratedContentPath),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -56,6 +57,13 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
text := strings.Join(t.Text, " ")
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(t.ModelsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
externalBackends := make(map[string]string)
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range t.ExternalGRPCBackends {
|
||||
@@ -66,12 +74,12 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
opts := &config.ApplicationConfig{
|
||||
ModelPath: t.ModelsPath,
|
||||
SystemState: systemState,
|
||||
Context: context.Background(),
|
||||
GeneratedContentDir: outputDir,
|
||||
ExternalGRPCBackends: externalBackends,
|
||||
}
|
||||
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||
ml := model.NewModelLoader(systemState, opts.SingleBackend)
|
||||
|
||||
defer func() {
|
||||
err := ml.StopAllGRPC()
|
||||
@@ -80,7 +88,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
options := config.BackendConfig{}
|
||||
options := config.ModelConfig{}
|
||||
options.SetDefaults()
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -24,18 +25,24 @@ type TranscriptCMD struct {
|
||||
}
|
||||
|
||||
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(t.ModelsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts := &config.ApplicationConfig{
|
||||
ModelPath: t.ModelsPath,
|
||||
Context: context.Background(),
|
||||
SystemState: systemState,
|
||||
Context: context.Background(),
|
||||
}
|
||||
|
||||
cl := config.NewBackendConfigLoader(t.ModelsPath)
|
||||
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||
if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil {
|
||||
cl := config.NewModelConfigLoader(t.ModelsPath)
|
||||
ml := model.NewModelLoader(systemState, opts.SingleBackend)
|
||||
if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c, exists := cl.GetBackendConfig(t.Model)
|
||||
c, exists := cl.GetModelConfig(t.Model)
|
||||
if !exists {
|
||||
return errors.New("model not found")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -34,12 +35,20 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
text := strings.Join(t.Text, " ")
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(t.ModelsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts := &config.ApplicationConfig{
|
||||
ModelPath: t.ModelsPath,
|
||||
SystemState: systemState,
|
||||
Context: context.Background(),
|
||||
GeneratedContentDir: outputDir,
|
||||
}
|
||||
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||
|
||||
ml := model.NewModelLoader(systemState, opts.SingleBackend)
|
||||
|
||||
defer func() {
|
||||
err := ml.StopAllGRPC()
|
||||
@@ -48,7 +57,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
options := config.BackendConfig{}
|
||||
options := config.ModelConfig{}
|
||||
options.SetDefaults()
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
type UtilCMD struct {
|
||||
@@ -108,6 +109,14 @@ func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(hfscmd.ModelsPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msg("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!")
|
||||
if len(hfscmd.ToScan) == 0 {
|
||||
log.Info().Msg("Checking all installed models against galleries")
|
||||
@@ -116,7 +125,7 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
}
|
||||
|
||||
err := gallery.SafetyScanGalleryModels(galleries, hfscmd.ModelsPath)
|
||||
err := gallery.SafetyScanGalleryModels(galleries, systemState)
|
||||
if err == nil {
|
||||
log.Info().Msg("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.")
|
||||
} else {
|
||||
@@ -150,17 +159,17 @@ func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error {
|
||||
log.Error().Msg("ModelsPath is a required parameter")
|
||||
return fmt.Errorf("model path is a required parameter")
|
||||
}
|
||||
bcl := config.NewBackendConfigLoader(uhcmd.ModelsPath)
|
||||
err := bcl.LoadBackendConfig(uhcmd.ConfigName)
|
||||
bcl := config.NewModelConfigLoader(uhcmd.ModelsPath)
|
||||
err := bcl.ReadModelConfig(uhcmd.ConfigName)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("ConfigName", uhcmd.ConfigName).Msg("error while loading backend")
|
||||
return err
|
||||
}
|
||||
bc, exists := bcl.GetBackendConfig(uhcmd.ConfigName)
|
||||
bc, exists := bcl.GetModelConfig(uhcmd.ConfigName)
|
||||
if !exists {
|
||||
log.Error().Str("ConfigName", uhcmd.ConfigName).Msg("ConfigName not found")
|
||||
}
|
||||
for name, uc := range config.GetAllBackendConfigUsecases() {
|
||||
for name, uc := range config.GetAllModelConfigUsecases() {
|
||||
if bc.HasUsecases(uc) {
|
||||
log.Info().Str("Usecase", name)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package worker
|
||||
|
||||
type WorkerFlags struct {
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -21,20 +22,19 @@ const (
|
||||
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
|
||||
)
|
||||
|
||||
func findLLamaCPPBackend(backendSystemPath string) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(backendSystemPath)
|
||||
func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed listing system backends: %s", err)
|
||||
return "", err
|
||||
}
|
||||
log.Debug().Msgf("System backends: %v", backends)
|
||||
|
||||
backendPath := ""
|
||||
backend, ok := backends.Get("llama-cpp")
|
||||
if !ok {
|
||||
return "", errors.New("llama-cpp backend not found, install it first")
|
||||
}
|
||||
backendPath = filepath.Dir(backend.RunFile)
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
|
||||
if backendPath == "" {
|
||||
return "", errors.New("llama-cpp backend not found, install it first")
|
||||
@@ -54,7 +54,14 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
||||
return fmt.Errorf("usage: local-ai worker llama-cpp-rpc -- <llama-rpc-server-args>")
|
||||
}
|
||||
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendsPath)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
grpcProcess, err := findLLamaCPPBackend(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -25,6 +26,14 @@ type P2P struct {
|
||||
|
||||
func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the token is set
|
||||
// as we always need it.
|
||||
if r.Token == "" {
|
||||
@@ -60,7 +69,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
for {
|
||||
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
|
||||
|
||||
grpcProcess, err := findLLamaCPPBackend(r.BackendsPath)
|
||||
grpcProcess, err := findLLamaCPPBackend(systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -13,8 +14,7 @@ import (
|
||||
type ApplicationConfig struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
ModelPath string
|
||||
BackendsPath string
|
||||
SystemState *system.SystemState
|
||||
ExternalBackends []string
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
@@ -86,15 +86,9 @@ func WithModelsURL(urls ...string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithModelPath(path string) AppOption {
|
||||
func WithSystemState(state *system.SystemState) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ModelPath = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendsPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.BackendsPath = path
|
||||
o.SystemState = state
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,7 +373,7 @@ func (o *ApplicationConfig) ToConfigLoaderOptions() []ConfigLoaderOption {
|
||||
LoadOptionDebug(o.Debug),
|
||||
LoadOptionF16(o.F16),
|
||||
LoadOptionThreads(o.Threads),
|
||||
ModelPath(o.ModelPath),
|
||||
ModelPath(o.SystemState.Model.ModelsPath),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,20 +24,20 @@ type TTSConfig struct {
|
||||
AudioPath string `yaml:"audio_path"`
|
||||
}
|
||||
|
||||
type BackendConfig struct {
|
||||
type ModelConfig struct {
|
||||
schema.PredictionOptions `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
|
||||
F16 *bool `yaml:"f16"`
|
||||
Threads *int `yaml:"threads"`
|
||||
Debug *bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases"`
|
||||
KnownUsecases *BackendConfigUsecases `yaml:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline"`
|
||||
F16 *bool `yaml:"f16"`
|
||||
Threads *int `yaml:"threads"`
|
||||
Debug *bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases"`
|
||||
KnownUsecases *ModelConfigUsecases `yaml:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
@@ -217,18 +217,18 @@ type TemplateConfig struct {
|
||||
ReplyPrefix string `yaml:"reply_prefix"`
|
||||
}
|
||||
|
||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
type BCAlias BackendConfig
|
||||
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
type BCAlias ModelConfig
|
||||
var aux BCAlias
|
||||
if err := value.Decode(&aux); err != nil {
|
||||
return err
|
||||
}
|
||||
*c = BackendConfig(aux)
|
||||
*c = ModelConfig(aux)
|
||||
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
// Make sure the usecases are valid, we rewrite with what we identified
|
||||
c.KnownUsecaseStrings = []string{}
|
||||
for k, usecase := range GetAllBackendConfigUsecases() {
|
||||
for k, usecase := range GetAllModelConfigUsecases() {
|
||||
if c.HasUsecases(usecase) {
|
||||
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
|
||||
}
|
||||
@@ -236,25 +236,25 @@ func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackendConfig) SetFunctionCallString(s string) {
|
||||
func (c *ModelConfig) SetFunctionCallString(s string) {
|
||||
c.functionCallString = s
|
||||
}
|
||||
|
||||
func (c *BackendConfig) SetFunctionCallNameString(s string) {
|
||||
func (c *ModelConfig) SetFunctionCallNameString(s string) {
|
||||
c.functionCallNameString = s
|
||||
}
|
||||
|
||||
func (c *BackendConfig) ShouldUseFunctions() bool {
|
||||
func (c *ModelConfig) ShouldUseFunctions() bool {
|
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||
}
|
||||
|
||||
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
|
||||
func (c *ModelConfig) ShouldCallSpecificFunction() bool {
|
||||
return len(c.functionCallNameString) > 0
|
||||
}
|
||||
|
||||
// MMProjFileName returns the filename of the MMProj file
|
||||
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *BackendConfig) MMProjFileName() string {
|
||||
func (c *ModelConfig) MMProjFileName() string {
|
||||
uri := downloader.URI(c.MMProj)
|
||||
if uri.LooksLikeURL() {
|
||||
f, _ := uri.FilenameFromUrl()
|
||||
@@ -264,19 +264,19 @@ func (c *BackendConfig) MMProjFileName() string {
|
||||
return c.MMProj
|
||||
}
|
||||
|
||||
func (c *BackendConfig) IsMMProjURL() bool {
|
||||
func (c *ModelConfig) IsMMProjURL() bool {
|
||||
uri := downloader.URI(c.MMProj)
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
func (c *BackendConfig) IsModelURL() bool {
|
||||
func (c *ModelConfig) IsModelURL() bool {
|
||||
uri := downloader.URI(c.Model)
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
// ModelFileName returns the filename of the model
|
||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *BackendConfig) ModelFileName() string {
|
||||
func (c *ModelConfig) ModelFileName() string {
|
||||
uri := downloader.URI(c.Model)
|
||||
if uri.LooksLikeURL() {
|
||||
f, _ := uri.FilenameFromUrl()
|
||||
@@ -286,7 +286,7 @@ func (c *BackendConfig) ModelFileName() string {
|
||||
return c.Model
|
||||
}
|
||||
|
||||
func (c *BackendConfig) FunctionToCall() string {
|
||||
func (c *ModelConfig) FunctionToCall() string {
|
||||
if c.functionCallNameString != "" &&
|
||||
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
|
||||
return c.functionCallNameString
|
||||
@@ -295,7 +295,7 @@ func (c *BackendConfig) FunctionToCall() string {
|
||||
return c.functionCallString
|
||||
}
|
||||
|
||||
func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
lo := &LoadOptions{}
|
||||
lo.Apply(opts...)
|
||||
|
||||
@@ -411,7 +411,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
guessDefaultsFromFile(cfg, lo.modelPath, ctx)
|
||||
}
|
||||
|
||||
func (c *BackendConfig) Validate() bool {
|
||||
func (c *ModelConfig) Validate() bool {
|
||||
downloadedFileNames := []string{}
|
||||
for _, f := range c.DownloadFiles {
|
||||
downloadedFileNames = append(downloadedFileNames, f.Filename)
|
||||
@@ -438,34 +438,34 @@ func (c *BackendConfig) Validate() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *BackendConfig) HasTemplate() bool {
|
||||
func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
|
||||
}
|
||||
|
||||
type BackendConfigUsecases int
|
||||
type ModelConfigUsecases int
|
||||
|
||||
const (
|
||||
FLAG_ANY BackendConfigUsecases = 0b000000000000
|
||||
FLAG_CHAT BackendConfigUsecases = 0b000000000001
|
||||
FLAG_COMPLETION BackendConfigUsecases = 0b000000000010
|
||||
FLAG_EDIT BackendConfigUsecases = 0b000000000100
|
||||
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000
|
||||
FLAG_RERANK BackendConfigUsecases = 0b000000010000
|
||||
FLAG_IMAGE BackendConfigUsecases = 0b000000100000
|
||||
FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000
|
||||
FLAG_TTS BackendConfigUsecases = 0b000010000000
|
||||
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000
|
||||
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
||||
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
||||
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
||||
FLAG_DETECTION BackendConfigUsecases = 0b1000000000000
|
||||
FLAG_ANY ModelConfigUsecases = 0b000000000000
|
||||
FLAG_CHAT ModelConfigUsecases = 0b000000000001
|
||||
FLAG_COMPLETION ModelConfigUsecases = 0b000000000010
|
||||
FLAG_EDIT ModelConfigUsecases = 0b000000000100
|
||||
FLAG_EMBEDDINGS ModelConfigUsecases = 0b000000001000
|
||||
FLAG_RERANK ModelConfigUsecases = 0b000000010000
|
||||
FLAG_IMAGE ModelConfigUsecases = 0b000000100000
|
||||
FLAG_TRANSCRIPT ModelConfigUsecases = 0b000001000000
|
||||
FLAG_TTS ModelConfigUsecases = 0b000010000000
|
||||
FLAG_SOUND_GENERATION ModelConfigUsecases = 0b000100000000
|
||||
FLAG_TOKENIZE ModelConfigUsecases = 0b001000000000
|
||||
FLAG_VAD ModelConfigUsecases = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecases = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecases = 0b1000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
FLAG_LLM ModelConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
)
|
||||
|
||||
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
||||
return map[string]BackendConfigUsecases{
|
||||
func GetAllModelConfigUsecases() map[string]ModelConfigUsecases {
|
||||
return map[string]ModelConfigUsecases{
|
||||
"FLAG_ANY": FLAG_ANY,
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
@@ -488,12 +488,12 @@ func stringToFlag(s string) string {
|
||||
return "FLAG_" + strings.ToUpper(s)
|
||||
}
|
||||
|
||||
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
|
||||
func GetUsecasesFromYAML(input []string) *ModelConfigUsecases {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := FLAG_ANY
|
||||
flags := GetAllBackendConfigUsecases()
|
||||
flags := GetAllModelConfigUsecases()
|
||||
for _, str := range input {
|
||||
flag, exists := flags[stringToFlag(str)]
|
||||
if exists {
|
||||
@@ -503,8 +503,8 @@ func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
|
||||
return &result
|
||||
}
|
||||
|
||||
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
|
||||
func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
|
||||
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
|
||||
func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
|
||||
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
|
||||
return true
|
||||
}
|
||||
@@ -514,7 +514,7 @@ func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
|
||||
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
|
||||
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
|
||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||
func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
||||
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
|
||||
return false
|
||||
|
||||
@@ -2,11 +2,11 @@ package config
|
||||
|
||||
import "regexp"
|
||||
|
||||
type BackendConfigFilterFn func(string, *BackendConfig) bool
|
||||
type ModelConfigFilterFn func(string, *ModelConfig) bool
|
||||
|
||||
func NoFilterFn(_ string, _ *BackendConfig) bool { return true }
|
||||
func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
|
||||
|
||||
func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
|
||||
func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
|
||||
if filter == "" {
|
||||
return NoFilterFn, nil
|
||||
}
|
||||
@@ -14,7 +14,7 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(name string, config *BackendConfig) bool {
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config != nil {
|
||||
return rxp.MatchString(config.Name)
|
||||
}
|
||||
@@ -22,11 +22,11 @@ func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn {
|
||||
func BuildUsecaseFilterFn(usecases ModelConfigUsecases) ModelConfigFilterFn {
|
||||
if usecases == FLAG_ANY {
|
||||
return NoFilterFn
|
||||
}
|
||||
return func(name string, config *BackendConfig) bool {
|
||||
return func(name string, config *ModelConfig) bool {
|
||||
if config == nil {
|
||||
return false // TODO: Potentially make this a param, for now, no known usecase to include
|
||||
}
|
||||
|
||||
@@ -18,15 +18,15 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type BackendConfigLoader struct {
|
||||
configs map[string]BackendConfig
|
||||
type ModelConfigLoader struct {
|
||||
configs map[string]ModelConfig
|
||||
modelPath string
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewBackendConfigLoader(modelPath string) *BackendConfigLoader {
|
||||
return &BackendConfigLoader{
|
||||
configs: make(map[string]BackendConfig),
|
||||
func NewModelConfigLoader(modelPath string) *ModelConfigLoader {
|
||||
return &ModelConfigLoader{
|
||||
configs: make(map[string]ModelConfig),
|
||||
modelPath: modelPath,
|
||||
}
|
||||
}
|
||||
@@ -77,14 +77,14 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
|
||||
}
|
||||
|
||||
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
|
||||
func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
|
||||
c := &[]*BackendConfig{}
|
||||
func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) {
|
||||
c := &[]*ModelConfig{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot read config file %q: %w", file, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
for _, cc := range *c {
|
||||
@@ -94,17 +94,17 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
|
||||
lo := &LoadOptions{}
|
||||
lo.Apply(opts...)
|
||||
|
||||
c := &BackendConfig{}
|
||||
c := &ModelConfig{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readModelConfigFromFile cannot read config file %q: %w", file, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readModelConfigFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
c.SetDefaults(opts...)
|
||||
@@ -112,10 +112,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
|
||||
}
|
||||
|
||||
// Load a config file for a model
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
func (bcl *ModelConfigLoader) LoadModelConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
|
||||
|
||||
// Load a config file if present after the model name
|
||||
cfg := &BackendConfig{
|
||||
cfg := &ModelConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: modelName,
|
||||
@@ -123,19 +123,19 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
|
||||
},
|
||||
}
|
||||
|
||||
cfgExisting, exists := bcl.GetBackendConfig(modelName)
|
||||
cfgExisting, exists := bcl.GetModelConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
// Try loading a model config file
|
||||
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := bcl.LoadBackendConfig(
|
||||
if err := bcl.ReadModelConfig(
|
||||
modelConfig, opts...,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = bcl.GetBackendConfig(modelName)
|
||||
cfgExisting, exists = bcl.GetModelConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
@@ -147,20 +147,20 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
|
||||
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
|
||||
func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*ModelConfig, error) {
|
||||
return bcl.LoadModelConfigFileByName(modelName, appConfig.SystemState.Model.ModelsPath,
|
||||
LoadOptionDebug(appConfig.Debug),
|
||||
LoadOptionThreads(appConfig.Threads),
|
||||
LoadOptionContextSize(appConfig.ContextSize),
|
||||
LoadOptionF16(appConfig.F16),
|
||||
ModelPath(appConfig.ModelPath))
|
||||
ModelPath(appConfig.SystemState.Model.ModelsPath))
|
||||
}
|
||||
|
||||
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
|
||||
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
|
||||
func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readMultipleBackendConfigsFromFile(file, opts...)
|
||||
c, err := readMultipleModelConfigsFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
@@ -173,12 +173,12 @@ func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
|
||||
func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readBackendConfigFromFile(file, opts...)
|
||||
c, err := readModelConfigFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
|
||||
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
if c.Validate() {
|
||||
@@ -190,17 +190,17 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
|
||||
func (bcl *ModelConfigLoader) GetModelConfig(m string) (ModelConfig, bool) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
v, exists := bcl.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||
func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
var res []BackendConfig
|
||||
var res []ModelConfig
|
||||
for _, v := range bcl.configs {
|
||||
res = append(res, v)
|
||||
}
|
||||
@@ -212,10 +212,10 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||
return res
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig {
|
||||
func (bcl *ModelConfigLoader) GetModelConfigsByFilter(filter ModelConfigFilterFn) []ModelConfig {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
var res []BackendConfig
|
||||
var res []ModelConfig
|
||||
|
||||
if filter == nil {
|
||||
filter = NoFilterFn
|
||||
@@ -232,14 +232,14 @@ func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFi
|
||||
return res
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
|
||||
func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
delete(bcl.configs, m)
|
||||
}
|
||||
|
||||
// Preload prepare models if they are not local but url or huggingface repositories
|
||||
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
func (bcl *ModelConfigLoader) Preload(modelPath string) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
|
||||
@@ -330,15 +330,15 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
|
||||
// LoadModelConfigsFromPath reads all the configurations of the models from a path
|
||||
// (non-recursive)
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
|
||||
return fmt.Errorf("LoadModelConfigsFromPath cannot read directory '%s': %w", path, err)
|
||||
}
|
||||
files := make([]fs.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -354,9 +354,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
|
||||
strings.HasPrefix(file.Name(), ".") {
|
||||
continue
|
||||
}
|
||||
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
|
||||
c, err := readModelConfigFromFile(filepath.Join(path, file.Name()), opts...)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
|
||||
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file")
|
||||
continue
|
||||
}
|
||||
if c.Validate() {
|
||||
|
||||
@@ -25,7 +25,7 @@ known_usecases:
|
||||
- COMPLETION
|
||||
`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readBackendConfigFromFile(tmp.Name())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
Expect(config.Validate()).To(BeFalse())
|
||||
@@ -41,7 +41,7 @@ backend: "foo-bar"
|
||||
parameters:
|
||||
model: "foo-bar"`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readBackendConfigFromFile(tmp.Name())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@@ -58,7 +58,7 @@ parameters:
|
||||
defer os.Remove(tmp.Name())
|
||||
_, err = io.Copy(tmp, resp.Body)
|
||||
Expect(err).To(BeNil())
|
||||
config, err = readBackendConfigFromFile(tmp.Name())
|
||||
config, err = readModelConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@@ -68,12 +68,12 @@ parameters:
|
||||
})
|
||||
It("Properly handles backend usecase matching", func() {
|
||||
|
||||
a := BackendConfig{
|
||||
a := ModelConfig{
|
||||
Name: "a",
|
||||
}
|
||||
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
|
||||
|
||||
b := BackendConfig{
|
||||
b := ModelConfig{
|
||||
Name: "b",
|
||||
Backend: "stablediffusion",
|
||||
}
|
||||
@@ -81,7 +81,7 @@ parameters:
|
||||
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
|
||||
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
|
||||
c := BackendConfig{
|
||||
c := ModelConfig{
|
||||
Name: "c",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
@@ -93,7 +93,7 @@ parameters:
|
||||
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
||||
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
d := BackendConfig{
|
||||
d := ModelConfig{
|
||||
Name: "d",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
@@ -107,7 +107,7 @@ parameters:
|
||||
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
trueValue := true
|
||||
e := BackendConfig{
|
||||
e := ModelConfig{
|
||||
Name: "e",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
@@ -122,7 +122,7 @@ parameters:
|
||||
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
|
||||
|
||||
f := BackendConfig{
|
||||
f := ModelConfig{
|
||||
Name: "f",
|
||||
Backend: "piper",
|
||||
}
|
||||
@@ -130,7 +130,7 @@ parameters:
|
||||
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
|
||||
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
|
||||
g := BackendConfig{
|
||||
g := ModelConfig{
|
||||
Name: "g",
|
||||
Backend: "whisper",
|
||||
}
|
||||
@@ -138,7 +138,7 @@ parameters:
|
||||
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
|
||||
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
|
||||
|
||||
h := BackendConfig{
|
||||
h := ModelConfig{
|
||||
Name: "h",
|
||||
Backend: "transformers-musicgen",
|
||||
}
|
||||
@@ -148,7 +148,7 @@ parameters:
|
||||
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
|
||||
|
||||
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
|
||||
i := BackendConfig{
|
||||
i := ModelConfig{
|
||||
Name: "i",
|
||||
Backend: "whisper",
|
||||
// Earlier test checks parsing, this just needs to set final values
|
||||
|
||||
@@ -16,7 +16,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("Test Read configuration functions", func() {
|
||||
configFile = os.Getenv("CONFIG_FILE")
|
||||
It("Test readConfigFile", func() {
|
||||
config, err := readMultipleBackendConfigsFromFile(configFile)
|
||||
config, err := readMultipleModelConfigsFromFile(configFile)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@@ -26,11 +26,11 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
|
||||
It("Test LoadConfigs", func() {
|
||||
|
||||
bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH"))
|
||||
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
|
||||
err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
configs := bcl.GetAllBackendConfigs()
|
||||
configs := bcl.GetAllModelsConfigs()
|
||||
loadedModelNames := []string{}
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
@@ -51,10 +51,10 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
|
||||
It("Test new loadconfig", func() {
|
||||
|
||||
bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH"))
|
||||
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
|
||||
err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
Expect(err).To(BeNil())
|
||||
configs := bcl.GetAllBackendConfigs()
|
||||
configs := bcl.GetAllModelsConfigs()
|
||||
loadedModelNames := []string{}
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
@@ -90,14 +90,14 @@ options:
|
||||
err = os.WriteFile(modelFile, []byte(model), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = bcl.LoadBackendConfigsFromPath(tmpdir)
|
||||
err = bcl.LoadModelConfigsFromPath(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
configs = bcl.GetAllBackendConfigs()
|
||||
configs = bcl.GetAllModelsConfigs()
|
||||
Expect(len(configs)).ToNot(Equal(totalModels))
|
||||
|
||||
loadedModelNames = []string{}
|
||||
var testModel BackendConfig
|
||||
var testModel ModelConfig
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
if v.Name == "test-model" {
|
||||
|
||||
@@ -146,7 +146,7 @@ var knownTemplates = map[string]familyType{
|
||||
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
|
||||
}
|
||||
|
||||
func guessGGUFFromFile(cfg *BackendConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
|
||||
if defaultCtx == 0 && cfg.ContextSize == nil {
|
||||
ctxSize := f.EstimateLLaMACppRun().ContextSize
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func guessDefaultsFromFile(cfg *BackendConfig, modelPath string, defaultCtx int) {
|
||||
func guessDefaultsFromFile(cfg *ModelConfig, modelPath string, defaultCtx int) {
|
||||
if os.Getenv("LOCALAI_DISABLE_GUESSING") == "true" {
|
||||
log.Debug().Msgf("guessDefaultsFromFile: %s", "guessing disabled with LOCALAI_DISABLE_GUESSING")
|
||||
return
|
||||
|
||||
@@ -59,10 +59,10 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
||||
}
|
||||
|
||||
// Installs a model from the gallery
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, basePath string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
if !force {
|
||||
// check if we already have the backend installed
|
||||
backends, err := ListSystemBackends(basePath)
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -77,12 +77,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
|
||||
log.Debug().Interface("galleries", galleries).Str("name", name).Msg("Installing backend from gallery")
|
||||
|
||||
backends, err := AvailableBackends(galleries, basePath)
|
||||
backends, err := AvailableBackends(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend := FindGalleryElement(backends, name, basePath)
|
||||
backend := FindGalleryElement(backends, name)
|
||||
if backend == nil {
|
||||
return fmt.Errorf("no backend found with name %q", name)
|
||||
}
|
||||
@@ -99,12 +99,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
|
||||
|
||||
// Then, let's install the best backend
|
||||
if err := InstallBackend(basePath, bestBackend, downloadStatus); err != nil {
|
||||
if err := InstallBackend(systemState, bestBackend, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we need now to create a path for the meta backend, with the alias to the installed ones so it can be used to remove it
|
||||
metaBackendPath := filepath.Join(basePath, name)
|
||||
metaBackendPath := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
if err := os.MkdirAll(metaBackendPath, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create meta backend path %q: %v", metaBackendPath, err)
|
||||
}
|
||||
@@ -124,12 +124,12 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
return nil
|
||||
}
|
||||
|
||||
return InstallBackend(basePath, backend, downloadStatus)
|
||||
return InstallBackend(systemState, backend, downloadStatus)
|
||||
}
|
||||
|
||||
func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallBackend(systemState *system.SystemState, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create base path: %v", err)
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
|
||||
}
|
||||
|
||||
name := config.Name
|
||||
backendPath := filepath.Join(basePath, name)
|
||||
backendPath := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
err = os.MkdirAll(backendPath, 0750)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create base path: %v", err)
|
||||
@@ -188,14 +188,28 @@ func InstallBackend(basePath string, config *GalleryBackend, downloadStatus func
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteBackendFromSystem(basePath string, name string) error {
|
||||
backendDirectory := filepath.Join(basePath, name)
|
||||
func DeleteBackendFromSystem(systemState *system.SystemState, name string) error {
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("backend %q not found", name)
|
||||
}
|
||||
|
||||
if backend.IsSystem {
|
||||
return fmt.Errorf("system backend %q cannot be deleted", name)
|
||||
}
|
||||
|
||||
backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
|
||||
// check if the backend dir exists
|
||||
if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
|
||||
// if doesn't exist, it might be an alias, so we need to check if we have a matching alias in
|
||||
// all the backends in the basePath
|
||||
backends, err := os.ReadDir(basePath)
|
||||
backends, err := os.ReadDir(systemState.Backend.BackendsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -203,12 +217,12 @@ func DeleteBackendFromSystem(basePath string, name string) error {
|
||||
|
||||
for _, backend := range backends {
|
||||
if backend.IsDir() {
|
||||
metadata, err := readBackendMetadata(filepath.Join(basePath, backend.Name()))
|
||||
metadata, err := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, backend.Name()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if metadata != nil && metadata.Alias == name {
|
||||
backendDirectory = filepath.Join(basePath, backend.Name())
|
||||
backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
|
||||
foundBackend = true
|
||||
break
|
||||
}
|
||||
@@ -228,7 +242,7 @@ func DeleteBackendFromSystem(basePath string, name string) error {
|
||||
}
|
||||
|
||||
if metadata != nil && metadata.MetaBackendFor != "" {
|
||||
metaBackendDirectory := filepath.Join(basePath, metadata.MetaBackendFor)
|
||||
metaBackendDirectory := filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor)
|
||||
log.Debug().Str("backendDirectory", metaBackendDirectory).Msg("Deleting meta backend")
|
||||
if _, err := os.Stat(metaBackendDirectory); os.IsNotExist(err) {
|
||||
return fmt.Errorf("meta backend %q not found", metadata.MetaBackendFor)
|
||||
@@ -243,6 +257,7 @@ type SystemBackend struct {
|
||||
Name string
|
||||
RunFile string
|
||||
IsMeta bool
|
||||
IsSystem bool
|
||||
Metadata *BackendMetadata
|
||||
}
|
||||
|
||||
@@ -266,30 +281,51 @@ func (b SystemBackends) GetAll() []SystemBackend {
|
||||
return backends
|
||||
}
|
||||
|
||||
func ListSystemBackends(basePath string) (SystemBackends, error) {
|
||||
potentialBackends, err := os.ReadDir(basePath)
|
||||
func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) {
|
||||
potentialBackends, err := os.ReadDir(systemState.Backend.BackendsPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backends := make(SystemBackends)
|
||||
|
||||
systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath)
|
||||
if err == nil {
|
||||
// system backends are special, they are provided by the system and not managed by LocalAI
|
||||
for _, systemBackend := range systemBackends {
|
||||
if systemBackend.IsDir() {
|
||||
systemBackendRunFile := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
|
||||
if _, err := os.Stat(systemBackendRunFile); err == nil {
|
||||
backends[systemBackend.Name()] = SystemBackend{
|
||||
Name: systemBackend.Name(),
|
||||
RunFile: filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile),
|
||||
IsMeta: false,
|
||||
IsSystem: true,
|
||||
Metadata: nil,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("Failed to read system backends, but that's ok, we will just use the backends managed by LocalAI")
|
||||
}
|
||||
|
||||
for _, potentialBackend := range potentialBackends {
|
||||
if potentialBackend.IsDir() {
|
||||
potentialBackendRunFile := filepath.Join(basePath, potentialBackend.Name(), runFile)
|
||||
potentialBackendRunFile := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), runFile)
|
||||
|
||||
var metadata *BackendMetadata
|
||||
|
||||
// If metadata file does not exist, we just use the directory name
|
||||
// and we do not fill the other metadata (such as potential backend Aliases)
|
||||
metadataFilePath := filepath.Join(basePath, potentialBackend.Name(), metadataFile)
|
||||
metadataFilePath := filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name(), metadataFile)
|
||||
if _, err := os.Stat(metadataFilePath); os.IsNotExist(err) {
|
||||
metadata = &BackendMetadata{
|
||||
Name: potentialBackend.Name(),
|
||||
}
|
||||
} else {
|
||||
// Check for alias in metadata
|
||||
metadata, err = readBackendMetadata(filepath.Join(basePath, potentialBackend.Name()))
|
||||
metadata, err = readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, potentialBackend.Name()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -323,7 +359,7 @@ func ListSystemBackends(basePath string) (SystemBackends, error) {
|
||||
if metadata.MetaBackendFor != "" {
|
||||
backends[metadata.Name] = SystemBackend{
|
||||
Name: metadata.Name,
|
||||
RunFile: filepath.Join(basePath, metadata.MetaBackendFor, runFile),
|
||||
RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
|
||||
IsMeta: true,
|
||||
Metadata: metadata,
|
||||
}
|
||||
@@ -334,8 +370,8 @@ func ListSystemBackends(basePath string) (SystemBackends, error) {
|
||||
return backends, nil
|
||||
}
|
||||
|
||||
func RegisterBackends(basePath string, modelLoader *model.ModelLoader) error {
|
||||
backends, err := ListSystemBackends(basePath)
|
||||
func RegisterBackends(systemState *system.SystemState, modelLoader *model.ModelLoader) error {
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -43,13 +43,21 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Describe("InstallBackendFromGallery", func() {
|
||||
It("should return error when backend is not found", func() {
|
||||
err := InstallBackendFromGallery(galleries, nil, "non-existent", tempDir, nil, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackendFromGallery(galleries, systemState, "non-existent", nil, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||
})
|
||||
|
||||
It("should install backend from gallery", func() {
|
||||
err := InstallBackendFromGallery(galleries, nil, "test-backend", tempDir, nil, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackendFromGallery(galleries, systemState, "test-backend", nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
})
|
||||
@@ -220,26 +228,32 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Test with NVIDIA system state
|
||||
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
|
||||
nvidiaSystemState := &system.SystemState{
|
||||
GPUVendor: "nvidia",
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
Expect(metaBackendPath).To(BeADirectory())
|
||||
|
||||
metaBackendPath = filepath.Join(tempDir, "meta-backend", "metadata.json")
|
||||
Expect(metaBackendPath).To(BeARegularFile())
|
||||
|
||||
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
|
||||
Expect(concreteBackendPath).To(BeADirectory())
|
||||
|
||||
allBackends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
|
||||
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
|
||||
|
||||
allBackends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends).To(HaveKey("meta-backend"))
|
||||
Expect(allBackends).To(HaveKey("nvidia-backend"))
|
||||
|
||||
// Delete meta backend by name
|
||||
err = DeleteBackendFromSystem(tempDir, "meta-backend")
|
||||
err = DeleteBackendFromSystem(systemState, "meta-backend")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Verify meta backend directory is deleted
|
||||
@@ -294,8 +308,12 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Test with NVIDIA system state
|
||||
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
|
||||
nvidiaSystemState := &system.SystemState{
|
||||
GPUVendor: "nvidia",
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -304,19 +322,22 @@ var _ = Describe("Gallery Backends", func() {
|
||||
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
|
||||
Expect(concreteBackendPath).To(BeADirectory())
|
||||
|
||||
allBackends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
|
||||
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
|
||||
|
||||
backend, ok := allBackends.Get("meta-backend")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(backend.Metadata.MetaBackendFor).To(Equal("nvidia-backend"))
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
|
||||
Expect(backend.IsMeta).To(BeTrue())
|
||||
allBackends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends).To(HaveKey("meta-backend"))
|
||||
Expect(allBackends).To(HaveKey("nvidia-backend"))
|
||||
mback, exists := allBackends.Get("meta-backend")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(mback.IsMeta).To(BeTrue())
|
||||
Expect(mback.Metadata.MetaBackendFor).To(Equal("nvidia-backend"))
|
||||
|
||||
// Delete meta backend by name
|
||||
err = DeleteBackendFromSystem(tempDir, "meta-backend")
|
||||
err = DeleteBackendFromSystem(systemState, "meta-backend")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Verify meta backend directory is deleted
|
||||
@@ -371,8 +392,12 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Test with NVIDIA system state
|
||||
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
|
||||
nvidiaSystemState := &system.SystemState{
|
||||
GPUVendor: "nvidia",
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -381,16 +406,21 @@ var _ = Describe("Gallery Backends", func() {
|
||||
concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
|
||||
Expect(concreteBackendPath).To(BeADirectory())
|
||||
|
||||
allBackends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends.Exists("meta-backend")).To(BeTrue())
|
||||
Expect(allBackends.Exists("nvidia-backend")).To(BeTrue())
|
||||
backend, ok := allBackends.Get("meta-backend")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
|
||||
|
||||
allBackends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(allBackends).To(HaveKey("meta-backend"))
|
||||
Expect(allBackends).To(HaveKey("nvidia-backend"))
|
||||
mback, exists := allBackends.Get("meta-backend")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(mback.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
|
||||
|
||||
// Delete meta backend by name
|
||||
err = DeleteBackendFromSystem(tempDir, "meta-backend")
|
||||
err = DeleteBackendFromSystem(systemState, "meta-backend")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Verify meta backend directory is deleted
|
||||
@@ -427,25 +457,28 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// List system backends
|
||||
backends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackend, exists := backends.Get("meta-backend")
|
||||
concreteBackendRunFile := filepath.Join(tempDir, "concrete-backend", "run.sh")
|
||||
|
||||
// Should include both the meta backend name and concrete backend name
|
||||
Expect(backends.Exists("meta-backend")).To(BeTrue())
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(backends.Exists("concrete-backend")).To(BeTrue())
|
||||
|
||||
// meta-backend should point to concrete-backend
|
||||
Expect(backends.Exists("meta-backend")).To(BeTrue())
|
||||
backend, ok := backends.Get("meta-backend")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(backend.Metadata.MetaBackendFor).To(Equal("concrete-backend"))
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh")))
|
||||
Expect(backend.IsMeta).To(BeTrue())
|
||||
|
||||
// meta-backend should be empty
|
||||
Expect(metaBackend.IsMeta).To(BeTrue())
|
||||
Expect(metaBackend.RunFile).To(Equal(concreteBackendRunFile))
|
||||
// concrete-backend should point to its own run.sh
|
||||
backend, ok = backends.Get("concrete-backend")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, "concrete-backend", "run.sh")))
|
||||
concreteBackend, exists := backends.Get("concrete-backend")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(concreteBackend.RunFile).To(Equal(concreteBackendRunFile))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -459,11 +492,80 @@ var _ = Describe("Gallery Backends", func() {
|
||||
URI: "test-uri",
|
||||
}
|
||||
|
||||
err := InstallBackend(newPath, &backend, nil)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(newPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, &backend, nil)
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
Expect(newPath).To(BeADirectory())
|
||||
})
|
||||
|
||||
It("should overwrite existing backend", func() {
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
Skip("Skipping test on darwin/arm64")
|
||||
}
|
||||
newPath := filepath.Join(tempDir, "test-backend")
|
||||
|
||||
// Create a dummy backend directory
|
||||
err := os.MkdirAll(newPath, 0750)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = os.WriteFile(filepath.Join(newPath, "metadata.json"), []byte("foo"), 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(newPath, "run.sh"), []byte(""), 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "test-backend",
|
||||
},
|
||||
URI: "quay.io/mudler/tests:localai-backend-test",
|
||||
Alias: "test-alias",
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(dat)).ToNot(Equal("foo"))
|
||||
})
|
||||
|
||||
It("should overwrite existing backend", func() {
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
Skip("Skipping test on darwin/arm64")
|
||||
}
|
||||
newPath := filepath.Join(tempDir, "test-backend")
|
||||
|
||||
// Create a dummy backend directory
|
||||
err := os.MkdirAll(newPath, 0750)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backend := GalleryBackend{
|
||||
Metadata: Metadata{
|
||||
Name: "test-backend",
|
||||
},
|
||||
URI: "quay.io/mudler/tests:localai-backend-test",
|
||||
Alias: "test-alias",
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||
|
||||
err = InstallBackend(systemState, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
})
|
||||
|
||||
It("should create alias file when specified", func() {
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
Skip("Skipping test on darwin/arm64")
|
||||
@@ -476,7 +578,11 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Alias: "test-alias",
|
||||
}
|
||||
|
||||
err := InstallBackend(tempDir, &backend, nil)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
|
||||
@@ -492,16 +598,14 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
|
||||
// Check that the alias was recognized
|
||||
backends, err := ListSystemBackends(tempDir)
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(backends.Exists("test-alias")).To(BeTrue())
|
||||
Expect(backends.Exists("test-backend")).To(BeTrue())
|
||||
b, ok := backends.Get("test-alias")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
|
||||
b, ok = backends.Get("test-backend")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(b.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
|
||||
aliasBackend, exists := backends.Get("test-alias")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(aliasBackend.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
|
||||
testB, exists := backends.Get("test-backend")
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(testB.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -514,13 +618,26 @@ var _ = Describe("Gallery Backends", func() {
|
||||
err := os.MkdirAll(backendPath, 0750)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = DeleteBackendFromSystem(tempDir, backendName)
|
||||
err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), []byte("{}"), 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0644)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = DeleteBackendFromSystem(systemState, backendName)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(backendPath).NotTo(BeADirectory())
|
||||
})
|
||||
|
||||
It("should not error when backend doesn't exist", func() {
|
||||
err := DeleteBackendFromSystem(tempDir, "non-existent")
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = DeleteBackendFromSystem(systemState, "non-existent")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
@@ -538,14 +655,17 @@ var _ = Describe("Gallery Backends", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
backends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(backends.GetAll()).To(HaveLen(len(backendNames)))
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(backends).To(HaveLen(len(backendNames)))
|
||||
|
||||
for _, name := range backendNames {
|
||||
Expect(backends.Exists(name)).To(BeTrue())
|
||||
backend, ok := backends.Get(name)
|
||||
Expect(ok).To(BeTrue())
|
||||
backend, exists := backends.Get(name)
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, name, "run.sh")))
|
||||
}
|
||||
})
|
||||
@@ -572,16 +692,23 @@ var _ = Describe("Gallery Backends", func() {
|
||||
err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0755)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
backends, err := ListSystemBackends(tempDir)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(backends.Exists(alias)).To(BeTrue())
|
||||
backend, ok := backends.Get(alias)
|
||||
Expect(ok).To(BeTrue())
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
backend, exists := backends.Get(alias)
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, backendName, "run.sh")))
|
||||
})
|
||||
|
||||
It("should return error when base path doesn't exist", func() {
|
||||
_, err := ListSystemBackends(filepath.Join(tempDir, "non-existent"))
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath("foobardir"),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = ListSystemBackends(systemState)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
@@ -89,7 +90,7 @@ func (gm GalleryElements[T]) Paginate(pageNum int, itemsNum int) GalleryElements
|
||||
return gm[start:end]
|
||||
}
|
||||
|
||||
func FindGalleryElement[T GalleryElement](models []T, name string, basePath string) T {
|
||||
func FindGalleryElement[T GalleryElement](models []T, name string) T {
|
||||
var model T
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
@@ -116,13 +117,13 @@ func FindGalleryElement[T GalleryElement](models []T, name string, basePath stri
|
||||
// List available models
|
||||
// Models galleries are a list of yaml files that are hosted on a remote server (for example github).
|
||||
// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
|
||||
func AvailableGalleryModels(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryModel], error) {
|
||||
func AvailableGalleryModels(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) {
|
||||
var models []*GalleryModel
|
||||
|
||||
// Get models from galleries
|
||||
for _, gallery := range galleries {
|
||||
galleryModels, err := getGalleryElements[*GalleryModel](gallery, basePath, func(model *GalleryModel) bool {
|
||||
if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
|
||||
galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
|
||||
if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -137,13 +138,13 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) (Galler
|
||||
}
|
||||
|
||||
// List available backends
|
||||
func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElements[*GalleryBackend], error) {
|
||||
var models []*GalleryBackend
|
||||
func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
|
||||
var backends []*GalleryBackend
|
||||
|
||||
// Get models from galleries
|
||||
// Get backends from galleries
|
||||
for _, gallery := range galleries {
|
||||
galleryModels, err := getGalleryElements[*GalleryBackend](gallery, basePath, func(backend *GalleryBackend) bool {
|
||||
backends, err := ListSystemBackends(basePath)
|
||||
galleryBackends, err := getGalleryElements[*GalleryBackend](gallery, systemState.Backend.BackendsPath, func(backend *GalleryBackend) bool {
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -152,10 +153,10 @@ func AvailableBackends(galleries []config.Gallery, basePath string) (GalleryElem
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models = append(models, galleryModels...)
|
||||
backends = append(backends, galleryBackends...)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
return backends, nil
|
||||
}
|
||||
|
||||
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
|
||||
|
||||
@@ -72,7 +72,8 @@ type PromptTemplate struct {
|
||||
// Installs a model from the gallery
|
||||
func InstallModelFromGallery(
|
||||
modelGalleries, backendGalleries []config.Gallery,
|
||||
name string, basePath, backendBasePath string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
|
||||
systemState *system.SystemState,
|
||||
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
@@ -81,7 +82,7 @@ func InstallModelFromGallery(
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, basePath)
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,19 +123,15 @@ func InstallModelFromGallery(
|
||||
return err
|
||||
}
|
||||
|
||||
installedModel, err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Installed model %q", installedModel.Name)
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil {
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -142,12 +139,12 @@ func InstallModelFromGallery(
|
||||
return nil
|
||||
}
|
||||
|
||||
models, err := AvailableGalleryModels(modelGalleries, basePath)
|
||||
models, err := AvailableGalleryModels(modelGalleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := FindGalleryElement(models, name, basePath)
|
||||
model := FindGalleryElement(models, name)
|
||||
if model == nil {
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
}
|
||||
@@ -155,7 +152,8 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(basePath, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.BackendConfig, error) {
|
||||
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
if err != nil {
|
||||
@@ -221,7 +219,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backendConfig := lconfig.BackendConfig{}
|
||||
modelConfig := lconfig.ModelConfig{}
|
||||
|
||||
// write config file
|
||||
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
|
||||
@@ -246,12 +244,12 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
|
||||
return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(updatedConfigYAML, &backendConfig)
|
||||
err = yaml.Unmarshal(updatedConfigYAML, &modelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
if !backendConfig.Validate() {
|
||||
if !modelConfig.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate updated config YAML")
|
||||
}
|
||||
|
||||
@@ -272,7 +270,7 @@ func InstallModel(basePath, nameOverride string, config *ModelConfig, configOver
|
||||
|
||||
log.Debug().Msgf("Written gallery file %s", modelFile)
|
||||
|
||||
return &backendConfig, os.WriteFile(modelFile, data, 0600)
|
||||
return &modelConfig, os.WriteFile(modelFile, data, 0600)
|
||||
}
|
||||
|
||||
func galleryFileName(name string) string {
|
||||
@@ -285,21 +283,39 @@ func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, err
|
||||
return ReadConfigFile[ModelConfig](galleryFile)
|
||||
}
|
||||
|
||||
func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error {
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
additionalFiles := []string{}
|
||||
|
||||
configFile := filepath.Join(basePath, fmt.Sprintf("%s.yaml", name))
|
||||
configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
|
||||
if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", configFile, err)
|
||||
}
|
||||
// Galleryname is the name of the model in this case
|
||||
dat, err := os.ReadFile(configFile)
|
||||
if err == nil {
|
||||
modelConfig := &config.ModelConfig{}
|
||||
|
||||
galleryFile := filepath.Join(basePath, galleryFileName(name))
|
||||
err = yaml.Unmarshal(dat, &modelConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if modelConfig.Model != "" {
|
||||
additionalFiles = append(additionalFiles, modelConfig.ModelFileName())
|
||||
}
|
||||
|
||||
for _, f := range []string{configFile, galleryFile} {
|
||||
if err := utils.VerifyPath(f, basePath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", f, err)
|
||||
if modelConfig.MMProj != "" {
|
||||
additionalFiles = append(additionalFiles, modelConfig.MMProjFileName())
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name))
|
||||
if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
|
||||
}
|
||||
|
||||
// Delete all the files associated to the model
|
||||
// read the model config
|
||||
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
|
||||
@@ -312,13 +328,19 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin
|
||||
// Remove additional files
|
||||
if galleryconfig != nil {
|
||||
for _, f := range galleryconfig.Files {
|
||||
fullPath := filepath.Join(basePath, f.Filename)
|
||||
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
|
||||
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
}
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range additionalFiles {
|
||||
fullPath := filepath.Join(filepath.Join(basePath, f))
|
||||
fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f))
|
||||
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
}
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
}
|
||||
|
||||
@@ -340,8 +362,8 @@ func DeleteModelFromSystem(basePath string, name string, additionalFiles []strin
|
||||
|
||||
// This is ***NEVER*** going to be perfect or finished.
|
||||
// This is a BEST EFFORT function to surface known-vulnerable models to users.
|
||||
func SafetyScanGalleryModels(galleries []config.Gallery, basePath string) error {
|
||||
galleryModels, err := AvailableGalleryModels(galleries, basePath)
|
||||
func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error {
|
||||
galleryModels, err := AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -29,7 +30,11 @@ var _ = Describe("Model test", func() {
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -71,15 +76,19 @@ var _ = Describe("Model test", func() {
|
||||
URL: "file://" + galleryFilePath,
|
||||
},
|
||||
}
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
models, err := AvailableGalleryModels(galleries, tempdir)
|
||||
models, err := AvailableGalleryModels(galleries, systemState)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models)).To(Equal(1))
|
||||
Expect(models[0].Name).To(Equal("bert"))
|
||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, "test@bert", tempdir, "", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
@@ -90,16 +99,16 @@ var _ = Describe("Model test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
|
||||
|
||||
models, err = AvailableGalleryModels(galleries, tempdir)
|
||||
models, err = AvailableGalleryModels(galleries, systemState)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models)).To(Equal(1))
|
||||
Expect(models[0].Installed).To(BeTrue())
|
||||
|
||||
// delete
|
||||
err = DeleteModelFromSystem(tempdir, "bert", []string{})
|
||||
err = DeleteModelFromSystem(systemState, "bert")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
models, err = AvailableGalleryModels(galleries, tempdir)
|
||||
models, err = AvailableGalleryModels(galleries, systemState)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models)).To(Equal(1))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
@@ -116,7 +125,11 @@ var _ = Describe("Model test", func() {
|
||||
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -132,7 +145,11 @@ var _ = Describe("Model test", func() {
|
||||
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -158,7 +175,11 @@ var _ = Describe("Model test", func() {
|
||||
c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -198,7 +198,7 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
|
||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader(), application.ApplicationConfig().SystemState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -320,12 +321,17 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithApiKeys([]string{apiKey}),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -523,13 +529,18 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithGeneratedContentDir(tmpdir),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
)...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -729,12 +740,17 @@ var _ = Describe("API test", func() {
|
||||
|
||||
var err error
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelPath),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
config.WithContext(c),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithSystemState(systemState),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
@@ -960,11 +976,17 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelPath),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
@@ -23,7 +23,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
@@ -27,7 +27,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
@@ -25,7 +25,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,24 +11,27 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type BackendEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendApplier *services.GalleryService
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendSystemPath string
|
||||
backendApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, backendPath string, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
return BackendEndpointService{
|
||||
galleries: galleries,
|
||||
backendPath: backendPath,
|
||||
backendApplier: backendApplier,
|
||||
galleries: galleries,
|
||||
backendPath: systemState.Backend.BackendsPath,
|
||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||
backendApplier: backendApplier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.ListSystemBackends(mgs.backendPath)
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -141,9 +144,9 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, mgs.backendPath)
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
@@ -24,7 +24,7 @@ func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, ap
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -26,11 +27,11 @@ type GalleryModel struct {
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: modelPath,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
@@ -115,10 +116,10 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
@@ -37,7 +37,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
|
||||
@@ -14,14 +14,14 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -22,14 +22,14 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -16,14 +16,14 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -72,7 +72,7 @@ func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.BackendConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -34,7 +34,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ApplicationConfig": appConfig,
|
||||
"ProcessingModels": processingModels,
|
||||
|
||||
@@ -27,11 +27,11 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
@@ -66,7 +66,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
result := ""
|
||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
@@ -183,7 +183,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
@@ -501,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
|
||||
if len(funcResults) == 0 && result != "" {
|
||||
log.Debug().Msgf("nothing function results but we had a message from the LLM")
|
||||
|
||||
@@ -27,10 +27,10 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -73,7 +73,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
@@ -34,7 +34,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -73,7 +73,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *config.BackendConfig,
|
||||
bcl *config.BackendConfigLoader,
|
||||
config *config.ModelConfig,
|
||||
bcl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
|
||||
@@ -559,7 +559,7 @@ func sendNotImplemented(c *websocket.Conn, message string) {
|
||||
sendError(c, "not_implemented", message, "", "event_TODO")
|
||||
}
|
||||
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -589,7 +589,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
|
||||
}
|
||||
|
||||
// Function to update session configurations
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -628,7 +628,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Bac
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-done
|
||||
@@ -742,7 +742,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -853,7 +853,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
|
||||
// TODO: Below needed for normal mode instead of transcription only
|
||||
// Function to generate a response based on the conversation
|
||||
// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
//
|
||||
// log.Debug().Msg("Generating realtime response...")
|
||||
//
|
||||
@@ -1067,7 +1067,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
// }
|
||||
|
||||
// Function to process text response and detect function calls
|
||||
func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
|
||||
// Placeholder implementation
|
||||
// Replace this with actual model inference logic using session.Model and prompt
|
||||
|
||||
@@ -22,14 +22,14 @@ var (
|
||||
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
|
||||
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
|
||||
type wrappedModel struct {
|
||||
TTSConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
LLMConfig *config.BackendConfig
|
||||
TTSConfig *config.ModelConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
TTSClient grpcClient.Backend
|
||||
TranscriptionClient grpcClient.Backend
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -37,17 +37,17 @@ type wrappedModel struct {
|
||||
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
|
||||
// In the future there could be models that accept continous audio input only so this design will be useful for that
|
||||
type anyToAnyModel struct {
|
||||
LLMConfig *config.BackendConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
type transcriptOnlyModel struct {
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
TranscriptionClient grpcClient.Backend
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) {
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -122,7 +122,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
return nil, nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -139,17 +139,17 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
}
|
||||
|
||||
return &transcriptOnlyModel{
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
TranscriptionConfig: cfgSST,
|
||||
TranscriptionClient: transcriptionClient,
|
||||
}, cfgSST, nil
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -166,7 +166,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
}
|
||||
|
||||
// TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -185,7 +185,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
// TODO: Decide when we have a real any-to-any model
|
||||
if false {
|
||||
|
||||
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -212,7 +212,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
log.Debug().Msg("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -222,7 +222,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -232,7 +232,6 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
|
||||
opts = backend.ModelOptions(*cfgTTS, appConfig)
|
||||
ttsClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,14 +24,14 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -26,16 +26,16 @@ type correlationIDKeyType string
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
backendConfigLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
backendConfigLoader: backendConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModel
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
@@ -89,7 +89,7 @@ func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn con
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
@@ -129,7 +129,7 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
@@ -152,7 +152,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func RegisterJINARoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
requestExtractor *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService) {
|
||||
@@ -23,20 +23,23 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.ModelPath, galleryService)
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(appConfig.BackendGalleries, appConfig.BackendsPath, galleryService)
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(
|
||||
appConfig.BackendGalleries,
|
||||
appConfig.SystemState,
|
||||
galleryService)
|
||||
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint())
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func RegisterUIRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService) {
|
||||
@@ -65,9 +65,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
app.Get("/talk/", func(c *fiber.Ctx) error {
|
||||
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
|
||||
if len(backendConfigs) == 0 {
|
||||
if len(modelConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -75,8 +75,8 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Talk",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"Model": backendConfigs[0],
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelConfigs[0],
|
||||
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
@@ -86,17 +86,17 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/chat/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
modelThatCanBeUsed := ""
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -106,7 +106,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
title := "LocalAI - Chat"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_CHAT) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Chat with " + modelThatCanBeUsed
|
||||
@@ -119,7 +119,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
@@ -130,12 +130,12 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
// Show the Chat page
|
||||
app.Get("/chat/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -146,7 +146,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Chat with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
@@ -158,13 +158,13 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -175,10 +175,10 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/text2image/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
modelThatCanBeUsed := ""
|
||||
title := "LocalAI - Generate images"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_IMAGE) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Generate images with " + modelThatCanBeUsed
|
||||
@@ -197,7 +197,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -208,13 +208,13 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/tts/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -225,10 +225,10 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/tts/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
modelThatCanBeUsed := ""
|
||||
title := "LocalAI - Generate audio"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_TTS) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Generate audio with " + modelThatCanBeUsed
|
||||
@@ -246,7 +246,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
|
||||
@@ -28,7 +28,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
|
||||
page := c.Query("page")
|
||||
items := c.Query("items")
|
||||
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list backends from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
|
||||
@@ -129,7 +129,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
|
||||
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
|
||||
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
|
||||
if page != "" {
|
||||
// return a subset of the backends
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
// Show the Models page (all models)
|
||||
app.Get("/browse", func(c *fiber.Ctx) error {
|
||||
@@ -28,7 +28,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
page := c.Query("page")
|
||||
items := c.Query("items")
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
|
||||
@@ -131,7 +131,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
||||
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
|
||||
if page != "" {
|
||||
// return a subset of the models
|
||||
@@ -224,7 +224,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
}
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
cl.RemoveBackendConfig(galleryName)
|
||||
cl.RemoveModelConfig(galleryName)
|
||||
}()
|
||||
|
||||
return c.SendString(elements.StartModelProgressBar(uid, "0", "Deletion"))
|
||||
|
||||
@@ -16,21 +16,21 @@ import (
|
||||
)
|
||||
|
||||
type BackendMonitorService struct {
|
||||
backendConfigLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||
}
|
||||
|
||||
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
|
||||
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.ModelConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
|
||||
return &BackendMonitorService{
|
||||
modelLoader: modelLoader,
|
||||
backendConfigLoader: configLoader,
|
||||
options: appConfig,
|
||||
modelLoader: modelLoader,
|
||||
modelConfigLoader: configLoader,
|
||||
options: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||
config, exists := bms.backendConfigLoader.GetBackendConfig(modelName)
|
||||
config, exists := bms.modelConfigLoader.GetModelConfig(modelName)
|
||||
var backendId string
|
||||
if exists {
|
||||
backendId = config.Model
|
||||
@@ -47,7 +47,7 @@ func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string)
|
||||
}
|
||||
|
||||
func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||
config, exists := bms.backendConfigLoader.GetBackendConfig(model)
|
||||
config, exists := bms.modelConfigLoader.GetModelConfig(model)
|
||||
var backend string
|
||||
if exists {
|
||||
backend = config.Model
|
||||
|
||||
@@ -20,21 +20,21 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend], s
|
||||
|
||||
var err error
|
||||
if op.Delete {
|
||||
err = gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName)
|
||||
err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName)
|
||||
g.modelLoader.DeleteExternalBackend(op.GalleryElementName)
|
||||
} else {
|
||||
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
|
||||
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
|
||||
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, g.appConfig.BackendsPath, progressCallback, true)
|
||||
err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, op.GalleryElementName, progressCallback, true)
|
||||
if err == nil {
|
||||
err = gallery.RegisterBackends(g.appConfig.BackendsPath, g.modelLoader)
|
||||
err = gallery.RegisterBackends(systemState, g.modelLoader)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName)
|
||||
if !op.Delete {
|
||||
// If we didn't install the backend, we need to make sure we don't have a leftover directory
|
||||
gallery.DeleteBackendFromSystem(g.appConfig.BackendsPath, op.GalleryElementName)
|
||||
gallery.DeleteBackendFromSystem(systemState, op.GalleryElementName)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type GalleryService struct {
|
||||
@@ -52,7 +51,7 @@ func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) error {
|
||||
func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
|
||||
// updates the status with an error
|
||||
var updateError func(id string, e error)
|
||||
if !g.appConfig.OpaqueErrors {
|
||||
@@ -65,11 +64,6 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
|
||||
}
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get system state")
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -82,7 +76,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
|
||||
}
|
||||
|
||||
case op := <-g.ModelGalleryChannel:
|
||||
err := g.modelHandler(&op, cl)
|
||||
err := g.modelHandler(&op, cl, systemState)
|
||||
if err != nil {
|
||||
updateError(op.ID, err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ const (
|
||||
ALWAYS_INCLUDE
|
||||
)
|
||||
|
||||
func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter config.BackendConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
|
||||
func ListModels(bcl *config.ModelConfigLoader, ml *model.ModelLoader, filter config.ModelConfigFilterFn, looseFilePolicy LooseFilePolicy) ([]string, error) {
|
||||
|
||||
var skipMap map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
@@ -22,7 +22,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
|
||||
|
||||
// Start with known configurations
|
||||
|
||||
for _, c := range bcl.GetBackendConfigsByFilter(filter) {
|
||||
for _, c := range bcl.GetModelConfigsByFilter(filter) {
|
||||
// Is this better than looseFilePolicy <= SKIP_IF_CONFIGURED ? less performant but more readable?
|
||||
if (looseFilePolicy == SKIP_IF_CONFIGURED) || (looseFilePolicy == LOOSE_ONLY) {
|
||||
skipMap[c.Model] = nil
|
||||
@@ -50,7 +50,7 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
|
||||
return dataModels, nil
|
||||
}
|
||||
|
||||
func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
|
||||
func CheckIfModelExists(bcl *config.ModelConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
|
||||
filter, err := config.BuildNameFilterFn(modelName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package services
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.BackendConfigLoader) error {
|
||||
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *config.ModelConfigLoader, systemState *system.SystemState) error {
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
|
||||
@@ -23,18 +22,18 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
err := processModelOperation(op, g.appConfig.ModelPath, g.appConfig.BackendsPath, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
|
||||
err := processModelOperation(op, systemState, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reload models
|
||||
err = cl.LoadBackendConfigsFromPath(g.appConfig.ModelPath)
|
||||
err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = cl.Preload(g.appConfig.ModelPath)
|
||||
err = cl.Preload(systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -50,26 +49,21 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel], cl *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func installModelFromRemoteConfig(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery, backendBasePath string) error {
|
||||
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, modelPath)
|
||||
func installModelFromRemoteConfig(systemState *system.SystemState, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
|
||||
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
installedModel, err := gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, backendBasePath, downloadStatus, false); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -82,22 +76,22 @@ type galleryModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func processRequests(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
|
||||
func processRequests(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = installModelFromRemoteConfig(modelPath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries, backendBasePath)
|
||||
err = installModelFromRemoteConfig(systemState, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
|
||||
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, backendGalleries, r.ID, modelPath, backendBasePath, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
|
||||
galleries, backendGalleries, systemState, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
|
||||
func ApplyGalleryFromFile(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -108,58 +102,35 @@ func ApplyGalleryFromFile(modelPath, backendBasePath string, enforceScan, automa
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
|
||||
return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, backendBasePath string, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
|
||||
func ApplyGalleryFromString(systemState *system.SystemState, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, s string) error {
|
||||
var requests []galleryModel
|
||||
err := json.Unmarshal([]byte(s), &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, backendBasePath, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
|
||||
return processRequests(systemState, enforceScan, automaticallyInstallBackend, galleries, backendGalleries, requests)
|
||||
}
|
||||
|
||||
// processModelOperation handles the installation or deletion of a model
|
||||
func processModelOperation(
|
||||
op *GalleryOp[gallery.GalleryModel],
|
||||
modelPath string,
|
||||
backendBasePath string,
|
||||
systemState *system.SystemState,
|
||||
enforcePredownloadScans bool,
|
||||
automaticallyInstallBackend bool,
|
||||
progressCallback func(string, string, string, float64),
|
||||
) error {
|
||||
// delete a model
|
||||
if op.Delete {
|
||||
modelConfig := &config.BackendConfig{}
|
||||
|
||||
// Galleryname is the name of the model in this case
|
||||
dat, err := os.ReadFile(filepath.Join(modelPath, op.GalleryElementName+".yaml"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = yaml.Unmarshal(dat, modelConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
files := []string{}
|
||||
// Remove the model from the config
|
||||
if modelConfig.Model != "" {
|
||||
files = append(files, modelConfig.ModelFileName())
|
||||
}
|
||||
|
||||
if modelConfig.MMProj != "" {
|
||||
files = append(files, modelConfig.MMProjFileName())
|
||||
}
|
||||
|
||||
return gallery.DeleteModelFromSystem(modelPath, op.GalleryElementName, files)
|
||||
return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName)
|
||||
}
|
||||
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.GalleryElementName != "" {
|
||||
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, op.GalleryElementName, modelPath, backendBasePath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
|
||||
return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
|
||||
// } else if op.ConfigURL != "" {
|
||||
// err := startup.InstallModels(op.Galleries, modelPath, enforcePredownloadScans, progressCallback, op.ConfigURL)
|
||||
// if err != nil {
|
||||
@@ -167,6 +138,6 @@ func processModelOperation(
|
||||
// }
|
||||
// return cl.Preload(modelPath)
|
||||
} else {
|
||||
return installModelFromRemoteConfig(modelPath, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries, backendBasePath)
|
||||
return installModelFromRemoteConfig(systemState, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func InstallExternalBackends(galleries []config.Gallery, backendPath string, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get system state: %w", err)
|
||||
}
|
||||
func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
|
||||
uri := downloader.URI(backend)
|
||||
switch {
|
||||
case uri.LooksLikeDir():
|
||||
@@ -24,7 +20,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
|
||||
name = filepath.Base(backend)
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -38,7 +34,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
|
||||
return fmt.Errorf("specifying a name is required for OCI images")
|
||||
}
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -56,7 +52,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
|
||||
name = strings.TrimSuffix(name, filepath.Ext(name))
|
||||
|
||||
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
|
||||
if err := gallery.InstallBackend(backendPath, &gallery.GalleryBackend{
|
||||
if err := gallery.InstallBackend(systemState, &gallery.GalleryBackend{
|
||||
Metadata: gallery.Metadata{
|
||||
Name: name,
|
||||
},
|
||||
@@ -69,7 +65,7 @@ func InstallExternalBackends(galleries []config.Gallery, backendPath string, dow
|
||||
if name != "" || alias != "" {
|
||||
return fmt.Errorf("specifying a name or alias is not supported for this backend")
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, backendPath, downloadStatus, true)
|
||||
err := gallery.InstallBackendFromGallery(galleries, systemState, backend, downloadStatus, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error installing backend %s: %w", backend, err)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ const (
|
||||
// InstallModels will preload models from the given list of URLs and galleries
|
||||
// It will download the model if it is not already present in the model path
|
||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||
func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, backendBasePath string, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
||||
func InstallModels(galleries, backendGalleries []config.Gallery, systemState *system.SystemState, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
|
||||
// create an error that groups all errors
|
||||
var err error
|
||||
|
||||
@@ -36,7 +36,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
return e
|
||||
}
|
||||
|
||||
var model config.BackendConfig
|
||||
var model config.ModelConfig
|
||||
if e := yaml.Unmarshal(modelYAML, &model); e != nil {
|
||||
log.Error().Err(e).Str("filepath", modelPath).Msg("error unmarshalling model definition")
|
||||
return e
|
||||
@@ -47,12 +47,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
return nil
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, backendBasePath, downloadStatus, false); err != nil {
|
||||
if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, model.Backend, downloadStatus, false); err != nil {
|
||||
log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend")
|
||||
return err
|
||||
}
|
||||
@@ -77,8 +72,8 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
ociName = strings.ReplaceAll(ociName, ":", "__")
|
||||
|
||||
// check if file exists
|
||||
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, ociName)
|
||||
if _, e := os.Stat(filepath.Join(systemState.Model.ModelsPath, ociName)); errors.Is(e, os.ErrNotExist) {
|
||||
modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, ociName)
|
||||
e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
})
|
||||
@@ -100,7 +95,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
continue
|
||||
}
|
||||
|
||||
modelPath := filepath.Join(modelPath, fileName)
|
||||
modelPath := filepath.Join(systemState.Model.ModelsPath, fileName)
|
||||
|
||||
if e := utils.VerifyPath(fileName, modelPath); e != nil {
|
||||
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
|
||||
@@ -138,7 +133,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
continue
|
||||
}
|
||||
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + YAML_EXTENSION
|
||||
modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, md5Name) + YAML_EXTENSION
|
||||
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil {
|
||||
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
|
||||
err = errors.Join(err, e)
|
||||
@@ -152,7 +147,7 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
}
|
||||
} else {
|
||||
// Check if it's a model gallery, or print a warning
|
||||
e, found := installModel(galleries, backendGalleries, url, modelPath, backendBasePath, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
e, found := installModel(galleries, backendGalleries, url, systemState, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
if e != nil && found {
|
||||
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
|
||||
err = errors.Join(err, e)
|
||||
@@ -166,13 +161,13 @@ func InstallModels(galleries, backendGalleries []config.Gallery, modelPath, back
|
||||
return err
|
||||
}
|
||||
|
||||
func installModel(galleries, backendGalleries []config.Gallery, modelName, modelPath, backendBasePath string, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
|
||||
models, err := gallery.AvailableGalleryModels(galleries, modelPath)
|
||||
func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
|
||||
models, err := gallery.AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
model := gallery.FindGalleryElement(models, modelName, modelPath)
|
||||
model := gallery.FindGalleryElement(models, modelName)
|
||||
if model == nil {
|
||||
return err, false
|
||||
}
|
||||
@@ -182,7 +177,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName, model
|
||||
}
|
||||
|
||||
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
|
||||
err = gallery.InstallModelFromGallery(galleries, backendGalleries, modelName, modelPath, backendBasePath, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
|
||||
if err != nil {
|
||||
return err, true
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -21,7 +22,10 @@ var _ = Describe("Preload test", func() {
|
||||
url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", "phi-2")
|
||||
|
||||
InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", true, true, nil, url)
|
||||
systemState, err := system.GetSystemState(system.WithModelPath(tmpdir))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
@@ -36,7 +40,10 @@ var _ = Describe("Preload test", func() {
|
||||
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
|
||||
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
|
||||
|
||||
err = InstallModels([]config.Gallery{}, []config.Gallery{}, tmpdir, "", false, true, nil, url)
|
||||
systemState, err := system.GetSystemState(system.WithModelPath(tmpdir))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModels([]config.Gallery{}, []config.Gallery{}, systemState, true, true, nil, url)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
@@ -55,7 +55,7 @@ func NewEvaluator(modelPath string) *Evaluator {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||||
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.ModelConfig, in PromptTemplateData) (string, error) {
|
||||
template := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
@@ -135,7 +135,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
|
||||
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
var messageData []ChatMessageTemplateData
|
||||
|
||||
@@ -53,7 +53,7 @@ Function response:
|
||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
@@ -69,7 +69,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
@@ -86,7 +86,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
"function_call": {
|
||||
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
@@ -102,7 +102,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
@@ -121,7 +121,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
@@ -137,7 +137,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
@@ -153,7 +153,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
},
|
||||
"function_call": {
|
||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
@@ -175,7 +175,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
@@ -194,7 +194,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
||||
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"config": &config.BackendConfig{
|
||||
"config": &config.ModelConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: toolCallJinja,
|
||||
JinjaTemplate: true,
|
||||
@@ -219,7 +219,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
@@ -232,7 +232,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
@@ -245,7 +245,7 @@ var _ = Describe("Templates", func() {
|
||||
for key := range jinjaTest {
|
||||
foo := jinjaTest[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user