From 03a71d2deee7bb2788fc40b8d21d90cc75b787e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Dec 2024 14:47:40 +0100 Subject: [PATCH] ESQL: Make Categorize usable in aggs when identical to a grouping (#117835) Cases like `STATS MV_APPEND(cat, CATEGORIZE(x)) BY cat=CATEGORIZE(x)` should work, as they're moved to an EVAL by a rule. Also, these cases were discarded, as they fail because of other verifications (Which also fail for BUCKET): ``` STATS x = category BY category=CATEGORIZE(message) STATS x = CATEGORIZE(message) BY CATEGORIZE(message) STATS x = CATEGORIZE(message) BY category=CATEGORIZE(message) --- .../src/main/resources/bucket.csv-spec | 21 +++ .../src/main/resources/categorize.csv-spec | 121 ++++++++++++------ .../src/main/resources/docs.csv-spec | 2 +- .../xpack/esql/action/EsqlCapabilities.java | 2 +- .../xpack/esql/analysis/Verifier.java | 39 +++--- ...ReplaceAggregateAggExpressionWithEval.java | 16 +++ ...laceAggregateNestedExpressionWithEval.java | 6 +- .../xpack/esql/analysis/VerifierTests.java | 34 +++-- .../optimizer/LogicalPlanOptimizerTests.java | 4 +- 9 files changed, 167 insertions(+), 78 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec index 7bbf011176693..b29c489910f65 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec @@ -503,6 +503,27 @@ FROM employees //end::reuseGroupingFunctionWithExpression-result[] ; +reuseGroupingFunctionImplicitAliasWithExpression#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| STATS s1 = `BUCKET(salary / 100 + 99, 50.)` + 1, s2 = BUCKET(salary / 1000 + 999, 50.) + 2 BY BUCKET(salary / 100 + 99, 50.), b2 = BUCKET(salary / 1000 + 999, 50.) +| SORT `BUCKET(salary / 100 + 99, 50.)`, b2 +| KEEP s1, `BUCKET(salary / 100 + 99, 50.)`, s2, b2 +; + + s1:double | BUCKET(salary / 100 + 99, 50.):double | s2:double | b2:double +351.0 |350.0 |1002.0 |1000.0 +401.0 |400.0 |1002.0 |1000.0 +451.0 |450.0 |1002.0 |1000.0 +501.0 |500.0 |1002.0 |1000.0 +551.0 |550.0 |1002.0 |1000.0 +601.0 |600.0 |1002.0 |1000.0 +601.0 |600.0 |1052.0 |1050.0 +651.0 |650.0 |1052.0 |1050.0 +701.0 |700.0 |1052.0 |1050.0 +751.0 |750.0 |1052.0 |1050.0 +801.0 |800.0 |1052.0 |1050.0 +; + reuseGroupingFunctionWithinAggs#[skip:-8.13.99, reason:BUCKET renamed in 8.14] FROM employees | STATS sum = 1 + MAX(1 + BUCKET(salary, 1000.)) BY BUCKET(salary, 1000.) + 1 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index e45b10d1aa122..804c1c56a1eb5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -1,5 +1,5 @@ standard aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS count=COUNT(), @@ -17,7 +17,7 @@ count:long | sum:long | avg:double | count_distinct:long | category:keyw ; values aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS values=MV_SORT(VALUES(message)), @@ -33,7 +33,7 @@ values:keyword | top ; mv -required_capability: categorize_v4 +required_capability: categorize_v5 FROM mv_sample_data | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message) @@ -48,7 +48,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; row mv -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"] | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message) @@ -61,7 +61,7 @@ COUNT():long | VALUES(str):keyword | category:keyword ; skips stopwords -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["Mon Tue connected to a", "Jul Aug connected to b September ", "UTC connected GMT to c UTC"] | STATS COUNT() BY category=CATEGORIZE(message) @@ -73,7 +73,7 @@ COUNT():long | category:keyword ; with multiple indices -required_capability: categorize_v4 +required_capability: categorize_v5 required_capability: union_types FROM sample_data* @@ -88,7 +88,7 @@ COUNT():long | category:keyword ; mv with many values -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS COUNT() BY category=CATEGORIZE(job_positions) @@ -105,7 +105,7 @@ COUNT():long | category:keyword ; mv with many values and SUM -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -120,7 +120,7 @@ SUM(languages):long | category:keyword ; mv with many values and nulls and SUM -required_capability: categorize_v4 +required_capability: categorize_v5 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -134,7 +134,7 @@ SUM(languages):long | category:keyword ; mv via eval -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = MV_APPEND(message, "Banana") @@ -150,7 +150,7 @@ COUNT():long | category:keyword ; mv via eval const -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -164,7 +164,7 @@ COUNT():long | category:keyword ; mv via eval const without aliases -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -178,7 +178,7 @@ COUNT():long | CATEGORIZE(message):keyword ; mv const in parameter -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -191,7 +191,7 @@ COUNT():long | c:keyword ; agg alias shadowing -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -206,7 +206,7 @@ c:keyword ; chained aggregations using categorize -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -221,7 +221,7 @@ COUNT():long | category:keyword ; stats without aggs -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS BY category=CATEGORIZE(message) @@ -235,7 +235,7 @@ category:keyword ; text field -required_capability: categorize_v4 +required_capability: categorize_v5 FROM hosts | STATS COUNT() BY category=CATEGORIZE(host_group) @@ -253,7 +253,7 @@ COUNT():long | category:keyword ; on TO_UPPER -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message)) @@ -267,7 +267,7 @@ COUNT():long | category:keyword ; on CONCAT -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana")) @@ -281,7 +281,7 @@ COUNT():long | category:keyword ; on CONCAT with unicode -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊")) @@ -295,7 +295,7 @@ COUNT():long | category:keyword ; on REVERSE(CONCAT()) -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊"))) @@ -309,7 +309,7 @@ COUNT():long | category:keyword ; and then TO_LOWER -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -324,7 +324,7 @@ COUNT():long | category:keyword ; on const empty string -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE("") @@ -336,7 +336,7 @@ COUNT():long | category:keyword ; on const empty string from eval -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = "" @@ -349,7 +349,7 @@ COUNT():long | category:keyword ; on null -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = null @@ -362,7 +362,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; on null string -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = null::string @@ -375,7 +375,7 @@ COUNT():long | category:keyword ; filtering out all data -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | WHERE @timestamp < "2023-10-23T00:00:00Z" @@ -387,7 +387,7 @@ COUNT():long | category:keyword ; filtering out all data with constant -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -398,7 +398,7 @@ COUNT():long | category:keyword ; drop output columns -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS count=COUNT() BY category=CATEGORIZE(message) @@ -413,7 +413,7 @@ x:integer ; category value processing -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = ["connected to a", "connected to b", "disconnected"] | STATS COUNT() BY category=CATEGORIZE(message) @@ -427,7 +427,7 @@ COUNT():long | category:keyword ; row aliases -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | EVAL x = message @@ -441,7 +441,7 @@ COUNT():long | category:keyword | y:keyword ; from aliases -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = message @@ -457,7 +457,7 @@ COUNT():long | category:keyword | y:keyword ; row aliases with keep -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | EVAL x = message @@ -473,7 +473,7 @@ COUNT():long | y:keyword ; from aliases with keep -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | EVAL x = message @@ -491,7 +491,7 @@ COUNT():long | y:keyword ; row rename -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to xyz" | RENAME message as x @@ -505,7 +505,7 @@ COUNT():long | y:keyword ; from rename -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | RENAME message as x @@ -521,7 +521,7 @@ COUNT():long | y:keyword ; row drop -required_capability: categorize_v4 +required_capability: categorize_v5 ROW message = "connected to a" | STATS c = COUNT() BY category=CATEGORIZE(message) @@ -534,7 +534,7 @@ c:long ; from drop -required_capability: categorize_v4 +required_capability: categorize_v5 FROM sample_data | STATS c = COUNT() BY category=CATEGORIZE(message) @@ -547,3 +547,48 @@ c:long 3 3 ; + +categorize in aggs inside function +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(category, category) BY category=CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; + +categorize in aggs same as grouping inside function +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), `CATEGORIZE(message)`) BY CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; + +categorize in aggs same as grouping inside function with explicit alias +required_capability: categorize_v5 + +FROM sample_data + | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), category) BY category=CATEGORIZE(message) + | SORT x + | KEEP `COUNT()`, x +; + +COUNT():long | x:keyword + 3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?] + 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] + 1 | [.*?Disconnected.*?,.*?Disconnected.*?] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec index 24baf1263d06a..aa89c775da4cf 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec @@ -678,7 +678,7 @@ Bangalore | 9 | 72 ; docsCategorize -required_capability: categorize_v4 +required_capability: categorize_v5 // tag::docsCategorize[] FROM sample_data | STATS count=COUNT() BY category=CATEGORIZE(message) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 646c4f8240c3e..b5d6dd8584e8c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -407,7 +407,7 @@ public enum Cap { /** * Supported the text categorization function "CATEGORIZE". */ - CATEGORIZE_V4, + CATEGORIZE_V5, /** * QSTR function diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 5f8c011cff53a..49d8a5ee8caad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; -import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; @@ -63,12 +62,10 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -364,35 +361,35 @@ private static void checkCategorizeGrouping(Aggregate agg, Set failures ); }); - // Forbid CATEGORIZE being used in the aggregations - agg.aggregates().forEach(a -> { - a.forEachDown( - Categorize.class, - categorize -> failures.add( - fail(categorize, "cannot use CATEGORIZE grouping function [{}] within the aggregations", categorize.sourceText()) + // Forbid CATEGORIZE being used in the aggregations, unless it appears as a grouping + agg.aggregates() + .forEach( + a -> a.forEachDown( + AggregateFunction.class, + aggregateFunction -> aggregateFunction.forEachDown( + Categorize.class, + categorize -> failures.add( + fail(categorize, "cannot use CATEGORIZE grouping function [{}] within an aggregation", categorize.sourceText()) + ) + ) ) ); - }); - // Forbid CATEGORIZE being referenced in the aggregation functions - Map categorizeByAliasId = new HashMap<>(); + // Forbid CATEGORIZE being referenced as a child of an aggregation function + AttributeMap categorizeByAttribute = new AttributeMap<>(); agg.groupings().forEach(g -> { g.forEachDown(Alias.class, alias -> { if (alias.child() instanceof Categorize categorize) { - categorizeByAliasId.put(alias.id(), categorize); + categorizeByAttribute.put(alias.toAttribute(), categorize); } }); }); agg.aggregates() .forEach(a -> a.forEachDown(AggregateFunction.class, aggregate -> aggregate.forEachDown(Attribute.class, attribute -> { - var categorize = categorizeByAliasId.get(attribute.id()); + var categorize = categorizeByAttribute.get(attribute); if (categorize != null) { failures.add( - fail( - attribute, - "cannot reference CATEGORIZE grouping function [{}] within the aggregations", - attribute.sourceText() - ) + fail(attribute, "cannot reference CATEGORIZE grouping function [{}] within an aggregation", attribute.sourceText()) ); } }))); @@ -449,7 +446,7 @@ private static void checkInvalidNamedExpressionUsage( // check the bucketing function against the group else if (c instanceof GroupingFunction gf) { if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { - failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText())); + failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText())); } } }); @@ -466,7 +463,7 @@ else if (c instanceof GroupingFunction gf) { // optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer // be verified (by check above in checkAggregate()), so do it explicitly here if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) { - failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText())); + failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText())); } else if (level == 0) { addFailureOnGroupingUsedNakedInAggs(failures, gf, "function"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java index 2361b46b2be6f..c36d4caf7f599 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java @@ -9,18 +9,21 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,6 +54,16 @@ protected LogicalPlan rule(Aggregate aggregate) { AttributeMap aliases = new AttributeMap<>(); aggregate.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); + // Build Categorize grouping functions map. + // Functions like BUCKET() shouldn't reach this point, + // as they are moved to an early EVAL by ReplaceAggregateNestedExpressionWithEval + Map groupingAttributes = new HashMap<>(); + aggregate.forEachExpressionUp(Alias.class, a -> { + if (a.child() instanceof Categorize groupingFunction) { + groupingAttributes.put(groupingFunction, a.toAttribute()); + } + }); + // break down each aggregate into AggregateFunction and/or grouping key // preserve the projection at the end List aggs = aggregate.aggregates(); @@ -109,6 +122,9 @@ protected LogicalPlan rule(Aggregate aggregate) { return alias.toAttribute(); }); + // replace grouping functions with their references + aggExpression = aggExpression.transformUp(Categorize.class, groupingAttributes::get); + Alias alias = as.replaceChild(aggExpression); newEvals.add(alias); newProjections.add(alias.toAttribute()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java index 985e68252a1f9..4dbc43454a023 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java @@ -51,6 +51,7 @@ protected LogicalPlan rule(Aggregate aggregate) { // Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval. if (g instanceof Alias as) { if (as.child() instanceof Categorize cat) { + // For Categorize grouping function, we only move the child expression into an eval if (cat.field() instanceof Attribute == false) { groupingChanged = true; var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true); @@ -59,7 +60,6 @@ protected LogicalPlan rule(Aggregate aggregate) { evalNames.put(fieldAs.name(), fieldAttr); Categorize replacement = cat.replaceChildren(List.of(fieldAttr)); newGroupings.set(i, as.replaceChild(replacement)); - groupingAttributes.put(cat, fieldAttr); } } else { groupingChanged = true; @@ -135,6 +135,10 @@ protected LogicalPlan rule(Aggregate aggregate) { }); // replace any grouping functions with their references pointing to the added synthetic eval replaced = replaced.transformDown(GroupingFunction.class, gf -> { + // Categorize in aggs depends on the grouping result, not on an early eval + if (gf instanceof Categorize) { + return gf; + } aggsChanged.set(true); // should never return null, as it's verified. // but even if broken, the transform will fail safely; otoh, returning `gf` will fail later due to incorrect plan. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index d02e78202e0c2..74e2de1141728 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -407,12 +407,12 @@ public void testAggFilterOnBucketingOrAggFunctions() { // but fails if it's different assertEquals( - "1:32: can only use grouping function [bucket(a, 3)] part of the BY clause", + "1:32: can only use grouping function [bucket(a, 3)] as part of the BY clause", error("row a = 1 | stats sum(a) where bucket(a, 3) > -1 by bucket(a,2)") ); assertEquals( - "1:40: can only use grouping function [bucket(salary, 10)] part of the BY clause", + "1:40: can only use grouping function [bucket(salary, 10)] as part of the BY clause", error("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by emp_no") ); @@ -444,19 +444,19 @@ public void testAggWithNonBooleanFilter() { public void testGroupingInsideAggsAsAgg() { assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.) by emp_no") ); assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.)") ); assertEquals( - "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats bucket(emp_no, 5.) by bucket(emp_no, 6.)") ); assertEquals( - "1:22: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause", + "1:22: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause", error("from test| stats 3 + bucket(emp_no, 5.) by bucket(emp_no, 6.)") ); } @@ -1846,7 +1846,7 @@ public void testIntervalAsString() { } public void testCategorizeSingleGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)"); query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); @@ -1875,7 +1875,7 @@ public void testCategorizeSingleGrouping() { } public void testCategorizeNestedGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)"); @@ -1890,27 +1890,33 @@ public void testCategorizeNestedGrouping() { } public void testCategorizeWithinAggregations() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)"); + query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY cat = CATEGORIZE(first_name)"); + query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY CATEGORIZE(first_name)"); assertEquals( - "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within the aggregations", + "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within an aggregation", error("FROM test | STATS COUNT(CATEGORIZE(first_name)) BY CATEGORIZE(first_name)") ); - assertEquals( - "1:25: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + "1:25: cannot reference CATEGORIZE grouping function [cat] within an aggregation", error("FROM test | STATS COUNT(cat) BY cat = CATEGORIZE(first_name)") ); assertEquals( - "1:30: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + "1:30: cannot reference CATEGORIZE grouping function [cat] within an aggregation", error("FROM test | STATS SUM(LENGTH(cat::keyword) + LENGTH(last_name)) BY cat = CATEGORIZE(first_name)") ); assertEquals( - "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within the aggregations", + "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within an aggregation", error("FROM test | STATS COUNT(`CATEGORIZE(first_name)`) BY CATEGORIZE(first_name)") ); + + assertEquals( + "1:28: can only use grouping function [CATEGORIZE(last_name)] as part of the BY clause", + error("FROM test | STATS MV_COUNT(CATEGORIZE(last_name)) BY CATEGORIZE(first_name)") + ); } public void testSortByAggregate() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index a74efca3b3d99..b76781f76f4af 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -1212,7 +1212,7 @@ public void testCombineProjectionWithAggregationFirstAndAliasedGroupingUsedInAgg * \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..] */ public void testCombineProjectionWithCategorizeGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); var plan = plan(""" from test @@ -3949,7 +3949,7 @@ public void testNestedExpressionsInGroups() { * \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..] */ public void testNestedExpressionsInGroupsWithCategorize() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); var plan = optimizedPlan(""" from test