mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat: add support to logitbias and logprobs (#7283)
* feat: add support to logprobs in results Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: add support to logitbias Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
cd7d384500
commit
d7f9f3ac93
@@ -156,6 +156,8 @@ message PredictOptions {
|
||||
string CorrelationId = 47;
|
||||
string Tools = 48; // JSON array of available tools/functions for tool calling
|
||||
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
@@ -166,6 +168,7 @@ message Reply {
|
||||
double timing_prompt_processing = 4;
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
|
||||
@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs and top_logprobs from proto and add to JSON data
|
||||
// Following server.cpp pattern: logprobs maps to n_probs when provided
|
||||
if (predict->logprobs() > 0) {
|
||||
data["logprobs"] = predict->logprobs();
|
||||
// Map logprobs to n_probs (following server.cpp line 369 pattern)
|
||||
// n_probs will be set by params_from_json_cmpl if logprobs is provided
|
||||
data["n_probs"] = predict->logprobs();
|
||||
SRV_INF("Using logprobs: %d\n", predict->logprobs());
|
||||
}
|
||||
if (predict->toplogprobs() > 0) {
|
||||
data["top_logprobs"] = predict->toplogprobs();
|
||||
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
|
||||
}
|
||||
|
||||
// Extract logit_bias from proto and add to JSON data
|
||||
if (!predict->logitbias().empty()) {
|
||||
try {
|
||||
// Parse logit_bias JSON string from proto
|
||||
json logit_bias_json = json::parse(predict->logitbias());
|
||||
// Add to data - llama.cpp server expects it as an object (map)
|
||||
data["logit_bias"] = logit_bias_json;
|
||||
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
|
||||
} catch (const json::parse_error& e) {
|
||||
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
@@ -568,6 +596,28 @@ public:
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
// Helper function to extract logprobs from JSON response
|
||||
static json extract_logprobs_from_json(const json& res_json) {
|
||||
json logprobs_json = json::object();
|
||||
|
||||
// Check for OAI-compatible format: choices[0].logprobs
|
||||
if (res_json.contains("choices") && res_json["choices"].is_array() &&
|
||||
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
|
||||
logprobs_json = res_json["choices"][0]["logprobs"];
|
||||
}
|
||||
// Check for non-OAI format: completion_probabilities
|
||||
else if (res_json.contains("completion_probabilities")) {
|
||||
// Convert completion_probabilities to OAI format
|
||||
logprobs_json["content"] = res_json["completion_probabilities"];
|
||||
}
|
||||
// Check for direct logprobs field
|
||||
else if (res_json.contains("logprobs")) {
|
||||
logprobs_json = res_json["logprobs"];
|
||||
}
|
||||
|
||||
return logprobs_json;
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
json data = parse_options(true, request, ctx_server);
|
||||
|
||||
@@ -915,6 +965,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -934,6 +991,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(first_res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
||||
@@ -969,6 +1033,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -988,6 +1059,13 @@ public:
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
writer->Write(reply);
|
||||
}
|
||||
}
|
||||
@@ -1335,28 +1413,54 @@ public:
|
||||
if (all_results.results.size() == 1) {
|
||||
// single result
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
|
||||
reply->set_message(all_results.results[0]->to_json().value("content", ""));
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0);
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (all_results.results[0]->to_json().contains("timings")) {
|
||||
double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0);
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
json logprobs_arr = json::array();
|
||||
bool has_logprobs = false;
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json().value("content", ""));
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
has_logprobs = true;
|
||||
logprobs_arr.push_back(logprobs_json);
|
||||
} else {
|
||||
logprobs_arr.push_back(json::object());
|
||||
}
|
||||
}
|
||||
reply->set_message(arr);
|
||||
|
||||
// Set logprobs if any result has them
|
||||
if (has_logprobs) {
|
||||
std::string logprobs_str = logprobs_arr.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -24,6 +25,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -33,7 +35,7 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
@@ -78,6 +80,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
opts.Audios = audios
|
||||
opts.Tools = tools
|
||||
opts.ToolChoice = toolChoice
|
||||
if logprobs != nil {
|
||||
opts.Logprobs = int32(*logprobs)
|
||||
}
|
||||
if topLogprobs != nil {
|
||||
opts.TopLogprobs = int32(*topLogprobs)
|
||||
}
|
||||
if len(logitBias) > 0 {
|
||||
// Serialize logit_bias map to JSON string for proto
|
||||
logitBiasJSON, err := json.Marshal(logitBias)
|
||||
if err == nil {
|
||||
opts.LogitBias = string(logitBiasJSON)
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
@@ -109,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
|
||||
ss := ""
|
||||
var logprobs *schema.Logprobs
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||
@@ -120,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
// Process complete runes and accumulate them
|
||||
var completeRunes []byte
|
||||
for len(partialRune) > 0 {
|
||||
@@ -145,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
@@ -167,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
response = c.TemplateConfig.ReplyPrefix + response
|
||||
}
|
||||
|
||||
// Parse logprobs from reply if present
|
||||
var logprobs *schema.Logprobs
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
return LLMResponse{
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
}
|
||||
}
|
||||
|
||||
return &pb.PredictOptions{
|
||||
pbOpts := &pb.PredictOptions{
|
||||
Temperature: float32(*c.Temperature),
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TailFreeSamplingZ: float32(*c.TFZ),
|
||||
TypicalP: float32(*c.TypicalP),
|
||||
}
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
return pbOpts
|
||||
}
|
||||
|
||||
@@ -816,6 +816,83 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
topLogprobsVal := 3
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
LogProbs: true,
|
||||
TopLogProbs: topLogprobsVal,
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
|
||||
// Verify logprobs are present and have correct structure
|
||||
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
|
||||
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
|
||||
|
||||
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
|
||||
|
||||
foundatLeastToken := ""
|
||||
foundAtLeastBytes := []byte{}
|
||||
foundAtLeastTopLogprobBytes := []byte{}
|
||||
foundatLeastTopLogprob := ""
|
||||
// Verify logprobs content structure matches OpenAI format
|
||||
for _, logprobContent := range response.Choices[0].LogProbs.Content {
|
||||
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
|
||||
if len(logprobContent.Bytes) > 0 {
|
||||
foundAtLeastBytes = logprobContent.Bytes
|
||||
}
|
||||
if len(logprobContent.Token) > 0 {
|
||||
foundatLeastToken = logprobContent.Token
|
||||
}
|
||||
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
|
||||
|
||||
// If top_logprobs is requested, verify top_logprobs array respects the limit
|
||||
if len(logprobContent.TopLogProbs) > 0 {
|
||||
// Should respect top_logprobs limit (3 in this test)
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
|
||||
for _, topLogprob := range logprobContent.TopLogProbs {
|
||||
if len(topLogprob.Bytes) > 0 {
|
||||
foundAtLeastTopLogprobBytes = topLogprob.Bytes
|
||||
}
|
||||
if len(topLogprob.Token) > 0 {
|
||||
foundatLeastTopLogprob = topLogprob.Token
|
||||
}
|
||||
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(foundAtLeastBytes).ToNot(BeEmpty())
|
||||
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
|
||||
Expect(foundatLeastToken).ToNot(BeEmpty())
|
||||
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies logit_bias to chat completions when requested", func() {
|
||||
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
|
||||
logitBias := map[string]int{
|
||||
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
|
||||
}
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
LogitBias: logitBias,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
// If logit_bias is applied, the response should be generated successfully
|
||||
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
|
||||
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
@@ -635,7 +635,32 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
|
||||
}
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON)
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if input.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if input.TopLogprobs != nil {
|
||||
topLogprobs = input.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = input.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(input.LogitBias) > 0 {
|
||||
logitBias = input.LogitBias
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("model inference failed")
|
||||
return "", err
|
||||
|
||||
@@ -55,9 +55,34 @@ func ComputeChoices(
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if req.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if req.TopLogprobs != nil {
|
||||
topLogprobs = req.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = req.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(req.LogitBias) > 0 {
|
||||
logitBias = req.LogitBias
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON)
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
@@ -78,6 +103,11 @@ func ComputeChoices(
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Add logprobs to the last choice if present
|
||||
if prediction.Logprobs != nil && len(result) > 0 {
|
||||
result[len(result)-1].Logprobs = prediction.Logprobs
|
||||
}
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
|
||||
@@ -1087,7 +1087,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
|
||||
// For example, the model might return a special token or JSON indicating a function call
|
||||
|
||||
/*
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
|
||||
@@ -54,6 +54,19 @@ type Choice struct {
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Logprobs *Logprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type Logprobs struct {
|
||||
Content []LogprobContent `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type LogprobContent struct {
|
||||
ID int32 `json:"id"`
|
||||
Token string `json:"token"`
|
||||
Bytes []int `json:"bytes,omitempty"`
|
||||
Logprob float64 `json:"logprob"`
|
||||
TopLogprobs []LogprobContent `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
|
||||
@@ -1,5 +1,82 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// LogprobsValue represents the logprobs parameter which is a boolean.
|
||||
// According to OpenAI API: true means return log probabilities, false/null means don't return them.
|
||||
// The actual number of top logprobs per token is controlled by top_logprobs (0-5).
|
||||
type LogprobsValue struct {
|
||||
Enabled bool // true if logprobs should be returned
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to handle boolean
|
||||
func (l *LogprobsValue) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as boolean
|
||||
var b bool
|
||||
if err := json.Unmarshal(data, &b); err == nil {
|
||||
l.Enabled = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's null, set to false
|
||||
var n *bool
|
||||
if err := json.Unmarshal(data, &n); err == nil {
|
||||
l.Enabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try as integer for backward compatibility (treat > 0 as true)
|
||||
var i int
|
||||
if err := json.Unmarshal(data, &i); err == nil {
|
||||
l.Enabled = i > 0
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, &l.Enabled)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (l LogprobsValue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(l.Enabled)
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements yaml.Unmarshaler to handle boolean
|
||||
func (l *LogprobsValue) UnmarshalYAML(value *yaml.Node) error {
|
||||
switch value.Kind {
|
||||
case yaml.ScalarNode:
|
||||
switch value.Tag {
|
||||
case "!!bool":
|
||||
var b bool
|
||||
if err := value.Decode(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Enabled = b
|
||||
return nil
|
||||
case "!!int":
|
||||
// For backward compatibility, treat integer > 0 as true
|
||||
var i int
|
||||
if err := value.Decode(&i); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Enabled = i > 0
|
||||
return nil
|
||||
case "!!null":
|
||||
l.Enabled = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return value.Decode(&l.Enabled)
|
||||
}
|
||||
|
||||
// IsEnabled returns true if logprobs should be returned
|
||||
func (l *LogprobsValue) IsEnabled() bool {
|
||||
return l.Enabled
|
||||
}
|
||||
|
||||
// @Description PredictionOptions contains prediction parameters for model inference
|
||||
type PredictionOptions struct {
|
||||
|
||||
@@ -38,6 +115,13 @@ type PredictionOptions struct {
|
||||
TypicalP *float64 `json:"typical_p,omitempty" yaml:"typical_p,omitempty"`
|
||||
Seed *int `json:"seed,omitempty" yaml:"seed,omitempty"`
|
||||
|
||||
// OpenAI API logprobs parameters
|
||||
// logprobs: boolean - if true, returns log probabilities of each output token
|
||||
// top_logprobs: integer 0-20 - number of most likely tokens to return at each token position
|
||||
Logprobs LogprobsValue `json:"logprobs,omitempty" yaml:"logprobs,omitempty"` // Whether to return log probabilities (true/false)
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty" yaml:"top_logprobs,omitempty"` // Number of top logprobs per token (0-20)
|
||||
LogitBias map[string]float64 `json:"logit_bias,omitempty" yaml:"logit_bias,omitempty"` // Map of token IDs to bias values (-100 to 100)
|
||||
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" yaml:"negative_prompt,omitempty"`
|
||||
RopeFreqBase float32 `json:"rope_freq_base,omitempty" yaml:"rope_freq_base,omitempty"`
|
||||
RopeFreqScale float32 `json:"rope_freq_scale,omitempty" yaml:"rope_freq_scale,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user