Skip to content

Commit

Permalink
Step 2 - fixed custom scenario, base scenarious work too
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 29, 2024
1 parent 7c84340 commit f858be0
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -142,7 +143,7 @@ public int hashCode() {
}

public Collection<Query> getSubQueries() {
return subQueries;
return Collections.unmodifiableCollection(subQueries);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
// return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class HybridQueryScorer extends Scorer {

private final DocIdSetIterator approximation;
HybridScorePropagator disjunctionBlockPropagator;
// private final TwoPhase twoPhase;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
Expand All @@ -64,7 +64,7 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode score
// base
this.subScorersPQ = initializeSubScorersPQ();
// base
// this.approximation = new DisjunctionDISIApproximation(this.subScorersPQ);
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ);
// max
if (scoreMode == ScoreMode.TOP_SCORES) {
Expand All @@ -86,12 +86,12 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode score
sumMatchCost += w.matchCost * costWeight;
}
}
/*if (hasApproximation == false) { // no sub scorer supports approximations
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ);
}*/
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);
}
}

@Override
Expand Down Expand Up @@ -147,12 +147,11 @@ private float score(DisiWrapper topList) throws IOException {
}

DisiWrapper getSubMatches() throws IOException {
// if (twoPhase == null) {
return subScorersPQ.topList();
// } else {
// return twoPhase.getSubMatches();
// }
// return subScorersPQ.topList();
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();
}
}

/**
Expand All @@ -161,19 +160,17 @@ DisiWrapper getSubMatches() throws IOException {
*/
@Override
public DocIdSetIterator iterator() {
/*if (twoPhase != null) {
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {*/
return approximation;
// }
// return new DisjunctionDISIApproximation(this.subScorersPQ);
// return new HybridDisjunctionDISIApproximation(this.subScorersPQ);
} else {
return approximation;
}
}

/*@Override
@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}*/
}

/**
* Return the maximum score that documents between the last target that this iterator was shallow-advanced to included and upTo included.
Expand Down Expand Up @@ -295,24 +292,25 @@ public Collection<ChildScorable> getChildren() throws IOException {
}

static class TwoPhase extends TwoPhaseIterator {

private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers) {
private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<DisiWrapper>(subScorers.size()) {
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

DisiWrapper getSubMatches() throws IOException {
Expand Down Expand Up @@ -340,10 +338,10 @@ public boolean matches() throws IOException {
w.next = verifiedMatches;
verifiedMatches = w;

// if (needsScores == false) {
// we can stop here
// return true;
// }
if (!needsScores) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(w);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
*/
public final class HybridQueryWeight extends Weight {

private final HybridQuery queries;
// The Weights for our subqueries, in 1-1 correspondence
private final List<Weight> weights;

Expand All @@ -38,7 +37,6 @@ public final class HybridQueryWeight extends Weight {
*/
public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
super(hybridQuery);
this.queries = hybridQuery;
weights = hybridQuery.getSubQueries().stream().map(q -> {
try {
return searcher.createWeight(q, scoreMode, boost);
Expand Down Expand Up @@ -72,75 +70,18 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
// critical section
/*List<ScorerSupplier> scorerSuppliers = new ArrayList<>();
for (Weight w : weights) {
ScorerSupplier ss = w.scorerSupplier(context);
if (ss != null) {
scorerSuppliers.add(ss);
}
}
if (scorerSuppliers.isEmpty()) {
return null;
// } else if (scorerSuppliers.size() == 1) {
// return scorerSuppliers.get(0);
} else {
final Weight thisWeight = this;
return new ScorerSupplier() {
private long cost = -1;
@Override
public Scorer get(long leadCost) throws IOException {
List<Scorer> scorers = new ArrayList<>();
for (ScorerSupplier ss : scorerSuppliers) {
scorers.add(ss.get(leadCost));
}
return new HybridQueryScorer(thisWeight, scorers, scoreMode);
}
@Override
public long cost() {
if (cost == -1) {
long cost = 0;
for (ScorerSupplier ss : scorerSuppliers) {
cost += ss.cost();
}
this.cost = cost;
}
return cost;
}
@Override
public void setTopLevelScoringClause() throws IOException {
for (ScorerSupplier ss : scorerSuppliers) {
// sub scorers need to be able to skip too as calls to setMinCompetitiveScore get
// propagated
ss.setTopLevelScoringClause();
}
}
};
}*/
// return super.scorerSupplier(context);
List<ScorerSupplier> scorerSuppliers = new ArrayList<>();
for (Weight w : weights) {
ScorerSupplier ss = w.scorerSupplier(context);
scorerSuppliers.add(ss);
}
List<Scorer> scorers = weights.stream().map(w -> {
try {
return w.scorer(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());

if (scorerSuppliers.isEmpty()) {
return null;
} else {
final Weight thisWeight = this;
return new ScorerSupplier() {

private long cost = -1;

@Override
Expand All @@ -161,14 +102,10 @@ public Scorer get(long leadCost) throws IOException {
public long cost() {
if (cost == -1) {
long cost = 0;
for (int i = 0; i < scorerSuppliers.size(); i++) {
ScorerSupplier ss = scorerSuppliers.get(i);
for (ScorerSupplier ss : scorerSuppliers) {
if (Objects.nonNull(ss)) {
cost += ss.cost();
} /*else {
cost += scorers.get(i).iterator().cost();
}*/
// cost += ss.cost();
}
}
this.cost = cost;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import org.apache.lucene.index.IndexReader;
Expand Down Expand Up @@ -77,16 +74,10 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
/*if (isHybridQuery(query, searchContext)) {
Query hybridQuery = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
validateQuery(searchContext, query);*/
if (!isHybridQuery(query, searchContext)) {
validateQuery(searchContext, query);
} else {
Query hybridQuery = extractHybridQuery(searchContext, query);
// return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
Expand Down Expand Up @@ -255,15 +246,6 @@ private void setTopDocsInQueryResult(
final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
// update total size
updateQueryResult(queryResult, searchContext);
}

private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) {
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard) {
searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length);
}
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
Expand Down Expand Up @@ -317,30 +299,20 @@ private TotalHits getTotalHits(final SearchContext searchContext, final List<Top
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
/*long maxTotalHits = topDocs.get(0).totalHits.value;
long maxTotalHits = topDocs.get(0).totalHits.value;
int totalSize = 0;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
if (isSingleShard) {
totalSize += topDoc.totalHits.value + 1;
}
}*/
List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();

// add 1 qty per each sub-query and + 2 for start and stop delimiters
/*totalSize += 2;
totalSize += 2;
if (isSingleShard) {
// for single shard we need to update total size as this is how many docs are fetched in Fetch phase
searchContext.size(totalSize);
}*/
}

return new TotalHits(maxTotalHits, relation);
}
Expand Down

0 comments on commit f858be0

Please sign in to comment.