Skip to content

Commit

Permalink
[ML] Fix categorize_text aggregation nested under empty buckets
Browse files Browse the repository at this point in the history
Previously the `categorize_text` aggregation could throw an
exception if nested as a sub-aggregation of another aggregation
that produced empty buckets at the end of its results. This
change avoids this possibility.

Fixes elastic#105836
  • Loading branch information
droberts195 committed Mar 5, 2024
1 parent f1035bb commit 93917cf
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ protected void doClose() {
public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException {
Bucket[][] topBucketsPerOrd = new Bucket[ordsToCollect.length][];
for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) {
final TokenListCategorizer categorizer = categorizers.get(ordsToCollect[ordIdx]);
final long ord = ordsToCollect[ordIdx];
final TokenListCategorizer categorizer = (ord < categorizers.size()) ? categorizers.get(ord) : null;
if (categorizer == null) {
topBucketsPerOrd[ordIdx] = new Bucket[0];
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,90 @@ public void testCategorizationWithSubAggsManyDocs() throws Exception {
);
}

public void testCategorizationAsSubAggWithExtendedBounds() throws Exception {
// Test with more buckets than we have data for (via extended bounds in the histogram config).
// This will confirm that we don't try to read beyond the end of arrays of categorizers.
int numHistoBuckets = 50;
HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME)
.interval(1)
.extendedBounds(0, numHistoBuckets - 1)
.subAggregation(new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME));
testCase(CategorizeTextAggregatorTests::writeTestDocs, (InternalHistogram histo) -> {
assertThat(histo.getBuckets(), hasSize(numHistoBuckets));
assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(2L));
assertThat(histo.getBuckets().get(0).getKeyAsString(), equalTo("0.0"));
InternalCategorizationAggregation categorizationAggregation = histo.getBuckets().get(0).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(
categorizationAggregation.getBuckets().get(0).getKeyAsString(),
equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception")
);
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(84));
assertThat(
categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(),
equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*?")
);
assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(1L));
assertThat(histo.getBuckets().get(1).getKeyAsString(), equalTo("1.0"));
categorizationAggregation = histo.getBuckets().get(1).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L));
assertThat(histo.getBuckets().get(2).getKeyAsString(), equalTo("2.0"));
categorizationAggregation = histo.getBuckets().get(2).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
assertThat(histo.getBuckets().get(3).getDocCount(), equalTo(1L));
assertThat(histo.getBuckets().get(3).getKeyAsString(), equalTo("3.0"));
categorizationAggregation = histo.getBuckets().get(3).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
assertThat(histo.getBuckets().get(4).getDocCount(), equalTo(2L));
assertThat(histo.getBuckets().get(4).getKeyAsString(), equalTo("4.0"));
categorizationAggregation = histo.getBuckets().get(4).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(
categorizationAggregation.getBuckets().get(0).getKeyAsString(),
equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception")
);
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(84));
assertThat(
categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(),
equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*?")
);
assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(1).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
assertThat(histo.getBuckets().get(5).getDocCount(), equalTo(1L));
assertThat(histo.getBuckets().get(5).getKeyAsString(), equalTo("5.0"));
categorizationAggregation = histo.getBuckets().get(5).getAggregations().get("my_agg");
assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(1L));
assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started"));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().maxMatchingStringLen(), equalTo(15));
assertThat(categorizationAggregation.getBuckets().get(0).getSerializableCategory().getRegex(), equalTo(".*?Node.+?started.*?"));
for (int bucket = 6; bucket < numHistoBuckets; ++bucket) {
assertThat(histo.getBuckets().get(bucket).getDocCount(), equalTo(0L));
}
},
new AggTestConfig(
aggBuilder,
new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME, randomBoolean()),
longField(NUMERIC_FIELD_NAME)
)
);
}

private static void writeTestDocs(RandomIndexWriter w) throws IOException {
w.addDocument(
Arrays.asList(
Expand Down

0 comments on commit 93917cf

Please sign in to comment.