mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-05 10:10:08 -06:00
fix(reranker): reproduce ignoring top_n (#7025)
* fix(reranker): reproduce ignoring top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> * fix(reranker): ignoring top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> --------- Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
This commit is contained in:
@@ -61,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.PipelineType != "": # Reuse the PipelineType field for language
|
||||
kwargs['lang'] = request.PipelineType
|
||||
self.model_name = model_name
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
@@ -80,7 +80,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
index=res.doc_id,
|
||||
text=res.text,
|
||||
relevance_score=res.score
|
||||
) for res in ranked_results.results
|
||||
) for res in ranked_results.top_k(request.top_n)
|
||||
]
|
||||
|
||||
# Calculate the usage and total tokens
|
||||
|
||||
@@ -86,5 +86,33 @@ class TestBackendServicer(unittest.TestCase):
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_crop(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
|
||||
top_n=2
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
Reference in New Issue
Block a user