diff --git a/Makefile b/Makefile index 714d1c1c4..ec94ef1d4 100644 --- a/Makefile +++ b/Makefile @@ -117,8 +117,8 @@ run: ## run local-ai CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./ test-models/testmodel.ggml: - mkdir test-models - mkdir test-dir + mkdir -p test-models + mkdir -p test-dir wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 2c7c2e52d..d0de49946 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -701,7 +701,7 @@ public: */ // for the shape of input/content, see tokenize_input_prompts() - json prompt = body.at("prompt"); + json prompt = body.at("embeddings"); auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); @@ -712,6 +712,7 @@ public: } } + int embd_normalize = 2; // default to Euclidean/L2 norm // create and queue the task json responses = json::array(); bool error = false; @@ -725,9 +726,8 @@ public: task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); - // OAI-compat - task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING; - + task.params.oaicompat = OAICOMPAT_TYPE_NONE; + task.params.embd_normalize = embd_normalize; tasks.push_back(std::move(task)); } @@ -743,9 +743,8 @@ public: responses.push_back(res->to_json()); } }, [&](const json & error_data) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", "")); + error = true; }, [&]() { - // NOTE: we should try to check when the writer is closed here return false; }); @@ -755,12 +754,36 @@ public: return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } - std::vector embeddings = responses[0].value("embedding", std::vector()); - // loop the vector and set the embeddings results - for (int i = 0; i < embeddings.size(); i++) { - embeddingResult->add_embeddings(embeddings[i]); + std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl; + + // Process the responses and extract embeddings + for (const auto & response_elem : responses) { + // Check if the response has an "embedding" field + if (response_elem.contains("embedding")) { + json embedding_data = json_value(response_elem, "embedding", json::array()); + + if (embedding_data.is_array() && !embedding_data.empty()) { + for (const auto & embedding_vector : embedding_data) { + if (embedding_vector.is_array()) { + for (const auto & embedding_value : embedding_vector) { + embeddingResult->add_embeddings(embedding_value.get()); + } + } + } + } + } else { + // Check if the response itself contains the embedding data directly + if (response_elem.is_array()) { + for (const auto & embedding_value : response_elem) { + embeddingResult->add_embeddings(embedding_value.get()); + } + } + } } + + + return grpc::Status::OK; } diff --git a/core/http/app_test.go b/core/http/app_test.go index b1f43d332..09726c19b 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -836,27 +836,40 @@ var _ = Describe("API test", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } + embeddingModel := openai.AdaEmbeddingV2 resp, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ - Model: openai.AdaEmbeddingV2, + Model: embeddingModel, Input: []string{"sun", "cat"}, }, ) Expect(err).ToNot(HaveOccurred(), err) - Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 2048)) - Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 2048)) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096)) sunEmbedding := resp.Data[0].Embedding resp2, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ - Model: openai.AdaEmbeddingV2, + Model: embeddingModel, Input: []string{"sun"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) + + resp3, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: embeddingModel, + Input: []string{"cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding)) + Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding)) }) Context("External gRPC calls", func() { diff --git a/core/http/http_suite_test.go b/core/http/http_suite_test.go index 0269a9732..94467437f 100644 --- a/core/http/http_suite_test.go +++ b/core/http/http_suite_test.go @@ -9,5 +9,5 @@ import ( func TestLocalAI(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "LocalAI test suite") + RunSpecs(t, "LocalAI HTTP test suite") } diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index f503f4952..371f2bedb 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -169,6 +169,30 @@ var _ = Describe("E2E test", func() { Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Data[0].Embedding).ToNot(BeEmpty()) + + resp2, err := client.CreateEmbeddings(context.TODO(), + openai.EmbeddingRequestStrings{ + Input: []string{"cat"}, + Model: openai.AdaEmbeddingV2, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp2.Data)).To(Equal(1), fmt.Sprint(resp)) + Expect(resp2.Data[0].Embedding).ToNot(BeEmpty()) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding)) + + resp3, err := client.CreateEmbeddings(context.TODO(), + openai.EmbeddingRequestStrings{ + Input: []string{"doc", "cat"}, + Model: openai.AdaEmbeddingV2, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp3.Data)).To(Equal(2), fmt.Sprint(resp)) + Expect(resp3.Data[0].Embedding).ToNot(BeEmpty()) + Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[0].Embedding)) + Expect(resp3.Data[1].Embedding).To(Equal(resp2.Data[0].Embedding)) + Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding)) }) }) Context("vision", func() {