Skip to content

Commit

Permalink
Cardinality aggregation dynamic pruning changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhmaurya committed Feb 14, 2024
1 parent 84750b3 commit ea3e08c
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.support.FieldContext;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand All @@ -71,6 +72,8 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue
private final int precision;
private final ValuesSource valuesSource;

private final FieldContext fieldContext;

// Expensive to initialize, so we only initialize it when we have an actual value source
@Nullable
private HyperLogLogPlusPlus counts;
Expand All @@ -95,6 +98,7 @@ public CardinalityAggregator(
// TODO: Stop using nulls here
this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null;
this.precision = precision;
this.fieldContext = valuesSourceConfig.fieldContext();
this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1);
}

Expand Down Expand Up @@ -132,11 +136,11 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
// only use ordinals if they don't increase memory usage by more than 25%
if (ordinalsMemoryUsage < countsMemoryUsage / 4) {
ordinalsCollectorsUsed++;
return new OrdinalsCollector(counts, ordinalValues, context.bigArrays());
return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
context, ctx, fieldContext, source);
}
ordinalsCollectorsOverheadTooHigh++;
}

stringHashingCollectorsUsed++;
return new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx)));
}
Expand Down Expand Up @@ -206,7 +210,7 @@ public void collectDebugInfo(BiConsumer<String, Object> add) {
*
* @opensearch.internal
*/
private abstract static class Collector extends LeafBucketCollector implements Releasable {
abstract static class Collector extends LeafBucketCollector implements Releasable {

public abstract void postCollect() throws IOException;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.PriorityQueue;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;


/**
* Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications -
* 1. {@link #removeAllDISIsOnCurrentDoc()} - it removes all the DISIs for subscorer pointing to current doc. This is
* helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for
* rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process.
* <p>
* 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be
* less than or equal to current docID iterator is pointing to. When we remove elements from priority, it results in heapify action, which modifies
* the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the
* iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()}
* is called and updates this docID only when next() or advance() is called.
*/
public class DisjunctionWithDynamicPruningScorer extends Scorer {

private final boolean needsScores;
private final DisiPriorityQueue subScorers;
private final DocIdSetIterator approximation;
private final TwoPhase twoPhase;

private Integer docID;

public DisjunctionWithDynamicPruningScorer(Weight weight, List<Scorer> subScorers)
throws IOException {
super(weight);
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
this.subScorers.add(w);
}
this.needsScores = false;
this.approximation = new DisjunctionDISIApproximation(this.subScorers);

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : this.subScorers) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}

if (hasApproximation == false) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost);
}
}

public void removeAllDISIsOnCurrentDoc() {
docID = this.docID();
while (subScorers.size() > 0 && subScorers.top().doc == docID) {
subScorers.pop();
}
}

@Override
public DocIdSetIterator iterator() {
DocIdSetIterator disi = getIterator();
docID = disi.docID();
return new SlowDocIdPropagatorDISI(getIterator(), docID);
}

private static class SlowDocIdPropagatorDISI extends DocIdSetIterator {
DocIdSetIterator disi;

Integer curDocId;

SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) {
this.disi = disi;
this.curDocId = curDocId;
}

@Override
public int docID() {
assert curDocId <= disi.docID();
return curDocId;
}

@Override
public int nextDoc() throws IOException {
return advance(curDocId + 1);
}

@Override
public int advance(int i) throws IOException {
if (i <= disi.docID()) {
// since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i
// in such scenarios we can simply return the docID where disi is pointing to and update the curDocId
curDocId = disi.docID();
return disi.docID();
}
curDocId = disi.advance(i);
return curDocId;
}

@Override
public long cost() {
return disi.cost();
}
}

private DocIdSetIterator getIterator() {
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

@Override
public float getMaxScore(int i) throws IOException {
return 0;
}

private class TwoPhase extends TwoPhaseIterator {

private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;

private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
}

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null; ) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;

if (needsScores == false) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(w);
}
w = next;
}

if (verifiedMatches != null) {
return true;
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;
}
}

return false;
}

@Override
public float matchCost() {
return matchCost;
}
}


@Override
public final int docID() {
return subScorers.top().doc;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();
} else {
return twoPhase.getSubMatches();
}
}

@Override
public final float score() throws IOException {
return score(getSubMatches());
}

protected float score(DisiWrapper topList) throws IOException {
return 1f;
}

@Override
public final Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));
}
return children;
}
}
Loading

0 comments on commit ea3e08c

Please sign in to comment.