Skip to content

Commit

Permalink
Fix KNNScorer to apply boost (#1403) (#1405)
Browse files Browse the repository at this point in the history
* apply boost

Signed-off-by: panguixin <[email protected]>

* add change log

Signed-off-by: panguixin <[email protected]>

---------

Signed-off-by: panguixin <[email protected]>
(cherry picked from commit fcbfef1)

Co-authored-by: panguixin <[email protected]>
  • Loading branch information
1 parent 4d9da8d commit 8513952
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public float score() {
assert docID() != DocIdSetIterator.NO_MORE_DOCS;
Float score = scores.get(docID());
if (score == null) throw new RuntimeException("Null score for the docID: " + docID());
return score;
return score * boost;
}

@Override
Expand Down
35 changes: 21 additions & 14 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ public void testQueryScoreForFaissWithModel() {
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);

KNNWeight.initialize(modelDao);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
Expand Down Expand Up @@ -214,7 +215,7 @@ public void testQueryScoreForFaissWithModel() {
final Map<Integer, Float> translatedScores = getTranslatedScores(scoreTranslator);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand Down Expand Up @@ -364,7 +365,8 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
// Just to make sure that we are not hitting the exact search condition
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
Expand Down Expand Up @@ -408,7 +410,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand All @@ -433,7 +435,8 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
Expand All @@ -457,7 +460,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand All @@ -483,7 +486,8 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
Expand All @@ -507,7 +511,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand Down Expand Up @@ -543,7 +547,8 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null);

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
Expand All @@ -567,7 +572,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand Down Expand Up @@ -631,7 +636,8 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

// Execute
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
Expand All @@ -642,7 +648,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
.collect(Collectors.toList());
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertEquals(1, docIdSetIterator.nextDoc());
assertEquals(expectedScores.get(1), knnScorer.score(), 0.01f);
assertEquals(expectedScores.get(1) * boost, knnScorer.score(), 0.01f);
assertEquals(NO_MORE_DOCS, docIdSetIterator.nextDoc());
}

Expand Down Expand Up @@ -733,7 +739,8 @@ private void testQueryScore(
.thenReturn(getKNNQueryResults());

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
Expand Down Expand Up @@ -777,7 +784,7 @@ private void testQueryScore(
final Map<Integer, Float> translatedScores = getTranslatedScores(scoreTranslator);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
Expand Down

0 comments on commit 8513952

Please sign in to comment.