From c2e88ecd01804f261a06a9335c4280af198be1ee Mon Sep 17 00:00:00 2001 From: Sandesh Kumar Date: Fri, 8 Mar 2024 15:48:01 -0800 Subject: [PATCH] Test cases improvement Signed-off-by: Sandesh Kumar --- .../GlobalOrdinalsStringTermsAggregator.java | 10 +- .../terms/KeywordTermsAggregatorTests.java | 75 ++---- .../bucket/terms/TermsAggregatorTests.java | 213 ++++++++++-------- .../aggregations/AggregatorTestCase.java | 131 +++++++---- 4 files changed, 243 insertions(+), 186 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 2a83710117e74..15e538f01e632 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -40,6 +40,7 @@ import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; @@ -162,7 +163,7 @@ public void setWeight(Weight weight) { @param ctx The LeafReaderContext to collect terms from @param globalOrds The SortedSetDocValues for the field's ordinals @param ordCountConsumer A consumer to accept collected term frequencies - @return A no-operation LeafBucketCollector implementation, since collection is complete + @return A LeafBucketCollector implementation with collection termination, since collection is complete @throws IOException If an I/O error occurs during reading */ LeafBucketCollector termDocFreqCollector( @@ -217,7 +218,12 @@ LeafBucketCollector termDocFreqCollector( ordinalTerm = globalOrdinalTermsEnum.next(); } } - return LeafBucketCollector.NO_OP_COLLECTOR; + return new LeafBucketCollector() { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + throw new CollectionTerminatedException(); + } + }; } @Override diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java index 5d1e02116f189..753644dce81d5 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/KeywordTermsAggregatorTests.java @@ -32,7 +32,6 @@ package org.opensearch.search.aggregations.bucket.terms; import org.apache.lucene.document.Document; -import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; @@ -41,11 +40,9 @@ import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; -import org.apache.lucene.util.BytesRef; import org.opensearch.common.TriConsumer; import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.support.ValueType; @@ -59,6 +56,8 @@ public class KeywordTermsAggregatorTests extends AggregatorTestCase { private static final String KEYWORD_FIELD = "keyword"; + private static final Consumer CONFIGURE_KEYWORD_FIELD = agg -> agg.field(KEYWORD_FIELD); + private static final List dataset; static { List d = new ArrayList<>(45); @@ -70,7 +69,7 @@ public class KeywordTermsAggregatorTests extends AggregatorTestCase { dataset = d; } - private static Consumer VERIFY_MATCH_ALL_DOCS = agg -> { + private static final Consumer VERIFY_MATCH_ALL_DOCS = agg -> { assertEquals(9, agg.getBuckets().size()); for (int i = 0; i < 9; i++) { StringTerms.Bucket bucket = (StringTerms.Bucket) agg.getBuckets().get(i); @@ -79,77 +78,49 @@ public class KeywordTermsAggregatorTests extends AggregatorTestCase { } }; - private static Query MATCH_ALL_DOCS_QUERY = new MatchAllDocsQuery(); + private static final Consumer VERIFY_MATCH_NO_DOCS = agg -> { assertEquals(0, agg.getBuckets().size()); }; + + private static final Query MATCH_ALL_DOCS_QUERY = new MatchAllDocsQuery(); - private static Query MATCH_NO_DOCS_QUERY = new MatchNoDocsQuery(); + private static final Query MATCH_NO_DOCS_QUERY = new MatchNoDocsQuery(); public void testMatchNoDocs() throws IOException { testSearchCase( - ADD_SORTED_FIELD_NO_STORE, + ADD_SORTED_SET_FIELD_NOT_INDEXED, MATCH_NO_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - agg -> assertEquals(0, agg.getBuckets().size()), - null, // without type hint - DEFAULT_POST_COLLECTION + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_NO_DOCS, + null // without type hint ); testSearchCase( - ADD_SORTED_FIELD_NO_STORE, + ADD_SORTED_SET_FIELD_NOT_INDEXED, MATCH_NO_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - agg -> assertEquals(0, agg.getBuckets().size()), - ValueType.STRING, // with type hint - DEFAULT_POST_COLLECTION + CONFIGURE_KEYWORD_FIELD, + VERIFY_MATCH_NO_DOCS, + ValueType.STRING // with type hint ); } public void testMatchAllDocs() throws IOException { testSearchCase( - ADD_SORTED_FIELD_NO_STORE, - MATCH_ALL_DOCS_QUERY, - dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - VERIFY_MATCH_ALL_DOCS, - null, // without type hint - DEFAULT_POST_COLLECTION - ); - - testSearchCase( - ADD_SORTED_FIELD_NO_STORE, - MATCH_ALL_DOCS_QUERY, - dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), - VERIFY_MATCH_ALL_DOCS, - ValueType.STRING, // with type hint - DEFAULT_POST_COLLECTION - ); - } - - public void testMatchAllDocsWithStoredValues() throws IOException { - // aggregator.postCollection() is not required when LeafBucketCollector#termDocFreqCollector optimization is used, - // therefore using NOOP_POST_COLLECTION - // This also verifies that the bucket count is completed without running postCollection() - - testSearchCase( - ADD_SORTED_FIELD_STORE, + ADD_SORTED_SET_FIELD_NOT_INDEXED, MATCH_ALL_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), + CONFIGURE_KEYWORD_FIELD, VERIFY_MATCH_ALL_DOCS, - null, // without type hint - NOOP_POST_COLLECTION + null // without type hint ); testSearchCase( - ADD_SORTED_FIELD_STORE, + ADD_SORTED_SET_FIELD_NOT_INDEXED, MATCH_ALL_DOCS_QUERY, dataset, - aggregation -> aggregation.field(KEYWORD_FIELD), + CONFIGURE_KEYWORD_FIELD, VERIFY_MATCH_ALL_DOCS, - ValueType.STRING, // with type hint - NOOP_POST_COLLECTION + ValueType.STRING // with type hint ); } @@ -159,15 +130,13 @@ private void testSearchCase( List dataset, Consumer configure, Consumer verify, - ValueType valueType, - Consumer postCollectionConsumer + ValueType valueType ) throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { Document document = new Document(); for (String value : dataset) { addField.apply(document, KEYWORD_FIELD, value); - document.add(new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef(value))); indexWriter.addDocument(document); document.clear(); } diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java index 93939657b6981..b8a08068f76a3 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregatorTests.java @@ -44,6 +44,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; @@ -121,6 +122,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; @@ -137,9 +139,6 @@ import static org.mockito.Mockito.when; public class TermsAggregatorTests extends AggregatorTestCase { - - private boolean randomizeAggregatorImpl = true; - // Constants for a script that returns a string private static final String STRING_SCRIPT_NAME = "string_script"; private static final String STRING_SCRIPT_OUTPUT = "Orange"; @@ -172,9 +171,22 @@ protected ScriptService getMockScriptService() { return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); } + protected CountingAggregator createCountingAggregator( + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + boolean randomizeAggregatorImpl, + MappedFieldType... fieldTypes + ) throws IOException { + return new CountingAggregator( + new AtomicInteger(), + createAggregator(aggregationBuilder, indexSearcher, randomizeAggregatorImpl, fieldTypes) + ); + } + protected A createAggregator( AggregationBuilder aggregationBuilder, IndexSearcher indexSearcher, + boolean randomizeAggregatorImpl, MappedFieldType... fieldTypes ) throws IOException { try { @@ -189,6 +201,14 @@ protected A createAggregator( } } + protected A createAggregator( + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + MappedFieldType... fieldTypes + ) throws IOException { + return createAggregator(aggregationBuilder, indexSearcher, true, fieldTypes); + } + @Override protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) { return new TermsAggregationBuilder("foo").field(fieldName); @@ -208,8 +228,6 @@ protected List getSupportedValuesSourceTypes() { } public void testUsesGlobalOrdinalsByDefault() throws Exception { - randomizeAggregatorImpl = false; - Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); indexWriter.close(); @@ -221,7 +239,7 @@ public void testUsesGlobalOrdinalsByDefault() throws Exception { .field("string"); MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); - TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); + TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, false, fieldType); assertThat(aggregator, instanceOf(GlobalOrdinalsStringTermsAggregator.class)); GlobalOrdinalsStringTermsAggregator globalAgg = (GlobalOrdinalsStringTermsAggregator) aggregator; assertThat(globalAgg.descriptCollectionStrategy(), equalTo("dense")); @@ -259,30 +277,55 @@ public void testUsesGlobalOrdinalsByDefault() throws Exception { } /** - * This test case utilizes the low cardinality implementation of GlobalOrdinalsStringTermsAggregator. - * In this case, the segment terms will not get initialized and will run without LeafBucketCollector#termDocFreqCollector optimization + * This test case utilizes the default implementation of GlobalOrdinalsStringTermsAggregator. */ public void testSimpleAggregation() throws Exception { - testSimple(ADD_SORTED_FIELD_NO_STORE, DEFAULT_POST_COLLECTION); + // Fields not indexed: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_NOT_INDEXED, false, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, deleted documents in segment: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, true, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, no deleted documents in segment: will use LeafBucketCollector#termDocFreqCollector - no documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, false, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 0); + } + + /** + * This test case utilizes the LowCardinality implementation of GlobalOrdinalsStringTermsAggregator. + */ + public void testSimpleAggregationLowCardinality() throws Exception { + // Fields not indexed: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_NOT_INDEXED, false, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, deleted documents in segment: cannot use LeafBucketCollector#termDocFreqCollector - all documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, true, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 4); + + // Fields indexed, no deleted documents in segment: will use LeafBucketCollector#termDocFreqCollector - no documents are visited + testSimple(ADD_SORTED_SET_FIELD_INDEXED, false, true, TermsAggregatorFactory.ExecutionMode.GLOBAL_ORDINALS, 0); } /** - * This test case utilizes the low cardinality implementation of GlobalOrdinalsStringTermsAggregator. - * In this case, the segment terms will get initialized and will use LeafBucketCollector#termDocFreqCollector optimization + * This test case utilizes the MapStringTermsAggregator. */ - public void testSimpleAggregationWithStoredValues() throws Exception { - // aggregator.postCollection() is not required when LeafBucketCollector#termDocFreqCollector optimization is used, - // therefore using NOOP_POST_COLLECTION - // This also verifies that the bucket count is completed without running postCollection() - testSimple(ADD_SORTED_FIELD_STORE, NOOP_POST_COLLECTION); + public void testSimpleMapStringAggregation() throws Exception { + testSimple(ADD_SORTED_SET_FIELD_INDEXED, randomBoolean(), randomBoolean(), TermsAggregatorFactory.ExecutionMode.MAP, 4); } /** * This is a utility method to test out string terms aggregation * @param addFieldConsumer a function that determines how a field is added to the document + * @param includeDeletedDocumentsInSegment to include deleted documents in the segment or not + * @param collectSegmentOrds collect segment ords or not - set true to utilize LowCardinality implementation for GlobalOrdinalsStringTermsAggregator + * @param executionMode execution mode MAP or GLOBAL_ORDINALS + * @param expectedCollectCount expected number of documents visited as part of collect() invocation */ - private void testSimple(TriConsumer addFieldConsumer, Consumer postCollectionConsumer) - throws Exception { + private void testSimple( + TriConsumer addFieldConsumer, + final boolean includeDeletedDocumentsInSegment, + boolean collectSegmentOrds, + TermsAggregatorFactory.ExecutionMode executionMode, + final int expectedCollectCount + ) throws Exception { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { Document document = new Document(); @@ -301,94 +344,84 @@ private void testSimple(TriConsumer addFieldConsumer, document = new Document(); addFieldConsumer.apply(document, "string", ""); indexWriter.addDocument(document); + + if (includeDeletedDocumentsInSegment) { + document = new Document(); + ADD_SORTED_SET_FIELD_INDEXED.apply(document, "string", "e"); + indexWriter.addDocument(document); + indexWriter.deleteDocuments(new Term("string", "e")); + assertEquals(5, indexWriter.getDocStats().maxDoc); + } + assertEquals(4, indexWriter.getDocStats().numDocs); + try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { IndexSearcher indexSearcher = newIndexSearcher(indexReader); - for (TermsAggregatorFactory.ExecutionMode executionMode : TermsAggregatorFactory.ExecutionMode.values()) { - TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("_name").userValueTypeHint( - ValueType.STRING - ).executionHint(executionMode.toString()).field("string").order(BucketOrder.key(true)); - MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); - TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); - aggregator.preCollection(); - indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); - Terms result = reduce(aggregator); - assertEquals(5, result.getBuckets().size()); - assertEquals("", result.getBuckets().get(0).getKeyAsString()); - assertEquals(2L, result.getBuckets().get(0).getDocCount()); - assertEquals("a", result.getBuckets().get(1).getKeyAsString()); - assertEquals(2L, result.getBuckets().get(1).getDocCount()); - assertEquals("b", result.getBuckets().get(2).getKeyAsString()); - assertEquals(2L, result.getBuckets().get(2).getDocCount()); - assertEquals("c", result.getBuckets().get(3).getKeyAsString()); - assertEquals(1L, result.getBuckets().get(3).getDocCount()); - assertEquals("d", result.getBuckets().get(4).getKeyAsString()); - assertEquals(1L, result.getBuckets().get(4).getDocCount()); - assertTrue(AggregationInspectionHelper.hasValue((InternalTerms) result)); - } + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("_name").userValueTypeHint(ValueType.STRING) + .executionHint(executionMode.toString()) + .field("string") + .order(BucketOrder.key(true)); + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string"); + + TermsAggregatorFactory.COLLECT_SEGMENT_ORDS = collectSegmentOrds; + TermsAggregatorFactory.REMAP_GLOBAL_ORDS = false; + CountingAggregator aggregator = createCountingAggregator(aggregationBuilder, indexSearcher, false, fieldType); + + aggregator.preCollection(); + indexSearcher.search(new MatchAllDocsQuery(), aggregator); + aggregator.postCollection(); + Terms result = reduce(aggregator); + assertEquals(5, result.getBuckets().size()); + assertEquals("", result.getBuckets().get(0).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(0).getDocCount()); + assertEquals("a", result.getBuckets().get(1).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(1).getDocCount()); + assertEquals("b", result.getBuckets().get(2).getKeyAsString()); + assertEquals(2L, result.getBuckets().get(2).getDocCount()); + assertEquals("c", result.getBuckets().get(3).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(3).getDocCount()); + assertEquals("d", result.getBuckets().get(4).getKeyAsString()); + assertEquals(1L, result.getBuckets().get(4).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue((InternalTerms) result)); + + assertEquals(expectedCollectCount, aggregator.getCollectCount().get()); } } } } - /** - * This test case utilizes the default implementation of GlobalOrdinalsStringTermsAggregator. - * In this case, the segment terms will not get initialized and will run without LeafBucketCollector#termDocFreqCollector optimization - */ public void testStringIncludeExclude() throws Exception { - testStringIncludeExclude( - (document, field, value) -> document.add(new SortedSetDocValuesField(field, new BytesRef(value))), - DEFAULT_POST_COLLECTION - ); - } - - /** - * This test case utilizes the default implementation of GlobalOrdinalsStringTermsAggregator. - * In this case, the segment terms will get initialized and will use LeafBucketCollector#termDocFreqCollector optimization - */ - public void testStringIncludeExcludeWithStoredValues() throws Exception { - // aggregator.postCollection() is not required when LeafBucketCollector#termDocFreqCollector optimization is used - // This also verifies that the bucket count is completed without running postCollection() - testStringIncludeExclude((document, field, value) -> { - document.add(new SortedSetDocValuesField(field, new BytesRef(value))); - document.add(new StringField(field, value, Field.Store.NO)); - }, NOOP_POST_COLLECTION); - } - - private void testStringIncludeExclude(TriConsumer addField, Consumer postCollectionConsumer) - throws Exception { try (Directory directory = newDirectory()) { try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { Document document = new Document(); - addField.apply(document, "mv_field", "val000"); - addField.apply(document, "mv_field", "val001"); - addField.apply(document, "sv_field", "val001"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val000"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val001"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val001"))); indexWriter.addDocument(document); document = new Document(); - addField.apply(document, "mv_field", "val002"); - addField.apply(document, "mv_field", "val003"); - addField.apply(document, "sv_field", "val003"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val002"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val003"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val003"))); indexWriter.addDocument(document); document = new Document(); - addField.apply(document, "mv_field", "val004"); - addField.apply(document, "mv_field", "val005"); - addField.apply(document, "sv_field", "val005"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val004"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val005"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val005"))); indexWriter.addDocument(document); document = new Document(); - addField.apply(document, "mv_field", "val006"); - addField.apply(document, "mv_field", "val007"); - addField.apply(document, "sv_field", "val007"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val006"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val007"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val007"))); indexWriter.addDocument(document); document = new Document(); - addField.apply(document, "mv_field", "val008"); - addField.apply(document, "mv_field", "val009"); - addField.apply(document, "sv_field", "val009"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val008"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val009"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val009"))); indexWriter.addDocument(document); document = new Document(); - addField.apply(document, "mv_field", "val010"); - addField.apply(document, "mv_field", "val011"); - addField.apply(document, "sv_field", "val011"); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val010"))); + document.add(new SortedSetDocValuesField("mv_field", new BytesRef("val011"))); + document.add(new SortedDocValuesField("sv_field", new BytesRef("val011"))); indexWriter.addDocument(document); try (IndexReader indexReader = maybeWrapReaderEs(indexWriter.getReader())) { IndexSearcher indexSearcher = newIndexSearcher(indexReader); @@ -405,7 +438,7 @@ private void testStringIncludeExclude(TriConsumer addF TermsAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); Terms result = reduce(aggregator); assertEquals(10, result.getBuckets().size()); assertEquals("val000", result.getBuckets().get(0).getKeyAsString()); @@ -440,7 +473,7 @@ private void testStringIncludeExclude(TriConsumer addF aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType2); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); result = reduce(aggregator); assertEquals(5, result.getBuckets().size()); assertEquals("val001", result.getBuckets().get(0).getKeyAsString()); @@ -464,7 +497,7 @@ private void testStringIncludeExclude(TriConsumer addF aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); result = reduce(aggregator); assertEquals(8, result.getBuckets().size()); assertEquals("val002", result.getBuckets().get(0).getKeyAsString()); @@ -493,7 +526,7 @@ private void testStringIncludeExclude(TriConsumer addF aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); result = reduce(aggregator); assertEquals(2, result.getBuckets().size()); assertEquals("val010", result.getBuckets().get(0).getKeyAsString()); @@ -510,7 +543,7 @@ private void testStringIncludeExclude(TriConsumer addF aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); result = reduce(aggregator); assertEquals(2, result.getBuckets().size()); assertEquals("val000", result.getBuckets().get(0).getKeyAsString()); @@ -542,7 +575,7 @@ private void testStringIncludeExclude(TriConsumer addF aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType); aggregator.preCollection(); indexSearcher.search(new MatchAllDocsQuery(), aggregator); - postCollectionConsumer.accept(aggregator); + aggregator.postCollection(); result = reduce(aggregator); assertEquals(2, result.getBuckets().size()); assertEquals("val000", result.getBuckets().get(0).getKeyAsString()); diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 9fdae80bd1ada..4eb49ebb42241 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -124,6 +124,7 @@ import org.opensearch.search.aggregations.AggregatorFactories.Builder; import org.opensearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer; import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.search.aggregations.metrics.MetricsAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; @@ -150,6 +151,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -181,23 +183,10 @@ public abstract class AggregatorTestCase extends OpenSearchTestCase { // A list of field types that should not be tested, or are not currently supported private static List TYPE_TEST_DENYLIST; - protected static final Consumer DEFAULT_POST_COLLECTION = termsAggregator -> { - try { - termsAggregator.postCollection(); - } catch (IOException e) { - throw new RuntimeException(e); - } - }; - - // aggregator.postCollection() is not required when LeafBucketCollector#termDocFreqCollector optimization is used. - // using NOOP_POST_COLLECTION_CONSUMER ensures that the bucket count in aggregation is completed before/without running postCollection() - protected static final Consumer NOOP_POST_COLLECTION = termsAggregator -> {}; + protected static final TriConsumer ADD_SORTED_SET_FIELD_NOT_INDEXED = (document, field, value) -> document + .add(new SortedSetDocValuesField(field, new BytesRef(value))); - protected static final TriConsumer ADD_SORTED_FIELD_NO_STORE = (document, field, value) -> document.add( - new SortedSetDocValuesField(field, new BytesRef(value)) - ); - - protected static final TriConsumer ADD_SORTED_FIELD_STORE = (document, field, value) -> { + protected static final TriConsumer ADD_SORTED_SET_FIELD_INDEXED = (document, field, value) -> { document.add(new SortedSetDocValuesField(field, new BytesRef(value))); document.add(new StringField(field, value, Field.Store.NO)); }; @@ -457,7 +446,6 @@ protected QueryShardContext queryShardContextMock( CircuitBreakerService circuitBreakerService, BigArrays bigArrays ) { - return new QueryShardContext( 0, indexSettings, @@ -508,16 +496,6 @@ protected A searchAndReduc return searchAndReduce(createIndexSettings(), searcher, query, builder, DEFAULT_MAX_BUCKETS, fieldTypes); } - protected A searchAndReduce( - IndexSearcher searcher, - Query query, - AggregationBuilder builder, - Consumer postCollectionConsumer, - MappedFieldType... fieldTypes - ) throws IOException { - return searchAndReduce(createIndexSettings(), searcher, query, builder, DEFAULT_MAX_BUCKETS, postCollectionConsumer, fieldTypes); - } - protected A searchAndReduce( IndexSettings indexSettings, IndexSearcher searcher, @@ -538,17 +516,6 @@ protected A searchAndReduc return searchAndReduce(createIndexSettings(), searcher, query, builder, maxBucket, fieldTypes); } - protected A searchAndReduce( - IndexSettings indexSettings, - IndexSearcher searcher, - Query query, - AggregationBuilder builder, - int maxBucket, - MappedFieldType... fieldTypes - ) throws IOException { - return searchAndReduce(indexSettings, searcher, query, builder, maxBucket, DEFAULT_POST_COLLECTION, fieldTypes); - } - /** * Collects all documents that match the provided query {@link Query} and * returns the reduced {@link InternalAggregation}. @@ -563,7 +530,6 @@ protected A searchAndReduc Query query, AggregationBuilder builder, int maxBucket, - Consumer postCollectionConsumer, MappedFieldType... fieldTypes ) throws IOException { final IndexReaderContext ctx = searcher.getTopReaderContext(); @@ -594,13 +560,13 @@ protected A searchAndReduc a.preCollection(); Weight weight = subSearcher.createWeight(rewritten, ScoreMode.COMPLETE, 1f); subSearcher.search(weight, a); - postCollectionConsumer.accept(a); + a.postCollection(); aggs.add(a.buildTopLevel()); } } else { root.preCollection(); searcher.search(rewritten, root); - postCollectionConsumer.accept(root); + root.postCollection(); aggs.add(root.buildTopLevel()); } @@ -1142,6 +1108,89 @@ protected void doWriteTo(StreamOutput out) throws IOException { } } + /** + * Wrapper around Aggregator class + * Maintains a count for times collect() is invoked - number of documents visited + */ + protected static class CountingAggregator extends Aggregator { + private final AtomicInteger collectCounter; + public final Aggregator delegate; + + public CountingAggregator(AtomicInteger collectCounter, TermsAggregator delegate) { + this.collectCounter = collectCounter; + this.delegate = delegate; + } + + public AtomicInteger getCollectCount() { + return collectCounter; + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public String name() { + return delegate.name(); + } + + @Override + public SearchContext context() { + return delegate.context(); + } + + @Override + public Aggregator parent() { + return delegate.parent(); + } + + @Override + public Aggregator subAggregator(String name) { + return delegate.subAggregator(name); + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + return delegate.buildAggregations(owningBucketOrds); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return delegate.buildEmptyAggregation(); + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + return new LeafBucketCollector() { + @Override + public void collect(int doc, long bucket) throws IOException { + delegate.getLeafCollector(ctx).collect(doc, bucket); + collectCounter.incrementAndGet(); + } + }; + } + + @Override + public ScoreMode scoreMode() { + return delegate.scoreMode(); + } + + @Override + public void preCollection() throws IOException { + delegate.preCollection(); + } + + @Override + public void postCollection() throws IOException { + delegate.postCollection(); + } + + public void setWeight(Weight weight) { + this.delegate.setWeight(weight); + } + } + public static class InternalAggCardinality extends InternalAggregation { private final CardinalityUpperBound cardinality;