From 93917cfd8827ac51f22b276b4dea1ff34411ce73 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 5 Mar 2024 16:50:15 +0000 Subject: [PATCH] [ML] Fix `categorize_text` aggregation nested under empty buckets Previously the `categorize_text` aggregation could throw an exception if nested as a sub-aggregation of another aggregation that produced empty buckets at the end of its results. This change avoids this possibility. Fixes #105836 --- .../CategorizeTextAggregator.java | 3 +- .../CategorizeTextAggregatorTests.java | 84 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 520d554379cfc..cedaced0f57ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -113,7 +113,8 @@ protected void doClose() { public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { Bucket[][] topBucketsPerOrd = new Bucket[ordsToCollect.length][]; for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - final TokenListCategorizer categorizer = categorizers.get(ordsToCollect[ordIdx]); + final long ord = ordsToCollect[ordIdx]; + final TokenListCategorizer categorizer = (ord < categorizers.size()) ? categorizers.get(ord) : null; if (categorizer == null) { topBucketsPerOrd[ordIdx] = new Bucket[0]; continue; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java index cb5b98af29d57..29f298894477a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -322,6 +322,90 @@ public void testCategorizationWithSubAggsManyDocs() throws Exception { ); } + public void testCategorizationAsSubAggWithExtendedBounds() throws Exception { + // Test with more buckets than we have data for (via extended bounds in the histogram config). + // This will confirm that we don't try to read beyond the end of arrays of categorizers. + int numHistoBuckets = 50; + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(1) + .extendedBounds(0, numHistoBuckets - 1) + .subAggregation(new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME)); + testCase(CategorizeTextAggregatorTests::writeTestDocs, (InternalHistogram histo) -> { + assertThat(histo.getBuckets(), hasSize(numHistoBuckets)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(2L)); + assertThat(histo.getBuckets().get(0).getKeyAsString(), equalTo("0.0")); + InternalCategorizationAggregation categorizationAggregation = histo.getBuckets().get(0).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(0).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(84)); + assertThat( + categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), + equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*?") + ); + assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(1).getKeyAsString(), equalTo("1.0")); + categorizationAggregation = histo.getBuckets().get(1).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(2).getKeyAsString(), equalTo("2.0")); + categorizationAggregation = histo.getBuckets().get(2).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + assertThat(histo.getBuckets().get(3).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(3).getKeyAsString(), equalTo("3.0")); + categorizationAggregation = histo.getBuckets().get(3).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + assertThat(histo.getBuckets().get(4).getDocCount(), equalTo(2L)); + assertThat(histo.getBuckets().get(4).getKeyAsString(), equalTo("4.0")); + categorizationAggregation = histo.getBuckets().get(4).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(0).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(84)); + assertThat( + categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), + equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*?") + ); + assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + assertThat(histo.getBuckets().get(5).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(5).getKeyAsString(), equalTo("5.0")); + categorizationAggregation = histo.getBuckets().get(5).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15)); + assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?")); + for (int bucket = 6; bucket < numHistoBuckets; ++bucket) { + assertThat(histo.getBuckets().get(bucket).getDocCount(), equalTo(0L)); + } + }, + new AggTestConfig( + aggBuilder, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME, randomBoolean()), + longField(NUMERIC_FIELD_NAME) + ) + ); + } + private static void writeTestDocs(RandomIndexWriter w) throws IOException { w.addDocument( Arrays.asList(