feat: migrate to echo and enable cancellation of non-streaming requests (#7270)

* WIP: migrate to echo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-14 22:57:53 +01:00
committed by GitHub
parent 03e9f4b140
commit 1cdcaf0152
59 changed files with 2350 additions and 2011 deletions

View File

@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)
signals.RegisterGracefulTerminationHandler(func() {
if err := appHTTP.Shutdown(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := appHTTP.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("error during shutdown")
}
})
return appHTTP.Listen(e.Address)
return appHTTP.Start(e.Address)
}

View File

@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}
})
return appHTTP.Listen(r.Address)
return appHTTP.Start(r.Address)
}

View File

@@ -4,30 +4,23 @@ import (
"embed"
"errors"
"fmt"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/websocket/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log"
)
@@ -49,85 +42,85 @@ var embedDirStatic embed.FS
// @in header
// @name Authorization
func API(application *application.Application) (*fiber.App, error) {
func API(application *application.Application) (*echo.Echo, error) {
e := echo.New()
fiberCfg := fiber.Config{
Views: renderEngine(),
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
// Set body limit
if application.ApplicationConfig().UploadLimitMB > 0 {
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
}
// Set error handler
if !application.ApplicationConfig().OpaqueErrors {
// Normally, return errors as JSON responses
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
// Handle 404 errors with HTML rendering when appropriate
if code == http.StatusNotFound {
notFoundHandler(c)
return
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
c.JSON(code, schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
})
}
} else {
// If OpaqueErrors are required, replace everything with a blank 500.
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
return ctx.Status(500).SendString("")
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
c.NoContent(code)
}
}
router := fiber.New(fiberCfg)
// Set renderer
e.Renderer = renderEngine()
router.Use(middleware.StripPathPrefix())
// Hide banner
e.HideBanner = true
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
e.Pre(httpMiddleware.StripPathPrefix())
if application.ApplicationConfig().MachineTag != "" {
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return c.Next()
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return next(c)
}
})
}
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
// Returns true if the client requested upgrade to the WebSocket protocol
return c.Next()
// Custom logger middleware using zerolog
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
res := c.Response()
start := log.Logger.Info()
err := next(c)
start.
Str("method", req.Method).
Str("path", req.URL.Path).
Int("status", res.Status).
Msg("HTTP request")
return err
}
return nil
})
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
// Recover middleware
if !application.ApplicationConfig().Debug {
router.Use(recover.New())
e.Use(middleware.Recover())
}
// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
@@ -135,34 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
}
if metricsService != nil {
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
router.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
e.Server.RegisterOnShutdown(func() {
metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router)
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil || kaConfig == nil {
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(e)
// Get key auth middleware
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
httpFS := http.FS(embedDirStatic)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
router.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
}
e.StaticFS("/static", staticFS)
// Generated content directories
if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
@@ -173,51 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750)
router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
e.Static("/generated-audio", audioPath)
e.Static("/generated-images", imagePath)
e.Static("/generated-videos", videoPath)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
router.Use(v2keyauth.New(*kaConfig))
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
e.Use(keyAuthMiddleware)
// CORS middleware
if application.ApplicationConfig().CORS {
var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig().CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
corsConfig := middleware.CORSConfig{}
if application.ApplicationConfig().CORSAllowOrigins != "" {
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
}
router.Use(c)
e.Use(middleware.CORSWithConfig(corsConfig))
}
// CSRF middleware
if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
router.Use(csrf.New())
e.Use(middleware.CSRF())
}
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
}
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
return router, nil
// Log startup message
e.Server.RegisterOnShutdown(func() {
log.Info().Msg("LocalAI API server shutting down")
})
return e, nil
}

View File

@@ -10,13 +10,14 @@ import (
"os"
"path/filepath"
"runtime"
"time"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
@@ -25,6 +26,7 @@ import (
"gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/rs/zerolog/log"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
var _ = Describe("API test", func() {
var app *fiber.App
var app *echo.Echo
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
AfterEach(func(sc SpecContext) {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -547,7 +555,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -566,7 +578,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -755,7 +769,11 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -773,7 +791,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
})
@@ -1006,7 +1026,11 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -1022,7 +1046,9 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
err := app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred())
}
})

View File

@@ -1,7 +1,9 @@
package elevenlabs
import (
"github.com/gofiber/fiber/v2"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -15,17 +17,17 @@ import (
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
if err != nil {
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -1,13 +1,14 @@
package elevenlabs
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -17,19 +18,19 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
voiceID := c.Params("voice-id")
voiceID := c.Param("voice-id")
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
if err != nil {
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -2,28 +2,32 @@ package explorer
import (
"encoding/base64"
"net/http"
"sort"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/internal"
)
func Dashboard() func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
summary := fiber.Map{
func Dashboard() echo.HandlerFunc {
return func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
return c.JSON(http.StatusOK, summary)
} else {
// Render index
return c.Render("views/explorer", summary)
return c.Render(http.StatusOK, "views/explorer", summary)
}
}
}
@@ -39,8 +43,8 @@ type Network struct {
Token string `json:"token"`
}
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
results := []Network{}
for _, token := range db.TokenList() {
networkData, exists := db.Get(token) // get the token data
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return len(results[i].Clusters) > len(results[j].Clusters)
})
return c.JSON(results)
return c.JSON(http.StatusOK, results)
}
}
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
request := new(AddNetworkRequest)
if err := c.BodyParser(request); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
if err := c.Bind(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
}
if request.Token == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
}
if request.Name == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
}
if request.Description == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
}
// TODO: check if token is valid, otherwise reject
// try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
}
if _, exists := db.Get(request.Token); exists {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
}
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
}
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
}
}

View File

@@ -1,11 +1,12 @@
package jina
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@@ -17,17 +18,17 @@ import (
// @Param request body schema.JINARerankRequest true "query params"
// @Success 200 {object} schema.JINARerankResponse "Response"
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
@@ -58,6 +59,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
return c.Status(fiber.StatusOK).JSON(response)
return c.JSON(http.StatusOK, response)
}
}

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /backends/jobs/{uuid} [get]
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
return c.JSON(200, status)
}
}
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /backends/jobs [get]
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.backendApplier.GetAllStatus())
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.backendApplier.GetAllStatus())
}
}
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
// @Param request body GalleryBackend true "query params"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/apply [post]
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
input := new(GalleryBackend)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -82,7 +82,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
Galleries: mgs.galleries,
}
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
@@ -91,9 +91,9 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
// @Param name path string true "Backend name"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/delete/{name} [post]
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendName := c.Params("name")
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
backendName := c.Param("name")
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
Delete: true,
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
return err
}
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
// @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
return err
}
return c.JSON(backends.GetAll())
return c.JSON(200, backends.GetAll())
}
}
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
// @Success 200 {object} []config.Gallery "Response"
// @Router /backends/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
return c.Blob(200, "application/json", dat)
}
}
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
// @Summary List all available Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends/available [get]
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
if err != nil {
return err
}
return c.JSON(backends)
return c.JSON(200, backends)
}
}

View File

@@ -1,45 +1,45 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}
package localai
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(200, resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,17 +16,17 @@ import (
// @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post]
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
}
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -2,11 +2,13 @@ package localai
import (
"fmt"
"io"
"net/http"
"os"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/utils"
@@ -14,15 +16,15 @@ import (
)
// GetEditModelPage renders the edit model page with current configuration
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
modelConfigFile := modelConfig.GetModelConfigFile()
@@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration file not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
configData, err := os.ReadFile(modelConfigFile)
if err != nil {
@@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Failed to read configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Render the edit page with the current configuration
@@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Version: internal.PrintableVersion(),
}
return c.Render("views/model-editor", templateData)
return c.Render(http.StatusOK, "views/model-editor", templateData)
}
}
// EditModelEndpoint handles updating existing model configurations
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Existing model configuration not found",
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
// Get the raw body
body := c.Body()
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Check content to see if it's a valid model config
@@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Validate required fields
@@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Validate the configuration
@@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Error: "Validation failed",
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Load the existing configuration
@@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Model configuration not trusted: " + err.Error(),
}
return c.Status(404).JSON(response)
return c.JSON(http.StatusNotFound, response)
}
// Write new content to file
@@ -151,7 +160,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Reload configurations
@@ -160,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the model
@@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
@@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Filename: configPath,
Config: req,
}
return c.JSON(response)
return c.JSON(200, response)
}
}
// ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the models
@@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: false,
Error: "Failed to preload models: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
@@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: true,
Message: "Model configurations reloaded successfully",
}
return c.Status(fiber.StatusOK).JSON(response)
return c.JSON(http.StatusOK, response)
}
}

View File

@@ -2,12 +2,14 @@ package localai_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/pkg/system"
@@ -15,6 +17,14 @@ import (
. "github.com/onsi/gomega"
)
// testRenderer is a simple renderer for tests that returns JSON
type testRenderer struct{}
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
// For tests, just return the data as JSON
return json.NewEncoder(w).Encode(data)
}
var _ = Describe("Edit Model test", func() {
var tempDir string
@@ -40,33 +50,35 @@ var _ = Describe("Edit Model test", func() {
//modelLoader := model.NewModelLoader(systemState, true)
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
// Define Fiber app.
app := fiber.New()
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
// Define Echo app and register all routes upfront
app := echo.New()
// Set up a simple renderer for the test
app.Renderer = &testRenderer{}
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
req := httptest.NewRequest("PUT", "/import-model", requestBody)
resp, err := app.Test(req, 5000)
Expect(err).ToNot(HaveOccurred())
req := httptest.NewRequest("POST", "/import-model", requestBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
body, err := io.ReadAll(rec.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
Expect(rec.Code).To(Equal(http.StatusOK))
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
rec = httptest.NewRecorder()
app.ServeHTTP(rec, req)
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
resp, _ = app.Test(req, 1)
body, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
body, err = io.ReadAll(rec.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
// The response contains the model configuration with backend field
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
Expect(rec.Code).To(Equal(http.StatusOK))
})
})
})

View File

@@ -1,160 +1,160 @@
package localai
import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Send(dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
package localai
import (
"encoding/json"
"fmt"
"github.com/labstack/echo/v4"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(200, status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Blob(200, "application/json", dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Blob(200, "application/json", dat)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -21,17 +21,17 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.TokenMetricsRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
if err != nil {
return err
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -3,16 +3,18 @@ package localai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/gallery/importers"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
@@ -21,12 +23,12 @@ import (
)
// ImportModelURIEndpoint handles creating new model configurations from a URI
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) fiber.Handler {
return func(c *fiber.Ctx) error {
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.ImportModelRequest)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -61,7 +63,7 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
BackendGalleries: appConfig.BackendGalleries,
}
return c.JSON(schema.GalleryResponse{
return c.JSON(200, schema.GalleryResponse{
ID: uuid.String(),
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
})
@@ -69,22 +71,28 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
}
// ImportModelEndpoint handles creating new model configurations
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Get the raw body
body := c.Body()
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Check content type to determine how to parse
contentType := string(c.Context().Request.Header.ContentType())
contentType := c.Request().Header.Get("Content-Type")
var modelConfig config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") {
// Parse JSON
@@ -93,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
@@ -102,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else {
// Try to auto-detect format
if strings.TrimSpace(string(body))[0] == '{' {
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
} else {
// Assume YAML
@@ -122,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
}
}
@@ -133,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Set defaults
@@ -145,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Invalid configuration",
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Create the configuration file
@@ -155,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Model path not trusted: " + err.Error(),
}
return c.Status(400).JSON(response)
return c.JSON(http.StatusBadRequest, response)
}
// Marshal to YAML for storage
@@ -165,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Write the file
@@ -174,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
@@ -182,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Preload the model
@@ -191,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.JSON(http.StatusInternalServerError, response)
}
// Return success response
response := ModelResponse{
@@ -199,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Message: "Model configuration created successfully",
Filename: filepath.Base(configPath),
}
return c.JSON(response)
return c.JSON(200, response)
}
}

View File

@@ -1,46 +1,47 @@
package localai
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}
package localai
import (
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() echo.HandlerFunc {
return echo.WrapHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c echo.Context) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c echo.Context) bool {
return c.Path() == "/metrics"
},
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if cfg.Filter != nil && cfg.Filter(c) {
return next(c)
}
path := c.Path()
method := c.Request().Method
start := time.Now()
err := next(c)
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
@@ -11,10 +11,10 @@ import (
// @Summary Returns available P2P nodes
// @Success 200 {object} []schema.P2PNodesResponse "Response"
// @Router /api/p2p [get]
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
// Render index
return func(c *fiber.Ctx) error {
return c.JSON(schema.P2PNodesResponse{
return func(c echo.Context) error {
return c.JSON(200, schema.P2PNodesResponse{
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
})
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
// @Summary Show the P2P token
// @Success 200 {string} string "Response"
// @Router /api/p2p/token [get]
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
@@ -9,11 +9,11 @@ import (
"github.com/mudler/LocalAI/pkg/store"
)
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresSet)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
vals[i] = []byte(v)
}
err = store.SetCols(c.Context(), sb, input.Keys, vals)
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
if err != nil {
return err
}
return c.Send(nil)
return c.NoContent(200)
}
}
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresDelete)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
}
defer sl.Close()
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
return err
}
return c.Send(nil)
return c.NoContent(200)
}
}
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresGet)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
}
defer sl.Close()
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
if err != nil {
return err
}
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
res.Values[i] = string(v)
}
return c.JSON(res)
return c.JSON(200, res)
}
}
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.StoresFind)
if err := c.BodyParser(input); err != nil {
if err := c.Bind(input); err != nil {
return err
}
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
}
defer sl.Close()
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
if err != nil {
return err
}
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
res.Values[i] = string(v)
}
return c.JSON(res)
return c.JSON(200, res)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
@@ -11,8 +11,8 @@ import (
// @Summary Show the LocalAI instance information
// @Success 200 {object} schema.SystemInformationResponse "Response"
// @Router /system [get]
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
availableBackends := []string{}
loadedModels := ml.ListLoadedModels()
for b := range appConfig.ExternalGRPCBackends {
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
for _, m := range loadedModels {
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
}
return c.JSON(
return c.JSON(200,
schema.SystemInformationResponse{
Backends: availableBackends,
Models: sysmodels,

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,22 +14,22 @@ import (
// @Param request body schema.TokenizeRequest true "Request"
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return ctx.JSON(tokenResponse)
return c.JSON(200, tokenResponse)
}
}

View File

@@ -1,12 +1,14 @@
package localai
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
@@ -22,16 +24,16 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
@@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
return err
}
return c.Download(filePath)
return c.Attachment(filePath, filepath.Base(filePath))
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,26 +16,26 @@ import (
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
if err != nil {
return err
}
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -7,19 +7,20 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.VideoRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" {
log.Error().Msg("Video Endpoint - Invalid Input")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
src := ""
@@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := c.BaseURL()
baseURL := middleware.BaseURL(c)
fn, err := backend.VideoGeneration(
height,
@@ -201,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-videos/" + base
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
if err != nil {
return err
}
}
id := uuid.New().String()
@@ -216,6 +220,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -1,18 +1,20 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -40,10 +42,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
// Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := opcache.GetStatus()
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
@@ -54,12 +56,16 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
"InstalledBackends": installedBackends,
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
// Default to HTML if Accept header is empty (browser behavior)
// Only return JSON if explicitly requested or Content-Type is application/json
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
return c.JSON(200, summary)
} else {
// Render index
return c.Render("views/index", summary)
return c.Render(200, "views/index", summary)
}
}
}

View File

@@ -1,15 +1,12 @@
package openai
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -20,68 +17,14 @@ import (
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
// Fasthttp doesn't support context cancellation from the caller
// for non-streaming requests, so we need to monitor the connection directly.
// Monitor connection for client disconnection during non-streaming requests
// We access the connection directly via c.Context().Conn() to monitor it
// during ComputeChoices execution, not after the response is sent
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
var conn net.Conn = c.Context().Conn()
if conn == nil {
return
}
go func() {
defer func() {
// Clear read deadline when goroutine exits
conn.SetReadDeadline(time.Time{})
}()
buf := make([]byte, 1)
// Use a short read deadline to periodically check if connection is closed
// Without a deadline, Read() would block indefinitely waiting for data
// that will never come (client is waiting for response, not sending more data)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-requestCtx.Done():
// Request completed or was cancelled - exit goroutine
return
case <-ticker.C:
// Set a short deadline - if connection is closed, read will fail immediately
// If connection is open but no data, it will timeout and we check again
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
_, err := conn.Read(buf)
if err != nil {
// Check if it's a timeout (connection still open, just no data)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout is expected - connection is still open, just no data to read
// Continue the loop to check again
continue
}
// Connection closed or other error - cancel the context to stop gRPC call
log.Debug().Msgf("Calling cancellation function")
cancelFunc()
return
}
}
}
}()
}
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
// @Summary Generate a chat completions for a given prompt and model.
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
var id, textContentToReturn string
var created int
@@ -235,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return err
}
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
extraUsage := c.Get("Extra-Usage", "") != ""
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
@@ -392,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
case toStream:
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse)
ended := make(chan error, 1)
@@ -411,103 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
}()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
usage := &schema.OpenAIUsage{}
toolsCalled := false
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
input.Cancel()
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(w, "data: %s\n\n", string(respData))
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
}
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &stopReason,
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
w.WriteString("data: {\"error\":\"Internal error\"}\n\n")
} else {
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
}
w.WriteString("data: [DONE]\n\n")
w.Flush()
return
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
input.Cancel()
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &stopReason,
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
}
finishReason := FinishReasonStop
if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls
} else if toolsCalled {
finishReason = FinishReasonFunctionCall
}
finishReason := FinishReasonStop
if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls
} else if toolsCalled {
finishReason = FinishReasonFunctionCall
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &finishReason,
Index: 0,
Delta: &schema.Message{},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
log.Debug().Msgf("Stream ended")
}))
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &finishReason,
Index: 0,
Delta: &schema.Message{},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
log.Debug().Msgf("Stream ended")
return nil
// no streaming mode
@@ -589,9 +527,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
// NOTE: this is a workaround as fasthttp
// context cancellation does not fire in non-streaming requests
// handleConnectionCancellation(c, input.Cancel, input.Context)
// Echo properly supports context cancellation via c.Request().Context()
// No workaround needed!
result, tokenUsage, err := ComputeChoices(
input,
@@ -628,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}
}

View File

@@ -1,24 +1,22 @@
package openai
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
@@ -26,7 +24,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
created := int(time.Now().Unix())
@@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
return err
}
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
created := int(time.Now().Unix())
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
extraUsage := c.Get("Extra-Usage", "") != ""
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
if config.ResponseFormatMap != nil {
@@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
@@ -130,78 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
continue
}
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
fmt.Fprintf(w, "data: %s\n\n", string(respData))
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
errorResp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(w, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(w, "data: %s\n\n", string(errorData))
}
w.Flush()
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
break LOOP
}
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
stopReason := FinishReasonStop
errorResp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
},
Object: "text_completion",
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
c.Response().Flush()
return nil
}
respData, _ := json.Marshal(resp)
}
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return <-ended
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
var result []schema.Choice
@@ -257,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
@@ -23,20 +23,20 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post]
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -12,7 +13,6 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -21,16 +21,16 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
log.Debug().Msgf("Parameter Config: %+v", config)
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -14,13 +15,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
// Process input images (for img2img/inpainting)
@@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := c.BaseURL()
baseURL := middleware.BaseURL(c)
// Use the first input image as src if available, otherwise use the original src
inputSrc := src
@@ -215,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
if err != nil {
return err
}
}
result = append(result, *item)
@@ -234,7 +238,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -1,7 +1,7 @@
package openai
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
@@ -12,14 +12,15 @@ import (
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// If blank, no filter is applied.
filter := c.Query("filter")
filter := c.QueryParam("filter")
// By default, exclude any loose files that are already referenced by a configuration file.
var policy services.LooseFilePolicy
if c.QueryBool("excludeConfigured", true) {
excludeConfigured := c.QueryParam("excludeConfigured")
if excludeConfigured == "" || excludeConfigured == "true" {
policy = services.SKIP_IF_CONFIGURED
} else {
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
@@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(schema.ModelsDataResponse{
return c.JSON(200, schema.ModelsDataResponse{
Object: "list",
Data: dataModels,
})

View File

@@ -8,11 +8,11 @@ import (
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
@@ -26,24 +26,27 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /mcp/v1/completions [post]
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
// We do not support streaming mode (Yet?)
return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
created := int(time.Now().Unix())
ctx := c.Context()
ctx := c.Request().Context()
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
@@ -80,7 +83,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
//handleConnectionCancellation(c, cancel, ctxWithCancellation)
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
@@ -147,6 +150,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(200, resp)
}
}

View File

@@ -10,9 +10,11 @@ import (
"sync"
"time"
"net/http"
"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
@@ -167,32 +169,50 @@ type Model interface {
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
}
// TODO: Implement ephemeral keys to allow these endpoints to be used
func RealtimeSessions(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
}
}
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
}
}
func Realtime(application *application.Application) fiber.Handler {
return websocket.New(registerRealtime(application))
func Realtime(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
// Extract query parameters from Echo context before passing to websocket handler
model := c.QueryParam("model")
if model == "" {
model = "gpt-4o"
}
intent := c.QueryParam("intent")
registerRealtime(application, model, intent)(ws)
return nil
}
}
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
evaluator := application.TemplatesEvaluator()
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
model := c.Query("model", "gpt-4o")
intent := c.Query("intent")
if intent != "transcription" {
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
}

View File

@@ -7,13 +7,13 @@ import (
"path"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -24,19 +24,19 @@ import (
// @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
diarize := c.FormValue("diarize", "false") != "false"
diarize := c.FormValue("diarize") != "false"
// retrieve the file data from the request
file, err := c.FormFile("file")
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
return c.JSON(http.StatusOK, tr)
}
}

View File

@@ -6,7 +6,7 @@ import (
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,20 +14,24 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input == nil {
return fiber.ErrBadRequest
return echo.ErrBadRequest
}
var raw map[string]interface{}
if body := c.Body(); len(body) > 0 {
body := make([]byte, 0)
if c.Request().Body != nil {
c.Request().Body.Read(body)
}
if len(body) > 0 {
_ = json.Unmarshal(body, &raw)
}
// Build VideoRequest using shared mapper
vr := MapOpenAIToVideo(input, raw)
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Place VideoRequest into context so localai.VideoEndpoint can consume it
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Delegate to existing localai handler
return localai.VideoEndpoint(cl, ml, appConfig)(c)
}

View File

@@ -1,48 +1,50 @@
package http
import (
"io/fs"
"net/http"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/rs/zerolog/log"
)
func Explorer(db *explorer.Database) *fiber.App {
func Explorer(db *explorer.Database) *echo.Echo {
e := echo.New()
fiberCfg := fiber.Config{
Views: renderEngine(),
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: false,
// Override default error handler
// Set renderer
e.Renderer = renderEngine()
// Hide banner
e.HideBanner = true
e.Pre(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(e, db)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
// Log error but continue - static files might not work
log.Error().Err(err).Msg("failed to create static filesystem")
} else {
e.StaticFS("/static", staticFS)
}
app := fiber.New(fiberCfg)
app.Use(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(app, db)
httpFS := http.FS(embedDirStatic)
app.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
app.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler
// Note: keep this at the bottom!
app.Use(notFoundHandler)
e.GET("/*", notFoundHandler)
return app
return e
}

View File

@@ -3,50 +3,108 @@ package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
)
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
if err != nil {
return nil, err
}
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
// Create validator function
validator := getApiKeyValidationFunction(applicationConfig)
return &v2keyauth.Config{
CustomKeyLookup: customLookup,
Next: getApiKeyRequiredFilterFunction(applicationConfig),
Validator: getApiKeyValidationFunction(applicationConfig),
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
// Create error handler
errorHandler := getApiKeyErrorHandler(applicationConfig)
// Create Next function (skip middleware for certain requests)
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
// Wrap it with our custom key lookup that checks multiple sources
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if len(applicationConfig.ApiKeys) == 0 {
return next(c)
}
// Skip if skipper says so
if skipper != nil && skipper(c) {
return next(c)
}
// Try to extract key from multiple sources
key, err := extractKeyFromMultipleSources(c)
if err != nil {
return errorHandler(err, c)
}
// Validate the key
valid, err := validator(key, c)
if err != nil || !valid {
return errorHandler(ErrMissingOrMalformedAPIKey, c)
}
// Store key in context for later use
c.Set("api_key", key)
return next(c)
}
}, nil
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
return func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
// extractKeyFromMultipleSources checks multiple sources for the API key
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
// Check Authorization header first
auth := c.Request().Header.Get("Authorization")
if auth != "" {
// Check for Bearer scheme
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer "), nil
}
// If no Bearer prefix, return as-is (for backward compatibility)
return auth, nil
}
// Check x-api-key header
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key, nil
}
// Check xi-api-key header
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key, nil
}
// Check token cookie
cookie, err := c.Cookie("token")
if err == nil && cookie != nil && cookie.Value != "" {
return cookie.Value, nil
}
return "", ErrMissingOrMalformedAPIKey
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
return func(err error, c echo.Context) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return ctx.Next() // if no keys are set up, any error we get here is not an error.
return nil // if no keys are set up, any error we get here is not an error.
}
ctx.Set("WWW-Authenticate", "Bearer")
c.Response().Header().Set("WWW-Authenticate", "Bearer")
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(401)
return c.NoContent(http.StatusUnauthorized)
}
// Check if the request content type is JSON
contentType := string(ctx.Context().Request.Header.ContentType())
contentType := c.Request().Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
return ctx.Status(401).JSON(schema.ErrorResponse{
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: 401,
@@ -55,50 +113,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
})
}
return ctx.Status(401).Render("views/login", fiber.Map{
"BaseURL": utils.BaseURL(ctx),
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
"BaseURL": BaseURL(c),
})
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
return c.NoContent(http.StatusInternalServerError)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
return false, ErrMissingOrMalformedAPIKey
}
}
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if apiKey == validKey {
if key == validKey {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
return false, ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
if applicationConfig.DisableApiKeyRequirementForHttpGet {
return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
return func(c echo.Context) bool {
path := c.Request().URL.Path
// Always skip authentication for static files
if strings.HasPrefix(path, "/static/") {
return true
}
// Always skip authentication for generated content
if strings.HasPrefix(path, "/generated-audio/") ||
strings.HasPrefix(path, "/generated-images/") ||
strings.HasPrefix(path, "/generated-videos/") {
return true
}
// Skip authentication for favicon
if path == "/favicon.svg" {
return true
}
// Handle GET request exemptions if enabled
if applicationConfig.DisableApiKeyRequirementForHttpGet {
if c.Request().Method != http.MethodGet {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
@@ -106,8 +183,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
return true
}
}
return false
}
return false
}
return func(c *fiber.Ctx) bool { return false }
}

View File

@@ -0,0 +1,48 @@
package middleware
import (
"strings"
"github.com/labstack/echo/v4"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c echo.Context) string {
path := c.Path()
origPath := c.Request().URL.Path
// Check if StripPathPrefix middleware stored the original path
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
origPath = storedPath
}
// Check X-Forwarded-Proto for scheme
scheme := "http"
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
} else if c.Request().TLS != nil {
scheme = "https"
}
// Check X-Forwarded-Host for host
host := c.Request().Host
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
prefixLen := len(origPath) - len(path)
if prefixLen > 0 && prefixLen <= len(origPath) {
pathPrefix := origPath[:prefixLen]
if !strings.HasSuffix(pathPrefix, "/") {
pathPrefix += "/"
}
return scheme + "://" + host + pathPrefix
}
}
return scheme + "://" + host + "/"
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
"net/http/httptest"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("BaseURL", func() {
Context("without prefix", func() {
It("should return base URL without prefix", func() {
app := echo.New()
actualURL := ""
// Register route - use the actual request path so routing works
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
})
})
Context("with prefix", func() {
It("should return base URL with prefix", func() {
app := echo.New()
actualURL := ""
// Register route with the stripped path (after middleware removes prefix)
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
// Simulate what StripPathPrefix middleware does - store original path
c.Set("_original_path", "/myprefix/hello/world")
// Modify the request path to simulate prefix stripping
c.Request().URL.Path = "/hello/world"
actualURL = BaseURL(c)
return nil
})
// Make request with stripped path (middleware would have already processed it)
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
})
})
})

View File

@@ -0,0 +1,13 @@
package middleware_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestMiddleware(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware test suite")
}

View File

@@ -1,470 +1,482 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
}
if model == "" {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Use the application context as parent to ensure cancellation on app shutdown
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Monitor the Fiber context and cancel our context when it's canceled
// This ensures we respect request cancellation without causing panics
go func(fiberCtx *fasthttp.RequestCtx) {
if fiberCtx != nil {
<-fiberCtx.Done()
cancel()
}
}(ctx.Context())
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = c.Param("model")
if model == "" {
model = c.QueryParam("model")
}
if model == "" {
// Set model from bearer token, if available
auth := c.Request().Header.Get("Authorization")
bearer := strings.TrimPrefix(auth, "Bearer ")
if bearer != "" && bearer != auth {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return next(c)
}
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return next(c)
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return next(c)
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return next(c)
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return next(c)
}
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
input := initializer()
if input == nil {
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
}
if err := c.Bind(input); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return next(c)
}
}
}
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := c.Request().Header.Get("X-Correlation-ID")
if correlationID == "" {
correlationID = uuid.New().String()
}
c.Response().Header().Set("X-Correlation-ID", correlationID)
// Use the request context directly - Echo properly supports context cancellation!
// No need for workarounds like handleConnectionCancellation
reqCtx := c.Request().Context()
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Cancel when request context is cancelled (client disconnects)
go func() {
select {
case <-reqCtx.Done():
cancel()
case <-c1.Done():
// Already cancelled
}
}()
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return nil
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}

View File

@@ -3,34 +3,55 @@ package middleware
import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
)
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
// StripPathPrefix returns middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
func StripPathPrefix() fiber.Handler {
return func(c *fiber.Ctx) error {
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
if prefix != "" {
path := c.Path()
pos := len(prefix)
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
func StripPathPrefix() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
originalPath := c.Request().URL.Path
if prefix[pos-1] == '/' {
pos--
} else {
prefix += "/"
}
for _, prefix := range prefixes {
if prefix != "" {
normalizedPrefix := prefix
if !strings.HasSuffix(prefix, "/") {
normalizedPrefix = prefix + "/"
}
if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break
} else if prefix[:pos] == path {
c.Redirect(prefix)
return nil
if strings.HasPrefix(originalPath, normalizedPrefix) {
// Update the request path by stripping the normalized prefix
newPath := originalPath[len(normalizedPrefix):]
if newPath == "" {
newPath = "/"
}
// Ensure path starts with / for proper routing
if !strings.HasPrefix(newPath, "/") {
newPath = "/" + newPath
}
// Update the URL path - Echo's router uses URL.Path for routing
c.Request().URL.Path = newPath
c.Request().URL.RawPath = ""
// Update RequestURI to match the new path (needed for proper routing)
if c.Request().URL.RawQuery != "" {
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
} else {
c.Request().RequestURI = newPath
}
// Store original path for BaseURL utility
c.Set("_original_path", originalPath)
break
} else if originalPath == prefix || originalPath == prefix+"/" {
// Redirect to prefix with trailing slash (use 302 to match test expectations)
return c.Redirect(302, normalizedPrefix)
}
}
}
}
return c.Next()
return next(c)
}
}
}

View File

@@ -2,120 +2,133 @@ package middleware
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestStripPathPrefix(t *testing.T) {
var _ = Describe("StripPathPrefix", func() {
var app *echo.Echo
var actualPath string
var appInitialized bool
app := fiber.New()
BeforeEach(func() {
actualPath = ""
if !appInitialized {
app = echo.New()
app.Pre(StripPathPrefix())
app.Use(StripPathPrefix())
app.GET("/hello/world", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
app.GET("/", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
appInitialized = true
}
})
app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
Context("without prefix", func() {
It("should not modify path when no header is present", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = ""
req := httptest.NewRequest("GET", tc.path, nil)
if tc.prefixHeader != nil {
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
if tc.expectStatus == 200 {
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
}
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
}
}
It("should not modify root path when no header is present", func() {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/"), "rewritten path")
})
It("should not modify path when header does not match", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
})
Context("with prefix", func() {
It("should return 404 when prefix does not match header", func() {
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should strip matching prefix from path", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the first header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the second header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when header does not end with slash", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should return 404 when prefix does not match header without trailing slash", func() {
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should redirect when prefix does not end with a slash", func() {
req := httptest.NewRequest("GET", "/myprefix", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(302), "response status code")
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
})
})
})

View File

@@ -17,7 +17,7 @@ import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"fmt"
. "github.com/mudler/LocalAI/core/http"
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var tmpdir string
var appServer *application.Application
var app *fiber.App
var app *echo.Echo
var ctx context.Context
var cancel context.CancelFunc
@@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
AfterEach(func() {
cancel()
if app != nil {
_ = app.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = app.Shutdown(ctx)
}
_ = os.RemoveAll(tmpdir)
})
@@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var err error
app, err = API(appServer)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9091")
go func() {
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
// Log error if needed
}
}()
// wait for server
client := &http.Client{Timeout: 5 * time.Second}

View File

@@ -4,13 +4,15 @@ import (
"embed"
"fmt"
"html/template"
"io"
"io/fs"
"net/http"
"strings"
"github.com/Masterminds/sprig/v3"
"github.com/gofiber/fiber/v2"
fiberhtml "github.com/gofiber/template/html/v2"
"github.com/labstack/echo/v4"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/russross/blackfriday"
)
@@ -18,26 +20,67 @@ import (
//go:embed views/*
var viewsfs embed.FS
func notFoundHandler(c *fiber.Ctx) error {
// TemplateRenderer is a custom template renderer for Echo
type TemplateRenderer struct {
templates *template.Template
}
// Render renders a template document
func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
return t.templates.ExecuteTemplate(w, name, data)
}
func notFoundHandler(c echo.Context) error {
// Check if the request accepts JSON
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
// The client expects a JSON response
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
})
} else {
// The client expects an HTML response
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
"BaseURL": utils.BaseURL(c),
return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
"BaseURL": middleware.BaseURL(c),
})
}
}
func renderEngine() *fiberhtml.Engine {
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
engine.AddFuncMap(sprig.FuncMap())
engine.AddFunc("MDToHTML", markDowner)
return engine
func renderEngine() *TemplateRenderer {
// Parse all templates from embedded filesystem
tmpl := template.New("").Funcs(sprig.FuncMap())
tmpl = tmpl.Funcs(template.FuncMap{
"MDToHTML": markDowner,
})
// Recursively walk through embedded filesystem and parse all HTML templates
err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".html") {
data, err := viewsfs.ReadFile(path)
if err == nil {
// Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
templateName := strings.TrimSuffix(path, ".html")
_, err := tmpl.New(templateName).Parse(string(data))
if err != nil {
// If parsing fails, try parsing without explicit name (for templates with {{define}})
tmpl.Parse(string(data))
}
}
}
return nil
})
if err != nil {
// Log error but continue - templates might still work
fmt.Printf("Error walking views directory: %v\n", err)
}
return &TemplateRenderer{
templates: tmpl,
}
}
func markDowner(args ...interface{}) template.HTML {

View File

@@ -1,7 +1,7 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -9,21 +9,23 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterElevenLabsRoutes(app *fiber.App,
func RegisterElevenLabsRoutes(app *echo.Echo,
re *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id",
ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
app.POST("/v1/text-to-speech/:voice-id",
ttsHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
app.Post("/v1/sound-generation",
soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
app.POST("/v1/sound-generation",
soundGenHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
}

View File

@@ -1,13 +1,13 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
coreExplorer "github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/endpoints/explorer"
)
func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) {
app.Get("/", explorer.Dashboard())
app.Post("/network/add", explorer.AddNetwork(db))
app.Get("/networks", explorer.ShowNetworks(db))
func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) {
app.GET("/", explorer.Dashboard())
app.POST("/network/add", explorer.AddNetwork(db))
app.GET("/networks", explorer.ShowNetworks(db))
}

View File

@@ -1,13 +1,15 @@
package routes
import "github.com/gofiber/fiber/v2"
import (
"github.com/labstack/echo/v4"
)
func HealthRoutes(app *fiber.App) {
func HealthRoutes(app *echo.Echo) {
// Service health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
ok := func(c echo.Context) error {
return c.NoContent(200)
}
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.GET("/healthz", ok)
app.GET("/readyz", ok)
}

View File

@@ -1,24 +1,25 @@
package routes
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterJINARoutes(app *fiber.App,
func RegisterJINARoutes(app *echo.Echo,
re *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank",
rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig)
app.POST("/v1/rerank",
rerankHandler,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }))
}

View File

@@ -1,19 +1,18 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/swagger"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
echoswagger "github.com/swaggo/echo-swagger"
)
func RegisterLocalAIRoutes(router *fiber.App,
func RegisterLocalAIRoutes(router *echo.Echo,
requestExtractor *middleware.RequestExtractor,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
@@ -21,111 +20,117 @@ func RegisterLocalAIRoutes(router *fiber.App,
galleryService *services.GalleryService,
opcache *services.OpCache) {
router.Get("/swagger/*", swagger.HandlerDefault) // default
router.GET("/swagger/*", echoswagger.WrapHandler) // default
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
// Import model page
router.Get("/import-model", func(c *fiber.Ctx) error {
return c.Render("views/model-editor", fiber.Map{
router.GET("/import-model", func(c echo.Context) error {
return c.Render(200, "views/model-editor", map[string]interface{}{
"Title": "LocalAI - Import Model",
"BaseURL": httpUtils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
})
})
// Edit model page
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
backendGalleryEndpointService := localai.CreateBackendEndpointService(
appConfig.BackendGalleries,
appConfig.SystemState,
galleryService)
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
// Custom model import endpoint
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig))
// URI model import endpoint
router.Post("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
// Custom model edit endpoint
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
// Reload models endpoint
router.Post("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
}
router.Post("/v1/detection",
detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig)
router.POST("/v1/detection",
detectionHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }),
localai.DetectionEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }))
router.Post("/tts",
ttsHandler := localai.TTSEndpoint(cl, ml, appConfig)
router.POST("/tts",
ttsHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }))
vadChain := []fiber.Handler{
vadHandler := localai.VADEndpoint(cl, ml, appConfig)
router.POST("/vad",
vadHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
router.POST("/v1/vad",
vadHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
// Stores
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
if !appConfig.DisableMetrics {
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
router.GET("/metrics", localai.LocalAIMetricsEndpoint())
}
router.Post("/video",
videoHandler := localai.VideoEndpoint(cl, ml, appConfig)
router.POST("/video",
videoHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
localai.VideoEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }))
// Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
router.GET("/api/p2p", localai.ShowP2PNodes(appConfig))
router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig))
router.Get("/version", func(c *fiber.Ctx) error {
return c.JSON(struct {
router.GET("/version", func(c echo.Context) error {
return c.JSON(200, struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
router.Get("/system", localai.SystemInformations(ml, appConfig))
router.GET("/system", localai.SystemInformations(ml, appConfig))
// misc
router.Post("/v1/tokenize",
tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig)
router.POST("/v1/tokenize",
tokenizeHandler,
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }))
}

View File

@@ -1,7 +1,7 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
@@ -10,118 +10,172 @@ import (
"github.com/mudler/LocalAI/core/schema"
)
func RegisterOpenAIRoutes(app *fiber.App,
func RegisterOpenAIRoutes(app *echo.Echo,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
// realtime
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
app.Get("/v1/realtime", openai.Realtime(application))
app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
app.GET("/v1/realtime", openai.Realtime(application))
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
// chat
chatChain := []fiber.Handler{
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
chatMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/chat/completions", chatChain...)
app.Post("/chat/completions", chatChain...)
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
app.POST("/chat/completions", chatHandler, chatMiddleware...)
// edit
editChain := []fiber.Handler{
editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
editMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.POST("/v1/edits", editHandler, editMiddleware...)
app.POST("/edits", editHandler, editMiddleware...)
// completion
completionChain := []fiber.Handler{
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
completionMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
app.POST("/v1/completions", completionHandler, completionMiddleware...)
app.POST("/completions", completionHandler, completionMiddleware...)
app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...)
// MCPcompletion
mcpCompletionChain := []fiber.Handler{
mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
mcpCompletionMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/mcp/v1/chat/completions", mcpCompletionChain...)
app.Post("/mcp/chat/completions", mcpCompletionChain...)
app.POST("/mcp/v1/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
app.POST("/mcp/chat/completions", mcpCompletionHandler, mcpCompletionMiddleware...)
// embeddings
embeddingChain := []fiber.Handler{
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
embeddingMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...)
audioChain := []fiber.Handler{
audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
// audio
app.Post("/v1/audio/transcriptions", audioChain...)
app.Post("/audio/transcriptions", audioChain...)
app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
audioSpeechChain := []fiber.Handler{
audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioSpeechMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/audio/speech",
audioSpeechChain...)
app.Post("/audio/speech", audioSpeechChain...)
app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
// images
imageChain := []fiber.Handler{
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
imageMiddleware := []echo.MiddlewareFunc{
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
app.Post("/v1/images/generations",
imageChain...)
app.Post("/images/generations", imageChain...)
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
app.POST("/images/generations", imageHandler, imageMiddleware...)
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
videoChain := []fiber.Handler{
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
videoMiddleware := []echo.MiddlewareFunc{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()),
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := re.SetOpenAIRequest(c); err != nil {
return err
}
return next(c)
}
},
}
// OpenAI-style create video endpoint
app.Post("/v1/videos", videoChain...)
app.Post("/v1/videos/generations", videoChain...)
app.Post("/videos", videoChain...)
app.POST("/v1/videos", videoHandler, videoMiddleware...)
app.POST("/v1/videos/generations", videoHandler, videoMiddleware...)
app.POST("/videos", videoHandler, videoMiddleware...)
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
}

View File

@@ -1,18 +1,17 @@
package routes
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterUIRoutes(app *fiber.App,
func RegisterUIRoutes(app *echo.Echo,
cl *config.ModelConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -21,13 +20,13 @@ func RegisterUIRoutes(app *fiber.App,
// keeps the state of ops that are started from the UI
var processingOps = services.NewOpCache(galleryService)
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
// P2P
app.Get("/p2p", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/p2p/", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - P2P dashboard",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
//"Nodes": p2p.GetAvailableNodes(""),
//"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID),
@@ -37,7 +36,7 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/p2p", summary)
return c.Render(200, "views/p2p", summary)
})
// Note: P2P UI fragment routes (/p2p/ui/*) were removed
@@ -50,17 +49,17 @@ func RegisterUIRoutes(app *fiber.App,
registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps)
}
app.Get("/talk/", func(c *fiber.Ctx) error {
app.GET("/talk/", func(c echo.Context) error {
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(modelConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI - Talk",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"Model": modelConfigs[0],
@@ -68,16 +67,16 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/talk", summary)
return c.Render(200, "views/talk", summary)
})
app.Get("/chat/", func(c *fiber.Ctx) error {
app.GET("/chat/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -104,9 +103,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsWithoutConfig": modelsWithoutConfig,
"GalleryConfig": galleryConfigs,
"ModelsConfig": modelConfigs,
@@ -116,16 +115,16 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/chat", summary)
return c.Render(200, "views/chat", summary)
})
// Show the Chat page
app.Get("/chat/:model", func(c *fiber.Ctx) error {
app.GET("/chat/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
galleryConfigs := map[string]*gallery.ModelConfig{}
modelName := c.Params("model")
modelName := c.Param("model")
var modelContextSize *int
for _, m := range modelConfigs {
@@ -139,9 +138,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": "LocalAI - Chat with " + modelName,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
@@ -151,33 +150,33 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/chat", summary)
return c.Render(200, "views/chat", summary)
})
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
app.GET("/text2image/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
summary := map[string]interface{}{
"Title": "LocalAI - Generate images with " + c.Param("model"),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Model": c.Param("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/text2image", summary)
return c.Render(200, "views/text2image", summary)
})
app.Get("/text2image/", func(c *fiber.Ctx) error {
app.GET("/text2image/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
@@ -191,9 +190,9 @@ func RegisterUIRoutes(app *fiber.App,
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
@@ -201,33 +200,33 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/text2image", summary)
return c.Render(200, "views/text2image", summary)
})
app.Get("/tts/:model", func(c *fiber.Ctx) error {
app.GET("/tts/:model", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"BaseURL": utils.BaseURL(c),
summary := map[string]interface{}{
"Title": "LocalAI - Generate images with " + c.Param("model"),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": c.Params("model"),
"Model": c.Param("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/tts", summary)
return c.Render(200, "views/tts", summary)
})
app.Get("/tts/", func(c *fiber.Ctx) error {
app.GET("/tts/", func(c echo.Context) error {
modelConfigs := cl.GetAllModelsConfigs()
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
// If no model is available redirect to the index which suggests how to install models
return c.Redirect(utils.BaseURL(c))
return c.Redirect(302, middleware.BaseURL(c))
}
modelThatCanBeUsed := ""
@@ -240,9 +239,9 @@ func RegisterUIRoutes(app *fiber.App,
break
}
}
summary := fiber.Map{
summary := map[string]interface{}{
"Title": title,
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"ModelsConfig": modelConfigs,
"ModelsWithoutConfig": modelsWithoutConfig,
"Model": modelThatCanBeUsed,
@@ -250,6 +249,6 @@ func RegisterUIRoutes(app *fiber.App,
}
// Render index
return c.Render("views/tts", summary)
return c.Render(200, "views/tts", summary)
})
}

View File

@@ -4,13 +4,14 @@ import (
"context"
"fmt"
"math"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
@@ -19,13 +20,13 @@ import (
)
// RegisterUIAPIRoutes registers JSON API routes for the web UI
func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
// Operations API - Get all current operations (models + backends)
app.Get("/api/operations", func(c *fiber.Ctx) error {
app.GET("/api/operations", func(c echo.Context) error {
processingData, taskTypes := opcache.GetStatus()
operations := []fiber.Map{}
operations := []map[string]interface{}{}
for galleryID, jobID := range processingData {
taskType := "installation"
if tt, ok := taskTypes[galleryID]; ok {
@@ -88,7 +89,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
operations = append(operations, fiber.Map{
operations = append(operations, map[string]interface{}{
"id": galleryID,
"name": displayName,
"fullName": galleryID,
@@ -118,20 +119,20 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
return operations[i]["id"].(string) < operations[j]["id"].(string)
})
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"operations": operations,
})
})
// Cancel operation endpoint
app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error {
jobID := strings.Clone(c.Params("jobID"))
app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error {
jobID := c.Param("jobID")
log.Debug().Msgf("API request to cancel operation: %s", jobID)
err := galleryService.CancelOperation(jobID)
if err != nil {
log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": err.Error(),
})
}
@@ -139,22 +140,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
// Clean up opcache for cancelled operation
opcache.DeleteUUID(jobID)
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"success": true,
"message": "Operation cancelled",
})
})
// Model Gallery APIs
app.Get("/api/models", func(c *fiber.Ctx) error {
term := c.Query("term")
page := c.Query("page", "1")
items := c.Query("items", "21")
app.GET("/api/models", func(c echo.Context) error {
term := c.QueryParam("term")
page := c.QueryParam("page")
if page == "" {
page = "1"
}
items := c.QueryParam("items")
if items == "" {
items = "21"
}
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -197,7 +204,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
// Convert models to JSON-friendly format and deduplicate by ID
modelsJSON := make([]fiber.Map, 0, len(models))
modelsJSON := make([]map[string]interface{}, 0, len(models))
seenIDs := make(map[string]bool)
for _, m := range models {
@@ -223,7 +230,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
modelsJSON = append(modelsJSON, fiber.Map{
modelsJSON = append(modelsJSON, map[string]interface{}{
"id": modelID,
"name": m.Name,
"description": m.Description,
@@ -250,7 +257,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
nextPage = totalPages
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"models": modelsJSON,
"repositories": appConfig.Galleries,
"allTags": tags,
@@ -264,12 +271,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
})
app.Post("/api/models/install/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/install/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID (e.g., "localai%40model" -> "localai@model")
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -277,7 +284,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -300,18 +307,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
galleryService.ModelGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Installation started",
})
})
app.Post("/api/models/delete/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/delete/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -324,7 +331,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -350,18 +357,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
cl.RemoveModelConfig(galleryName)
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Deletion started",
})
})
app.Post("/api/models/config/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id"))
app.POST("/api/models/config/:id", func(c echo.Context) error {
galleryID := c.Param("id")
// URL decode the gallery ID
galleryID, err := url.QueryUnescape(galleryID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid model ID",
})
}
@@ -369,44 +376,44 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
model := gallery.FindGalleryElement(models, galleryID)
if model == nil {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
return c.JSON(http.StatusNotFound, map[string]interface{}{
"error": "model not found",
})
}
config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
_, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"message": "Configuration file saved",
})
})
app.Get("/api/models/job/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid"))
app.GET("/api/models/job/:uid", func(c echo.Context) error {
jobUID := c.Param("uid")
status := galleryService.GetStatus(jobUID)
if status == nil {
// Job is queued but hasn't started processing yet
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"progress": 0,
"message": "Operation queued",
"galleryElementName": "",
@@ -416,7 +423,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
response := fiber.Map{
response := map[string]interface{}{
"progress": status.Progress,
"message": status.Message,
"galleryElementName": status.GalleryElementName,
@@ -434,19 +441,25 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
response["completed"] = true
}
return c.JSON(response)
return c.JSON(200, response)
})
// Backend Gallery APIs
app.Get("/api/backends", func(c *fiber.Ctx) error {
term := c.Query("term")
page := c.Query("page", "1")
items := c.Query("items", "21")
app.GET("/api/backends", func(c echo.Context) error {
term := c.QueryParam("term")
page := c.QueryParam("page")
if page == "" {
page = "1"
}
items := c.QueryParam("items")
if items == "" {
items = "21"
}
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
log.Error().Err(err).Msg("could not list backends from galleries")
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -489,7 +502,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
// Convert backends to JSON-friendly format and deduplicate by ID
backendsJSON := make([]fiber.Map, 0, len(backends))
backendsJSON := make([]map[string]interface{}, 0, len(backends))
seenBackendIDs := make(map[string]bool)
for _, b := range backends {
@@ -513,7 +526,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
backendsJSON = append(backendsJSON, fiber.Map{
backendsJSON = append(backendsJSON, map[string]interface{}{
"id": backendID,
"name": b.Name,
"description": b.Description,
@@ -538,7 +551,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
nextPage = totalPages
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"backends": backendsJSON,
"repositories": appConfig.BackendGalleries,
"allTags": tags,
@@ -552,12 +565,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
})
app.Post("/api/backends/install/:id", func(c *fiber.Ctx) error {
backendID := strings.Clone(c.Params("id"))
app.POST("/api/backends/install/:id", func(c echo.Context) error {
backendID := c.Param("id")
// URL decode the backend ID
backendID, err := url.QueryUnescape(backendID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend ID",
})
}
@@ -565,7 +578,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -587,18 +600,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
galleryService.BackendGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Backend installation started",
})
})
app.Post("/api/backends/delete/:id", func(c *fiber.Ctx) error {
backendID := strings.Clone(c.Params("id"))
app.POST("/api/backends/delete/:id", func(c echo.Context) error {
backendID := c.Param("id")
// URL decode the backend ID
backendID, err := url.QueryUnescape(backendID)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend ID",
})
}
@@ -611,7 +624,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
id, err := uuid.NewUUID()
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
@@ -635,19 +648,19 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
galleryService.BackendGalleryChannel <- op
}()
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"jobID": uid,
"message": "Backend deletion started",
})
})
app.Get("/api/backends/job/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid"))
app.GET("/api/backends/job/:uid", func(c echo.Context) error {
jobUID := c.Param("uid")
status := galleryService.GetStatus(jobUID)
if status == nil {
// Job is queued but hasn't started processing yet
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"progress": 0,
"message": "Operation queued",
"galleryElementName": "",
@@ -657,7 +670,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
response := fiber.Map{
response := map[string]interface{}{
"progress": status.Progress,
"message": status.Message,
"galleryElementName": status.GalleryElementName,
@@ -675,16 +688,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
response["completed"] = true
}
return c.JSON(response)
return c.JSON(200, response)
})
// System Backend Deletion API (for installed backends on index page)
app.Post("/api/backends/system/delete/:name", func(c *fiber.Ctx) error {
backendName := strings.Clone(c.Params("name"))
app.POST("/api/backends/system/delete/:name", func(c echo.Context) error {
backendName := c.Param("name")
// URL decode the backend name
backendName, err := url.QueryUnescape(backendName)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"error": "invalid backend name",
})
}
@@ -693,24 +706,24 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
// Use the gallery package to delete the backend
if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil {
log.Error().Err(err).Msgf("Failed to delete backend: %s", backendName)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": err.Error(),
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"success": true,
"message": "Backend deleted successfully",
})
})
// P2P APIs
app.Get("/api/p2p/workers", func(c *fiber.Ctx) error {
app.GET("/api/p2p/workers", func(c echo.Context) error {
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
nodesJSON := make([]fiber.Map, 0, len(nodes))
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
for _, n := range nodes {
nodesJSON = append(nodesJSON, fiber.Map{
nodesJSON = append(nodesJSON, map[string]interface{}{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
@@ -720,17 +733,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"nodes": nodesJSON,
})
})
app.Get("/api/p2p/federation", func(c *fiber.Ctx) error {
app.GET("/api/p2p/federation", func(c echo.Context) error {
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
nodesJSON := make([]fiber.Map, 0, len(nodes))
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
for _, n := range nodes {
nodesJSON = append(nodesJSON, fiber.Map{
nodesJSON = append(nodesJSON, map[string]interface{}{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
@@ -740,12 +753,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
return c.JSON(fiber.Map{
return c.JSON(200, map[string]interface{}{
"nodes": nodesJSON,
})
})
app.Get("/api/p2p/stats", func(c *fiber.Ctx) error {
app.GET("/api/p2p/stats", func(c echo.Context) error {
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
@@ -763,12 +776,12 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
}
return c.JSON(fiber.Map{
"workers": fiber.Map{
return c.JSON(200, map[string]interface{}{
"workers": map[string]interface{}{
"online": workersOnline,
"total": len(workerNodes),
},
"federated": fiber.Map{
"federated": map[string]interface{}{
"online": federatedOnline,
"total": len(federatedNodes),
},

View File

@@ -1,24 +1,24 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
)
func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
// Show the Backends page (all backends are loaded client-side via Alpine.js)
app.Get("/browse/backends", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/browse/backends", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - Backends",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
"Repositories": appConfig.BackendGalleries,
}
// Render index - backends are now loaded via Alpine.js from /api/backends
return c.Render("views/backends", summary)
return c.Render(200, "views/backends", summary)
})
}

View File

@@ -1,24 +1,24 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
)
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
app.Get("/browse", func(c *fiber.Ctx) error {
summary := fiber.Map{
app.GET("/browse/", func(c echo.Context) error {
summary := map[string]interface{}{
"Title": "LocalAI - Models",
"BaseURL": utils.BaseURL(c),
"BaseURL": middleware.BaseURL(c),
"Version": internal.PrintableVersion(),
"Repositories": appConfig.Galleries,
}
// Render index - models are now loaded via Alpine.js from /api/models
return c.Render("views/models", summary)
return c.Render(200, "views/models", summary)
})
}

View File

@@ -1,24 +0,0 @@
package utils
import (
"strings"
"github.com/gofiber/fiber/v2"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c *fiber.Ctx) string {
path := c.Path()
origPath := c.OriginalURL()
if path != origPath && strings.HasSuffix(origPath, path) {
pathPrefix := origPath[:len(origPath)-len(path)+1]
return c.BaseURL() + pathPrefix
}
return c.BaseURL() + "/"
}

View File

@@ -1,48 +0,0 @@
package utils
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
func TestBaseURL(t *testing.T) {
for _, tc := range []struct {
name string
prefix string
expectURL string
}{
{
name: "without prefix",
prefix: "/",
expectURL: "http://example.com/",
},
{
name: "with prefix",
prefix: "/myprefix/",
expectURL: "http://example.com/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
app := fiber.New()
actualURL := ""
app.Get(tc.prefix+"hello/world", func(c *fiber.Ctx) error {
if tc.prefix != "/" {
c.Path("/hello/world")
}
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil)
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode, "response status code")
require.Equal(t, tc.expectURL, actualURL, "base URL")
})
}
}