mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-06 10:39:55 -06:00
feat: migrate to echo and enable cancellation of non-streaming requests (#7270)
* WIP: migrate to echo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
03e9f4b140
commit
1cdcaf0152
@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := appHTTP.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
return appHTTP.Start(e.Address)
|
||||
}
|
||||
|
||||
@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
return appHTTP.Start(r.Address)
|
||||
}
|
||||
|
||||
223
core/http/app.go
223
core/http/app.go
@@ -4,30 +4,23 @@ import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/csrf"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
|
||||
// swagger handler
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -49,85 +42,85 @@ var embedDirStatic embed.FS
|
||||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func API(application *application.Application) (*fiber.App, error) {
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
// Override default error handler
|
||||
// Set body limit
|
||||
if application.ApplicationConfig().UploadLimitMB > 0 {
|
||||
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
|
||||
}
|
||||
|
||||
// Set error handler
|
||||
if !application.ApplicationConfig().OpaqueErrors {
|
||||
// Normally, return errors as JSON responses
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
// Handle 404 errors with HTML rendering when appropriate
|
||||
if code == http.StatusNotFound {
|
||||
notFoundHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
c.JSON(code, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If OpaqueErrors are required, replace everything with a blank 500.
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
|
||||
return ctx.Status(500).SendString("")
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
c.NoContent(code)
|
||||
}
|
||||
}
|
||||
|
||||
router := fiber.New(fiberCfg)
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
router.Use(middleware.StripPathPrefix())
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
|
||||
e.Pre(httpMiddleware.StripPathPrefix())
|
||||
|
||||
if application.ApplicationConfig().MachineTag != "" {
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
|
||||
return c.Next()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Returns true if the client requested upgrade to the WebSocket protocol
|
||||
return c.Next()
|
||||
// Custom logger middleware using zerolog
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
start := log.Logger.Info()
|
||||
err := next(c)
|
||||
start.
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Int("status", res.Status).
|
||||
Msg("HTTP request")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
if listenData.TLS {
|
||||
scheme = "https"
|
||||
}
|
||||
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||
logger := log.Logger
|
||||
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
Logger: &logger,
|
||||
}))
|
||||
|
||||
// Default middleware config
|
||||
|
||||
// Recover middleware
|
||||
if !application.ApplicationConfig().Debug {
|
||||
router.Use(recover.New())
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
@@ -135,34 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(router)
|
||||
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil || kaConfig == nil {
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(e)
|
||||
|
||||
// Get key auth middleware
|
||||
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
router.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
router.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
|
||||
}
|
||||
e.StaticFS("/static", staticFS)
|
||||
|
||||
// Generated content directories
|
||||
if application.ApplicationConfig().GeneratedContentDir != "" {
|
||||
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
|
||||
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
|
||||
@@ -173,51 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
os.MkdirAll(imagePath, 0750)
|
||||
os.MkdirAll(videoPath, 0750)
|
||||
|
||||
router.Static("/generated-audio", audioPath)
|
||||
router.Static("/generated-images", imagePath)
|
||||
router.Static("/generated-videos", videoPath)
|
||||
e.Static("/generated-audio", audioPath)
|
||||
e.Static("/generated-images", imagePath)
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||
router.Use(v2keyauth.New(*kaConfig))
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
|
||||
e.Use(keyAuthMiddleware)
|
||||
|
||||
// CORS middleware
|
||||
if application.ApplicationConfig().CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||
corsConfig := middleware.CORSConfig{}
|
||||
if application.ApplicationConfig().CORSAllowOrigins != "" {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
}
|
||||
|
||||
router.Use(c)
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
if application.ApplicationConfig().CSRF {
|
||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
router.Use(csrf.New())
|
||||
e.Use(middleware.CSRF())
|
||||
}
|
||||
|
||||
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *services.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
}
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
router.Use(notFoundHandler)
|
||||
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
|
||||
|
||||
return router, nil
|
||||
// Log startup message
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
log.Info().Msg("LocalAI API server shutting down")
|
||||
})
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
|
||||
|
||||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func(sc SpecContext) {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -547,7 +555,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -566,7 +578,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -755,7 +769,11 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -773,7 +791,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
@@ -1006,7 +1026,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -1022,7 +1046,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -15,17 +17,17 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -17,19 +18,19 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
voiceID := c.Param("voice-id")
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
|
||||
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,28 +2,32 @@ package explorer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func Dashboard() func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
func Dashboard() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(http.StatusOK, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/explorer", summary)
|
||||
return c.Render(http.StatusOK, "views/explorer", summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,8 +43,8 @@ type Network struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
results := []Network{}
|
||||
for _, token := range db.TokenList() {
|
||||
networkData, exists := db.Get(token) // get the token data
|
||||
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
||||
})
|
||||
|
||||
return c.JSON(results)
|
||||
return c.JSON(http.StatusOK, results)
|
||||
}
|
||||
}
|
||||
|
||||
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := new(AddNetworkRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
if err := c.Bind(request); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if request.Token == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
||||
}
|
||||
|
||||
if request.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
||||
}
|
||||
|
||||
if request.Description == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
||||
}
|
||||
|
||||
// TODO: check if token is valid, otherwise reject
|
||||
// try to decode the token from base64
|
||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
if _, exists := db.Get(request.Token); exists {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
||||
}
|
||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -17,17 +18,17 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
|
||||
@@ -58,6 +59,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
|
||||
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.backendApplier.GetAllStatus())
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.backendApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryBackend)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,9 +91,9 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendName := c.Params("name")
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
Delete: true,
|
||||
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends.GetAll())
|
||||
return c.JSON(200, backends.GetAll())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends)
|
||||
return c.JSON(200, backends)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,17 +16,17 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
@@ -14,15 +16,15 @@ import (
|
||||
)
|
||||
|
||||
// GetEditModelPage renders the edit model page with current configuration
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
modelConfigFile := modelConfig.GetModelConfigFile()
|
||||
@@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
configData, err := os.ReadFile(modelConfigFile)
|
||||
if err != nil {
|
||||
@@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Render the edit page with the current configuration
|
||||
@@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Version: internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
return c.Render("views/model-editor", templateData)
|
||||
return c.Render(http.StatusOK, "views/model-editor", templateData)
|
||||
}
|
||||
}
|
||||
|
||||
// EditModelEndpoint handles updating existing model configurations
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Existing model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content to see if it's a valid model config
|
||||
@@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
@@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
@@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Error: "Validation failed",
|
||||
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Load the existing configuration
|
||||
@@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Write new content to file
|
||||
@@ -151,7 +160,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Reload configurations
|
||||
@@ -160,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Filename: configPath,
|
||||
Config: req,
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadModelsEndpoint handles reloading model configurations from disk
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the models
|
||||
@@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: false,
|
||||
Error: "Failed to preload models: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: true,
|
||||
Message: "Model configurations reloaded successfully",
|
||||
}
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -15,6 +17,14 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// testRenderer is a simple renderer for tests that returns JSON
|
||||
type testRenderer struct{}
|
||||
|
||||
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
// For tests, just return the data as JSON
|
||||
return json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
var _ = Describe("Edit Model test", func() {
|
||||
|
||||
var tempDir string
|
||||
@@ -40,33 +50,35 @@ var _ = Describe("Edit Model test", func() {
|
||||
//modelLoader := model.NewModelLoader(systemState, true)
|
||||
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
|
||||
|
||||
// Define Fiber app.
|
||||
app := fiber.New()
|
||||
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
// Define Echo app and register all routes upfront
|
||||
app := echo.New()
|
||||
// Set up a simple renderer for the test
|
||||
app.Renderer = &testRenderer{}
|
||||
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
|
||||
|
||||
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/import-model", requestBody)
|
||||
resp, err := app.Test(req, 5000)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req := httptest.NewRequest("POST", "/import-model", requestBody)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
|
||||
resp, _ = app.Test(req, 1)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err = io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
// The response contains the model configuration with backend field
|
||||
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
|
||||
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,160 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -21,17 +21,17 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || modelFile != "" {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@ package localai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -21,12 +23,12 @@ import (
|
||||
)
|
||||
|
||||
// ImportModelURIEndpoint handles creating new model configurations from a URI
|
||||
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.ImportModelRequest)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -61,7 +63,7 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{
|
||||
return c.JSON(200, schema.GalleryResponse{
|
||||
ID: uuid.String(),
|
||||
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
|
||||
})
|
||||
@@ -69,22 +71,28 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
|
||||
}
|
||||
|
||||
// ImportModelEndpoint handles creating new model configurations
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
var modelConfig config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
@@ -93,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
@@ -102,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
@@ -122,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
@@ -145,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Invalid configuration",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Create the configuration file
|
||||
@@ -155,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Model path not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
@@ -165,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Write the file
|
||||
@@ -174,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
@@ -182,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -191,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
@@ -199,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Message: "Model configuration created successfully",
|
||||
Filename: filepath.Base(configPath),
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,46 +1,47 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||
return echo.WrapHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c echo.Context) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c echo.Context) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return next(c)
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Request().Method
|
||||
|
||||
start := time.Now()
|
||||
err := next(c)
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
// @Summary Returns available P2P nodes
|
||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||
// @Router /api/p2p [get]
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// Render index
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(schema.P2PNodesResponse{
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.P2PNodesResponse{
|
||||
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
|
||||
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
|
||||
})
|
||||
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
// @Summary Show the P2P token
|
||||
// @Success 200 {string} string "Response"
|
||||
// @Router /api/p2p/token [get]
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresSet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
vals[i] = []byte(v)
|
||||
}
|
||||
|
||||
err = store.SetCols(c.Context(), sb, input.Keys, vals)
|
||||
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresDelete)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresGet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresFind)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
||||
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
// @Summary Show the LocalAI instance information
|
||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||
// @Router /system [get]
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
availableBackends := []string{}
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
for b := range appConfig.ExternalGRPCBackends {
|
||||
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
for _, m := range loadedModels {
|
||||
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
|
||||
}
|
||||
return c.JSON(
|
||||
return c.JSON(200,
|
||||
schema.SystemInformationResponse{
|
||||
Backends: availableBackends,
|
||||
Models: sysmodels,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,22 +14,22 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.JSON(tokenResponse)
|
||||
return c.JSON(200, tokenResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
@@ -22,16 +24,16 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
|
||||
@@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,26 +16,26 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
|
||||
|
||||
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
|
||||
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,19 +7,20 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.VideoRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Video Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
src := ""
|
||||
@@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
fn, err := backend.VideoGeneration(
|
||||
height,
|
||||
@@ -201,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-videos/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
@@ -216,6 +220,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
@@ -40,10 +42,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
// Get model statuses to display in the UI the operation in progress
|
||||
processingModels, taskTypes := opcache.GetStatus()
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
@@ -54,12 +56,16 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"InstalledBackends": installedBackends,
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
// Default to HTML if Accept header is empty (browser behavior)
|
||||
// Only return JSON if explicitly requested or Content-Type is application/json
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(200, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/index", summary)
|
||||
return c.Render(200, "views/index", summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -20,68 +17,14 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
|
||||
// Fasthttp doesn't support context cancellation from the caller
|
||||
// for non-streaming requests, so we need to monitor the connection directly.
|
||||
// Monitor connection for client disconnection during non-streaming requests
|
||||
// We access the connection directly via c.Context().Conn() to monitor it
|
||||
// during ComputeChoices execution, not after the response is sent
|
||||
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
|
||||
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
|
||||
var conn net.Conn = c.Context().Conn()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
// Clear read deadline when goroutine exits
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1)
|
||||
// Use a short read deadline to periodically check if connection is closed
|
||||
// Without a deadline, Read() would block indefinitely waiting for data
|
||||
// that will never come (client is waiting for response, not sending more data)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
// Request completed or was cancelled - exit goroutine
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Set a short deadline - if connection is closed, read will fail immediately
|
||||
// If connection is open but no data, it will timeout and we check again
|
||||
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
// Check if it's a timeout (connection still open, just no data)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Timeout is expected - connection is still open, just no data to read
|
||||
// Continue the loop to check again
|
||||
continue
|
||||
}
|
||||
// Connection closed or other error - cancel the context to stop gRPC call
|
||||
log.Debug().Msgf("Calling cancellation function")
|
||||
cancelFunc()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
|
||||
// @Summary Generate a chat completions for a given prompt and model.
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
@@ -235,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
textContentToReturn = ""
|
||||
id = uuid.New().String()
|
||||
created = int(time.Now().Unix())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
|
||||
@@ -392,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
case toStream:
|
||||
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
c.Set("X-Correlation-ID", id)
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan error, 1)
|
||||
@@ -411,103 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-input.Context.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
input.Cancel()
|
||||
break LOOP
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
input.Cancel()
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(w, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
}
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
w.WriteString("data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
}
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
|
||||
return
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-input.Context.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
input.Cancel()
|
||||
break LOOP
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
input.Cancel()
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
}
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
}))
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
return nil
|
||||
|
||||
// no streaming mode
|
||||
@@ -589,9 +527,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
}
|
||||
|
||||
// NOTE: this is a workaround as fasthttp
|
||||
// context cancellation does not fire in non-streaming requests
|
||||
// handleConnectionCancellation(c, input.Cancel, input.Context)
|
||||
// Echo properly supports context cancellation via c.Request().Context()
|
||||
// No workaround needed!
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(
|
||||
input,
|
||||
@@ -628,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
@@ -26,7 +24,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.ResponseFormatMap != nil {
|
||||
@@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
@@ -130,78 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(respData))
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
errorResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
Text: "Internal error: " + err.Error(),
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(w, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
stopReason := FinishReasonStop
|
||||
errorResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
Text: "Internal error: " + err.Error(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
}
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return <-ended
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
@@ -257,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
@@ -23,20 +23,20 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
|
||||
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -21,16 +21,16 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -14,13 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Image Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Process input images (for img2img/inpainting)
|
||||
@@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
// Use the first input image as src if available, otherwise use the original src
|
||||
inputSrc := src
|
||||
@@ -215,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
@@ -234,7 +238,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
filter := c.QueryParam("filter")
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
var policy services.LooseFilePolicy
|
||||
if c.QueryBool("excludeConfigured", true) {
|
||||
excludeConfigured := c.QueryParam("excludeConfigured")
|
||||
if excludeConfigured == "" || excludeConfigured == "true" {
|
||||
policy = services.SKIP_IF_CONFIGURED
|
||||
} else {
|
||||
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
|
||||
@@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
|
||||
return c.JSON(schema.ModelsDataResponse{
|
||||
return c.JSON(200, schema.ModelsDataResponse{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -26,24 +26,27 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
ctx := c.Context()
|
||||
ctx := c.Request().Context()
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
@@ -80,7 +83,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
|
||||
ctxWithCancellation, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
//handleConnectionCancellation(c, cancel, ctxWithCancellation)
|
||||
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
@@ -147,6 +150,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
@@ -167,32 +169,50 @@ type Model interface {
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Implement ephemeral keys to allow these endpoints to be used
|
||||
func RealtimeSessions(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func Realtime(application *application.Application) fiber.Handler {
|
||||
return websocket.New(registerRealtime(application))
|
||||
func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Extract query parameters from Echo context before passing to websocket handler
|
||||
model := c.QueryParam("model")
|
||||
if model == "" {
|
||||
model = "gpt-4o"
|
||||
}
|
||||
intent := c.QueryParam("intent")
|
||||
|
||||
registerRealtime(application, model, intent)(ws)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
|
||||
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
|
||||
|
||||
model := c.Query("model", "gpt-4o")
|
||||
|
||||
intent := c.Query("intent")
|
||||
if intent != "transcription" {
|
||||
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
|
||||
}
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -24,19 +24,19 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
diarize := c.FormValue("diarize", "false") != "false"
|
||||
diarize := c.FormValue("diarize") != "false"
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
return c.JSON(http.StatusOK, tr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,20 +14,24 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if body := c.Body(); len(body) > 0 {
|
||||
body := make([]byte, 0)
|
||||
if c.Request().Body != nil {
|
||||
c.Request().Body.Read(body)
|
||||
}
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
|
||||
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Place VideoRequest into context so localai.VideoEndpoint can consume it
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
|
||||
@@ -1,48 +1,50 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Explorer(db *explorer.Database) *fiber.App {
|
||||
func Explorer(db *explorer.Database) *echo.Echo {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: false,
|
||||
// Override default error handler
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
e.Pre(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(e, db)
|
||||
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
// Log error but continue - static files might not work
|
||||
log.Error().Err(err).Msg("failed to create static filesystem")
|
||||
} else {
|
||||
e.StaticFS("/static", staticFS)
|
||||
}
|
||||
|
||||
app := fiber.New(fiberCfg)
|
||||
|
||||
app.Use(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(app, db)
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
app.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
app.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
app.Use(notFoundHandler)
|
||||
e.GET("/*", notFoundHandler)
|
||||
|
||||
return app
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -3,50 +3,108 @@ package middleware
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
|
||||
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
|
||||
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
|
||||
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
|
||||
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
|
||||
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
|
||||
// Create validator function
|
||||
validator := getApiKeyValidationFunction(applicationConfig)
|
||||
|
||||
return &v2keyauth.Config{
|
||||
CustomKeyLookup: customLookup,
|
||||
Next: getApiKeyRequiredFilterFunction(applicationConfig),
|
||||
Validator: getApiKeyValidationFunction(applicationConfig),
|
||||
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
|
||||
AuthScheme: "Bearer",
|
||||
// Create error handler
|
||||
errorHandler := getApiKeyErrorHandler(applicationConfig)
|
||||
|
||||
// Create Next function (skip middleware for certain requests)
|
||||
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
|
||||
|
||||
// Wrap it with our custom key lookup that checks multiple sources
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Skip if skipper says so
|
||||
if skipper != nil && skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try to extract key from multiple sources
|
||||
key, err := extractKeyFromMultipleSources(c)
|
||||
if err != nil {
|
||||
return errorHandler(err, c)
|
||||
}
|
||||
|
||||
// Validate the key
|
||||
valid, err := validator(key, c)
|
||||
if err != nil || !valid {
|
||||
return errorHandler(ErrMissingOrMalformedAPIKey, c)
|
||||
}
|
||||
|
||||
// Store key in context for later use
|
||||
c.Set("api_key", key)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
|
||||
return func(ctx *fiber.Ctx, err error) error {
|
||||
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
|
||||
// extractKeyFromMultipleSources checks multiple sources for the API key
|
||||
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
|
||||
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
|
||||
// Check Authorization header first
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
// Check for Bearer scheme
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer "), nil
|
||||
}
|
||||
// If no Bearer prefix, return as-is (for backward compatibility)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Check x-api-key header
|
||||
if key := c.Request().Header.Get("x-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check xi-api-key header
|
||||
if key := c.Request().Header.Get("xi-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check token cookie
|
||||
cookie, err := c.Cookie("token")
|
||||
if err == nil && cookie != nil && cookie.Value != "" {
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
|
||||
return func(err error, c echo.Context) error {
|
||||
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return ctx.Next() // if no keys are set up, any error we get here is not an error.
|
||||
return nil // if no keys are set up, any error we get here is not an error.
|
||||
}
|
||||
ctx.Set("WWW-Authenticate", "Bearer")
|
||||
c.Response().Header().Set("WWW-Authenticate", "Bearer")
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(401)
|
||||
return c.NoContent(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Check if the request content type is JSON
|
||||
contentType := string(ctx.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ctx.Status(401).JSON(schema.ErrorResponse{
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: 401,
|
||||
@@ -55,50 +113,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
|
||||
})
|
||||
}
|
||||
|
||||
return ctx.Status(401).Render("views/login", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(ctx),
|
||||
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
|
||||
"BaseURL": BaseURL(c),
|
||||
})
|
||||
}
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(500)
|
||||
return c.NoContent(http.StatusInternalServerError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
|
||||
if applicationConfig.UseSubtleKeyComparison {
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
|
||||
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if apiKey == validKey {
|
||||
if key == validKey {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
return func(c *fiber.Ctx) bool {
|
||||
if c.Method() != "GET" {
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
|
||||
return func(c echo.Context) bool {
|
||||
path := c.Request().URL.Path
|
||||
|
||||
// Always skip authentication for static files
|
||||
if strings.HasPrefix(path, "/static/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Always skip authentication for generated content
|
||||
if strings.HasPrefix(path, "/generated-audio/") ||
|
||||
strings.HasPrefix(path, "/generated-images/") ||
|
||||
strings.HasPrefix(path, "/generated-videos/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip authentication for favicon
|
||||
if path == "/favicon.svg" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle GET request exemptions if enabled
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
if c.Request().Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
|
||||
@@ -106,8 +183,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
return func(c *fiber.Ctx) bool { return false }
|
||||
}
|
||||
|
||||
48
core/http/middleware/baseurl.go
Normal file
48
core/http/middleware/baseurl.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// BaseURL returns the base URL for the given HTTP request context.
|
||||
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
|
||||
// The returned URL is guaranteed to end with `/`.
|
||||
// The method should be used in conjunction with the StripPathPrefix middleware.
|
||||
func BaseURL(c echo.Context) string {
|
||||
path := c.Path()
|
||||
origPath := c.Request().URL.Path
|
||||
|
||||
// Check if StripPathPrefix middleware stored the original path
|
||||
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
|
||||
origPath = storedPath
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Proto for scheme
|
||||
scheme := "http"
|
||||
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
} else if c.Request().TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Host for host
|
||||
host := c.Request().Host
|
||||
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
|
||||
prefixLen := len(origPath) - len(path)
|
||||
if prefixLen > 0 && prefixLen <= len(origPath) {
|
||||
pathPrefix := origPath[:prefixLen]
|
||||
if !strings.HasSuffix(pathPrefix, "/") {
|
||||
pathPrefix += "/"
|
||||
}
|
||||
return scheme + "://" + host + pathPrefix
|
||||
}
|
||||
}
|
||||
|
||||
return scheme + "://" + host + "/"
|
||||
}
|
||||
58
core/http/middleware/baseurl_test.go
Normal file
58
core/http/middleware/baseurl_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BaseURL", func() {
|
||||
Context("without prefix", func() {
|
||||
It("should return base URL without prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route - use the actual request path so routing works
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return base URL with prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route with the stripped path (after middleware removes prefix)
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
// Simulate what StripPathPrefix middleware does - store original path
|
||||
c.Set("_original_path", "/myprefix/hello/world")
|
||||
// Modify the request path to simulate prefix stripping
|
||||
c.Request().URL.Path = "/hello/world"
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make request with stripped path (middleware would have already processed it)
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
|
||||
})
|
||||
})
|
||||
})
|
||||
13
core/http/middleware/middleware_suite_test.go
Normal file
13
core/http/middleware/middleware_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware test suite")
|
||||
}
|
||||
@@ -1,470 +1,482 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = ctx.Params("model")
|
||||
|
||||
if (model == "") && ctx.Query("model") != "" {
|
||||
model = ctx.Query("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return fmt.Errorf("unable to initialize body")
|
||||
}
|
||||
if err := ctx.BodyParser(input); err != nil {
|
||||
return fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
|
||||
ctx.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Use the application context as parent to ensure cancellation on app shutdown
|
||||
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Monitor the Fiber context and cancel our context when it's canceled
|
||||
// This ensures we respect request cancellation without causing panics
|
||||
go func(fiberCtx *fasthttp.RequestCtx) {
|
||||
if fiberCtx != nil {
|
||||
<-fiberCtx.Done()
|
||||
cancel()
|
||||
}
|
||||
}(ctx.Context())
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
||||
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = c.Param("model")
|
||||
|
||||
if model == "" {
|
||||
model = c.QueryParam("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
bearer := strings.TrimPrefix(auth, "Bearer ")
|
||||
if bearer != "" && bearer != auth {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return next(c)
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
|
||||
}
|
||||
if err := c.Bind(input); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
||||
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := c.Request().Header.Get("X-Correlation-ID")
|
||||
if correlationID == "" {
|
||||
correlationID = uuid.New().String()
|
||||
}
|
||||
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
||||
|
||||
// Use the request context directly - Echo properly supports context cancellation!
|
||||
// No need for workarounds like handleConnectionCancellation
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
|
||||
// Cancel when request context is cancelled (client disconnects)
|
||||
go func() {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
cancel()
|
||||
case <-c1.Done():
|
||||
// Already cancelled
|
||||
}
|
||||
}()
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
|
||||
@@ -3,34 +3,55 @@ package middleware
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
|
||||
// StripPathPrefix returns middleware that strips a path prefix from the request path.
|
||||
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
|
||||
func StripPathPrefix() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
|
||||
if prefix != "" {
|
||||
path := c.Path()
|
||||
pos := len(prefix)
|
||||
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
|
||||
func StripPathPrefix() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
|
||||
originalPath := c.Request().URL.Path
|
||||
|
||||
if prefix[pos-1] == '/' {
|
||||
pos--
|
||||
} else {
|
||||
prefix += "/"
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if prefix != "" {
|
||||
normalizedPrefix := prefix
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
normalizedPrefix = prefix + "/"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
c.Path(path[pos:])
|
||||
break
|
||||
} else if prefix[:pos] == path {
|
||||
c.Redirect(prefix)
|
||||
return nil
|
||||
if strings.HasPrefix(originalPath, normalizedPrefix) {
|
||||
// Update the request path by stripping the normalized prefix
|
||||
newPath := originalPath[len(normalizedPrefix):]
|
||||
if newPath == "" {
|
||||
newPath = "/"
|
||||
}
|
||||
// Ensure path starts with / for proper routing
|
||||
if !strings.HasPrefix(newPath, "/") {
|
||||
newPath = "/" + newPath
|
||||
}
|
||||
// Update the URL path - Echo's router uses URL.Path for routing
|
||||
c.Request().URL.Path = newPath
|
||||
c.Request().URL.RawPath = ""
|
||||
// Update RequestURI to match the new path (needed for proper routing)
|
||||
if c.Request().URL.RawQuery != "" {
|
||||
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
|
||||
} else {
|
||||
c.Request().RequestURI = newPath
|
||||
}
|
||||
// Store original path for BaseURL utility
|
||||
c.Set("_original_path", originalPath)
|
||||
break
|
||||
} else if originalPath == prefix || originalPath == prefix+"/" {
|
||||
// Redirect to prefix with trailing slash (use 302 to match test expectations)
|
||||
return c.Redirect(302, normalizedPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,120 +2,133 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestStripPathPrefix(t *testing.T) {
|
||||
var _ = Describe("StripPathPrefix", func() {
|
||||
var app *echo.Echo
|
||||
var actualPath string
|
||||
var appInitialized bool
|
||||
|
||||
app := fiber.New()
|
||||
BeforeEach(func() {
|
||||
actualPath = ""
|
||||
if !appInitialized {
|
||||
app = echo.New()
|
||||
app.Pre(StripPathPrefix())
|
||||
|
||||
app.Use(StripPathPrefix())
|
||||
app.GET("/hello/world", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
|
||||
app.Get("/hello/world", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
app.GET("/", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
appInitialized = true
|
||||
}
|
||||
})
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
})
|
||||
Context("without prefix", func() {
|
||||
It("should not modify path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
path string
|
||||
prefixHeader []string
|
||||
expectStatus int
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
name: "without prefix and header",
|
||||
path: "/hello/world",
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "without prefix and headers on root path",
|
||||
path: "/",
|
||||
expectStatus: 200,
|
||||
expectPath: "/",
|
||||
},
|
||||
{
|
||||
name: "without prefix but header",
|
||||
path: "/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix but non-matching header",
|
||||
path: "/prefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "with prefix and matching header",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 1st header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 2nd header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and header not ending with slash",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and non-matching header not ending with slash",
|
||||
path: "/myprefix-suffix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "redirect when prefix does not end with a slash",
|
||||
path: "/myprefix",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 302,
|
||||
expectPath: "/myprefix/",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actualPath = ""
|
||||
req := httptest.NewRequest("GET", tc.path, nil)
|
||||
if tc.prefixHeader != nil {
|
||||
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
|
||||
}
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
|
||||
|
||||
if tc.expectStatus == 200 {
|
||||
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
|
||||
} else if tc.expectStatus == 302 {
|
||||
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
|
||||
}
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
It("should not modify root path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should not modify path when header does not match", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return 404 when prefix does not match header", func() {
|
||||
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should strip matching prefix from path", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the first header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the second header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when header does not end with slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should return 404 when prefix does not match header without trailing slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should redirect when prefix does not end with a slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(302), "response status code")
|
||||
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
@@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
_ = app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = app.Shutdown(ctx)
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
@@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9091")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
|
||||
// Log error if needed
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
fiberhtml "github.com/gofiber/template/html/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/russross/blackfriday"
|
||||
)
|
||||
@@ -18,26 +20,67 @@ import (
|
||||
//go:embed views/*
|
||||
var viewsfs embed.FS
|
||||
|
||||
func notFoundHandler(c *fiber.Ctx) error {
|
||||
// TemplateRenderer is a custom template renderer for Echo
|
||||
type TemplateRenderer struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// Render renders a template document
|
||||
func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
return t.templates.ExecuteTemplate(w, name, data)
|
||||
}
|
||||
|
||||
func notFoundHandler(c echo.Context) error {
|
||||
// Check if the request accepts JSON
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
|
||||
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
|
||||
})
|
||||
} else {
|
||||
// The client expects an HTML response
|
||||
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func renderEngine() *fiberhtml.Engine {
|
||||
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
|
||||
engine.AddFuncMap(sprig.FuncMap())
|
||||
engine.AddFunc("MDToHTML", markDowner)
|
||||
return engine
|
||||
func renderEngine() *TemplateRenderer {
|
||||
// Parse all templates from embedded filesystem
|
||||
tmpl := template.New("").Funcs(sprig.FuncMap())
|
||||
tmpl = tmpl.Funcs(template.FuncMap{
|
||||
"MDToHTML": markDowner,
|
||||
})
|
||||
|
||||
// Recursively walk through embedded filesystem and parse all HTML templates
|
||||
err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".html") {
|
||||
data, err := viewsfs.ReadFile(path)
|
||||
if err == nil {
|
||||
// Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
|
||||
templateName := strings.TrimSuffix(path, ".html")
|
||||
_, err := tmpl.New(templateName).Parse(string(data))
|
||||
if err != nil {
|
||||
// If parsing fails, try parsing without explicit name (for templates with {{define}})
|
||||
tmpl.Parse(string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but continue - templates might still work
|
||||
fmt.Printf("Error walking views directory: %v\n", err)
|
||||
}
|
||||
|
||||
return &TemplateRenderer{
|
||||
templates: tmpl,
|
||||
}
|
||||
}
|
||||
|
||||
func markDowner(args ...interface{}) template.HTML {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -9,21 +9,23 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
func RegisterElevenLabsRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
|
||||
elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
|
||||
|
||||
app.Post("/v1/sound-generation",
|
||||
soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/sound-generation",
|
||||
soundGenHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
|
||||
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
coreExplorer "github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/explorer"
|
||||
)
|
||||
|
||||
func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) {
|
||||
app.Get("/", explorer.Dashboard())
|
||||
app.Post("/network/add", explorer.AddNetwork(db))
|
||||
app.Get("/networks", explorer.ShowNetworks(db))
|
||||
func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) {
|
||||
app.GET("/", explorer.Dashboard())
|
||||
app.POST("/network/add", explorer.AddNetwork(db))
|
||||
app.GET("/networks", explorer.ShowNetworks(db))
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package routes
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func HealthRoutes(app *fiber.App) {
|
||||
func HealthRoutes(app *echo.Echo) {
|
||||
// Service health checks
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
ok := func(c echo.Context) error {
|
||||
return c.NoContent(200)
|
||||
}
|
||||
|
||||
app.Get("/healthz", ok)
|
||||
app.Get("/readyz", ok)
|
||||
app.GET("/healthz", ok)
|
||||
app.GET("/readyz", ok)
|
||||
}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/jina"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterJINARoutes(app *fiber.App,
|
||||
func RegisterJINARoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// POST endpoint to mimic the reranking
|
||||
app.Post("/v1/rerank",
|
||||
rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/rerank",
|
||||
rerankHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
|
||||
jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }))
|
||||
}
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/swagger"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
echoswagger "github.com/swaggo/echo-swagger"
|
||||
)
|
||||
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
requestExtractor *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
@@ -21,111 +20,117 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
galleryService *services.GalleryService,
|
||||
opcache *services.OpCache) {
|
||||
|
||||
router.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||
router.GET("/swagger/*", echoswagger.WrapHandler) // default
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
// Import model page
|
||||
router.Get("/import-model", func(c *fiber.Ctx) error {
|
||||
return c.Render("views/model-editor", fiber.Map{
|
||||
router.GET("/import-model", func(c echo.Context) error {
|
||||
return c.Render(200, "views/model-editor", map[string]interface{}{
|
||||
"Title": "LocalAI - Import Model",
|
||||
"BaseURL": httpUtils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
})
|
||||
})
|
||||
|
||||
// Edit model page
|
||||
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
|
||||
router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(
|
||||
appConfig.BackendGalleries,
|
||||
appConfig.SystemState,
|
||||
galleryService)
|
||||
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
// Custom model import endpoint
|
||||
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
|
||||
router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig))
|
||||
|
||||
// URI model import endpoint
|
||||
router.Post("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
|
||||
router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
|
||||
|
||||
// Custom model edit endpoint
|
||||
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
|
||||
router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
|
||||
|
||||
// Reload models endpoint
|
||||
router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
|
||||
router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
|
||||
}
|
||||
|
||||
router.Post("/v1/detection",
|
||||
detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig)
|
||||
router.POST("/v1/detection",
|
||||
detectionHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
|
||||
localai.DetectionEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }))
|
||||
|
||||
router.Post("/tts",
|
||||
ttsHandler := localai.TTSEndpoint(cl, ml, appConfig)
|
||||
router.POST("/tts",
|
||||
ttsHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }))
|
||||
|
||||
vadChain := []fiber.Handler{
|
||||
vadHandler := localai.VADEndpoint(cl, ml, appConfig)
|
||||
router.POST("/vad",
|
||||
vadHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
|
||||
localai.VADEndpoint(cl, ml, appConfig),
|
||||
}
|
||||
router.Post("/vad", vadChain...)
|
||||
router.Post("/v1/vad", vadChain...)
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
|
||||
router.POST("/v1/vad",
|
||||
vadHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
|
||||
|
||||
// Stores
|
||||
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
|
||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
|
||||
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
|
||||
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
|
||||
router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
|
||||
router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
|
||||
router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
|
||||
router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
|
||||
|
||||
if !appConfig.DisableMetrics {
|
||||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
router.GET("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
}
|
||||
|
||||
router.Post("/video",
|
||||
videoHandler := localai.VideoEndpoint(cl, ml, appConfig)
|
||||
router.POST("/video",
|
||||
videoHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
|
||||
localai.VideoEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }))
|
||||
|
||||
// Backend Statistics Module
|
||||
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
|
||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
|
||||
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
|
||||
// p2p
|
||||
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
router.GET("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
|
||||
router.Get("/version", func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
router.GET("/version", func(c echo.Context) error {
|
||||
return c.JSON(200, struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||
router.GET("/system", localai.SystemInformations(ml, appConfig))
|
||||
|
||||
// misc
|
||||
router.Post("/v1/tokenize",
|
||||
tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig)
|
||||
router.POST("/v1/tokenize",
|
||||
tokenizeHandler,
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
|
||||
localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }))
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
@@ -10,118 +10,172 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
func RegisterOpenAIRoutes(app *fiber.App,
|
||||
func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// realtime
|
||||
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
|
||||
app.Get("/v1/realtime", openai.Realtime(application))
|
||||
app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
|
||||
app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
|
||||
app.GET("/v1/realtime", openai.Realtime(application))
|
||||
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
|
||||
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
|
||||
|
||||
// chat
|
||||
chatChain := []fiber.Handler{
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/chat/completions", chatChain...)
|
||||
app.Post("/chat/completions", chatChain...)
|
||||
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||
app.POST("/chat/completions", chatHandler, chatMiddleware...)
|
||||
|
||||
// edit
|
||||
editChain := []fiber.Handler{
|
||||
editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
editMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/edits", editChain...)
|
||||
app.Post("/edits", editChain...)
|
||||
app.POST("/v1/edits", editHandler, editMiddleware...)
|
||||
app.POST("/edits", editHandler, editMiddleware...)
|
||||
|
||||
// completion
|
||||
completionChain := []fiber.Handler{
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/completions", completionChain...)
|
||||
app.Post("/completions", completionChain...)
|
||||
app.Post("/v1/engines/:model/completions", completionChain...)
|
||||
app.POST("/v1/completions", completionHandler, completionMiddleware...)
|
||||
app.POST("/completions", completionHandler, completionMiddleware...)
|
||||
app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...)
|
||||
|
||||
// MCPcompletion
|
||||
mcpCompletionChain := []fiber.Handler{
|
||||
mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
mcpCompletionMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/mcp/v1/chat/completions", mcpCompletionChain...)
|
||||
app.Post("/mcp/chat/completions", mcpCompletionChain...)
|
||||
app.POST("/mcp/v1/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
|
||||
app.POST("/mcp/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
|
||||
|
||||
// embeddings
|
||||
embeddingChain := []fiber.Handler{
|
||||
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
embeddingMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.Post("/v1/embeddings", embeddingChain...)
|
||||
app.Post("/embeddings", embeddingChain...)
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
|
||||
app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
|
||||
audioChain := []fiber.Handler{
|
||||
audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
audioMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", audioChain...)
|
||||
app.Post("/audio/transcriptions", audioChain...)
|
||||
app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
|
||||
audioSpeechChain := []fiber.Handler{
|
||||
audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
audioSpeechMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
}
|
||||
|
||||
app.Post("/v1/audio/speech",
|
||||
audioSpeechChain...)
|
||||
app.Post("/audio/speech", audioSpeechChain...)
|
||||
app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
|
||||
app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
|
||||
|
||||
// images
|
||||
imageChain := []fiber.Handler{
|
||||
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
imageMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.Post("/v1/images/generations",
|
||||
imageChain...)
|
||||
app.Post("/images/generations", imageChain...)
|
||||
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
|
||||
app.POST("/images/generations", imageHandler, imageMiddleware...)
|
||||
|
||||
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
|
||||
videoChain := []fiber.Handler{
|
||||
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
videoMiddleware := []echo.MiddlewareFunc{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// OpenAI-style create video endpoint
|
||||
app.Post("/v1/videos", videoChain...)
|
||||
app.Post("/v1/videos/generations", videoChain...)
|
||||
app.Post("/videos", videoChain...)
|
||||
app.POST("/v1/videos", videoHandler, videoMiddleware...)
|
||||
app.POST("/v1/videos/generations", videoHandler, videoMiddleware...)
|
||||
app.POST("/videos", videoHandler, videoMiddleware...)
|
||||
|
||||
// List models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func RegisterUIRoutes(app *fiber.App,
|
||||
func RegisterUIRoutes(app *echo.Echo,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -21,13 +20,13 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
// keeps the state of ops that are started from the UI
|
||||
var processingOps = services.NewOpCache(galleryService)
|
||||
|
||||
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
|
||||
app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
|
||||
|
||||
// P2P
|
||||
app.Get("/p2p", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/p2p/", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - P2P dashboard",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
//"Nodes": p2p.GetAvailableNodes(""),
|
||||
//"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID),
|
||||
@@ -37,7 +36,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/p2p", summary)
|
||||
return c.Render(200, "views/p2p", summary)
|
||||
})
|
||||
|
||||
// Note: P2P UI fragment routes (/p2p/ui/*) were removed
|
||||
@@ -50,17 +49,17 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps)
|
||||
}
|
||||
|
||||
app.Get("/talk/", func(c *fiber.Ctx) error {
|
||||
app.GET("/talk/", func(c echo.Context) error {
|
||||
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
|
||||
if len(modelConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Talk",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelConfigs[0],
|
||||
|
||||
@@ -68,16 +67,16 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/talk", summary)
|
||||
return c.Render(200, "views/talk", summary)
|
||||
})
|
||||
|
||||
app.Get("/chat/", func(c *fiber.Ctx) error {
|
||||
app.GET("/chat/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
modelThatCanBeUsed := ""
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
@@ -104,9 +103,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
@@ -116,16 +115,16 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/chat", summary)
|
||||
return c.Render(200, "views/chat", summary)
|
||||
})
|
||||
|
||||
// Show the Chat page
|
||||
app.Get("/chat/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/chat/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
modelName := c.Params("model")
|
||||
modelName := c.Param("model")
|
||||
var modelContextSize *int
|
||||
|
||||
for _, m := range modelConfigs {
|
||||
@@ -139,9 +138,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Chat with " + modelName,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
@@ -151,33 +150,33 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/chat", summary)
|
||||
return c.Render(200, "views/chat", summary)
|
||||
})
|
||||
|
||||
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/text2image/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Generate images with " + c.Param("model"),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Model": c.Param("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/text2image", summary)
|
||||
return c.Render(200, "views/text2image", summary)
|
||||
})
|
||||
|
||||
app.Get("/text2image/", func(c *fiber.Ctx) error {
|
||||
app.GET("/text2image/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
modelThatCanBeUsed := ""
|
||||
@@ -191,9 +190,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
@@ -201,33 +200,33 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/text2image", summary)
|
||||
return c.Render(200, "views/text2image", summary)
|
||||
})
|
||||
|
||||
app.Get("/tts/:model", func(c *fiber.Ctx) error {
|
||||
app.GET("/tts/:model", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Generate images with " + c.Param("model"),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Model": c.Param("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/tts", summary)
|
||||
return c.Render(200, "views/tts", summary)
|
||||
})
|
||||
|
||||
app.Get("/tts/", func(c *fiber.Ctx) error {
|
||||
app.GET("/tts/", func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
return c.Redirect(302, middleware.BaseURL(c))
|
||||
}
|
||||
|
||||
modelThatCanBeUsed := ""
|
||||
@@ -240,9 +239,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
break
|
||||
}
|
||||
}
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
@@ -250,6 +249,6 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
// Render index
|
||||
return c.Render("views/tts", summary)
|
||||
return c.Render(200, "views/tts", summary)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
@@ -19,13 +20,13 @@ import (
|
||||
)
|
||||
|
||||
// RegisterUIAPIRoutes registers JSON API routes for the web UI
|
||||
func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
// Operations API - Get all current operations (models + backends)
|
||||
app.Get("/api/operations", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/operations", func(c echo.Context) error {
|
||||
processingData, taskTypes := opcache.GetStatus()
|
||||
|
||||
operations := []fiber.Map{}
|
||||
operations := []map[string]interface{}{}
|
||||
for galleryID, jobID := range processingData {
|
||||
taskType := "installation"
|
||||
if tt, ok := taskTypes[galleryID]; ok {
|
||||
@@ -88,7 +89,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
operations = append(operations, fiber.Map{
|
||||
operations = append(operations, map[string]interface{}{
|
||||
"id": galleryID,
|
||||
"name": displayName,
|
||||
"fullName": galleryID,
|
||||
@@ -118,20 +119,20 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
return operations[i]["id"].(string) < operations[j]["id"].(string)
|
||||
})
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"operations": operations,
|
||||
})
|
||||
})
|
||||
|
||||
// Cancel operation endpoint
|
||||
app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error {
|
||||
jobID := strings.Clone(c.Params("jobID"))
|
||||
app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error {
|
||||
jobID := c.Param("jobID")
|
||||
log.Debug().Msgf("API request to cancel operation: %s", jobID)
|
||||
|
||||
err := galleryService.CancelOperation(jobID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -139,22 +140,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
// Clean up opcache for cancelled operation
|
||||
opcache.DeleteUUID(jobID)
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Operation cancelled",
|
||||
})
|
||||
})
|
||||
|
||||
// Model Gallery APIs
|
||||
app.Get("/api/models", func(c *fiber.Ctx) error {
|
||||
term := c.Query("term")
|
||||
page := c.Query("page", "1")
|
||||
items := c.Query("items", "21")
|
||||
app.GET("/api/models", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
}
|
||||
items := c.QueryParam("items")
|
||||
if items == "" {
|
||||
items = "21"
|
||||
}
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -197,7 +204,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
|
||||
// Convert models to JSON-friendly format and deduplicate by ID
|
||||
modelsJSON := make([]fiber.Map, 0, len(models))
|
||||
modelsJSON := make([]map[string]interface{}, 0, len(models))
|
||||
seenIDs := make(map[string]bool)
|
||||
|
||||
for _, m := range models {
|
||||
@@ -223,7 +230,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
|
||||
|
||||
modelsJSON = append(modelsJSON, fiber.Map{
|
||||
modelsJSON = append(modelsJSON, map[string]interface{}{
|
||||
"id": modelID,
|
||||
"name": m.Name,
|
||||
"description": m.Description,
|
||||
@@ -250,7 +257,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
nextPage = totalPages
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"models": modelsJSON,
|
||||
"repositories": appConfig.Galleries,
|
||||
"allTags": tags,
|
||||
@@ -264,12 +271,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/install/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/install/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID (e.g., "localai%40model" -> "localai@model")
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -277,7 +284,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -300,18 +307,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Installation started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/delete/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/delete/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -324,7 +331,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -350,18 +357,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
cl.RemoveModelConfig(galleryName)
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Deletion started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/models/config/:id", func(c *fiber.Ctx) error {
|
||||
galleryID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/models/config/:id", func(c echo.Context) error {
|
||||
galleryID := c.Param("id")
|
||||
// URL decode the gallery ID
|
||||
galleryID, err := url.QueryUnescape(galleryID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid model ID",
|
||||
})
|
||||
}
|
||||
@@ -369,44 +376,44 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
model := gallery.FindGalleryElement(models, galleryID)
|
||||
if model == nil {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{
|
||||
"error": "model not found",
|
||||
})
|
||||
}
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"message": "Configuration file saved",
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/models/job/:uid", func(c *fiber.Ctx) error {
|
||||
jobUID := strings.Clone(c.Params("uid"))
|
||||
app.GET("/api/models/job/:uid", func(c echo.Context) error {
|
||||
jobUID := c.Param("uid")
|
||||
|
||||
status := galleryService.GetStatus(jobUID)
|
||||
if status == nil {
|
||||
// Job is queued but hasn't started processing yet
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"progress": 0,
|
||||
"message": "Operation queued",
|
||||
"galleryElementName": "",
|
||||
@@ -416,7 +423,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
response := fiber.Map{
|
||||
response := map[string]interface{}{
|
||||
"progress": status.Progress,
|
||||
"message": status.Message,
|
||||
"galleryElementName": status.GalleryElementName,
|
||||
@@ -434,19 +441,25 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
response["completed"] = true
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
})
|
||||
|
||||
// Backend Gallery APIs
|
||||
app.Get("/api/backends", func(c *fiber.Ctx) error {
|
||||
term := c.Query("term")
|
||||
page := c.Query("page", "1")
|
||||
items := c.Query("items", "21")
|
||||
app.GET("/api/backends", func(c echo.Context) error {
|
||||
term := c.QueryParam("term")
|
||||
page := c.QueryParam("page")
|
||||
if page == "" {
|
||||
page = "1"
|
||||
}
|
||||
items := c.QueryParam("items")
|
||||
if items == "" {
|
||||
items = "21"
|
||||
}
|
||||
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list backends from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -489,7 +502,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
|
||||
// Convert backends to JSON-friendly format and deduplicate by ID
|
||||
backendsJSON := make([]fiber.Map, 0, len(backends))
|
||||
backendsJSON := make([]map[string]interface{}, 0, len(backends))
|
||||
seenBackendIDs := make(map[string]bool)
|
||||
|
||||
for _, b := range backends {
|
||||
@@ -513,7 +526,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
backendsJSON = append(backendsJSON, fiber.Map{
|
||||
backendsJSON = append(backendsJSON, map[string]interface{}{
|
||||
"id": backendID,
|
||||
"name": b.Name,
|
||||
"description": b.Description,
|
||||
@@ -538,7 +551,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
nextPage = totalPages
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"backends": backendsJSON,
|
||||
"repositories": appConfig.BackendGalleries,
|
||||
"allTags": tags,
|
||||
@@ -552,12 +565,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/backends/install/:id", func(c *fiber.Ctx) error {
|
||||
backendID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/backends/install/:id", func(c echo.Context) error {
|
||||
backendID := c.Param("id")
|
||||
// URL decode the backend ID
|
||||
backendID, err := url.QueryUnescape(backendID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend ID",
|
||||
})
|
||||
}
|
||||
@@ -565,7 +578,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -587,18 +600,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Backend installation started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Post("/api/backends/delete/:id", func(c *fiber.Ctx) error {
|
||||
backendID := strings.Clone(c.Params("id"))
|
||||
app.POST("/api/backends/delete/:id", func(c echo.Context) error {
|
||||
backendID := c.Param("id")
|
||||
// URL decode the backend ID
|
||||
backendID, err := url.QueryUnescape(backendID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend ID",
|
||||
})
|
||||
}
|
||||
@@ -611,7 +624,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -635,19 +648,19 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
galleryService.BackendGalleryChannel <- op
|
||||
}()
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"jobID": uid,
|
||||
"message": "Backend deletion started",
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/backends/job/:uid", func(c *fiber.Ctx) error {
|
||||
jobUID := strings.Clone(c.Params("uid"))
|
||||
app.GET("/api/backends/job/:uid", func(c echo.Context) error {
|
||||
jobUID := c.Param("uid")
|
||||
|
||||
status := galleryService.GetStatus(jobUID)
|
||||
if status == nil {
|
||||
// Job is queued but hasn't started processing yet
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"progress": 0,
|
||||
"message": "Operation queued",
|
||||
"galleryElementName": "",
|
||||
@@ -657,7 +670,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
response := fiber.Map{
|
||||
response := map[string]interface{}{
|
||||
"progress": status.Progress,
|
||||
"message": status.Message,
|
||||
"galleryElementName": status.GalleryElementName,
|
||||
@@ -675,16 +688,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
response["completed"] = true
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
})
|
||||
|
||||
// System Backend Deletion API (for installed backends on index page)
|
||||
app.Post("/api/backends/system/delete/:name", func(c *fiber.Ctx) error {
|
||||
backendName := strings.Clone(c.Params("name"))
|
||||
app.POST("/api/backends/system/delete/:name", func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
// URL decode the backend name
|
||||
backendName, err := url.QueryUnescape(backendName)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||
"error": "invalid backend name",
|
||||
})
|
||||
}
|
||||
@@ -693,24 +706,24 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
// Use the gallery package to delete the backend
|
||||
if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to delete backend: %s", backendName)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Backend deleted successfully",
|
||||
})
|
||||
})
|
||||
|
||||
// P2P APIs
|
||||
app.Get("/api/p2p/workers", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/workers", func(c echo.Context) error {
|
||||
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
|
||||
nodesJSON := make([]fiber.Map, 0, len(nodes))
|
||||
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
nodesJSON = append(nodesJSON, fiber.Map{
|
||||
nodesJSON = append(nodesJSON, map[string]interface{}{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
@@ -720,17 +733,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"nodes": nodesJSON,
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/p2p/federation", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/federation", func(c echo.Context) error {
|
||||
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
|
||||
|
||||
nodesJSON := make([]fiber.Map, 0, len(nodes))
|
||||
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
nodesJSON = append(nodesJSON, fiber.Map{
|
||||
nodesJSON = append(nodesJSON, map[string]interface{}{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
@@ -740,12 +753,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"nodes": nodesJSON,
|
||||
})
|
||||
})
|
||||
|
||||
app.Get("/api/p2p/stats", func(c *fiber.Ctx) error {
|
||||
app.GET("/api/p2p/stats", func(c echo.Context) error {
|
||||
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
|
||||
|
||||
@@ -763,12 +776,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"workers": fiber.Map{
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"workers": map[string]interface{}{
|
||||
"online": workersOnline,
|
||||
"total": len(workerNodes),
|
||||
},
|
||||
"federated": fiber.Map{
|
||||
"federated": map[string]interface{}{
|
||||
"online": federatedOnline,
|
||||
"total": len(federatedNodes),
|
||||
},
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
// Show the Backends page (all backends are loaded client-side via Alpine.js)
|
||||
app.Get("/browse/backends", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/browse/backends", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Backends",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"Repositories": appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
// Render index - backends are now loaded via Alpine.js from /api/backends
|
||||
return c.Render("views/backends", summary)
|
||||
return c.Render(200, "views/backends", summary)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
app.Get("/browse", func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
app.GET("/browse/", func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI - Models",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"Repositories": appConfig.Galleries,
|
||||
}
|
||||
|
||||
// Render index - models are now loaded via Alpine.js from /api/models
|
||||
return c.Render("views/models", summary)
|
||||
return c.Render(200, "views/models", summary)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// BaseURL returns the base URL for the given HTTP request context.
|
||||
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
|
||||
// The returned URL is guaranteed to end with `/`.
|
||||
// The method should be used in conjunction with the StripPathPrefix middleware.
|
||||
func BaseURL(c *fiber.Ctx) string {
|
||||
path := c.Path()
|
||||
origPath := c.OriginalURL()
|
||||
|
||||
if path != origPath && strings.HasSuffix(origPath, path) {
|
||||
pathPrefix := origPath[:len(origPath)-len(path)+1]
|
||||
|
||||
return c.BaseURL() + pathPrefix
|
||||
}
|
||||
|
||||
return c.BaseURL() + "/"
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBaseURL(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectURL string
|
||||
}{
|
||||
{
|
||||
name: "without prefix",
|
||||
prefix: "/",
|
||||
expectURL: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "with prefix",
|
||||
prefix: "/myprefix/",
|
||||
expectURL: "http://example.com/myprefix/",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
actualURL := ""
|
||||
|
||||
app.Get(tc.prefix+"hello/world", func(c *fiber.Ctx) error {
|
||||
if tc.prefix != "/" {
|
||||
c.Path("/hello/world")
|
||||
}
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil)
|
||||
resp, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode, "response status code")
|
||||
require.Equal(t, tc.expectURL, actualURL, "base URL")
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user