From ea3e08c9e04128466bd18ccaa633ce2c82def2bc Mon Sep 17 00:00:00 2001 From: Rishabh Maurya Date: Wed, 14 Feb 2024 11:40:43 -0800 Subject: [PATCH] Cardinality aggregation dynamic pruning changes --- .../metrics/CardinalityAggregator.java | 10 +- .../DisjunctionWithDynamicPruningScorer.java | 264 ++++++++++++++++++ .../DynamicPruningCollectorWrapper.java | 106 +++++++ .../metrics/CardinalityAggregatorTests.java | 58 ++++ 4 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 99c4eaac4b777..91887e2e4a202 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -53,6 +53,7 @@ import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.support.FieldContext; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -71,6 +72,8 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private final int precision; private final ValuesSource valuesSource; + private final FieldContext fieldContext; + // Expensive to initialize, so we only initialize it when we have an actual value source @Nullable private HyperLogLogPlusPlus counts; @@ -95,6 +98,7 @@ public CardinalityAggregator( // TODO: Stop using nulls here this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null; this.precision = precision; + this.fieldContext = valuesSourceConfig.fieldContext(); this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1); } @@ -132,11 +136,11 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { // only use ordinals if they don't increase memory usage by more than 25% if (ordinalsMemoryUsage < countsMemoryUsage / 4) { ordinalsCollectorsUsed++; - return new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); + return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + context, ctx, fieldContext, source); } ordinalsCollectorsOverheadTooHigh++; } - stringHashingCollectorsUsed++; return new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx))); } @@ -206,7 +210,7 @@ public void collectDebugInfo(BiConsumer add) { * * @opensearch.internal */ - private abstract static class Collector extends LeafBucketCollector implements Releasable { + abstract static class Collector extends LeafBucketCollector implements Releasable { public abstract void postCollect() throws IOException; diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java new file mode 100644 index 0000000000000..6a7e66e8be2f0 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -0,0 +1,264 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.metrics; + +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.PriorityQueue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + + +/** + * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications - + * 1. {@link #removeAllDISIsOnCurrentDoc()} - it removes all the DISIs for subscorer pointing to current doc. This is + * helpful in dynamic pruning for Cardinality aggregation, where once a term is found, it becomes irrelevant for + * rest of the search space, so this term's subscorer DISI can be safely removed from list of subscorer to process. + *

+ * 2. {@link #removeAllDISIsOnCurrentDoc()} breaks the invariant of Conjuction DISI i.e. the docIDs of all sub-scorers should be + * less than or equal to current docID iterator is pointing to. When we remove elements from priority, it results in heapify action, which modifies + * the top of the priority queye, which represents the current docID for subscorers here. To address this, we are wrapping the + * iterator with {@link SlowDocIdPropagatorDISI} which keeps the iterator pointing to last docID before {@link #removeAllDISIsOnCurrentDoc()} + * is called and updates this docID only when next() or advance() is called. + */ +public class DisjunctionWithDynamicPruningScorer extends Scorer { + + private final boolean needsScores; + private final DisiPriorityQueue subScorers; + private final DocIdSetIterator approximation; + private final TwoPhase twoPhase; + + private Integer docID; + + public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) + throws IOException { + super(weight); + if (subScorers.size() <= 1) { + throw new IllegalArgumentException("There must be at least 2 subScorers"); + } + this.subScorers = new DisiPriorityQueue(subScorers.size()); + for (Scorer scorer : subScorers) { + final DisiWrapper w = new DisiWrapper(scorer); + this.subScorers.add(w); + } + this.needsScores = false; + this.approximation = new DisjunctionDISIApproximation(this.subScorers); + + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : this.subScorers) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + + if (hasApproximation == false) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost); + } + } + + public void removeAllDISIsOnCurrentDoc() { + docID = this.docID(); + while (subScorers.size() > 0 && subScorers.top().doc == docID) { + subScorers.pop(); + } + } + + @Override + public DocIdSetIterator iterator() { + DocIdSetIterator disi = getIterator(); + docID = disi.docID(); + return new SlowDocIdPropagatorDISI(getIterator(), docID); + } + + private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { + DocIdSetIterator disi; + + Integer curDocId; + + SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) { + this.disi = disi; + this.curDocId = curDocId; + } + + @Override + public int docID() { + assert curDocId <= disi.docID(); + return curDocId; + } + + @Override + public int nextDoc() throws IOException { + return advance(curDocId + 1); + } + + @Override + public int advance(int i) throws IOException { + if (i <= disi.docID()) { + // since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i + // in such scenarios we can simply return the docID where disi is pointing to and update the curDocId + curDocId = disi.docID(); + return disi.docID(); + } + curDocId = disi.advance(i); + return curDocId; + } + + @Override + public long cost() { + return disi.cost(); + } + } + + private DocIdSetIterator getIterator() { + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; + } + + @Override + public float getMaxScore(int i) throws IOException { + return 0; + } + + private class TwoPhase extends TwoPhaseIterator { + + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + + private TwoPhase(DocIdSetIterator approximation, float matchCost) { + super(approximation); + this.matchCost = matchCost; + unverifiedMatches = + new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + } + + DisiWrapper getSubMatches() throws IOException { + // iteration order does not matter + for (DisiWrapper w : unverifiedMatches) { + if (w.twoPhaseView.matches()) { + w.next = verifiedMatches; + verifiedMatches = w; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper w = subScorers.topList(); w != null; ) { + DisiWrapper next = w.next; + + if (w.twoPhaseView == null) { + // implicitly verified, move it to verifiedMatches + w.next = verifiedMatches; + verifiedMatches = w; + + if (needsScores == false) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(w); + } + w = next; + } + + if (verifiedMatches != null) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper w = unverifiedMatches.pop(); + if (w.twoPhaseView.matches()) { + w.next = null; + verifiedMatches = w; + return true; + } + } + + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + + @Override + public final int docID() { + return subScorers.top().doc; + } + + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorers.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + + @Override + public final float score() throws IOException { + return score(getSubMatches()); + } + + protected float score(DisiWrapper topList) throws IOException { + return 1f; + } + + @Override + public final Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java new file mode 100644 index 0000000000000..f4c3d59a3833f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.metrics; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.ConjunctionUtils; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.opensearch.search.aggregations.support.FieldContext; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { + + private final LeafReaderContext ctx; + private final DisjunctionWithDynamicPruningScorer disjunctionScorer; + private final DocIdSetIterator disi; + private final CardinalityAggregator.Collector delegateCollector; + + DynamicPruningCollectorWrapper(CardinalityAggregator.Collector delegateCollector, + SearchContext context, LeafReaderContext ctx, FieldContext fieldContext, + ValuesSource.Bytes.WithOrdinals source) throws IOException { + this.ctx = ctx; + this.delegateCollector = delegateCollector; + final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + boolean isCardinalityLow = ordinalValues.getValueCount() < 10; + boolean isCardinalityAggregationOnlyAggregation = true; + boolean isFieldSupportedForDynamicPruning = true; + if (isCardinalityLow && isCardinalityAggregationOnlyAggregation && isFieldSupportedForDynamicPruning) { + // create disjunctions from terms + // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator + TermsEnum terms = ordinalValues.termsEnum(); + Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); + Map found = new HashMap<>(); + List subScorers = new ArrayList<>(); + while (terms.next() != null && !found.containsKey(terms.ord())) { + // TODO can we get rid of terms previously encountered in other segments? + TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), terms.term())); + Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + if (scorer != null) { + subScorers.add(scorer); + } + found.put(terms.ord(), true); + } + disjunctionScorer = new DisjunctionWithDynamicPruningScorer(weight, subScorers); + disi = ConjunctionUtils.intersectScorers(List.of(disjunctionScorer, weight.scorer(ctx))); + } else { + disjunctionScorer = null; + disi = null; + } + } + + @Override + public void collect(int doc, long bucketOrd) throws IOException { + if (disi == null || disjunctionScorer == null) { + delegateCollector.collect(doc, bucketOrd); + } else { + // perform the full iteration using dynamic pruning of DISIs and return right away + disi.advance(doc); + int currDoc = disi.docID(); + assert currDoc == doc; + final Bits liveDocs = ctx.reader().getLiveDocs(); + assert liveDocs == null || liveDocs.get(currDoc); + do { + if (liveDocs == null || liveDocs.get(currDoc)) { + delegateCollector.collect(currDoc, bucketOrd); + disjunctionScorer.removeAllDISIsOnCurrentDoc(); + } + currDoc = disi.nextDoc(); + } while (currDoc != DocIdSetIterator.NO_MORE_DOCS); + throw new CollectionTerminatedException(); + } + } + + @Override + public void close() { + delegateCollector.close(); + } + + @Override + public void postCollect() throws IOException { + delegateCollector.postCollect(); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index cdd17e2fa7dd6..a9966c9e70e76 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -33,15 +33,22 @@ package org.opensearch.search.aggregations.metrics; import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KeywordField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.geo.GeoPoint; +import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.mapper.RangeFieldMapper; @@ -56,6 +63,7 @@ import java.util.Set; import java.util.function.Consumer; +import static java.util.Arrays.asList; import static java.util.Collections.singleton; public class CardinalityAggregatorTests extends AggregatorTestCase { @@ -90,6 +98,56 @@ public void testRangeFieldValues() throws IOException { }, fieldType); } + public void testDynamicPruningOrdinalCollector() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + iw.addDocument(asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "3", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("3")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "4", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("4")) + )); + iw.addDocument(asList( + new KeywordField(fieldName, "5", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("5")) + )); + }, card -> { + assertEquals(3.0, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, fieldType); + } + public void testNoMatchingField() throws IOException { testAggregation(new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 7)));