diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go index 6441dcf48..220fbb612 100644 --- a/cmd/launcher/main.go +++ b/cmd/launcher/main.go @@ -2,14 +2,12 @@ package main import ( "log" - "os" - "os/signal" - "syscall" "fyne.io/fyne/v2" "fyne.io/fyne/v2/app" "fyne.io/fyne/v2/driver/desktop" coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal" + "github.com/mudler/LocalAI/pkg/signals" ) func main() { @@ -42,7 +40,12 @@ func main() { } // Setup signal handling for graceful shutdown - setupSignalHandling(launcher) + signals.RegisterGracefulTerminationHandler(func() { + // Perform cleanup + if err := launcher.Shutdown(); err != nil { + log.Printf("Error during shutdown: %v", err) + } + }) // Initialize the launcher state go func() { @@ -67,26 +70,3 @@ func main() { // Run the application in background (window only shown when "Settings" is clicked) myApp.Run() } - -// setupSignalHandling sets up signal handlers for graceful shutdown -func setupSignalHandling(launcher *coreLauncher.Launcher) { - // Create a channel to receive OS signals - sigChan := make(chan os.Signal, 1) - - // Register for interrupt and terminate signals - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // Handle signals in a separate goroutine - go func() { - sig := <-sigChan - log.Printf("Received signal %v, shutting down gracefully...", sig) - - // Perform cleanup - if err := launcher.Shutdown(); err != nil { - log.Printf("Error during shutdown: %v", err) - } - - // Exit the application - os.Exit(0) - }() -} diff --git a/core/cli/explorer.go b/core/cli/explorer.go index 4e35657ef..b12735c73 100644 --- a/core/cli/explorer.go +++ b/core/cli/explorer.go @@ -5,9 +5,10 @@ import ( "time" cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/cli/signals" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http" + "github.com/mudler/LocalAI/pkg/signals" + "github.com/rs/zerolog/log" ) type ExplorerCMD struct { @@ -46,7 +47,11 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error { appHTTP := http.Explorer(db) - signals.Handler(nil) + signals.RegisterGracefulTerminationHandler(func() { + if err := appHTTP.Shutdown(); err != nil { + log.Error().Err(err).Msg("error during shutdown") + } + }) return appHTTP.Listen(e.Address) } diff --git a/core/cli/federated.go b/core/cli/federated.go index 3ad20304d..ceea5a9e4 100644 --- a/core/cli/federated.go +++ b/core/cli/federated.go @@ -4,8 +4,8 @@ import ( "context" cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/cli/signals" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/signals" ) type FederatedCLI struct { @@ -20,7 +20,11 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error { fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker) - signals.Handler(nil) + c, cancel := context.WithCancel(context.Background()) - return fs.Start(context.Background()) + signals.RegisterGracefulTerminationHandler(func() { + cancel() + }) + + return fs.Start(c) } diff --git a/core/cli/run.go b/core/cli/run.go index 473440041..560b2d8f2 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -10,11 +10,11 @@ import ( "github.com/mudler/LocalAI/core/application" cli_api "github.com/mudler/LocalAI/core/cli/api" cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/cli/signals" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/internal" + "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -226,8 +226,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { return err } - // Catch signals from the OS requesting us to exit, and stop all backends - signals.Handler(app.ModelLoader()) + signals.RegisterGracefulTerminationHandler(func() { + if err := app.ModelLoader().StopAllGRPC(); err != nil { + log.Error().Err(err).Msg("error while stopping all grpc backends") + } + }) return appHTTP.Listen(r.Address) } diff --git a/core/cli/signals/signals.go b/core/cli/signals/signals.go deleted file mode 100644 index af218c91b..000000000 --- a/core/cli/signals/signals.go +++ /dev/null @@ -1,25 +0,0 @@ -package signals - -import ( - "os" - "os/signal" - "syscall" - - "github.com/mudler/LocalAI/pkg/model" - "github.com/rs/zerolog/log" -) - -func Handler(m *model.ModelLoader) { - // Catch signals from the OS requesting us to exit, and stop all backends - go func(m *model.ModelLoader) { - c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked - signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - <-c - if m != nil { - if err := m.StopAllGRPC(); err != nil { - log.Error().Err(err).Msg("error while stopping all grpc backends") - } - } - os.Exit(1) - }(m) -} diff --git a/core/cli/worker/worker_llamacpp.go b/core/cli/worker/worker_llamacpp.go index b829f5cbb..8a55f2345 100644 --- a/core/cli/worker/worker_llamacpp.go +++ b/core/cli/worker/worker_llamacpp.go @@ -11,7 +11,6 @@ import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/cli/signals" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" @@ -85,8 +84,6 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error { args = append([]string{grpcProcess}, args...) - signals.Handler(nil) - return syscall.Exec( grpcProcess, args, diff --git a/core/cli/worker/worker_p2p.go b/core/cli/worker/worker_p2p.go index 3ff98efe4..6263502d9 100644 --- a/core/cli/worker/worker_p2p.go +++ b/core/cli/worker/worker_p2p.go @@ -9,8 +9,8 @@ import ( "time" cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/cli/signals" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/phayes/freeport" "github.com/rs/zerolog/log" @@ -48,6 +48,9 @@ func (r *P2P) Run(ctx *cliContext.Context) error { address := "127.0.0.1" + c, cancel := context.WithCancel(context.Background()) + defer cancel() + if r.NoRunner { // Let override which port and address to bind if the user // configure the llama-cpp service on its own @@ -59,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { p = r.RunnerPort } - _, err = p2p.ExposeService(context.Background(), address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) if err != nil { return err } @@ -101,13 +104,15 @@ func (r *P2P) Run(ctx *cliContext.Context) error { } }() - _, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) if err != nil { return err } } - signals.Handler(nil) + signals.RegisterGracefulTerminationHandler(func() { + cancel() + }) for { time.Sleep(1 * time.Second) diff --git a/core/http/endpoints/mcp/tools.go b/core/http/endpoints/mcp/tools.go index 83b4ad570..e9d597771 100644 --- a/core/http/endpoints/mcp/tools.go +++ b/core/http/endpoints/mcp/tools.go @@ -2,43 +2,66 @@ package mcp import ( "context" - "encoding/json" - "errors" "net/http" "os" "os/exec" - "os/signal" - "syscall" + "sync" "time" "github.com/mudler/LocalAI/core/config" - "github.com/sashabaranov/go-openai" - "github.com/tmc/langchaingo/jsonschema" + "github.com/mudler/LocalAI/pkg/signals" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/rs/zerolog/log" ) -func ToolsFromMCPConfig(ctx context.Context, remote config.MCPGenericConfig[config.MCPRemoteServers], stdio config.MCPGenericConfig[config.MCPSTDIOServers]) ([]*MCPTool, error) { - allTools := []*MCPTool{} +type sessionCache struct { + mu sync.Mutex + cache map[string][]*mcp.ClientSession +} + +var ( + cache = sessionCache{ + cache: make(map[string][]*mcp.ClientSession), + } + + client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil) +) + +func SessionsFromMCPConfig( + name string, + remote config.MCPGenericConfig[config.MCPRemoteServers], + stdio config.MCPGenericConfig[config.MCPSTDIOServers], +) ([]*mcp.ClientSession, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + sessions, exists := cache.cache[name] + if exists { + return sessions, nil + } + + allSessions := []*mcp.ClientSession{} + + ctx, cancel := context.WithCancel(context.Background()) // Get the list of all the tools that the Agent will be esposed to for _, server := range remote.Servers { - + log.Debug().Msgf("[MCP remote server] Configuration : %+v", server) // Create HTTP client with custom roundtripper for bearer token injection - client := &http.Client{ + httpClient := &http.Client{ Timeout: 360 * time.Second, Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), } - tools, err := mcpToolsFromTransport(ctx, - &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: client}, - ) + transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient} + mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, err + log.Error().Err(err).Msgf("Failed to connect to MCP server %s", server.URL) + continue } - - allTools = append(allTools, tools...) + log.Debug().Msgf("[MCP remote server] Connected to MCP server %s", server.URL) + cache.cache[name] = append(cache.cache[name], mcpSession) } for _, server := range stdio.Servers { @@ -48,18 +71,24 @@ func ToolsFromMCPConfig(ctx context.Context, remote config.MCPGenericConfig[conf for key, value := range server.Env { command.Env = append(command.Env, key+"="+value) } - tools, err := mcpToolsFromTransport(ctx, - &mcp.CommandTransport{ - Command: command}, - ) + transport := &mcp.CommandTransport{Command: command} + mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, err + log.Error().Err(err).Msgf("Failed to start MCP server %s", command) + continue } - - allTools = append(allTools, tools...) + log.Debug().Msgf("[MCP stdio server] Connected to MCP server %s", command) + cache.cache[name] = append(cache.cache[name], mcpSession) } - return allTools, nil + signals.RegisterGracefulTerminationHandler(func() { + for _, session := range allSessions { + session.Close() + } + cancel() + }) + + return allSessions, nil } // bearerTokenRoundTripper is a custom roundtripper that injects a bearer token @@ -87,146 +116,3 @@ func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.Round base: base, } } - -type MCPTool struct { - name, description string - inputSchema ToolInputSchema - session *mcp.ClientSession - ctx context.Context - props map[string]jsonschema.Definition -} - -func (t *MCPTool) Run(args map[string]any) (string, error) { - - // Call a tool on the server. - params := &mcp.CallToolParams{ - Name: t.name, - Arguments: args, - } - res, err := t.session.CallTool(t.ctx, params) - if err != nil { - log.Error().Msgf("CallTool failed: %v", err) - return "", err - } - if res.IsError { - log.Error().Msgf("tool failed") - return "", errors.New("tool failed") - } - - result := "" - for _, c := range res.Content { - result += c.(*mcp.TextContent).Text - } - - return result, nil -} - -func (t *MCPTool) Tool() openai.Tool { - - return openai.Tool{ - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ - Name: t.name, - Description: t.description, - Parameters: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: t.props, - Required: t.inputSchema.Required, - }, - }, - } -} - -func (t *MCPTool) Close() { - t.session.Close() -} - -type ToolInputSchema struct { - Type string `json:"type"` - Properties map[string]interface{} `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} - -// probe the MCP remote and generate tools that are compliant with cogito -// TODO: Maybe move this to cogito? -func mcpToolsFromTransport(ctx context.Context, transport mcp.Transport) ([]*MCPTool, error) { - allTools := []*MCPTool{} - - // Create a new client, with no features. - client := mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - log.Error().Msgf("Error connecting to MCP server: %v", err) - return nil, err - } - - tools, err := session.ListTools(ctx, nil) - if err != nil { - log.Error().Msgf("Error listing tools: %v", err) - return nil, err - } - - for _, tool := range tools.Tools { - dat, err := json.Marshal(tool.InputSchema) - if err != nil { - log.Error().Msgf("Error marshalling input schema: %v", err) - continue - } - - // XXX: This is a wild guess, to verify (data types might be incompatible) - var inputSchema ToolInputSchema - err = json.Unmarshal(dat, &inputSchema) - if err != nil { - log.Error().Msgf("Error unmarshalling input schema: %v", err) - continue - } - - props := map[string]jsonschema.Definition{} - dat, err = json.Marshal(inputSchema.Properties) - if err != nil { - log.Error().Msgf("Error marshalling input schema: %v", err) - continue - } - err = json.Unmarshal(dat, &props) - if err != nil { - log.Error().Msgf("Error unmarshalling input schema properties: %v", err) - continue - } - - allTools = append(allTools, &MCPTool{ - name: tool.Name, - description: tool.Description, - session: session, - ctx: ctx, - props: props, - inputSchema: inputSchema, - }) - } - - // We make sure we run Close on signal - handleSignal(allTools) - - return allTools, nil -} - -func handleSignal(tools []*MCPTool) { - - // Create a channel to receive OS signals - sigChan := make(chan os.Signal, 1) - - // Register for interrupt and terminate signals - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // Handle signals in a separate goroutine - go func() { - sig := <-sigChan - log.Printf("Received signal %v, shutting down gracefully...", sig) - - for _, t := range tools { - t.Close() - } - - // Exit the application - os.Exit(0) - }() -} diff --git a/core/http/endpoints/openai/mcp.go b/core/http/endpoints/openai/mcp.go index 5ef17fc9f..ed0d0f843 100644 --- a/core/http/endpoints/openai/mcp.go +++ b/core/http/endpoints/openai/mcp.go @@ -5,11 +5,10 @@ import ( "errors" "fmt" "strings" - "sync" "time" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/endpoints/mcp" + mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" "github.com/gofiber/fiber/v2" @@ -27,10 +26,6 @@ import ( // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /mcp/v1/completions [post] func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - - toolsCache := map[string][]*mcp.MCPTool{} - mu := sync.Mutex{} - // We do not support streaming mode (Yet?) return func(c *fiber.Ctx) error { created := int(time.Now().Unix()) @@ -54,37 +49,17 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, return fmt.Errorf("no MCP servers configured") } - allTools := []*mcp.MCPTool{} - // Get MCP config from model config remote, stdio := config.MCP.MCPConfigFromYAML() // Check if we have tools in cache, or we have to have an initial connection - mu.Lock() - tools, exists := toolsCache[config.Name] - if exists { - allTools = append(allTools, tools...) - } else { - tools, err := mcp.ToolsFromMCPConfig(ctx, remote, stdio) - if err != nil { - mu.Unlock() - return err - } - - toolsCache[config.Name] = tools - - allTools = append(allTools, tools...) - } - mu.Unlock() - - cogitoTools := []cogito.Tool{} - for _, tool := range allTools { - cogitoTools = append(cogitoTools, tool) - // defer tool.Close() + sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio) + if err != nil { + return err } - if len(cogitoTools) == 0 { - return fmt.Errorf("no tools found in the specified MCP servers") + if len(sessions) == 0 { + return fmt.Errorf("no working MCP servers found") } fragment := cogito.NewEmptyFragment() @@ -109,7 +84,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s) }), cogito.WithContext(ctx), - cogito.WithTools(cogitoTools...), + cogito.WithMCPs(sessions...), cogito.WithIterations(3), // default to 3 iterations cogito.WithMaxAttempts(3), // default to 3 attempts } diff --git a/go.mod b/go.mod index df4d8266e..d2fbddabd 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/mholt/archiver/v3 v3.5.1 github.com/microcosm-cc/bluemonday v1.0.27 github.com/modelcontextprotocol/go-sdk v1.0.0 - github.com/mudler/cogito v0.1.0 + github.com/mudler/cogito v0.2.0 github.com/mudler/edgevpn v0.31.0 github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 github.com/nikolalohinski/gonja/v2 v2.4.1 @@ -60,6 +60,7 @@ require ( go.opentelemetry.io/otel/metric v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.36.8 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 oras.land/oras-go/v2 v2.6.0 @@ -149,7 +150,6 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/image v0.25.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/protobuf v1.36.8 // indirect ) require ( diff --git a/go.sum b/go.sum index 6146b4690..acbc16bfd 100644 --- a/go.sum +++ b/go.sum @@ -510,8 +510,8 @@ github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7P github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= -github.com/mudler/cogito v0.1.0 h1:RybskLSPuLkBlR9Z+y4LJgIU5wVscYoHuF9+ubXsHgM= -github.com/mudler/cogito v0.1.0/go.mod h1:MiipcWbTr+fcW3HiirQRrYYjEIamZFCLkpqvdgk/Nfw= +github.com/mudler/cogito v0.2.0 h1:UzowMlP6kiDLnuwQikac9yUOhI6Qe2tW1jZP5gHQvaY= +github.com/mudler/cogito v0.2.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU= github.com/mudler/edgevpn v0.31.0 h1:CXwxQ2ZygzE7iKGl1J+vq9pL5PvsW2uc3qI/zgpNpp4= github.com/mudler/edgevpn v0.31.0/go.mod h1:DKgh9Wu/NM3UbZoPyheMXFvpu1dSLkXrqAOy3oKJN3I= github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA= diff --git a/pkg/model/process.go b/pkg/model/process.go index 5bef6d4d7..bf40a7894 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -4,14 +4,13 @@ import ( "errors" "fmt" "os" - "os/signal" "path/filepath" "strconv" "strings" - "syscall" "time" "github.com/hpcloud/tail" + "github.com/mudler/LocalAI/pkg/signals" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -130,16 +129,13 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string } log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) - // clean up process - go func() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c + + signals.RegisterGracefulTerminationHandler(func() { err := grpcControlProcess.Stop() if err != nil { log.Error().Err(err).Msg("error while shutting down grpc process") } - }() + }) go func() { t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) diff --git a/pkg/signals/handler.go b/pkg/signals/handler.go new file mode 100644 index 000000000..22b8e8455 --- /dev/null +++ b/pkg/signals/handler.go @@ -0,0 +1,40 @@ +package signals + +import ( + "os" + "os/signal" + "sync" + "syscall" +) + +var ( + signalHandlers []func() + signalHandlersMutex sync.Mutex + signalHandlersOnce sync.Once +) + +func RegisterGracefulTerminationHandler(fn func()) { + signalHandlersMutex.Lock() + defer signalHandlersMutex.Unlock() + signalHandlers = append(signalHandlers, fn) +} + +func init() { + signalHandlersOnce.Do(func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + go signalHandler(c) + }) +} + +func signalHandler(c chan os.Signal) { + <-c + + signalHandlersMutex.Lock() + defer signalHandlersMutex.Unlock() + for _, fn := range signalHandlers { + fn() + } + + os.Exit(0) +}