mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
fix(reranker): support omitting top_n (#7199)
* fix(reranker): support omitting top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> * fix(reranker): support omitting top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> * pass 0 explicitly Signed-off-by: Mikhail Khludnev <mkhludnev@users.noreply.github.com> --------- Signed-off-by: Mikhail Khludnev <mkhl@apache.org> Signed-off-by: Mikhail Khludnev <mkhludnev@users.noreply.github.com>
This commit is contained in:
@@ -75,12 +75,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
documents.append(doc)
|
||||
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
|
||||
# Prepare results to return
|
||||
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
|
||||
results = [
|
||||
backend_pb2.DocumentResult(
|
||||
index=res.doc_id,
|
||||
text=res.text,
|
||||
relevance_score=res.score
|
||||
) for res in ranked_results.top_k(request.top_n)
|
||||
) for res in (cropped_results)
|
||||
]
|
||||
|
||||
# Calculate the usage and total tokens
|
||||
|
||||
@@ -76,7 +76,35 @@ class TestBackendServicer(unittest.TestCase):
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_omit_top_n(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully even top_n is omitted
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you"],
|
||||
top_n=0 #
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
@@ -91,7 +119,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
|
||||
def test_rerank_crop(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
This method tests top_n cropping
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
@@ -104,7 +132,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
@@ -115,4 +143,4 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
Reference in New Issue
Block a user