From 6cc8da561512890c437ef5c605ad99b575ff858b Mon Sep 17 00:00:00 2001 From: Jay Deng Date: Thu, 20 Jul 2023 11:46:37 -0700 Subject: [PATCH] Change InternalSignificantTerms to only sum shard level counts in final reduce (#8735) Signed-off-by: Jay Deng --- CHANGELOG.md | 1 + .../SignificantTermsSignificanceScoreIT.java | 30 +++++++++++++++ .../search/DefaultSearchContext.java | 6 ++- .../AggregationCollectorManager.java | 2 +- .../aggregations/InternalAggregation.java | 11 ++++++ .../terms/InternalSignificantTerms.java | 16 +++++++- .../internal/FilteredSearchContext.java | 4 +- .../search/internal/SearchContext.java | 2 +- .../SharedSignificantTermsTestMethods.java | 37 +++++++++++++++++++ .../opensearch/test/TestSearchContext.java | 2 +- 10 files changed, 102 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3eda3f852a24..dceec5c62d94f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Perform aggregation postCollection in ContextIndexSearcher after searching leaves ([#8303](https://github.com/opensearch-project/OpenSearch/pull/8303)) - Make Span exporter configurable ([#8620](https://github.com/opensearch-project/OpenSearch/issues/8620)) +- Change InternalSignificantTerms to sum shard-level superset counts only in final reduce ([#8735](https://github.com/opensearch-project/OpenSearch/pull/8735)) ### Deprecated diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java index 33a1f3a4b974a..43d49dc0bfd60 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/bucket/SignificantTermsSignificanceScoreIT.java @@ -31,6 +31,7 @@ package org.opensearch.search.aggregations.bucket; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchRequestBuilder; import org.opensearch.action.search.SearchResponse; @@ -42,6 +43,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.script.MockScriptPlugin; @@ -210,6 +212,34 @@ public void testXContentResponse() throws Exception { } + public void testConsistencyWithDifferentShardCounts() throws Exception { + // The purpose of this test is to validate that the aggregation results do not change with shard count. + // bg_count for significant term agg is summed up across shards, so in this test we compare a 1 shard and 2 shard search request + String type = randomBoolean() ? "text" : "long"; + String settings = "{\"index.number_of_shards\": 1, \"index.number_of_replicas\": 0}"; + SharedSignificantTermsTestMethods.index01Docs(type, settings, this); + + SearchRequestBuilder request = client().prepareSearch(INDEX_NAME) + .setQuery(new TermQueryBuilder(CLASS_FIELD, "0")) + .addAggregation((significantTerms("sig_terms").field(TEXT_FIELD))); + + SearchResponse response1 = request.get(); + + assertAcked(client().admin().indices().delete(new DeleteIndexRequest("*")).get()); + + settings = "{\"index.number_of_shards\": 2, \"index.number_of_replicas\": 0}"; + // We use a custom routing strategy here to ensure that each shard will have at least 1 bucket. + // If there are no buckets collected for a shard, then that will affect the scoring and bg_count and our assertion will not be + // valid. + SharedSignificantTermsTestMethods.index01DocsWithRouting(type, settings, this); + SearchResponse response2 = request.get(); + + assertEquals( + response1.getAggregations().asMap().get("sig_terms").toString(), + response2.getAggregations().asMap().get("sig_terms").toString() + ); + } + public void testPopularTermManyDeletedDocs() throws Exception { String settings = "{\"index.number_of_shards\": 1, \"index.number_of_replicas\": 0}"; assertAcked( diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index ee29d6bfe2b62..f377a5e315e1b 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -917,8 +917,10 @@ public ReaderContext readerContext() { } @Override - public InternalAggregation.ReduceContext partial() { - return requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction(); + public InternalAggregation.ReduceContext partialOnShard() { + InternalAggregation.ReduceContext rc = requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction(); + rc.setSliceLevel(isConcurrentSegmentSearchEnabled()); + return rc; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java index 0b36fc8b0cc5a..1f60ff6503ca8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java @@ -70,7 +70,7 @@ public ReduceableSearchResult reduce(Collection collectors) throws IO // using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all // documents are collected across shards for an aggregation return new AggregationReduceableSearchResult( - InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partial()) + InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard()) ); } else { return new AggregationReduceableSearchResult(internalAggregations); diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java index b7577fb647be5..c6d86316fa230 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java @@ -89,6 +89,8 @@ public static class ReduceContext { private final ScriptService scriptService; private final IntConsumer multiBucketConsumer; private final PipelineTree pipelineTreeRoot; + + private boolean isSliceLevel; /** * Supplies the pipelines when the result of the reduce is serialized * to node versions that need pipeline aggregators to be serialized @@ -138,6 +140,7 @@ private ReduceContext( this.multiBucketConsumer = multiBucketConsumer; this.pipelineTreeRoot = pipelineTreeRoot; this.pipelineTreeForBwcSerialization = pipelineTreeForBwcSerialization; + this.isSliceLevel = false; } /** @@ -149,6 +152,14 @@ public boolean isFinalReduce() { return pipelineTreeRoot != null; } + public void setSliceLevel(boolean sliceLevel) { + this.isSliceLevel = sliceLevel; + } + + public boolean isSliceLevel() { + return this.isSliceLevel; + } + public BigArrays bigArrays() { return bigArrays; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 6104d2193f6cd..84d148199a7f9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -232,7 +232,13 @@ public InternalAggregation reduce(List aggregations, Reduce @SuppressWarnings("unchecked") InternalSignificantTerms terms = (InternalSignificantTerms) aggregation; globalSubsetSize += terms.getSubsetSize(); - globalSupersetSize += terms.getSupersetSize(); + // supersetSize is a shard level count, if we sum it across slices we would produce num_slices_with_bucket * supersetSize where + // num_slices_with_bucket is the number of segment slices that have collected a bucket for the key + if (reduceContext.isSliceLevel()) { + globalSupersetSize = terms.getSupersetSize(); + } else { + globalSupersetSize += terms.getSupersetSize(); + } } Map> buckets = new HashMap<>(); for (InternalAggregation aggregation : aggregations) { @@ -291,7 +297,13 @@ protected B reduceBucket(List buckets, ReduceContext context) { List aggregationsList = new ArrayList<>(buckets.size()); for (B bucket : buckets) { subsetDf += bucket.subsetDf; - supersetDf += bucket.supersetDf; + // supersetDf is a shard level count, if we sum it across slices we would produce num_slices_with_bucket * supersetSize where + // num_slices_with_bucket is the number of segment slices that have collected a bucket for the key + if (context.isSliceLevel()) { + supersetDf = bucket.supersetDf; + } else { + supersetDf += bucket.supersetDf; + } aggregationsList.add(bucket.aggregations); } InternalAggregations aggs = InternalAggregations.reduce(aggregationsList, context); diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index 790d2ed5ee4b7..bb990e69e7722 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -546,8 +546,8 @@ public ReaderContext readerContext() { } @Override - public InternalAggregation.ReduceContext partial() { - return in.partial(); + public InternalAggregation.ReduceContext partialOnShard() { + return in.partialOnShard(); } @Override diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index fd02ba2ba12bb..c2f81b0d4b8b5 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -465,7 +465,7 @@ public String toString() { public abstract ReaderContext readerContext(); - public abstract InternalAggregation.ReduceContext partial(); + public abstract InternalAggregation.ReduceContext partialOnShard(); // processor used for bucket collectors public abstract void setBucketCollectorProcessor(BucketCollectorProcessor bucketCollectorProcessor); diff --git a/server/src/test/java/org/opensearch/test/search/aggregations/bucket/SharedSignificantTermsTestMethods.java b/server/src/test/java/org/opensearch/test/search/aggregations/bucket/SharedSignificantTermsTestMethods.java index 1092bc4f8f47c..34774758dcd0e 100644 --- a/server/src/test/java/org/opensearch/test/search/aggregations/bucket/SharedSignificantTermsTestMethods.java +++ b/server/src/test/java/org/opensearch/test/search/aggregations/bucket/SharedSignificantTermsTestMethods.java @@ -113,4 +113,41 @@ public static void index01Docs(String type, String settings, OpenSearchIntegTest indexRequestBuilderList.add(client().prepareIndex(INDEX_NAME).setId("7").setSource(TEXT_FIELD, "0", CLASS_FIELD, "0")); testCase.indexRandom(true, false, indexRequestBuilderList); } + + public static void index01DocsWithRouting(String type, String settings, OpenSearchIntegTestCase testCase) throws ExecutionException, + InterruptedException { + String textMappings = "type=" + type; + if (type.equals("text")) { + textMappings += ",fielddata=true"; + } + assertAcked( + testCase.prepareCreate(INDEX_NAME) + .setSettings(settings, XContentType.JSON) + .setMapping("text", textMappings, CLASS_FIELD, "type=keyword") + ); + String[] gb = { "0", "1" }; + List indexRequestBuilderList = new ArrayList<>(); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("1").setSource(TEXT_FIELD, "1", CLASS_FIELD, "1").setRouting("0") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("2").setSource(TEXT_FIELD, "1", CLASS_FIELD, "1").setRouting("0") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("3").setSource(TEXT_FIELD, "0", CLASS_FIELD, "0").setRouting("0") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("4").setSource(TEXT_FIELD, "0", CLASS_FIELD, "0").setRouting("1") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("5").setSource(TEXT_FIELD, gb, CLASS_FIELD, "1").setRouting("1") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("6").setSource(TEXT_FIELD, gb, CLASS_FIELD, "0").setRouting("0") + ); + indexRequestBuilderList.add( + client().prepareIndex(INDEX_NAME).setId("7").setSource(TEXT_FIELD, "0", CLASS_FIELD, "0").setRouting("0") + ); + testCase.indexRandom(true, false, indexRequestBuilderList); + } } diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 694f88d944f71..4e44791e77566 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -659,7 +659,7 @@ public ReaderContext readerContext() { } @Override - public InternalAggregation.ReduceContext partial() { + public InternalAggregation.ReduceContext partialOnShard() { return InternalAggregationTestCase.emptyReduceContextBuilder().forPartialReduction(); }