Skip to content

Commit

Permalink
Experimenting with doc iterators
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 29, 2024
1 parent 22ba5d3 commit ea92d08
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

import org.opensearch.index.query.MatchQueryBuilder;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.search.HybridDisiWrapper;
import org.opensearch.neuralsearch.search.SimpleDisiIterator;
import org.opensearch.neuralsearch.search.SimpleDisjunctionDISIApproximation;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -37,7 +37,8 @@ public class HybridQueryScorer extends Scorer {
@Getter
private final List<Scorer> subScorers;

private final DisiPriorityQueue subScorersPQ;
// private final DisiPriorityQueue subScorersPQ;
private final SimpleDisiIterator subScorersPQ;

private final DocIdSetIterator approximation;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
Expand Down Expand Up @@ -207,19 +208,20 @@ public float[] hybridScores() throws IOException {
return scores;
}

private DisiPriorityQueue initializeSubScorersPQ() {
private SimpleDisiIterator initializeSubScorersPQ() {
Objects.requireNonNull(subScorers, "should not be null");
// we need to count this way in order to include all identical sub-queries
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries);
// DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numSubqueries);
List<DisiWrapper> disiWrappers = new ArrayList<>();
for (int idx = 0; idx < numSubqueries; idx++) {
Scorer scorer = subScorers.get(idx);
if (scorer == null) {
continue;
}
final HybridDisiWrapper disiWrapper = new HybridDisiWrapper(scorer, idx);
subScorersPQ.add(disiWrapper);
disiWrappers.add(disiWrapper);
}
return subScorersPQ;
return new SimpleDisiIterator(disiWrappers.toArray(new DisiWrapper[0]));
}

@Override
Expand All @@ -244,10 +246,10 @@ static class TwoPhase extends TwoPhaseIterator {
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
SimpleDisiIterator subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
private TwoPhase(DocIdSetIterator approximation, float matchCost, SimpleDisiIterator subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
Expand Down Expand Up @@ -323,10 +325,10 @@ public float matchCost() {
*/
static class HybridSubqueriesDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator docIdSetIterator;
final DisiPriorityQueue subIterators;
final SimpleDisiIterator subIterators;

public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) {
docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
public HybridSubqueriesDISIApproximation(final SimpleDisiIterator subIterators) {
docIdSetIterator = new SimpleDisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;

import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;

import java.io.IOException;
import java.util.Iterator;
import java.util.NoSuchElementException;

public final class SimpleDisiIterator implements Iterable<DisiWrapper> {
private final DisiWrapper[] iterators;
private final int size;

public SimpleDisiIterator(DisiWrapper... iterators) {
this.iterators = iterators;
this.size = iterators.length;
try {
for (int i = 0; i < size; i++) {
if (iterators[i] != null && iterators[i].doc == -1) {
iterators[i].doc = iterators[i].iterator.nextDoc();
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public DisiWrapper top() {
if (size == 0) {
return null;
}

DisiWrapper top = null;
int minDoc = DocIdSetIterator.NO_MORE_DOCS;

for (int i = 0; i < size; i++) {
DisiWrapper wrapper = iterators[i];
if (wrapper != null && wrapper.doc != DocIdSetIterator.NO_MORE_DOCS) {
if (minDoc == DocIdSetIterator.NO_MORE_DOCS || wrapper.doc < minDoc) {
minDoc = wrapper.doc;
top = wrapper;
}
}
}
return top;
}

public DisiWrapper topList() {
DisiWrapper top = top();
if (top == null) {
return null;
}

int minDoc = top.doc;
DisiWrapper list = null;

try {
// First, collect all matching wrappers and their scores
float totalScore = 0;
int matchCount = 0;

// First pass: calculate total score
for (int i = 0; i < size; i++) {
DisiWrapper current = iterators[i];
if (current != null && current.doc == minDoc) {
float score = current.scorer.score();
totalScore += score;
matchCount++;
list = current;
}
}

// Advance all matching iterators
/*for (int i = 0; i < size; i++) {
DisiWrapper current = iterators[i];
if (current != null && current.doc == minDoc) {
current.doc = current.iterator.nextDoc();
}
}*/

} catch (IOException e) {
throw new RuntimeException(e);
}

return list;
}

@Override
public Iterator<DisiWrapper> iterator() {
return new Iterator<>() {
private DisiWrapper current = null;
private boolean initialized = false;

private void initializeIfNeeded() {
if (!initialized) {
current = topList();
initialized = true;
}
}

@Override
public boolean hasNext() {
initializeIfNeeded();
return current != null;
}

@Override
public DisiWrapper next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
DisiWrapper result = current;
current = topList();
return result;
}
};
}

public int size() {
return size;
}

public boolean isEmpty() {
return size == 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;

import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;

import java.io.IOException;

public class SimpleDisjunctionDISIApproximation extends DocIdSetIterator {

private final SimpleDisiIterator iterator;
private final long cost;
private int doc = -1;

public SimpleDisjunctionDISIApproximation(SimpleDisiIterator iterator) {
this.iterator = iterator;

// Calculate total cost
long totalCost = 0;
DisiWrapper top = iterator.top();
if (top != null) {
DisiWrapper current = iterator.topList();
while (current != null) {
totalCost += current.cost;
current = current.next;
}
}
this.cost = totalCost;
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() throws IOException {
DisiWrapper top = iterator.top();
if (top == null) {
return doc = NO_MORE_DOCS;
}

final int current = top.doc;

// Advance all iterators that are at current doc
DisiWrapper matchingList = iterator.topList();
while (matchingList != null) {
matchingList.doc = matchingList.approximation.nextDoc();
matchingList = matchingList.next;
}

return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS;
}

@Override
public int advance(int target) throws IOException {
DisiWrapper top = iterator.top();
if (top == null) {
return doc = NO_MORE_DOCS;
}

// If we're already at or past target, just do nextDoc()
if (top.doc >= target) {
return nextDoc();
}

// Advance all iterators to target
DisiWrapper matchingList = iterator.topList();
while (matchingList != null) {
matchingList.doc = matchingList.approximation.advance(target);
matchingList = matchingList.next;
}

return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS;
}

@Override
public long cost() {
return cost;
}

/**
* Returns the number of sub-iterators
*/
public int getSubIteratorCount() {
return iterator.size();
}

/**
* Returns list of matching sub-iterators at current position
*/
public DisiWrapper getMatches() {
return iterator.topList();
}
}

0 comments on commit ea92d08

Please sign in to comment.