From ff3cab4219b021868b1774a36263e53f583d3764 Mon Sep 17 00:00:00 2001 From: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:07:26 +0100 Subject: [PATCH] Disallow ES|QL CATEGORIZE in aggregation filters (#118319) (#118330) --- .../xpack/esql/analysis/Verifier.java | 20 ++++++++++++++++--- .../xpack/esql/analysis/VerifierTests.java | 20 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index ecfe1aa7f9169..a0728c9a91088 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -382,6 +382,18 @@ private static void checkCategorizeGrouping(Aggregate agg, Set failures ); } }))); + agg.aggregates().forEach(a -> a.forEachDown(FilteredExpression.class, fe -> fe.filter().forEachDown(Attribute.class, attribute -> { + var categorize = categorizeByAttribute.get(attribute); + if (categorize != null) { + failures.add( + fail( + attribute, + "cannot reference CATEGORIZE grouping function [{}] within an aggregation filter", + attribute.sourceText() + ) + ); + } + }))); } private static void checkRateAggregates(Expression expr, int nestedLevel, Set failures) { @@ -421,7 +433,8 @@ private static void checkInvalidNamedExpressionUsage( Expression filter = fe.filter(); failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", fe.sourceText())); } - Expression f = fe.filter(); // check the filter has to be a boolean term, similar as checkFilterConditionType + Expression f = fe.filter(); + // check the filter has to be a boolean term, similar as checkFilterConditionType if (f.dataType() != NULL && f.dataType() != BOOLEAN) { failures.add(fail(f, "Condition expression needs to be boolean, found [{}]", f.dataType())); } @@ -432,9 +445,10 @@ private static void checkInvalidNamedExpressionUsage( fail(af, "cannot use aggregate function [{}] in aggregate WHERE clause [{}]", af.sourceText(), fe.sourceText()) ); } - // check the bucketing function against the group + // check the grouping function against the group else if (c instanceof GroupingFunction gf) { - if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { + if (c instanceof Categorize + || Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText())); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 92cac30f1bb20..d58d233168e2b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1968,6 +1968,26 @@ public void testCategorizeWithinAggregations() { ); } + public void testCategorizeWithFilteredAggregations() { + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); + + query("FROM test | STATS COUNT(*) WHERE first_name == \"John\" BY CATEGORIZE(last_name)"); + query("FROM test | STATS COUNT(*) WHERE last_name == \"Doe\" BY CATEGORIZE(last_name)"); + + assertEquals( + "1:34: can only use grouping function [CATEGORIZE(first_name)] as part of the BY clause", + error("FROM test | STATS COUNT(*) WHERE CATEGORIZE(first_name) == \"John\" BY CATEGORIZE(last_name)") + ); + assertEquals( + "1:34: can only use grouping function [CATEGORIZE(last_name)] as part of the BY clause", + error("FROM test | STATS COUNT(*) WHERE CATEGORIZE(last_name) == \"Doe\" BY CATEGORIZE(last_name)") + ); + assertEquals( + "1:34: cannot reference CATEGORIZE grouping function [category] within an aggregation filter", + error("FROM test | STATS COUNT(*) WHERE category == \"Doe\" BY category = CATEGORIZE(last_name)") + ); + } + public void testSortByAggregate() { assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)")); assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))"));