diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 6a629b056..e656beca3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -43,7 +43,7 @@ public float combine(final float[] scores) { float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; - if (score != 0.0) { + if (score >= 0.0) { float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); score = score * weight; combinedScore += score; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java index ea14244d7..d197fc710 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryZScoreIT.java @@ -146,10 +146,10 @@ public void testComplexQuery_withZScoreNormalization() { } assertEquals(2, scores.size()); - // since it's z-score normalized we would expect 1 , -1 to be the corresponding score, by design when there are only two results with z score - // furthermore the combination logic with weights should make it doc1Score: (1 * w1 + 0.98 * w2)/(w1 + w2), -1 * w2/w2 + // by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the corresponding score, + // furthermore the combination logic with weights should make it doc1Score: (1 * w1 + 0.98 * w2)/(w1 + w2), doc2Score: -1 ~ 0 assertEquals(0.9999, scores.get(0).floatValue(), DELTA_FOR_SCORE_ASSERTION); - assertEquals(-1 , scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0 , scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION); // verify that scores are in desc order assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));