fix(llama.cpp): correctly set grammar triggers (#6432)

* fix(llama.cpp): correctly set grammar triggers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Do not enable lazy by default

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-10-10 19:50:17 +02:00
committed by GitHub
parent 81b31b4283
commit cd1e1124ea
3 changed files with 68 additions and 17 deletions

View File

@@ -92,7 +92,7 @@ static void start_llama_server(server_context& ctx_server) {
ctx_server.queue_tasks.start_loop();
}
json parse_options(bool streaming, const backend::PredictOptions* predict)
json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server)
{
// Create now a json data from the prediction options instead
@@ -147,6 +147,28 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
// data["n_probs"] = predict->nprobs();
//TODO: images,
// Serialize grammar triggers from server context to JSON array
if (!ctx_server.params_base.sampling.grammar_triggers.empty()) {
json grammar_triggers = json::array();
for (const auto& trigger : ctx_server.params_base.sampling.grammar_triggers) {
json trigger_json;
trigger_json["value"] = trigger.value;
// Always serialize as WORD type since upstream converts WORD to TOKEN internally
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
grammar_triggers.push_back(trigger_json);
}
data["grammar_triggers"] = grammar_triggers;
}
// Serialize preserved tokens from server context to JSON array
if (!ctx_server.params_base.sampling.preserved_tokens.empty()) {
json preserved_tokens = json::array();
for (const auto& token : ctx_server.params_base.sampling.preserved_tokens) {
preserved_tokens.push_back(common_token_to_piece(ctx_server.ctx, token));
}
data["preserved_tokens"] = preserved_tokens;
}
return data;
}
@@ -207,7 +229,7 @@ static void add_rpc_devices(std::string servers) {
}
}
static void params_parse(const backend::ModelOptions* request,
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request,
common_params & params) {
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
@@ -347,14 +369,14 @@ static void params_parse(const backend::ModelOptions* request,
}
if (request->grammartriggers_size() > 0) {
params.sampling.grammar_lazy = true;
//params.sampling.grammar_lazy = true;
// Store grammar trigger words for processing after model is loaded
for (int i = 0; i < request->grammartriggers_size(); i++) {
const auto & word = request->grammartriggers(i).word();
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
trigger.value = request->grammartriggers(i).word();
// trigger.at_start = request->grammartriggers(i).at_start();
params.sampling.grammar_triggers.push_back(trigger);
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
trigger.value = word;
params.sampling.grammar_triggers.push_back(std::move(trigger));
}
}
}
@@ -377,7 +399,7 @@ public:
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
// Implement LoadModel RPC
common_params params;
params_parse(request, params);
params_parse(ctx_server, request, params);
common_init();
@@ -396,6 +418,39 @@ public:
return Status::CANCELLED;
}
// Process grammar triggers now that vocab is available
if (!params.sampling.grammar_triggers.empty()) {
std::vector<common_grammar_trigger> processed_triggers;
for (const auto& trigger : params.sampling.grammar_triggers) {
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
auto ids = common_tokenize(ctx_server.vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
// Add the token to preserved_tokens if not already present
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
params.sampling.preserved_tokens.insert(token);
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
}
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
common_grammar_trigger processed_trigger;
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
processed_trigger.value = trigger.value;
processed_trigger.token = token;
processed_triggers.push_back(std::move(processed_trigger));
} else {
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
processed_triggers.push_back(trigger);
}
} else {
processed_triggers.push_back(trigger);
}
}
// Update the grammar triggers in params_base
ctx_server.params_base.sampling.grammar_triggers = std::move(processed_triggers);
// Also update preserved_tokens in params_base
ctx_server.params_base.sampling.preserved_tokens = params.sampling.preserved_tokens;
}
//ctx_server.init();
result->set_message("Loading succeeded");
result->set_success(true);
@@ -406,7 +461,7 @@ public:
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request);
json data = parse_options(true, request, ctx_server);
//Raise error if embeddings is set to true
@@ -556,7 +611,7 @@ public:
}
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
json data = parse_options(true, request);
json data = parse_options(true, request, ctx_server);
data["stream"] = false;
//Raise error if embeddings is set to true
@@ -691,7 +746,7 @@ public:
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
json body = parse_options(false, request);
json body = parse_options(false, request, ctx_server);
body["stream"] = false;
@@ -872,7 +927,7 @@ public:
}
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
json body = parse_options(false, request);
json body = parse_options(false, request, ctx_server);
body["stream"] = false;
json tokens_response = json::array();

View File

@@ -129,7 +129,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
triggers = append(triggers, &pb.GrammarTrigger{
Word: t.Word,
})
}
return &pb.ModelOptions{

View File

@@ -147,9 +147,6 @@
<div class="flex-1 min-w-0">
<div class="flex items-center justify-between">
<h3 class="font-bold text-xl text-[#E5E7EB] truncate group-hover:text-[#38BDF8] transition-colors">{{.Name}}</h3>
<a href="browse?term={{.Name}}" class="text-[#94A3B8] hover:text-[#38BDF8] transition-colors p-1 rounded-lg hover:bg-[#38BDF8]/10" title="Search for similar models">
<i class="fas fa-search text-sm"></i>
</a>
</div>
<div class="mt-2 flex flex-wrap gap-2">