From ced2b87d57fa821f0fd8d5148c1e75c7a7fb7acd Mon Sep 17 00:00:00 2001 From: Sandesh Kumar Date: Tue, 12 Mar 2024 16:45:02 -0700 Subject: [PATCH] Quickly compute terms aggregations when the top-level query is functionally match-all for a segment (#11643) --------- Signed-off-by: Sandesh Kumar --- CHANGELOG.md | 1 + .../GlobalOrdinalsStringTermsAggregator.java | 113 ++++++++++- .../aggregations/support/ValuesSource.java | 4 + .../search/internal/ContextIndexSearcher.java | 5 + .../terms/KeywordTermsAggregatorTests.java | 76 ++++--- .../bucket/terms/TermsAggregatorTests.java | 189 ++++++++++++++---- .../search/query/QueryPhaseTests.java | 28 ++- .../aggregations/AggregatorTestCase.java | 97 ++++++++- 8 files changed, 431 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad4f27f080d1b..1454c83ddfe22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Allow composite aggregation to run under a parent filter aggregation ([#11499](https://github.com/opensearch-project/OpenSearch/pull/11499)) +- Quickly compute terms aggregations when the top-level query is functionally match-all for a segment ([#11643](https://github.com/opensearch-project/OpenSearch/pull/11643)) ### Deprecated diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 5ed899408ab40..69fda2f3f6133 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -35,8 +35,13 @@ import org.apache.lucene.index.DocValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.PriorityQueue; @@ -46,6 +51,7 @@ import org.opensearch.common.util.LongHash; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.DocCountFieldMapper; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.AggregationExecutionException; import org.opensearch.search.aggregations.Aggregator; @@ -73,6 +79,7 @@ import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder; import static org.apache.lucene.index.SortedSetDocValues.NO_MORE_ORDS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * An aggregator of string values that relies on global ordinals in order to build buckets. @@ -85,6 +92,8 @@ public class GlobalOrdinalsStringTermsAggregator extends AbstractStringTermsAggr private final LongPredicate acceptedGlobalOrdinals; private final long valueCount; + private final String fieldName; + private Weight weight; private final GlobalOrdLookupFunction lookupGlobalOrd; protected final CollectionStrategy collectionStrategy; protected int segmentsWithSingleValuedOrds = 0; @@ -136,16 +145,105 @@ public GlobalOrdinalsStringTermsAggregator( return new DenseGlobalOrds(); }); } + this.fieldName = (valuesSource instanceof ValuesSource.Bytes.WithOrdinals.FieldData) + ? ((ValuesSource.Bytes.WithOrdinals.FieldData) valuesSource).getIndexFieldName() + : null; } String descriptCollectionStrategy() { return collectionStrategy.describe(); } + public void setWeight(Weight weight) { + this.weight = weight; + } + + /** + Read doc frequencies directly from indexed terms in the segment to skip iterating through individual documents + @param ctx The LeafReaderContext to collect terms from + @param globalOrds The SortedSetDocValues for the field's ordinals + @param ordCountConsumer A consumer to accept collected term frequencies + @return A LeafBucketCollector implementation with collection termination, since collection is complete + @throws IOException If an I/O error occurs during reading + */ + LeafBucketCollector termDocFreqCollector( + LeafReaderContext ctx, + SortedSetDocValues globalOrds, + BiConsumer ordCountConsumer + ) throws IOException { + if (weight == null) { + // Weight not assigned - cannot use this optimization + return null; + } else { + if (weight.count(ctx) == 0) { + // No documents matches top level query on this segment, we can skip the segment entirely + return LeafBucketCollector.NO_OP_COLLECTOR; + } else if (weight.count(ctx) != ctx.reader().maxDoc()) { + // weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and + // top-level query matches all docs in the segment + return null; + } + } + + Terms segmentTerms = ctx.reader().terms(this.fieldName); + if (segmentTerms == null) { + // Field is not indexed. + return null; + } + + NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME); + if (docCountValues.nextDoc() != NO_MORE_DOCS) { + // This segment has at least one document with the _doc_count field. + return null; + } + + TermsEnum indexTermsEnum = segmentTerms.iterator(); + BytesRef indexTerm = indexTermsEnum.next(); + TermsEnum globalOrdinalTermsEnum = globalOrds.termsEnum(); + BytesRef ordinalTerm = globalOrdinalTermsEnum.next(); + + // Iterate over the terms in the segment, look for matches in the global ordinal terms, + // and increment bucket count when segment terms match global ordinal terms. + while (indexTerm != null && ordinalTerm != null) { + int compare = indexTerm.compareTo(ordinalTerm); + if (compare == 0) { + if (acceptedGlobalOrdinals.test(globalOrdinalTermsEnum.ord())) { + ordCountConsumer.accept(globalOrdinalTermsEnum.ord(), indexTermsEnum.docFreq()); + } + indexTerm = indexTermsEnum.next(); + ordinalTerm = globalOrdinalTermsEnum.next(); + } else if (compare < 0) { + indexTerm = indexTermsEnum.next(); + } else { + ordinalTerm = globalOrdinalTermsEnum.next(); + } + } + return new LeafBucketCollector() { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + throw new CollectionTerminatedException(); + } + }; + } + @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx); collectionStrategy.globalOrdsReady(globalOrds); + + if (collectionStrategy instanceof DenseGlobalOrds + && this.resultStrategy instanceof StandardTermsResults + && sub == LeafBucketCollector.NO_OP_COLLECTOR) { + LeafBucketCollector termDocFreqCollector = termDocFreqCollector( + ctx, + globalOrds, + (ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount) + ); + if (termDocFreqCollector != null) { + return termDocFreqCollector; + } + } + SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds); if (singleValues != null) { segmentsWithSingleValuedOrds++; @@ -343,9 +441,20 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx); segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount()); assert sub == LeafBucketCollector.NO_OP_COLLECTOR; - final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds); mapping = valuesSource.globalOrdinalsMapping(ctx); - // Dense mode doesn't support include/exclude so we don't have to check it here. + + if (this.resultStrategy instanceof StandardTermsResults) { + LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector( + ctx, + segmentOrds, + (ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount) + ); + if (termDocFreqCollector != null) { + return termDocFreqCollector; + } + } + + final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds); if (singleValues != null) { segmentsWithSingleValuedOrds++; return resultStrategy.wrapCollector(new LeafBucketCollectorBase(sub, segmentOrds) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java index 3ce1f0447dfcc..1f4dd429e094e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java @@ -244,6 +244,10 @@ public FieldData(IndexOrdinalsFieldData indexFieldData) { this.indexFieldData = indexFieldData; } + public String getIndexFieldName() { + return this.indexFieldData.getFieldName(); + } + @Override public SortedBinaryDocValues bytesValues(LeafReaderContext context) { final LeafOrdinalsFieldData atomicFieldData = indexFieldData.load(context); diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index 403b0b545c113..ec3ed2332d0b8 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -387,6 +387,11 @@ public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { return null; } } + + @Override + public int count(LeafReaderContext context) throws IOException { + return weight.count(context); + } }; } else { return weight; diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java index 4229361aa7f46..753644dce81d5 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java @@ -32,7 +32,6 @@ package org.opensearch.search.aggregations.bucket.terms; import org.apache.lucene.document.Document; -import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; @@ -41,7 +40,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; -import org.apache.lucene.util.BytesRef; +import org.opensearch.common.TriConsumer; import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.search.aggregations.AggregatorTestCase; @@ -57,6 +56,8 @@ public class KeywordTermsAggregatorTests extends AggregatorTestCase { private static final String KEYWORD_FIELD = "keyword"; + private static final Consumer CONFIGURE_KEYWORD_FIELD = agg -> agg.field(KEYWORD_FIELD); + private static final List dataset; static { List d = new ArrayList<>(45); @@ -68,51 +69,63 @@ public class KeywordTermsAggregatorTests extends AggregatorTestCase { dataset = d; } + private static final Consumer VERIFY_MATCH_ALL_DOCS = agg -> { + assertEquals(9, agg.getBuckets().size()); + for (int i = 0; i < 9; i++) { + StringTerms.Bucket bucket = (StringTerms.Bucket) agg.getBuckets().get(i); + assertThat(bucket.getKey(), equalTo(String.valueOf(9L - i))); + assertThat(bucket.getDocCount(), equalTo(9L - i)); + } + }; + + private static final Consumer VERIFY_MATCH_NO_DOCS = agg -> { assertEquals(0, agg.getBuckets().size()); }; + + private static final Query MATCH_ALL_DOCS_QUERY = new MatchAllDocsQuery(); + + private static final Query MATCH_NO_DOCS_QUERY = new MatchNoDocsQuery(); + public void testMatchNoDocs() throws IOException { testSearchCase( - new MatchNoDocsQuery(), + ADD_SORTED_SET_FIELD_NOT_INDEXED, + MATCH_NO_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - agg -> assertEquals(0, agg.getBuckets().size()), - null // without type hint + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_NO_DOCS, + null // without type hint ); testSearchCase( - new MatchNoDocsQuery(), + ADD_SORTED_SET_FIELD_NOT_INDEXED, + MATCH_NO_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - agg -> assertEquals(0, agg.getBuckets().size()), - ValueType.STRING // with type hint + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_NO_DOCS, + ValueType.STRING // with type hint ); } public void testMatchAllDocs() throws IOException { - Query query = new MatchAllDocsQuery(); - - testSearchCase(query, dataset, aggregation -> aggregation.field(KEYWORD_FIELD), agg -> { - assertEquals(9, agg.getBuckets().size()); - for (int i = 0; i < 9; i++) { - StringTerms.Bucket bucket = (StringTerms.Bucket) agg.getBuckets().get(i); - assertThat(bucket.getKey(), equalTo(String.valueOf(9L - i))); - assertThat(bucket.getDocCount(), equalTo(9L - i)); - } - }, - null // without type hint + testSearchCase( + ADD_SORTED_SET_FIELD_NOT_INDEXED, + MATCH_ALL_DOCS_QUERY, + dataset, + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_ALL_DOCS, + null // without type hint ); - testSearchCase(query, dataset, aggregation -> aggregation.field(KEYWORD_FIELD), agg -> { - assertEquals(9, agg.getBuckets().size()); - for (int i = 0; i < 9; i++) { - StringTerms.Bucket bucket = (StringTerms.Bucket) agg.getBuckets().get(i); - assertThat(bucket.getKey(), equalTo(String.valueOf(9L - i))); - assertThat(bucket.getDocCount(), equalTo(9L - i)); - } - }, - ValueType.STRING // with type hint + testSearchCase( + ADD_SORTED_SET_FIELD_NOT_INDEXED, + MATCH_ALL_DOCS_QUERY, + dataset, + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_ALL_DOCS, + ValueType.STRING // with type hint ); } private void testSearchCase( + TriConsumer addField, Query query, List dataset, Consumer configure, @@ -123,7 +136,7 @@ private void testSearchCase( try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { Document document = new Document(); for (String value : dataset) { - document.add(new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef(value))); + addField.apply(document, KEYWORD_FIELD, value); indexWriter.addDocument(document); document.clear(); } @@ -147,5 +160,4 @@ private void testSearchCase( } } } - } diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java index 80744ecde4d69..cfb04d2aa1d19 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java @@ -44,6 +44,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; @@ -52,6 +53,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NumericUtils; +import org.opensearch.common.TriConsumer; import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.network.InetAddresses; import org.opensearch.common.settings.Settings; @@ -120,6 +122,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; @@ -136,9 +139,6 @@ import static org.mockito.Mockito.when; public class TermsAggregatorTests extends AggregatorTestCase { - - private boolean randomizeAggregatorImpl = true; - // Constants for a script that returns a string private static final String STRING_SCRIPT_NAME = "string_script"; private static final String STRING_SCRIPT_OUTPUT = "Orange"; @@ -171,9 +171,22 @@ protected ScriptService getMockScriptService() { return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); } + protected CountingAggregator createCountingAggregator( + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + boolean randomizeAggregatorImpl, + MappedFieldType... fieldTypes + ) throws IOException { + return new CountingAggregator( + new AtomicInteger(), + createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldTypes) + ); + } + protected A createAggregator( AggregationBuilder aggregationBuilder, IndexSearcher indexSearcher, + boolean randomizeAggregatorImpl, MappedFieldType... fieldTypes ) throws IOException { try { @@ -188,6 +201,14 @@ protected A createAggregator( } } + protected A createAggregator( + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + MappedFieldType... fieldTypes + ) throws IOException { + return createAggregator(aggregationBuilder, indexSearcher, true, fieldTypes); + } + @Override protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) { return new TermsAggregationBuilder("foo").field(fieldName); @@ -207,8 +228,7 @@ protected List getSupportedValuesSourceTypes() { } public void testUsesGlobalOrdinalsByDefault() throws Exception { - randomizeAggregatorImpl = false; - + boolean randomizeAggregatorImpl = false; Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); indexWriter.close(); @@ -220,35 +240,35 @@ public void testUsesGlobalOrdinalsByDefault() throws Exception { .field("string"); MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); - TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); GlobalOrdinalsStringTermsAggregator globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.descriptCollectionStrategy(), equalTo("dense")); // Infers depth_first because the maxOrd is 0 which is less than the size aggregationBuilder.subAggregation(AggregationBuilders.cardinality("card").field("string")); - aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + aggregator = createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.collectMode, equalTo(Aggregator.SubAggCollectionMode.DEPTH_FIRST)); assertThat(globalAgg.descriptCollectionStrategy(), equalTo("remap")); aggregationBuilder.collectMode(Aggregator.SubAggCollectionMode.DEPTH_FIRST); - aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + aggregator = createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.collectMode, equalTo(Aggregator.SubAggCollectionMode.DEPTH_FIRST)); assertThat(globalAgg.descriptCollectionStrategy(), equalTo("remap")); aggregationBuilder.collectMode(Aggregator.SubAggCollectionMode.BREADTH_FIRST); - aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + aggregator = createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.collectMode, equalTo(Aggregator.SubAggCollectionMode.BREADTH_FIRST)); assertThat(globalAgg.descriptCollectionStrategy(), equalTo("dense")); aggregationBuilder.order(BucketOrder.aggregation("card", true)); - aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + aggregator = createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.descriptCollectionStrategy(), equalTo("remap")); @@ -257,51 +277,139 @@ public void testUsesGlobalOrdinalsByDefault() throws Exception { directory.close(); } - public void testSimple() throws Exception { + /** + * This test case utilizes the default implementation of GlobalOrdinalsStringTermsAggregator since collectSegmentOrds is false + */ + public void testSimpleAggregation() throws Exception { + // Fields not indexed: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_NOT_INDEXED, false, false, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, deleted documents in segment: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, true, false, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, no deleted documents in segment: will use LeafBucketCollector#termDocFreqCollector - no documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, false, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 0); + + // Fields indexed, no deleted documents, but _doc_field value present in document: + // cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, true, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + } + + /** + * This test case utilizes the LowCardinality implementation of GlobalOrdinalsStringTermsAggregator since collectSegmentOrds is true + */ + public void testSimpleAggregationLowCardinality() throws Exception { + // Fields not indexed: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_NOT_INDEXED, false, false, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, deleted documents in segment: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, true, false, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, no deleted documents in segment: will use LeafBucketCollector#termDocFreqCollector - no documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, false, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 0); + + // Fields indexed, no deleted documents, but _doc_field value present in document: + // cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, true, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + } + + /** + * This test case utilizes the MapStringTermsAggregator. + */ + public void testSimpleMapStringAggregation() throws Exception { + testSimple( + ADD_SORTED_SET_FIELD_INDEXED, + randomBoolean(), + randomBoolean(), + randomBoolean(), + TermsAggregatorFactory.ExecutionMode.MAP, + 4 + ); + } + + /** + * This is a utility method to test out string terms aggregation + * @param addFieldConsumer a function that determines how a field is added to the document + * @param includeDeletedDocumentsInSegment to include deleted documents in the segment or not + * @param collectSegmentOrds collect segment ords or not - set true to utilize LowCardinality implementation for GlobalOrdinalsStringTermsAggregator + * @param executionMode execution mode MAP or GLOBAL_ORDINALS + * @param expectedCollectCount expected number of documents visited as part of collect() invocation + */ + private void testSimple( + TriConsumer addFieldConsumer, + final boolean includeDeletedDocumentsInSegment, + final boolean includeDocCountField, + boolean collectSegmentOrds, + TermsAggregatorFactory.ExecutionMode executionMode, + final int expectedCollectCount + ) throws Exception { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { Document document = new Document(); - document.add(new SortedSetDocValuesField("string", new BytesRef("a"))); - document.add(new SortedSetDocValuesField("string", new BytesRef("b"))); + addFieldConsumer.apply(document, "string", "a"); + addFieldConsumer.apply(document, "string", "b"); indexWriter.addDocument(document); document = new Document(); - document.add(new SortedSetDocValuesField("string", new BytesRef(""))); - document.add(new SortedSetDocValuesField("string", new BytesRef("c"))); - document.add(new SortedSetDocValuesField("string", new BytesRef("a"))); + addFieldConsumer.apply(document, "string", ""); + addFieldConsumer.apply(document, "string", "c"); + addFieldConsumer.apply(document, "string", "a"); indexWriter.addDocument(document); document = new Document(); - document.add(new SortedSetDocValuesField("string", new BytesRef("b"))); - document.add(new SortedSetDocValuesField("string", new BytesRef("d"))); + addFieldConsumer.apply(document, "string", "b"); + addFieldConsumer.apply(document, "string", "d"); indexWriter.addDocument(document); document = new Document(); - document.add(new SortedSetDocValuesField("string", new BytesRef(""))); + addFieldConsumer.apply(document, "string", ""); + if (includeDocCountField) { + // Adding _doc_count to one document + document.add(new NumericDocValuesField("_doc_count", 10)); + } indexWriter.addDocument(document); + + if (includeDeletedDocumentsInSegment) { + document = new Document(); + ADD_SORTED_SET_FIELD_INDEXED.apply(document, "string", "e"); + indexWriter.addDocument(document); + indexWriter.deleteDocuments(new Term("string", "e")); + assertEquals(5, indexWriter.getDocStats().maxDoc); // deleted document still in segment + } + try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { IndexSearcher indexSearcher = newIndexSearcher(indexReader); - for (TermsAggregatorFactory.ExecutionMode executionMode : TermsAggregatorFactory.ExecutionMode.values()) { - TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("_name").userValueTypeHint( - ValueType.STRING - ).executionHint(executionMode.toString()).field("string").order(BucketOrder.key(true)); - MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); - TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); - aggregator.preCollection(); - indexSearcher.search(new MatchAllDocsQuery(), aggregator); - aggregator.postCollection(); - Terms result = reduce(aggregator); - assertEquals(5, result.getBuckets().size()); - assertEquals("", result.getBuckets().get(0).getKeyAsString()); + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("_name").userValueTypeHint(ValueType.STRING) + .executionHint(executionMode.toString()) + .field("string") + .order(BucketOrder.key(true)); + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); + + TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = collectSegmentOrds; + TermsAggregatorFactory.REMAP_GLOBAL_ORDS = false; + CountingAggregator aggregator = createCountingAggregator(aggregationBuilder, indexSearcher, false, fieldType); + + aggregator.preCollection(); + indexSearcher.search(new MatchAllDocsQuery(), aggregator); + aggregator.postCollection(); + Terms result = reduce(aggregator); + assertEquals(5, result.getBuckets().size()); + assertEquals("", result.getBuckets().get(0).getKeyAsString()); + if (includeDocCountField) { + assertEquals(11L, result.getBuckets().get(0).getDocCount()); + } else { assertEquals(2L, result.getBuckets().get(0).getDocCount()); - assertEquals("a", result.getBuckets().get(1).getKeyAsString()); - assertEquals(2L, result.getBuckets().get(1).getDocCount()); - assertEquals("b", result.getBuckets().get(2).getKeyAsString()); - assertEquals(2L, result.getBuckets().get(2).getDocCount()); - assertEquals("c", result.getBuckets().get(3).getKeyAsString()); - assertEquals(1L, result.getBuckets().get(3).getDocCount()); - assertEquals("d", result.getBuckets().get(4).getKeyAsString()); - assertEquals(1L, result.getBuckets().get(4).getDocCount()); - assertTrue(AggregationInspectionHelper.hasValue((InternalTerms) result)); } + assertEquals("a", result.getBuckets().get(1).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(1).getDocCount()); + assertEquals("b", result.getBuckets().get(2).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(2).getDocCount()); + assertEquals("c", result.getBuckets().get(3).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(3).getDocCount()); + assertEquals("d", result.getBuckets().get(4).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(4).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue((InternalTerms) result)); + + assertEquals(expectedCollectCount, aggregator.getCollectCount().get()); } } } @@ -1543,5 +1651,4 @@ private T reduce(Aggregator agg) throws IOExcept doAssertReducedMultiBucketConsumer(result, reduceBucketConsumer); return result; } - } diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index d0e01c5461c79..4bd4d406e4391 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -122,6 +122,7 @@ import java.util.concurrent.TimeUnit; import static org.opensearch.search.query.TopDocsCollectorContext.hasInfMaxScore; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -437,10 +438,16 @@ public void testTerminateAfterEarlyTermination() throws Exception { assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); + // Do not expect an exact match when terminate_after is used in conjunction to size = 0 as an optimization introduced by + // https://issues.apache.org/jira/browse/LUCENE-10620 can produce a total hit count >= terminated_after, because + // TotalHitCountCollector is used in this case as part of Weight#count() optimization context.setSize(0); QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); assertTrue(context.queryResult().terminatedEarly()); - assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat( + context.queryResult().topDocs().topDocs.totalHits.value, + allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo((long) numDocs)) + ); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); } @@ -466,7 +473,10 @@ public void testTerminateAfterEarlyTermination() throws Exception { context.parsedQuery(new ParsedQuery(bq)); QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); assertTrue(context.queryResult().terminatedEarly()); - assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat( + context.queryResult().topDocs().topDocs.totalHits.value, + allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo((long) numDocs)) + ); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); } { @@ -486,9 +496,12 @@ public void testTerminateAfterEarlyTermination() throws Exception { context.queryCollectorManagers().put(TotalHitCountCollector.class, manager); QueryPhase.executeInternal(context.withCleanQueryResult(), queryPhaseSearcher); assertTrue(context.queryResult().terminatedEarly()); - assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); + assertThat( + context.queryResult().topDocs().topDocs.totalHits.value, + allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo((long) numDocs)) + ); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); - assertThat(manager.getTotalHits(), equalTo(1)); + assertThat(manager.getTotalHits(), allOf(greaterThanOrEqualTo(1), lessThanOrEqualTo(numDocs))); } // tests with trackTotalHits and terminateAfter @@ -503,7 +516,10 @@ public void testTerminateAfterEarlyTermination() throws Exception { if (trackTotalHits == -1) { assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); } else { - assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10))); + assertThat( + context.queryResult().topDocs().topDocs.totalHits.value, + allOf(greaterThanOrEqualTo(Math.min(trackTotalHits, 10L)), lessThanOrEqualTo((long) numDocs)) + ); } assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); // The concurrent search terminates the collection when the number of hits is reached by each @@ -511,7 +527,7 @@ public void testTerminateAfterEarlyTermination() throws Exception { // slices (as the unit of concurrency). To address that, we have to use the shared global state, // much as HitsThresholdChecker does. if (executor == null) { - assertThat(manager.getTotalHits(), equalTo(10)); + assertThat(manager.getTotalHits(), allOf(greaterThanOrEqualTo(Math.min(trackTotalHits, 10)), lessThanOrEqualTo(numDocs))); } } diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index ac0447dbebf7e..4eb49ebb42241 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -34,11 +34,13 @@ import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.document.LatLonDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.StringField; import org.apache.lucene.index.CompositeReaderContext; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; @@ -62,6 +64,7 @@ import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.TriConsumer; import org.opensearch.common.TriFunction; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; @@ -121,6 +124,7 @@ import org.opensearch.search.aggregations.AggregatorFactories.Builder; import org.opensearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer; import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.search.aggregations.metrics.MetricsAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; @@ -147,6 +151,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -178,6 +183,14 @@ public abstract class AggregatorTestCase extends OpenSearchTestCase { // A list of field types that should not be tested, or are not currently supported private static List TYPE_TEST_DENYLIST; + protected static final TriConsumer ADD_SORTED_SET_FIELD_NOT_INDEXED = (document, field, value) -> document + .add(new SortedSetDocValuesField(field, new BytesRef(value))); + + protected static final TriConsumer ADD_SORTED_SET_FIELD_INDEXED = (document, field, value) -> { + document.add(new SortedSetDocValuesField(field, new BytesRef(value))); + document.add(new StringField(field, value, Field.Store.NO)); + }; + static { List denylist = new ArrayList<>(); denylist.add(ObjectMapper.CONTENT_TYPE); // Cannot aggregate objects @@ -433,7 +446,6 @@ protected QueryShardContext queryShardContextMock( CircuitBreakerService circuitBreakerService, BigArrays bigArrays ) { - return new QueryShardContext( 0, indexSettings, @@ -1096,6 +1108,89 @@ protected void doWriteTo(StreamOutput out) throws IOException { } } + /** + * Wrapper around Aggregator class + * Maintains a count for times collect() is invoked - number of documents visited + */ + protected static class CountingAggregator extends Aggregator { + private final AtomicInteger collectCounter; + public final Aggregator delegate; + + public CountingAggregator(AtomicInteger collectCounter, TermsAggregator delegate) { + this.collectCounter = collectCounter; + this.delegate = delegate; + } + + public AtomicInteger getCollectCount() { + return collectCounter; + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public String name() { + return delegate.name(); + } + + @Override + public SearchContext context() { + return delegate.context(); + } + + @Override + public Aggregator parent() { + return delegate.parent(); + } + + @Override + public Aggregator subAggregator(String name) { + return delegate.subAggregator(name); + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + return delegate.buildAggregations(owningBucketOrds); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return delegate.buildEmptyAggregation(); + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + return new LeafBucketCollector() { + @Override + public void collect(int doc, long bucket) throws IOException { + delegate.getLeafCollector(ctx).collect(doc, bucket); + collectCounter.incrementAndGet(); + } + }; + } + + @Override + public ScoreMode scoreMode() { + return delegate.scoreMode(); + } + + @Override + public void preCollection() throws IOException { + delegate.preCollection(); + } + + @Override + public void postCollection() throws IOException { + delegate.postCollection(); + } + + public void setWeight(Weight weight) { + this.delegate.setWeight(weight); + } + } + public static class InternalAggCardinality extends InternalAggregation { private final CardinalityUpperBound cardinality;