feat: respect context and add request cancellation (#7187)

* feat: respect context

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* workaround fasthttp

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(ui): allow to abort call

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Refactor

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: improving error

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Respect context also with MCP

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Tie to both contexts

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make detection more robust

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-09 18:19:19 +01:00
committed by GitHub
parent 4730b52461
commit 679d43c2f5
8 changed files with 240 additions and 42 deletions

View File

@@ -822,6 +822,12 @@ public:
}
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
return false;
}
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
@@ -875,13 +881,18 @@ public:
reply.set_message(error_data.value("content", ""));
writer->Write(reply);
return true;
}, [&]() {
// NOTE: we should try to check when the writer is closed here
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
return grpc::Status::OK;
}
@@ -1145,6 +1156,14 @@ public:
std::cout << "[DEBUG] Waiting for results..." << std::endl;
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
if (results.size() == 1) {
@@ -1176,13 +1195,20 @@ public:
}, [&](const json & error_data) {
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
reply->set_message(error_data.value("content", ""));
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
return grpc::Status::OK;
}
@@ -1234,6 +1260,13 @@ public:
ctx_server.queue_tasks.post(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
// get the result
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
@@ -1242,12 +1275,18 @@ public:
}
}, [&](const json & error_data) {
error = true;
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
@@ -1325,6 +1364,13 @@ public:
ctx_server.queue_tasks.post(std::move(tasks));
}
// Check cancellation before waiting for results
if (context->IsCancelled()) {
ctx_server.cancel_tasks(task_ids);
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
// Get the results
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
@@ -1333,12 +1379,18 @@ public:
}
}, [&](const json & error_data) {
error = true;
}, [&]() {
return false;
}, [&context]() {
// Check if the gRPC context is cancelled
return context->IsCancelled();
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
// Check if context was cancelled during processing
if (context->IsCancelled()) {
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
if (error) {
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}

View File

@@ -93,19 +93,18 @@ type AgentConfig struct {
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio
return remote, stdio, err
}
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio
return remote, stdio, err
}
return remote, stdio
return remote, stdio, nil
}
type MCPGenericConfig[T any] struct {

View File

@@ -3,8 +3,10 @@ package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"time"
"github.com/gofiber/fiber/v2"
@@ -22,6 +24,59 @@ import (
"github.com/valyala/fasthttp"
)
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
// Fasthttp doesn't support context cancellation from the caller
// for non-streaming requests, so we need to monitor the connection directly.
// Monitor connection for client disconnection during non-streaming requests
// We access the connection directly via c.Context().Conn() to monitor it
// during ComputeChoices execution, not after the response is sent
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
var conn net.Conn = c.Context().Conn()
if conn == nil {
return
}
go func() {
defer func() {
// Clear read deadline when goroutine exits
conn.SetReadDeadline(time.Time{})
}()
buf := make([]byte, 1)
// Use a short read deadline to periodically check if connection is closed
// Without a deadline, Read() would block indefinitely waiting for data
// that will never come (client is waiting for response, not sending more data)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-requestCtx.Done():
// Request completed or was cancelled - exit goroutine
return
case <-ticker.C:
// Set a short deadline - if connection is closed, read will fail immediately
// If connection is open but no data, it will timeout and we check again
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
_, err := conn.Read(buf)
if err != nil {
// Check if it's a timeout (connection still open, just no data)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout is expected - connection is still open, just no data to read
// Continue the loop to check again
continue
}
// Connection closed or other error - cancel the context to stop gRPC call
log.Debug().Msgf("Calling cancellation function")
cancelFunc()
return
}
}
}
}()
}
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
// @Summary Generate a chat completions for a given prompt and model.
// @Param request body schema.OpenAIRequest true "query params"
@@ -358,6 +413,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
@@ -511,6 +571,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
// NOTE: this is a workaround as fasthttp
// context cancellation does not fire in non-streaming requests
handleConnectionCancellation(c, input.Cancel, input.Context)
result, tokenUsage, err := ComputeChoices(
input,
predInput,

View File

@@ -1,6 +1,7 @@
package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -50,12 +51,15 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
}
// Get MCP config from model config
remote, stdio := config.MCP.MCPConfigFromYAML()
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return err
return fmt.Errorf("failed to get MCP sessions: %w", err)
}
if len(sessions) == 0 {
@@ -73,6 +77,10 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
if appConfig.ApiKeys != nil {
apiKey = appConfig.ApiKeys[0]
}
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
handleConnectionCancellation(c, cancel, ctxWithCancellation)
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
@@ -83,7 +91,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithContext(ctx),
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts

View File

@@ -161,7 +161,17 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Use the application context as parent to ensure cancellation on app shutdown
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Monitor the Fiber context and cancel our context when it's canceled
// This ensures we respect request cancellation without causing panics
go func() {
<-ctx.Context().Done()
// Fiber context was canceled (request completed or client disconnected)
cancel()
}()
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)

View File

@@ -27,21 +27,43 @@ SOFTWARE.
*/
// Global variable to store the current AbortController
let currentAbortController = null;
let currentReader = null;
function toggleLoader(show) {
const loader = document.getElementById('loader');
const sendButton = document.getElementById('send-button');
const stopButton = document.getElementById('stop-button');
if (show) {
loader.style.display = 'block';
sendButton.style.display = 'none';
stopButton.style.display = 'block';
document.getElementById("input").disabled = true;
} else {
document.getElementById("input").disabled = false;
loader.style.display = 'none';
sendButton.style.display = 'block';
stopButton.style.display = 'none';
currentAbortController = null;
currentReader = null;
}
}
function stopRequest() {
if (currentAbortController) {
currentAbortController.abort();
currentAbortController = null;
}
if (currentReader) {
currentReader.cancel();
currentReader = null;
}
toggleLoader(false);
Alpine.store("chat").add(
"assistant",
`<span class='error'>Request cancelled by user</span>`,
);
}
function processThinkingTags(content) {
const thinkingRegex = /<thinking>(.*?)<\/thinking>|<think>(.*?)<\/think>/gs;
const parts = content.split(thinkingRegex);
@@ -295,8 +317,9 @@ async function promptGPT(systemPrompt, input) {
let response;
try {
// Create AbortController for timeout handling
// Create AbortController for timeout handling and stop button
const controller = new AbortController();
currentAbortController = controller; // Store globally so stop button can abort it
const timeoutId = setTimeout(() => controller.abort(), mcpMode ? 300000 : 30000); // 5 minutes for MCP, 30 seconds for regular
response = await fetch(endpoint, {
@@ -311,11 +334,20 @@ async function promptGPT(systemPrompt, input) {
clearTimeout(timeoutId);
} catch (error) {
// Don't show error if request was aborted by user (stop button)
if (error.name === 'AbortError') {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Request timeout: MCP processing is taking longer than expected. Please try again.</span>`,
);
// Check if this was a user-initiated abort (stop button was clicked)
// If currentAbortController is null, it means stopRequest() was called and already handled the UI
if (!currentAbortController) {
// User clicked stop button - error message already shown by stopRequest()
return;
} else {
// Timeout error (controller was aborted by timeout, not user)
Alpine.store("chat").add(
"assistant",
`<span class='error'>Request timeout: MCP processing is taking longer than expected. Please try again.</span>`,
);
}
} else {
Alpine.store("chat").add(
"assistant",
@@ -323,6 +355,7 @@ async function promptGPT(systemPrompt, input) {
);
}
toggleLoader(false);
currentAbortController = null;
return;
}
@@ -332,6 +365,7 @@ async function promptGPT(systemPrompt, input) {
`<span class='error'>Error: POST ${endpoint} ${response.status}</span>`,
);
toggleLoader(false);
currentAbortController = null;
return;
}
@@ -360,10 +394,15 @@ async function promptGPT(systemPrompt, input) {
// Highlight all code blocks
hljs.highlightAll();
} catch (error) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to parse MCP response</span>`,
);
// Don't show error if request was aborted by user
if (error.name !== 'AbortError' || currentAbortController) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to parse MCP response</span>`,
);
}
} finally {
currentAbortController = null;
}
} else {
// Handle regular streaming response
@@ -376,9 +415,13 @@ async function promptGPT(systemPrompt, input) {
"assistant",
`<span class='error'>Error: Failed to decode API response</span>`,
);
toggleLoader(false);
return;
}
// Store reader globally so stop button can cancel it
currentReader = reader;
// Function to add content to the chat and handle DOM updates efficiently
const addToChat = (token) => {
const chatStore = Alpine.store("chat");
@@ -479,13 +522,20 @@ async function promptGPT(systemPrompt, input) {
// Highlight all code blocks once at the end
hljs.highlightAll();
} catch (error) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to process stream</span>`,
);
// Don't show error if request was aborted by user
if (error.name !== 'AbortError' || !currentAbortController) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to process stream</span>`,
);
}
} finally {
// Perform any cleanup if necessary
reader.releaseLock();
if (reader) {
reader.releaseLock();
}
currentReader = null;
currentAbortController = null;
}
}

View File

@@ -402,15 +402,19 @@ SOFTWARE.
title="Upload text, markdown or PDF file"
></button>
<!-- Send button and loader in the same position -->
<!-- Send button and stop button in the same position -->
<div class="absolute right-3 top-4">
<!-- Loader (hidden by default) -->
<div id="loader" class="text-lg p-2" style="display: none;">
<svg class="animate-spin h-5 w-5 text-blue-500" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
</div>
<!-- Stop button (hidden by default, shown when request is in progress) -->
<button
id="stop-button"
type="button"
onclick="stopRequest()"
class="text-lg p-2 text-red-400 hover:text-red-500 transition-colors duration-200"
style="display: none;"
title="Stop request"
>
<i class="fa-solid fa-stop"></i>
</button>
<!-- Send button -->
<button

View File

@@ -178,11 +178,22 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
for {
// Check if context is cancelled before receiving
select {
case <-ctx.Done():
return ctx.Err()
default:
}
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
// Check if error is due to context cancellation
if ctx.Err() != nil {
return ctx.Err()
}
fmt.Println("Error", err)
return err