feat: add support to logitbias and logprobs (#7283)

* feat: add support to logprobs in results

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add support to logitbias

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-16 13:27:36 +01:00
committed by GitHub
parent cd7d384500
commit d7f9f3ac93
10 changed files with 385 additions and 12 deletions

View File

@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
}
}
// Extract logprobs and top_logprobs from proto and add to JSON data
// Following server.cpp pattern: logprobs maps to n_probs when provided
if (predict->logprobs() > 0) {
data["logprobs"] = predict->logprobs();
// Map logprobs to n_probs (following server.cpp line 369 pattern)
// n_probs will be set by params_from_json_cmpl if logprobs is provided
data["n_probs"] = predict->logprobs();
SRV_INF("Using logprobs: %d\n", predict->logprobs());
}
if (predict->toplogprobs() > 0) {
data["top_logprobs"] = predict->toplogprobs();
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
}
// Extract logit_bias from proto and add to JSON data
if (!predict->logitbias().empty()) {
try {
// Parse logit_bias JSON string from proto
json logit_bias_json = json::parse(predict->logitbias());
// Add to data - llama.cpp server expects it as an object (map)
data["logit_bias"] = logit_bias_json;
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
} catch (const json::parse_error& e) {
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
}
}
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();
@@ -568,6 +596,28 @@ public:
return Status::OK;
}
// Helper function to extract logprobs from JSON response
static json extract_logprobs_from_json(const json& res_json) {
json logprobs_json = json::object();
// Check for OAI-compatible format: choices[0].logprobs
if (res_json.contains("choices") && res_json["choices"].is_array() &&
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
logprobs_json = res_json["choices"][0]["logprobs"];
}
// Check for non-OAI format: completion_probabilities
else if (res_json.contains("completion_probabilities")) {
// Convert completion_probabilities to OAI format
logprobs_json["content"] = res_json["completion_probabilities"];
}
// Check for direct logprobs field
else if (res_json.contains("logprobs")) {
logprobs_json = res_json["logprobs"];
}
return logprobs_json;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request, ctx_server);
@@ -915,6 +965,13 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
} else {
@@ -934,6 +991,13 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(first_res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
@@ -969,6 +1033,13 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
} else {
@@ -988,6 +1059,13 @@ public:
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
writer->Write(reply);
}
}
@@ -1335,28 +1413,54 @@ public:
if (all_results.results.size() == 1) {
// single result
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
reply->set_message(all_results.results[0]->to_json().value("content", ""));
json result_json = all_results.results[0]->to_json();
reply->set_message(result_json.value("content", ""));
int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0);
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
reply->set_tokens(tokens_predicted);
int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0);
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);
if (all_results.results[0]->to_json().contains("timings")) {
double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0);
if (result_json.contains("timings")) {
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0);
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(result_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply->set_logprobs(logprobs_str);
}
} else {
// multiple results (multitask)
json arr = json::array();
json logprobs_arr = json::array();
bool has_logprobs = false;
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json().value("content", ""));
json res_json = res->to_json();
arr.push_back(res_json.value("content", ""));
// Extract logprobs for each result
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
has_logprobs = true;
logprobs_arr.push_back(logprobs_json);
} else {
logprobs_arr.push_back(json::object());
}
}
reply->set_message(arr);
// Set logprobs if any result has them
if (has_logprobs) {
std::string logprobs_str = logprobs_arr.dump();
reply->set_logprobs(logprobs_str);
}
}
}