fix(llama-cpp): correctly calculate embeddings (#6259)

* chore(tests): check embeddings differs in llama.cpp

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(llama.cpp): use the correct field for embedding

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(llama.cpp): use embedding type none

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(tests): add test-cases in aio-e2e suite

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-09-13 23:11:54 +02:00
committed by GitHub
parent 55766d269b
commit 6410c99bf2
5 changed files with 77 additions and 17 deletions

View File

@@ -117,8 +117,8 @@ run: ## run local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
test-models/testmodel.ggml:
mkdir test-models
mkdir test-dir
mkdir -p test-models
mkdir -p test-dir
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert

View File

@@ -701,7 +701,7 @@ public:
*/
// for the shape of input/content, see tokenize_input_prompts()
json prompt = body.at("prompt");
json prompt = body.at("embeddings");
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
@@ -712,6 +712,7 @@ public:
}
}
int embd_normalize = 2; // default to Euclidean/L2 norm
// create and queue the task
json responses = json::array();
bool error = false;
@@ -725,9 +726,8 @@ public:
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
// OAI-compat
task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING;
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
task.params.embd_normalize = embd_normalize;
tasks.push_back(std::move(task));
}
@@ -743,9 +743,8 @@ public:
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
error = true;
}, [&]() {
// NOTE: we should try to check when the writer is closed here
return false;
});
@@ -755,12 +754,36 @@ public:
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
}
std::vector<float> embeddings = responses[0].value("embedding", std::vector<float>());
// loop the vector and set the embeddings results
for (int i = 0; i < embeddings.size(); i++) {
embeddingResult->add_embeddings(embeddings[i]);
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
// Process the responses and extract embeddings
for (const auto & response_elem : responses) {
// Check if the response has an "embedding" field
if (response_elem.contains("embedding")) {
json embedding_data = json_value(response_elem, "embedding", json::array());
if (embedding_data.is_array() && !embedding_data.empty()) {
for (const auto & embedding_vector : embedding_data) {
if (embedding_vector.is_array()) {
for (const auto & embedding_value : embedding_vector) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
}
} else {
// Check if the response itself contains the embedding data directly
if (response_elem.is_array()) {
for (const auto & embedding_value : response_elem) {
embeddingResult->add_embeddings(embedding_value.get<float>());
}
}
}
}
return grpc::Status::OK;
}

View File

@@ -836,27 +836,40 @@ var _ = Describe("API test", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
embeddingModel := openai.AdaEmbeddingV2
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Model: embeddingModel,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred(), err)
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 2048))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 2048))
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Model: embeddingModel,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
resp3, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: embeddingModel,
Input: []string{"cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding))
Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding))
})
Context("External gRPC calls", func() {

View File

@@ -9,5 +9,5 @@ import (
func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI test suite")
RunSpecs(t, "LocalAI HTTP test suite")
}

View File

@@ -169,6 +169,30 @@ var _ = Describe("E2E test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
resp2, err := client.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{
Input: []string{"cat"},
Model: openai.AdaEmbeddingV2,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp2.Data[0].Embedding).ToNot(BeEmpty())
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding))
resp3, err := client.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{
Input: []string{"doc", "cat"},
Model: openai.AdaEmbeddingV2,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp3.Data)).To(Equal(2), fmt.Sprint(resp))
Expect(resp3.Data[0].Embedding).ToNot(BeEmpty())
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[0].Embedding))
Expect(resp3.Data[1].Embedding).To(Equal(resp2.Data[0].Embedding))
Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding))
})
})
Context("vision", func() {