Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize 2 keyword multi-terms aggregation #13929

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ protected Aggregator createInternal(
configs.stream()
.map(config -> queryShardContext.getValuesSourceRegistry().getAggregator(REGISTRY_KEY, config.v1()).build(config))
.collect(Collectors.toList()),
configs.stream().map(config -> config.v1().getValuesSource()).collect(Collectors.toList()),
configs.stream().map(c -> c.v1().format()).collect(Collectors.toList()),
order,
collectMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@

package org.opensearch.search.aggregations.bucket.terms;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.PriorityQueue;
Expand Down Expand Up @@ -62,19 +69,23 @@ public class MultiTermsAggregator extends DeferableBucketAggregator {

private final BytesKeyedBucketOrds bucketOrds;
private final MultiTermsValuesSource multiTermsValue;
private final List<ValuesSource> valuesSources;
private final boolean showTermDocCountError;
private final List<DocValueFormat> formats;
private final TermsAggregator.BucketCountThresholds bucketCountThresholds;
private final BucketOrder order;
private final Comparator<InternalMultiTerms.Bucket> partiallyBuiltBucketComparator;
private final SubAggCollectionMode collectMode;
private final Set<Aggregator> aggsUsedForSorting = new HashSet<>();
private Weight weight;
private static final Logger logger = LogManager.getLogger(MultiTermsAggregator.class);

public MultiTermsAggregator(
String name,
AggregatorFactories factories,
boolean showTermDocCountError,
List<InternalValuesSource> internalValuesSources,
List<ValuesSource> valuesSources,
List<DocValueFormat> formats,
BucketOrder order,
SubAggCollectionMode collectMode,
Expand All @@ -87,6 +98,7 @@ public MultiTermsAggregator(
super(name, factories, context, parent, metadata);
this.bucketOrds = BytesKeyedBucketOrds.build(context.bigArrays(), cardinality);
this.multiTermsValue = new MultiTermsValuesSource(internalValuesSources);
this.valuesSources = valuesSources;
this.showTermDocCountError = showTermDocCountError;
this.formats = formats;
this.bucketCountThresholds = bucketCountThresholds;
Expand Down Expand Up @@ -173,6 +185,10 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
return result;
}

public void setWeight(Weight weight) {
this.weight = weight;
}

InternalMultiTerms buildResult(long owningBucketOrd, long otherDocCount, InternalMultiTerms.Bucket[] topBuckets) {
BucketOrder reduceOrder;
if (isKeyOrder(order) == false) {
Expand Down Expand Up @@ -213,8 +229,114 @@ public InternalAggregation buildEmptyAggregation() {
);
}

private LeafBucketCollector getTermFrequencies(LeafReaderContext ctx) throws IOException {
// Instead of visiting doc values for each document, utilize posting data directly to get each composite bucket intersection
// For example, if we have a composite key of (a, b) where a is from field1 & b is from field2
// We can a find all the composite buckets by visiting both the posting lists
// and counting all the documents that intersect for each composite bucket.
// This is much faster than visiting the doc values for each document.

if (weight == null || weight.count(ctx) != ctx.reader().maxDoc()) {
// Weight not assigned - cannot use this optimization
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
// top-level query matches all docs in the segment
return null;
}

String field1, field2;
// Restricting the number of fields to 2 and only keyword fields with FieldData available
if (this.valuesSources.size() == 2
&& this.valuesSources.get(0) instanceof ValuesSource.Bytes.WithOrdinals.FieldData
&& this.valuesSources.get(1) instanceof ValuesSource.Bytes.WithOrdinals.FieldData) {
field1 = ((ValuesSource.Bytes.WithOrdinals.FieldData) valuesSources.get(0)).getIndexFieldName();
field2 = ((ValuesSource.Bytes.WithOrdinals.FieldData) valuesSources.get(1)).getIndexFieldName();

} else {
return null;
}

Terms segmentTerms1 = ctx.reader().terms(field1);
Terms segmentTerms2 = ctx.reader().terms(field2);

// TODO in this PR itself in coming commits:
// 1/ add check for fields cardinality - this might be ineffective for very high cardinality
// 2/ check for filter applied or not as default implementation might be resolving it as part of aggregation

TermsEnum segmentTermsEnum1 = segmentTerms1.iterator();

while (segmentTermsEnum1.next() != null) {
TermsEnum segmentTermsEnum2 = segmentTerms2.iterator();

while (segmentTermsEnum2.next() != null) {

PostingsEnum postings1 = segmentTermsEnum1.postings(null);
postings1.nextDoc();

PostingsEnum postings2 = segmentTermsEnum2.postings(null);
postings2.nextDoc();

int bucketCount = 0;

while (postings1.docID() != PostingsEnum.NO_MORE_DOCS && postings2.docID() != PostingsEnum.NO_MORE_DOCS) {

// Count of intersecting docs to get number of docs in each bucket
if (postings1.docID() == postings2.docID()) {
bucketCount++;
postings1.nextDoc();
postings2.nextDoc();
} else if (postings1.docID() < postings2.docID()) {
postings1.advance(postings2.docID());
} else {
postings2.advance(postings1.docID());
}
}
Comment on lines +280 to +292
Copy link
Contributor

@rishabhmaurya rishabhmaurya Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. The complexity of intersection logic is highly dependent on the documents in the posting lists. With larger datasets and higher cardinality, the leapfrogging method for intersection evaluation would require more frequent iterations over these lists, which can be expensive.


// For a key formed by value of t1 & a value of t2, create a composite key, convert it to byte ref and then update the
// ordinal data with count computed above
// The ordinal data is used to collect the sub-aggregations for each composite key
// The composite key is used to collect the buckets for each composite key
BytesRef v1 = segmentTermsEnum1.term();
BytesRef v2 = segmentTermsEnum2.term();

TermValue<BytesRef> termValue1 = new TermValue<>(v1, TermValue.BYTES_REF_WRITER);
TermValue<BytesRef> termValue2 = new TermValue<>(v2, TermValue.BYTES_REF_WRITER);

final BytesStreamOutput scratch = new BytesStreamOutput();
scratch.writeVInt(2); // number of fields per composite key
termValue1.writeTo(scratch);
termValue2.writeTo(scratch);
BytesRef compositeKeyBytesRef = scratch.bytes().toBytesRef(); // composite key formed
scratch.close();

long bucketOrd = bucketOrds.add(0, compositeKeyBytesRef);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
}
incrementBucketDocCount(bucketOrd, bucketCount);
}
}

return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
throw new CollectionTerminatedException();
}
};

}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

LeafBucketCollector optimizedCollector = this.getTermFrequencies(ctx);

if (optimizedCollector != null) {
logger.info("optimization used");
return optimizedCollector;
}

logger.info("optimization not not used");

MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
return new LeafBucketCollector() {
@Override
Expand Down Expand Up @@ -256,7 +378,7 @@ private boolean subAggsNeedScore() {

@Override
protected boolean shouldDefer(Aggregator aggregator) {
return collectMode == Aggregator.SubAggCollectionMode.BREADTH_FIRST && !aggsUsedForSorting.contains(aggregator);
return collectMode == SubAggCollectionMode.BREADTH_FIRST && !aggsUsedForSorting.contains(aggregator);
}

private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceType;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.lookup.LeafDocLookup;
Expand Down Expand Up @@ -102,6 +103,7 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase {
private static final String FLOAT_FIELD = "float";
private static final String DOUBLE_FIELD = "double";
private static final String KEYWORD_FIELD = "keyword";
private static final String KEYWORD_FIELD2 = "keyword2";
private static final String DATE_FIELD = "date";
private static final String IP_FIELD = "ip";
private static final String GEO_POINT_FIELD = "geopoint";
Expand All @@ -116,6 +118,7 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase {
put(DOUBLE_FIELD, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD, NumberFieldMapper.NumberType.DOUBLE));
put(DATE_FIELD, dateFieldType(DATE_FIELD));
put(KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(KEYWORD_FIELD));
put(KEYWORD_FIELD2, new KeywordFieldMapper.KeywordFieldType(KEYWORD_FIELD2));
put(IP_FIELD, new IpFieldMapper.IpFieldType(IP_FIELD));
put(FIELD_NAME, new NumberFieldMapper.NumberFieldType(FIELD_NAME, NumberFieldMapper.NumberType.INTEGER));
put(UNRELATED_KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(UNRELATED_KEYWORD_FIELD));
Expand Down Expand Up @@ -306,6 +309,41 @@ public void testMixNumberAndKeyword() throws IOException {
});
}

public void testKeywordAndKeywordField() throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, KEYWORD_FIELD2)), NONE_DECORATOR, iw -> {
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("n")),
new StringField(KEYWORD_FIELD2, new BytesRef("n"), Field.Store.NO)
)
);
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("n")),
new StringField(KEYWORD_FIELD2, new BytesRef("n"), Field.Store.NO)
)
);
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("m")),
new StringField(KEYWORD_FIELD2, new BytesRef("m"), Field.Store.NO)
)
);
}, h -> {
MatcherAssert.assertThat(h.getBuckets(), hasSize(2));
MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("n")));
MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L));
MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo("m")));
MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L));
});
}

public void testMultiValuesField() throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), NONE_DECORATOR, iw -> {
iw.addDocument(
Expand Down Expand Up @@ -885,6 +923,7 @@ public void testEmptyAggregations() throws IOException {
AggregatorFactories factories = AggregatorFactories.EMPTY;
boolean showTermDocCountError = true;
MultiTermsAggregator.InternalValuesSource internalValuesSources = mock(MultiTermsAggregator.InternalValuesSource.class);
ValuesSource valuesSource = mock(ValuesSource.class);
DocValueFormat format = mock(DocValueFormat.class);
BucketOrder order = mock(BucketOrder.class);
Aggregator.SubAggCollectionMode collectMode = Aggregator.SubAggCollectionMode.BREADTH_FIRST;
Expand All @@ -901,6 +940,7 @@ public void testEmptyAggregations() throws IOException {
factories,
showTermDocCountError,
List.of(internalValuesSources),
List.of(valuesSource),
List.of(format),
order,
collectMode,
Expand Down
Loading