mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-05 17:59:44 -05:00
feat: Add Agentic MCP support with a new chat/completion endpoint (#6381)
* WIP - add endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Rename Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Wire the Completion API Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Try to make it functional Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Almost functional Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Bump golang versions used in tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add description of the tool Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Make it working Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small optimizations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Cleanup/refactor Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Update docs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
6b2c8277c2
commit
60b6472fa0
@@ -799,7 +799,7 @@ var _ = Describe("API test", func() {
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error:"))
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:"))
|
||||
})
|
||||
|
||||
It("shows the external backend", func() {
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/jsonschema"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func ToolsFromMCPConfig(ctx context.Context, remote config.MCPGenericConfig[config.MCPRemoteServers], stdio config.MCPGenericConfig[config.MCPSTDIOServers]) ([]*MCPTool, error) {
|
||||
allTools := []*MCPTool{}
|
||||
|
||||
// Get the list of all the tools that the Agent will be esposed to
|
||||
for _, server := range remote.Servers {
|
||||
|
||||
// Create HTTP client with custom roundtripper for bearer token injection
|
||||
client := &http.Client{
|
||||
Timeout: 360 * time.Second,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
|
||||
tools, err := mcpToolsFromTransport(ctx,
|
||||
&mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: client},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
}
|
||||
|
||||
for _, server := range stdio.Servers {
|
||||
log.Debug().Msgf("[MCP stdio server] Configuration : %+v", server)
|
||||
command := exec.Command(server.Command, server.Args...)
|
||||
command.Env = os.Environ()
|
||||
for key, value := range server.Env {
|
||||
command.Env = append(command.Env, key+"="+value)
|
||||
}
|
||||
tools, err := mcpToolsFromTransport(ctx,
|
||||
&mcp.CommandTransport{
|
||||
Command: command},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
}
|
||||
|
||||
return allTools, nil
|
||||
}
|
||||
|
||||
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
|
||||
// into HTTP requests
|
||||
type bearerTokenRoundTripper struct {
|
||||
token string
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if rt.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+rt.token)
|
||||
}
|
||||
return rt.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
|
||||
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return &bearerTokenRoundTripper{
|
||||
token: token,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
type MCPTool struct {
|
||||
name, description string
|
||||
inputSchema ToolInputSchema
|
||||
session *mcp.ClientSession
|
||||
ctx context.Context
|
||||
props map[string]jsonschema.Definition
|
||||
}
|
||||
|
||||
func (t *MCPTool) Run(args map[string]any) (string, error) {
|
||||
|
||||
// Call a tool on the server.
|
||||
params := &mcp.CallToolParams{
|
||||
Name: t.name,
|
||||
Arguments: args,
|
||||
}
|
||||
res, err := t.session.CallTool(t.ctx, params)
|
||||
if err != nil {
|
||||
log.Error().Msgf("CallTool failed: %v", err)
|
||||
return "", err
|
||||
}
|
||||
if res.IsError {
|
||||
log.Error().Msgf("tool failed")
|
||||
return "", errors.New("tool failed")
|
||||
}
|
||||
|
||||
result := ""
|
||||
for _, c := range res.Content {
|
||||
result += c.(*mcp.TextContent).Text
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *MCPTool) Tool() openai.Tool {
|
||||
|
||||
return openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: t.name,
|
||||
Description: t.description,
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: t.props,
|
||||
Required: t.inputSchema.Required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) Close() {
|
||||
t.session.Close()
|
||||
}
|
||||
|
||||
type ToolInputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]interface{} `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// probe the MCP remote and generate tools that are compliant with cogito
|
||||
// TODO: Maybe move this to cogito?
|
||||
func mcpToolsFromTransport(ctx context.Context, transport mcp.Transport) ([]*MCPTool, error) {
|
||||
allTools := []*MCPTool{}
|
||||
|
||||
// Create a new client, with no features.
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error connecting to MCP server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tools, err := session.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error listing tools: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tool := range tools.Tools {
|
||||
dat, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error marshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// XXX: This is a wild guess, to verify (data types might be incompatible)
|
||||
var inputSchema ToolInputSchema
|
||||
err = json.Unmarshal(dat, &inputSchema)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error unmarshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
props := map[string]jsonschema.Definition{}
|
||||
dat, err = json.Marshal(inputSchema.Properties)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error marshalling input schema: %v", err)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(dat, &props)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Error unmarshalling input schema properties: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
allTools = append(allTools, &MCPTool{
|
||||
name: tool.Name,
|
||||
description: tool.Description,
|
||||
session: session,
|
||||
ctx: ctx,
|
||||
props: props,
|
||||
inputSchema: inputSchema,
|
||||
})
|
||||
}
|
||||
|
||||
// We make sure we run Close on signal
|
||||
handleSignal(allTools)
|
||||
|
||||
return allTools, nil
|
||||
}
|
||||
|
||||
func handleSignal(tools []*MCPTool) {
|
||||
|
||||
// Create a channel to receive OS signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
// Register for interrupt and terminate signals
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Handle signals in a separate goroutine
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
for _, t := range tools {
|
||||
t.Close()
|
||||
}
|
||||
|
||||
// Exit the application
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -28,10 +28,10 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
@@ -65,6 +65,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCPCompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
// @Summary Generate completions for a given prompt and model.
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
toolsCache := map[string][]*mcp.MCPTool{}
|
||||
mu := sync.Mutex{}
|
||||
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
ctx := c.Context()
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
allTools := []*mcp.MCPTool{}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio := config.MCP.MCPConfigFromYAML()
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
mu.Lock()
|
||||
tools, exists := toolsCache[config.Name]
|
||||
if exists {
|
||||
allTools = append(allTools, tools...)
|
||||
} else {
|
||||
tools, err := mcp.ToolsFromMCPConfig(ctx, remote, stdio)
|
||||
if err != nil {
|
||||
mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
toolsCache[config.Name] = tools
|
||||
|
||||
allTools = append(allTools, tools...)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
cogitoTools := []cogito.Tool{}
|
||||
for _, tool := range allTools {
|
||||
cogitoTools = append(cogitoTools, tool)
|
||||
// defer tool.Close()
|
||||
}
|
||||
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
|
||||
for _, message := range input.Messages {
|
||||
fragment = fragment.AddMessage(message.Role, message.StringContent)
|
||||
}
|
||||
|
||||
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
|
||||
apiKey := ""
|
||||
if appConfig.ApiKeys != nil {
|
||||
apiKey = appConfig.ApiKeys[0]
|
||||
}
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithContext(ctx),
|
||||
// TODO: move these to configs
|
||||
cogito.EnableToolReEvaluator,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.WithTools(
|
||||
cogitoTools...,
|
||||
),
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = defaultLLM.Ask(ctx, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Text: f.LastMessage().Content}},
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
@@ -54,6 +54,16 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
||||
app.Post("/completions", completionChain...)
|
||||
app.Post("/v1/engines/:model/completions", completionChain...)
|
||||
|
||||
// MCPcompletion
|
||||
mcpCompletionChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
}
|
||||
app.Post("/mcp/v1/chat/completions", mcpCompletionChain...)
|
||||
app.Post("/mcp/chat/completions", mcpCompletionChain...)
|
||||
|
||||
// embeddings
|
||||
embeddingChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
|
||||
Reference in New Issue
Block a user