mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 06:00:15 -06:00
fix(llama-cpp): correctly calculate embeddings (#6259)
* chore(tests): check embeddings differs in llama.cpp Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(llama.cpp): use the correct field for embedding Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(llama.cpp): use embedding type none Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(tests): add test-cases in aio-e2e suite Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
55766d269b
commit
6410c99bf2
4
Makefile
4
Makefile
@@ -117,8 +117,8 @@ run: ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
|
||||
test-models/testmodel.ggml:
|
||||
mkdir test-models
|
||||
mkdir test-dir
|
||||
mkdir -p test-models
|
||||
mkdir -p test-dir
|
||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||
|
||||
@@ -701,7 +701,7 @@ public:
|
||||
*/
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt = body.at("prompt");
|
||||
json prompt = body.at("embeddings");
|
||||
|
||||
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
@@ -712,6 +712,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -725,9 +726,8 @@ public:
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING;
|
||||
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
@@ -743,9 +743,8 @@ public:
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
|
||||
error = true;
|
||||
}, [&]() {
|
||||
// NOTE: we should try to check when the writer is closed here
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -755,12 +754,36 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
}
|
||||
|
||||
std::vector<float> embeddings = responses[0].value("embedding", std::vector<float>());
|
||||
// loop the vector and set the embeddings results
|
||||
for (int i = 0; i < embeddings.size(); i++) {
|
||||
embeddingResult->add_embeddings(embeddings[i]);
|
||||
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
|
||||
|
||||
// Process the responses and extract embeddings
|
||||
for (const auto & response_elem : responses) {
|
||||
// Check if the response has an "embedding" field
|
||||
if (response_elem.contains("embedding")) {
|
||||
json embedding_data = json_value(response_elem, "embedding", json::array());
|
||||
|
||||
if (embedding_data.is_array() && !embedding_data.empty()) {
|
||||
for (const auto & embedding_vector : embedding_data) {
|
||||
if (embedding_vector.is_array()) {
|
||||
for (const auto & embedding_value : embedding_vector) {
|
||||
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Check if the response itself contains the embedding data directly
|
||||
if (response_elem.is_array()) {
|
||||
for (const auto & embedding_value : response_elem) {
|
||||
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
|
||||
@@ -836,27 +836,40 @@ var _ = Describe("API test", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
embeddingModel := openai.AdaEmbeddingV2
|
||||
resp, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
Model: embeddingModel,
|
||||
Input: []string{"sun", "cat"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred(), err)
|
||||
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 2048))
|
||||
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 2048))
|
||||
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096))
|
||||
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096))
|
||||
|
||||
sunEmbedding := resp.Data[0].Embedding
|
||||
resp2, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
Model: embeddingModel,
|
||||
Input: []string{"sun"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
|
||||
|
||||
resp3, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: embeddingModel,
|
||||
Input: []string{"cat"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding))
|
||||
Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding))
|
||||
})
|
||||
|
||||
Context("External gRPC calls", func() {
|
||||
|
||||
@@ -9,5 +9,5 @@ import (
|
||||
|
||||
func TestLocalAI(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LocalAI test suite")
|
||||
RunSpecs(t, "LocalAI HTTP test suite")
|
||||
}
|
||||
|
||||
@@ -169,6 +169,30 @@ var _ = Describe("E2E test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
|
||||
|
||||
resp2, err := client.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{"cat"},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp2.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||
Expect(resp2.Data[0].Embedding).ToNot(BeEmpty())
|
||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding))
|
||||
|
||||
resp3, err := client.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{"doc", "cat"},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp3.Data)).To(Equal(2), fmt.Sprint(resp))
|
||||
Expect(resp3.Data[0].Embedding).ToNot(BeEmpty())
|
||||
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[0].Embedding))
|
||||
Expect(resp3.Data[1].Embedding).To(Equal(resp2.Data[0].Embedding))
|
||||
Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding))
|
||||
})
|
||||
})
|
||||
Context("vision", func() {
|
||||
|
||||
Reference in New Issue
Block a user