From cdc3cae32c5581756c674e4ed1d2cf1c1aea0f51 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Fri, 19 Jan 2024 13:07:44 -0800 Subject: [PATCH] [Bugfix] Pass TwoPhaseIterator from subQueryScorer Signed-off-by: Louis Chu --- .../search/function/ScriptScoreQuery.java | 6 +++ .../search/query/ScriptScoreQueryTests.java | 40 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java index 90350c0a21a42..5aff09d715622 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java @@ -45,6 +45,7 @@ import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; import org.opensearch.Version; @@ -302,6 +303,11 @@ public DocIdSetIterator iterator() { return subQueryScorer.iterator(); } + @Override + public TwoPhaseIterator twoPhaseIterator() { + return subQueryScorer.twoPhaseIterator(); + } + @Override public float getMaxScore(int upTo) { return Float.MAX_VALUE; // TODO: what would be a good upper bound? diff --git a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java index ca4b7dc49f6f0..55c50b8cf854d 100644 --- a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java @@ -39,9 +39,14 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.opensearch.Version; @@ -49,6 +54,7 @@ import org.opensearch.common.lucene.search.function.ScriptScoreQuery; import org.opensearch.script.ScoreScript; import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import org.opensearch.search.lookup.LeafSearchLookup; import org.opensearch.search.lookup.SearchLookup; import org.opensearch.test.OpenSearchTestCase; @@ -56,6 +62,8 @@ import org.junit.Before; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Function; import static org.hamcrest.CoreMatchers.containsString; @@ -177,6 +185,37 @@ public void testScriptScoreErrorOnNegativeScore() { assertTrue(e.getMessage().contains("Must be a non-negative score!")); } + public void testTwoPhaseIteratorDelegation() throws IOException { + Map params = new HashMap<>(); + String scriptSource = "doc['field'].value != null ? 2.0 : 0.0"; // Adjust based on actual field and logic + Script script = new Script(ScriptType.INLINE, "painless", scriptSource, params); + float minScore = 1.0f; // This should be below the score produced by the script for all docs + ScoreScript.LeafFactory factory = newFactory(script, false, explanation -> 2.0); + + Query subQuery = new MatchAllDocsQuery(); + ScriptScoreQuery scriptScoreQuery = new ScriptScoreQuery(subQuery, script, factory, minScore, "index", 0, Version.CURRENT); + + Weight weight = searcher.createWeight(searcher.rewrite(scriptScoreQuery), ScoreMode.COMPLETE, 1f); + + boolean foundMatchingDoc = false; + for (LeafReaderContext leafContext : searcher.getIndexReader().leaves()) { + Scorer scorer = weight.scorer(leafContext); + if (scorer != null) { + TwoPhaseIterator twoPhaseIterator = scorer.twoPhaseIterator(); + assertNotNull("TwoPhaseIterator should not be null", twoPhaseIterator); + DocIdSetIterator docIdSetIterator = twoPhaseIterator.approximation(); + int docId; + while ((docId = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (twoPhaseIterator.matches()) { + foundMatchingDoc = true; + break; + } + } + } + } + assertTrue("Expected to find at least one matching document", foundMatchingDoc); + } + private ScoreScript.LeafFactory newFactory( Script script, boolean needsScore, @@ -203,5 +242,4 @@ public double execute(ExplanationHolder explanation) { } }; } - }