diff --git a/qdrant_client/hybrid/fusion.py b/qdrant_client/hybrid/fusion.py index 015f4a73..1c8f280a 100644 --- a/qdrant_client/hybrid/fusion.py +++ b/qdrant_client/hybrid/fusion.py @@ -33,14 +33,20 @@ def distribution_based_score_fusion( responses: list[list[models.ScoredPoint]], limit: int ) -> list[models.ScoredPoint]: def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]: - if len(response) <= 1: + if len(response) == 1: + response[0].score = 0.5 return response total = sum([point.score for point in response]) mean = total / len(response) variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1) - std_dev = variance**0.5 + if variance == 0: + for point in response: + point.score = 0.5 + return response + + std_dev = variance**0.5 low = mean - 3 * std_dev high = mean + 3 * std_dev diff --git a/qdrant_client/hybrid/test_reranking.py b/qdrant_client/hybrid/test_reranking.py index fa9aa05c..bced877f 100644 --- a/qdrant_client/hybrid/test_reranking.py +++ b/qdrant_client/hybrid/test_reranking.py @@ -50,9 +50,6 @@ def test_distribution_based_score_fusion() -> None: assert fused[1].id == 0 assert fused[2].id == 4 - fused = distribution_based_score_fusion([[responses[0][0]]], limit=3) - assert fused[0].id == 1 - def test_reciprocal_rank_fusion_empty_responses() -> None: responses: list[list[models.ScoredPoint]] = [[]] @@ -97,3 +94,24 @@ def test_distribution_based_score_fusion_empty_response() -> None: assert fused[0].id == 1 assert fused[1].id == 0 assert fused[2].id == 5 + + +def test_distribution_based_score_fusion_zero_variance() -> None: + score = 85.0 + responses = [ + [ + models.ScoredPoint(id=1, version=0, score=score), + models.ScoredPoint(id=0, version=0, score=score), + models.ScoredPoint(id=5, version=0, score=score), + ], + [], + ] + fused = distribution_based_score_fusion( + [[models.ScoredPoint(id=1, version=0, score=score)]], limit=3 + ) + assert fused[0].id == 1 + assert fused[0].score == 0.5 + + fused = distribution_based_score_fusion(responses, limit=3) + assert len(fused) == 3 + assert all([p.score == 0.5 for p in fused])