fix(reranker): reproduce ignoring top_n (#7025)

* fix(reranker): reproduce ignoring top_n

Signed-off-by: Mikhail Khludnev <mkhl@apache.org>

* fix(reranker): ignoring top_n

Signed-off-by: Mikhail Khludnev <mkhl@apache.org>

---------

Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
This commit is contained in:
Mikhail Khludnev
2025-11-06 13:03:05 +03:00
committed by GitHub
parent 2573102317
commit 122e4c7094
2 changed files with 30 additions and 2 deletions

View File

@@ -61,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.PipelineType != "": # Reuse the PipelineType field for language
kwargs['lang'] = request.PipelineType
self.model_name = model_name
self.model = Reranker(model_name, **kwargs)
self.model = Reranker(model_name, **kwargs)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
@@ -80,7 +80,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
index=res.doc_id,
text=res.text,
relevance_score=res.score
) for res in ranked_results.results
) for res in ranked_results.top_k(request.top_n)
]
# Calculate the usage and total tokens

View File

@@ -86,5 +86,33 @@ class TestBackendServicer(unittest.TestCase):
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()
def test_rerank_crop(self):
"""
This method tests if the embeddings are generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
top_n=2
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()