Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Cardinality aggregation dynamic pruning changes (to be used only for prototype and reference purpose, not intended to merge to main) #12323

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.support.FieldContext;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand All @@ -71,6 +72,8 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue
private final int precision;
private final ValuesSource valuesSource;

private final FieldContext fieldContext;

// Expensive to initialize, so we only initialize it when we have an actual value source
@Nullable
private HyperLogLogPlusPlus counts;
Expand All @@ -95,6 +98,7 @@ public CardinalityAggregator(
// TODO: Stop using nulls here
this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null;
this.precision = precision;
this.fieldContext = valuesSourceConfig.fieldContext();
this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1);
}

Expand Down Expand Up @@ -132,11 +136,11 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
// only use ordinals if they don't increase memory usage by more than 25%
if (ordinalsMemoryUsage < countsMemoryUsage / 4) {
ordinalsCollectorsUsed++;
return new OrdinalsCollector(counts, ordinalValues, context.bigArrays());
return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
context, ctx, fieldContext, source);
}
ordinalsCollectorsOverheadTooHigh++;
}

stringHashingCollectorsUsed++;
return new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx)));
}
Expand Down Expand Up @@ -206,7 +210,7 @@ public void collectDebugInfo(BiConsumer<String, Object> add) {
*
* @opensearch.internal
*/
private abstract static class Collector extends LeafBucketCollector implements Releasable {
abstract static class Collector extends LeafBucketCollector implements Releasable {

public abstract void postCollect() throws IOException;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.PriorityQueue;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;


/**
* Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications -
* 1. {@link #removeAllDISIsOnCurrentDoc()} - it removes all the DISIs for subscorer pointing to current doc. This is
* helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for
* rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process.
* <p>
* 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be
* less than or equal to current docID iterator is pointing to. When we remove elements from priority, it results in heapify action, which modifies
* the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the
* iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()}
* is called and updates this docID only when next() or advance() is called.
*/
public class DisjunctionWithDynamicPruningScorer extends Scorer {

private final boolean needsScores;
private final DisiPriorityQueue subScorers;
private final DocIdSetIterator approximation;
private final TwoPhase twoPhase;

private Integer docID;

public DisjunctionWithDynamicPruningScorer(Weight weight, List<Scorer> subScorers)
throws IOException {
super(weight);
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
this.subScorers.add(w);
}
this.needsScores = false;
this.approximation = new DisjunctionDISIApproximation(this.subScorers);

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : this.subScorers) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}

if (hasApproximation == false) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost);
}
}

public void removeAllDISIsOnCurrentDoc() {
docID = this.docID();
while (subScorers.size() > 0 && subScorers.top().doc == docID) {
subScorers.pop();
}
}

@Override
public DocIdSetIterator iterator() {
DocIdSetIterator disi = getIterator();
docID = disi.docID();
return new SlowDocIdPropagatorDISI(getIterator(), docID);
}

private static class SlowDocIdPropagatorDISI extends DocIdSetIterator {
DocIdSetIterator disi;

Integer curDocId;

SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) {
this.disi = disi;
this.curDocId = curDocId;
}

@Override
public int docID() {
assert curDocId <= disi.docID();
return curDocId;
}

@Override
public int nextDoc() throws IOException {
return advance(curDocId + 1);
}

@Override
public int advance(int i) throws IOException {
if (i <= disi.docID()) {
// since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i
// in such scenarios we can simply return the docID where disi is pointing to and update the curDocId
curDocId = disi.docID();
return disi.docID();
}
curDocId = disi.advance(i);
return curDocId;
}

@Override
public long cost() {
return disi.cost();
}
}

private DocIdSetIterator getIterator() {
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

@Override
public float getMaxScore(int i) throws IOException {
return 0;
}

private class TwoPhase extends TwoPhaseIterator {

private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;

private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
}

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null; ) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;

if (needsScores == false) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(w);
}
w = next;
}

if (verifiedMatches != null) {
return true;
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;
}
}

return false;
}

@Override
public float matchCost() {
return matchCost;
}
}


@Override
public final int docID() {
return subScorers.top().doc;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();
} else {
return twoPhase.getSubMatches();
}
}

@Override
public final float score() throws IOException {
return score(getSubMatches());
}

protected float score(DisiWrapper topList) throws IOException {
return 1f;
}

@Override
public final Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));
}
return children;
}
}
Loading
Loading