Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-13 17:17:07 +01:00
parent 6a15419ced
commit 2501ca3ff2

View File

@@ -569,18 +569,21 @@ public:
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
std::cout << "[PredictStream] Starting PredictStream request" << std::endl;
json data = parse_options(true, request, ctx_server);
std::cout << "[PredictStream] Parsed options, stream=true" << std::endl;
//Raise error if embeddings is set to true
if (ctx_server.params_base.embedding) {
std::cout << "[PredictStream] ERROR: Embedding is not supported in streaming mode" << std::endl;
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode");
}
auto completion_id = gen_chatcmplid();
std::cout << "[PredictStream] Generated completion_id: " << completion_id << std::endl;
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
const auto rd = std::make_shared<server_response_reader>(ctx_server);
std::cout << "[PredictStream] Created server_response_reader" << std::endl;
try {
std::vector<server_task> tasks;
@@ -873,25 +876,44 @@ public:
}
rd->post_tasks(std::move(tasks));
std::cout << "[PredictStream] Posted " << tasks.size() << " tasks to queue" << std::endl;
} catch (const std::exception & e) {
std::cout << "[PredictStream] EXCEPTION during task creation: " << e.what() << std::endl;
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
std::cout << "[PredictStream] Waiting for first result..." << std::endl;
// Get first result for error checking (following server.cpp pattern)
server_task_result_ptr first_result = rd->next([&context]() { return context->IsCancelled(); });
std::cout << "[PredictStream] Received first result, is_null=" << (first_result == nullptr) << std::endl;
if (first_result == nullptr) {
// connection is closed
std::cout << "[PredictStream] First result is nullptr, connection closed" << std::endl;
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
} else if (first_result->is_error()) {
std::cout << "[PredictStream] First result is an ERROR" << std::endl;
json error_json = first_result->to_json();
std::cout << "[PredictStream] Error JSON: " << error_json.dump() << std::endl;
backend::Reply reply;
reply.set_message(error_json.value("message", ""));
std::cout << "[PredictStream] Writing error reply to stream" << std::endl;
writer->Write(reply);
std::cout << "[PredictStream] Returning INTERNAL error status" << std::endl;
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
} else {
// Ensure first result is a completion result (partial or final)
std::cout << "[PredictStream] First result is valid, checking type..." << std::endl;
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
);
std::cout << "[PredictStream] First result type check passed" << std::endl;
}
// Process first result
std::cout << "[PredictStream] Processing first result..." << std::endl;
json first_res_json = first_result->to_json();
std::cout << "[PredictStream] First result JSON: " << first_res_json.dump(2) << std::endl;
if (first_res_json.is_array()) {
for (const auto & res : first_res_json) {
std::string completion_text = res.value("content", "");
@@ -910,7 +932,9 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
writer->Write(reply);
std::cout << "[PredictStream] Writing first result array element, message length=" << completion_text.length() << std::endl;
bool write_ok = writer->Write(reply);
std::cout << "[PredictStream] Write result: " << (write_ok ? "OK" : "FAILED") << std::endl;
}
} else {
std::string completion_text = first_res_json.value("content", "");
@@ -929,23 +953,55 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
writer->Write(reply);
std::cout << "[PredictStream] Writing first result (non-array), message length=" << completion_text.length() << std::endl;
bool write_ok = writer->Write(reply);
std::cout << "[PredictStream] Write result: " << (write_ok ? "OK" : "FAILED") << std::endl;
}
// Process subsequent results
std::cout << "[PredictStream] Starting to process subsequent results, has_next=" << rd->has_next() << std::endl;
int result_count = 0;
while (rd->has_next()) {
result_count++;
std::cout << "[PredictStream] Processing result #" << result_count << std::endl;
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
std::cout << "[PredictStream] Context cancelled, breaking loop" << std::endl;
break;
}
std::cout << "[PredictStream] Calling rd->next()..." << std::endl;
auto result = rd->next([&context]() { return context->IsCancelled(); });
std::cout << "[PredictStream] Received result, is_null=" << (result == nullptr) << std::endl;
if (result == nullptr) {
// connection is closed
std::cout << "[PredictStream] Result is nullptr, connection closed, breaking" << std::endl;
break;
}
// Check for errors in subsequent results
if (result->is_error()) {
std::cout << "[PredictStream] Result #" << result_count << " is an ERROR" << std::endl;
json error_json = result->to_json();
std::cout << "[PredictStream] Error JSON: " << error_json.dump() << std::endl;
backend::Reply reply;
reply.set_message(error_json.value("message", ""));
std::cout << "[PredictStream] Writing error reply to stream" << std::endl;
writer->Write(reply);
std::cout << "[PredictStream] Returning INTERNAL error status" << std::endl;
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
} else {
// Ensure result is a completion result (partial or final)
std::cout << "[PredictStream] Result #" << result_count << " is valid, checking type..." << std::endl;
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
std::cout << "[PredictStream] Result #" << result_count << " type check passed" << std::endl;
}
json res_json = result->to_json();
std::cout << "[PredictStream] Result #" << result_count << " JSON: " << res_json.dump(2) << std::endl;
if (res_json.is_array()) {
for (const auto & res : res_json) {
std::string completion_text = res.value("content", "");
@@ -964,7 +1020,9 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
writer->Write(reply);
std::cout << "[PredictStream] Writing result #" << result_count << " array element, message length=" << completion_text.length() << std::endl;
bool write_ok = writer->Write(reply);
std::cout << "[PredictStream] Write result: " << (write_ok ? "OK" : "FAILED") << std::endl;
}
} else {
std::string completion_text = res_json.value("content", "");
@@ -983,15 +1041,20 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
writer->Write(reply);
std::cout << "[PredictStream] Writing result #" << result_count << " (non-array), message length=" << completion_text.length() << std::endl;
bool write_ok = writer->Write(reply);
std::cout << "[PredictStream] Write result: " << (write_ok ? "OK" : "FAILED") << std::endl;
}
}
std::cout << "[PredictStream] Finished processing all results, processed " << result_count << " subsequent results" << std::endl;
// Check if context was cancelled during processing
if (context->IsCancelled()) {
std::cout << "[PredictStream] Context was cancelled, returning CANCELLED status" << std::endl;
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
std::cout << "[PredictStream] Returning OK status" << std::endl;
return grpc::Status::OK;
}
@@ -1003,9 +1066,12 @@ public:
if (ctx_server.params_base.embedding) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode");
}
std::cout << "[PREDICT] Starting Predict request" << std::endl;
std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl;
auto completion_id = gen_chatcmplid();
std::cout << "[PREDICT] Generated completion_id: " << completion_id << std::endl;
const auto rd = std::make_shared<server_response_reader>(ctx_server);
std::cout << "[PREDICT] Created server_response_reader" << std::endl;
try {
std::vector<server_task> tasks;
@@ -1304,24 +1370,32 @@ public:
}
rd->post_tasks(std::move(tasks));
std::cout << "[PREDICT] Posted " << tasks.size() << " tasks to queue" << std::endl;
} catch (const std::exception & e) {
std::cout << "[PREDICT] EXCEPTION during task creation: " << e.what() << std::endl;
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
}
std::cout << "[DEBUG] Waiting for results..." << std::endl;
std::cout << "[PREDICT] Waiting for all results..." << std::endl;
// Wait for all results
auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); });
std::cout << "[PREDICT] wait_for_all returned, is_terminated=" << all_results.is_terminated
<< ", has_error=" << (all_results.error != nullptr)
<< ", results_count=" << all_results.results.size() << std::endl;
if (all_results.is_terminated) {
std::cout << "[PREDICT] Request was terminated, returning CANCELLED status" << std::endl;
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
} else if (all_results.error) {
std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl;
reply->set_message(all_results.error->to_json().value("message", ""));
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred"));
std::cout << "[PREDICT] Error in results: " << all_results.error->to_json().value("message", "") << std::endl;
json error_json = all_results.error->to_json();
std::cout << "[PREDICT] Error JSON: " << error_json.dump() << std::endl;
reply->set_message(error_json.value("message", ""));
std::cout << "[PREDICT] Returning INTERNAL error status" << std::endl;
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
} else {
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
std::cout << "[PREDICT] Received " << all_results.results.size() << " results" << std::endl;
if (all_results.results.size() == 1) {
// single result
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
@@ -1350,13 +1424,15 @@ public:
}
}
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
std::cout << "[PREDICT] Predict request completed successfully" << std::endl;
// Check if context was cancelled during processing
if (context->IsCancelled()) {
std::cout << "[PREDICT] Context was cancelled, returning CANCELLED status" << std::endl;
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
}
std::cout << "[PREDICT] Returning OK status" << std::endl;
return grpc::Status::OK;
}