Skip to content

Commit

Permalink
fix: dbsf zero variance (#875)
Browse files Browse the repository at this point in the history
* fix: dbsf zero variance

solves #871

* review suggestions

* fix: fix dbsf single response score

---------

Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
pavelm10 and joein authored Jan 6, 2025
1 parent 47ff758 commit dde9cb7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
10 changes: 8 additions & 2 deletions qdrant_client/hybrid/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ def distribution_based_score_fusion(
responses: list[list[models.ScoredPoint]], limit: int
) -> list[models.ScoredPoint]:
def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
if len(response) <= 1:
if len(response) == 1:
response[0].score = 0.5
return response

total = sum([point.score for point in response])
mean = total / len(response)
variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1)
std_dev = variance**0.5

if variance == 0:
for point in response:
point.score = 0.5
return response

std_dev = variance**0.5
low = mean - 3 * std_dev
high = mean + 3 * std_dev

Expand Down
24 changes: 21 additions & 3 deletions qdrant_client/hybrid/test_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def test_distribution_based_score_fusion() -> None:
assert fused[1].id == 0
assert fused[2].id == 4

fused = distribution_based_score_fusion([[responses[0][0]]], limit=3)
assert fused[0].id == 1


def test_reciprocal_rank_fusion_empty_responses() -> None:
responses: list[list[models.ScoredPoint]] = [[]]
Expand Down Expand Up @@ -97,3 +94,24 @@ def test_distribution_based_score_fusion_empty_response() -> None:
assert fused[0].id == 1
assert fused[1].id == 0
assert fused[2].id == 5


def test_distribution_based_score_fusion_zero_variance() -> None:
score = 85.0
responses = [
[
models.ScoredPoint(id=1, version=0, score=score),
models.ScoredPoint(id=0, version=0, score=score),
models.ScoredPoint(id=5, version=0, score=score),
],
[],
]
fused = distribution_based_score_fusion(
[[models.ScoredPoint(id=1, version=0, score=score)]], limit=3
)
assert fused[0].id == 1
assert fused[0].score == 0.5

fused = distribution_based_score_fusion(responses, limit=3)
assert len(fused) == 3
assert all([p.score == 0.5 for p in fused])

0 comments on commit dde9cb7

Please sign in to comment.