From ea92d0822d640a8a7feec5fa1f26f41c1e3ad4a8 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Sat, 28 Dec 2024 20:19:45 -0800 Subject: [PATCH] Experimenting with doc iterators Signed-off-by: Martin Gaievski --- .../neuralsearch/bwc/HybridSearchIT.java | 2 - .../neuralsearch/query/HybridQueryScorer.java | 26 ++-- .../search/SimpleDisiIterator.java | 130 ++++++++++++++++++ .../SimpleDisjunctionDISIApproximation.java | 98 +++++++++++++ 4 files changed, 242 insertions(+), 14 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/SimpleDisiIterator.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/SimpleDisjunctionDISIApproximation.java diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index d4ae88a3f..d2657555d 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -13,8 +13,6 @@ import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; -import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 48c69b618..c484f7c6a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -7,9 +7,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.Getter; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; -import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; @@ -17,6 +15,8 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.search.HybridDisiWrapper; +import org.opensearch.neuralsearch.search.SimpleDisiIterator; +import org.opensearch.neuralsearch.search.SimpleDisjunctionDISIApproximation; import java.io.IOException; import java.util.ArrayList; @@ -37,7 +37,8 @@ public class HybridQueryScorer extends Scorer { @Getter private final List subScorers; - private final DisiPriorityQueue subScorersPQ; + // private final DisiPriorityQueue subScorersPQ; + private final SimpleDisiIterator subScorersPQ; private final DocIdSetIterator approximation; private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; @@ -207,19 +208,20 @@ public float[] hybridScores() throws IOException { return scores; } - private DisiPriorityQueue initializeSubScorersPQ() { + private SimpleDisiIterator initializeSubScorersPQ() { Objects.requireNonNull(subScorers, "should not be null"); // we need to count this way in order to include all identical sub-queries - DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries); + // DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries); + List disiWrappers = new ArrayList<>(); for (int idx = 0; idx < numSubqueries; idx++) { Scorer scorer = subScorers.get(idx); if (scorer == null) { continue; } final HybridDisiWrapper disiWrapper = new HybridDisiWrapper(scorer, idx); - subScorersPQ.add(disiWrapper); + disiWrappers.add(disiWrapper); } - return subScorersPQ; + return new SimpleDisiIterator(disiWrappers.toArray(new DisiWrapper[0])); } @Override @@ -244,10 +246,10 @@ static class TwoPhase extends TwoPhaseIterator { DisiWrapper verifiedMatches; // priority queue of approximations on the current doc that have not been verified yet final PriorityQueue unverifiedMatches; - DisiPriorityQueue subScorers; + SimpleDisiIterator subScorers; boolean needsScores; - private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + private TwoPhase(DocIdSetIterator approximation, float matchCost, SimpleDisiIterator subScorers, boolean needsScores) { super(approximation); this.matchCost = matchCost; this.subScorers = subScorers; @@ -323,10 +325,10 @@ public float matchCost() { */ static class HybridSubqueriesDISIApproximation extends DocIdSetIterator { final DocIdSetIterator docIdSetIterator; - final DisiPriorityQueue subIterators; + final SimpleDisiIterator subIterators; - public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) { - docIdSetIterator = new DisjunctionDISIApproximation(subIterators); + public HybridSubqueriesDISIApproximation(final SimpleDisiIterator subIterators) { + docIdSetIterator = new SimpleDisjunctionDISIApproximation(subIterators); this.subIterators = subIterators; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/SimpleDisiIterator.java b/src/main/java/org/opensearch/neuralsearch/search/SimpleDisiIterator.java new file mode 100644 index 000000000..6e7114e8e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/SimpleDisiIterator.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search; + +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; + +public final class SimpleDisiIterator implements Iterable { + private final DisiWrapper[] iterators; + private final int size; + + public SimpleDisiIterator(DisiWrapper... iterators) { + this.iterators = iterators; + this.size = iterators.length; + try { + for (int i = 0; i < size; i++) { + if (iterators[i] != null && iterators[i].doc == -1) { + iterators[i].doc = iterators[i].iterator.nextDoc(); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public DisiWrapper top() { + if (size == 0) { + return null; + } + + DisiWrapper top = null; + int minDoc = DocIdSetIterator.NO_MORE_DOCS; + + for (int i = 0; i < size; i++) { + DisiWrapper wrapper = iterators[i]; + if (wrapper != null && wrapper.doc != DocIdSetIterator.NO_MORE_DOCS) { + if (minDoc == DocIdSetIterator.NO_MORE_DOCS || wrapper.doc < minDoc) { + minDoc = wrapper.doc; + top = wrapper; + } + } + } + return top; + } + + public DisiWrapper topList() { + DisiWrapper top = top(); + if (top == null) { + return null; + } + + int minDoc = top.doc; + DisiWrapper list = null; + + try { + // First, collect all matching wrappers and their scores + float totalScore = 0; + int matchCount = 0; + + // First pass: calculate total score + for (int i = 0; i < size; i++) { + DisiWrapper current = iterators[i]; + if (current != null && current.doc == minDoc) { + float score = current.scorer.score(); + totalScore += score; + matchCount++; + list = current; + } + } + + // Advance all matching iterators + /*for (int i = 0; i < size; i++) { + DisiWrapper current = iterators[i]; + if (current != null && current.doc == minDoc) { + current.doc = current.iterator.nextDoc(); + } + }*/ + + } catch (IOException e) { + throw new RuntimeException(e); + } + + return list; + } + + @Override + public Iterator iterator() { + return new Iterator<>() { + private DisiWrapper current = null; + private boolean initialized = false; + + private void initializeIfNeeded() { + if (!initialized) { + current = topList(); + initialized = true; + } + } + + @Override + public boolean hasNext() { + initializeIfNeeded(); + return current != null; + } + + @Override + public DisiWrapper next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + DisiWrapper result = current; + current = topList(); + return result; + } + }; + } + + public int size() { + return size; + } + + public boolean isEmpty() { + return size == 0; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/SimpleDisjunctionDISIApproximation.java b/src/main/java/org/opensearch/neuralsearch/search/SimpleDisjunctionDISIApproximation.java new file mode 100644 index 000000000..1658fcf43 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/SimpleDisjunctionDISIApproximation.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search; + +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; + +public class SimpleDisjunctionDISIApproximation extends DocIdSetIterator { + + private final SimpleDisiIterator iterator; + private final long cost; + private int doc = -1; + + public SimpleDisjunctionDISIApproximation(SimpleDisiIterator iterator) { + this.iterator = iterator; + + // Calculate total cost + long totalCost = 0; + DisiWrapper top = iterator.top(); + if (top != null) { + DisiWrapper current = iterator.topList(); + while (current != null) { + totalCost += current.cost; + current = current.next; + } + } + this.cost = totalCost; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + DisiWrapper top = iterator.top(); + if (top == null) { + return doc = NO_MORE_DOCS; + } + + final int current = top.doc; + + // Advance all iterators that are at current doc + DisiWrapper matchingList = iterator.topList(); + while (matchingList != null) { + matchingList.doc = matchingList.approximation.nextDoc(); + matchingList = matchingList.next; + } + + return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + DisiWrapper top = iterator.top(); + if (top == null) { + return doc = NO_MORE_DOCS; + } + + // If we're already at or past target, just do nextDoc() + if (top.doc >= target) { + return nextDoc(); + } + + // Advance all iterators to target + DisiWrapper matchingList = iterator.topList(); + while (matchingList != null) { + matchingList.doc = matchingList.approximation.advance(target); + matchingList = matchingList.next; + } + + return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS; + } + + @Override + public long cost() { + return cost; + } + + /** + * Returns the number of sub-iterators + */ + public int getSubIteratorCount() { + return iterator.size(); + } + + /** + * Returns list of matching sub-iterators at current position + */ + public DisiWrapper getMatches() { + return iterator.topList(); + } +}