mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-04 09:40:32 -06:00
feat: add support to logitbias and logprobs (#7283)
* feat: add support to logprobs in results Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: add support to logitbias Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
cd7d384500
commit
d7f9f3ac93
@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs and top_logprobs from proto and add to JSON data
|
||||
// Following server.cpp pattern: logprobs maps to n_probs when provided
|
||||
if (predict->logprobs() > 0) {
|
||||
data["logprobs"] = predict->logprobs();
|
||||
// Map logprobs to n_probs (following server.cpp line 369 pattern)
|
||||
// n_probs will be set by params_from_json_cmpl if logprobs is provided
|
||||
data["n_probs"] = predict->logprobs();
|
||||
SRV_INF("Using logprobs: %d\n", predict->logprobs());
|
||||
}
|
||||
if (predict->toplogprobs() > 0) {
|
||||
data["top_logprobs"] = predict->toplogprobs();
|
||||
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
|
||||
}
|
||||
|
||||
// Extract logit_bias from proto and add to JSON data
|
||||
if (!predict->logitbias().empty()) {
|
||||
try {
|
||||
// Parse logit_bias JSON string from proto
|
||||
json logit_bias_json = json::parse(predict->logitbias());
|
||||
// Add to data - llama.cpp server expects it as an object (map)
|
||||
data["logit_bias"] = logit_bias_json;
|
||||
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
|
||||
} catch (const json::parse_error& e) {
|
||||
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
@@ -568,6 +596,28 @@ public:
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
// Helper function to extract logprobs from JSON response
|
||||
static json extract_logprobs_from_json(const json& res_json) {
|
||||
json logprobs_json = json::object();
|
||||
|
||||
// Check for OAI-compatible format: choices[0].logprobs
|
||||
if (res_json.contains("choices") && res_json["choices"].is_array() &&
|
||||
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
|
||||
logprobs_json = res_json["choices"][0]["logprobs"];
|
||||
}
|
||||
// Check for non-OAI format: completion_probabilities
|
||||
else if (res_json.contains("completion_probabilities")) {
|
||||
// Convert completion_probabilities to OAI format
|
||||
logprobs_json["content"] = res_json["completion_probabilities"];
|
||||
}
|
||||
// Check for direct logprobs field
|
||||
else if (res_json.contains("logprobs")) {
|
||||
logprobs_json = res_json["logprobs"];
|
||||
}
|
||||
|
||||
return logprobs_json;
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
json data = parse_options(true, request, ctx_server);
|
||||
|
||||
@@ -915,6 +965,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -934,6 +991,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(first_res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
||||
@@ -969,6 +1033,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -988,6 +1059,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
}
|
||||
@@ -1335,28 +1413,54 @@ public:
|
||||
if (all_results.results.size() == 1) {
|
||||
// single result
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
|
||||
reply->set_message(all_results.results[0]->to_json().value("content", ""));
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0);
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (all_results.results[0]->to_json().contains("timings")) {
|
||||
double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0);
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
json logprobs_arr = json::array();
|
||||
bool has_logprobs = false;
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json().value("content", ""));
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
has_logprobs = true;
|
||||
logprobs_arr.push_back(logprobs_json);
|
||||
} else {
|
||||
logprobs_arr.push_back(json::object());
|
||||
}
|
||||
}
|
||||
reply->set_message(arr);
|
||||
|
||||
// Set logprobs if any result has them
|
||||
if (has_logprobs) {
|
||||
std::string logprobs_str = logprobs_arr.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user