diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java index 4438ccec04c4c..22224628e23ad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java @@ -76,7 +76,7 @@ public Max replaceChildren(List newChildren) { @Override protected TypeResolution resolveType() { return TypeResolutions.isType( - this, + field(), e -> e == DataType.BOOLEAN || e == DataType.DATETIME || e == DataType.IP || (e.isNumeric() && e != DataType.UNSIGNED_LONG), sourceText(), DEFAULT, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java index 490d227206e06..8e7bb6bc3e799 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java @@ -76,7 +76,7 @@ public Min replaceChildren(List newChildren) { @Override protected TypeResolution resolveType() { return TypeResolutions.isType( - this, + field(), e -> e == DataType.BOOLEAN || e == DataType.DATETIME || e == DataType.IP || (e.isNumeric() && e != DataType.UNSIGNED_LONG), sourceText(), DEFAULT, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java index 79276b26be6d5..136e1233601f9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java @@ -17,10 +17,10 @@ import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; @@ -30,6 +30,7 @@ import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; public class Values extends AggregateFunction implements ToAggregator { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Values", Values::new); @@ -84,7 +85,13 @@ public DataType dataType() { @Override protected TypeResolution resolveType() { - return EsqlTypeResolutions.isNotSpatial(field(), sourceText(), DEFAULT); + return TypeResolutions.isType( + field(), + dt -> DataType.isSpatial(dt) == false && dt != UNSIGNED_LONG, + sourceText(), + DEFAULT, + "any type except unsigned_long and spatial types" + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 7333bd0e9f8a6..f0dd72e18ac2f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -1834,13 +1834,13 @@ public void testUnsupportedTypesInStats() { line 2:20: argument of [count_distinct(x)] must be [any exact type except unsigned_long, _source, or counter types],\ found value [x] type [unsigned_long] line 2:39: argument of [max(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\ - found value [max(x)] type [unsigned_long] + found value [x] type [unsigned_long] line 2:47: argument of [median(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [unsigned_long] line 2:58: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [unsigned_long] line 2:88: argument of [min(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\ - found value [min(x)] type [unsigned_long] + found value [x] type [unsigned_long] line 2:96: first argument of [percentile(x, 10)] must be [numeric except unsigned_long],\ found value [x] type [unsigned_long] line 2:115: argument of [sum(x)] must be [numeric except unsigned_long or counter types],\ @@ -1854,13 +1854,13 @@ public void testUnsupportedTypesInStats() { line 2:10: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] line 2:18: argument of [max(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\ - found value [max(x)] type [version] + found value [x] type [version] line 2:26: argument of [median(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] line 2:37: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] line 2:67: argument of [min(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\ - found value [min(x)] type [version] + found value [x] type [version] line 2:75: first argument of [percentile(x, 10)] must be [numeric except unsigned_long], found value [x] type [version] line 2:94: argument of [sum(x)] must be [numeric except unsigned_long or counter types], found value [x] type [version]"""); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 08b1ef9f6fef6..49372da04d8c3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -494,7 +494,7 @@ public void testAggregateOnCounter() { equalTo( "1:20: argument of [min(network.bytes_in)] must be" + " [boolean, datetime, ip or numeric except unsigned_long or counter types]," - + " found value [min(network.bytes_in)] type [counter_long]" + + " found value [network.bytes_in] type [counter_long]" ) ); @@ -503,7 +503,7 @@ public void testAggregateOnCounter() { equalTo( "1:20: argument of [max(network.bytes_in)] must be" + " [boolean, datetime, ip or numeric except unsigned_long or counter types]," - + " found value [max(network.bytes_in)] type [counter_long]" + + " found value [network.bytes_in] type [counter_long]" ) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 25ff4f9c2122d..65425486ea4e0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -57,6 +57,25 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa * Use if possible, as this method may get updated with new checks in the future. *

*/ + protected static Iterable parameterSuppliersFromTypedDataWithDefaultChecks( + List suppliers, + boolean entirelyNullPreservesType, + PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + return parameterSuppliersFromTypedData( + errorsForCasesWithoutExamples( + withNoRowsExpectingNull(anyNullIsNull(entirelyNullPreservesType, randomizeBytesRefsOffset(suppliers))), + positionalErrorMessageSupplier + ) + ); + } + + // TODO: Remove and migrate everything to the method with all the parameters + /** + * @deprecated Use {@link #parameterSuppliersFromTypedDataWithDefaultChecks(List, boolean, PositionalErrorMessageSupplier)} instead. + * This method doesn't add all the default checks. + */ + @Deprecated protected static Iterable parameterSuppliersFromTypedDataWithDefaultChecks(List suppliers) { return parameterSuppliersFromTypedData(withNoRowsExpectingNull(randomizeBytesRefsOffset(suppliers))); } @@ -119,24 +138,9 @@ public void testFold() { Expression expression = buildLiteralExpression(testCase); resolveExpression(expression, aggregatorFunctionSupplier -> { - // An aggregation cannot be folded - }, evaluableExpression -> { - assertTrue(evaluableExpression.foldable()); - if (testCase.foldingExceptionClass() == null) { - Object result = evaluableExpression.fold(); - // Decode unsigned longs into BigIntegers - if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) { - result = NumericUtils.unsignedLongAsBigInteger((Long) result); - } - assertThat(result, testCase.getMatcher()); - if (testCase.getExpectedWarnings() != null) { - assertWarnings(testCase.getExpectedWarnings()); - } - } else { - Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold); - assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage())); - } - }); + // An aggregation cannot be folded. + // It's not an error either as not all aggregations are foldable. + }, this::evaluate); } private void aggregateSingleMode(Expression expression) { @@ -263,13 +267,19 @@ private void aggregateWithIntermediates(Expression expression) { } private void evaluate(Expression evaluableExpression) { - Object result; - try (var evaluator = evaluator(evaluableExpression).get(driverContext())) { - try (Block block = evaluator.eval(row(testCase.getDataValues()))) { - result = toJavaObjectUnsignedLongAware(block, 0); - } + assertTrue(evaluableExpression.foldable()); + + if (testCase.foldingExceptionClass() != null) { + Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold); + assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage())); + return; } + Object result = evaluableExpression.fold(); + // Decode unsigned longs into BigIntegers + if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) { + result = NumericUtils.unsignedLongAsBigInteger((Long) result); + } assertThat(result, not(equalTo(Double.NaN))); assert testCase.getMatcher().matches(Double.POSITIVE_INFINITY) == false; assertThat(result, not(equalTo(Double.POSITIVE_INFINITY))); @@ -435,16 +445,23 @@ private IntBlock makeGroupsVector(int groupStart, int groupEnd, int rowCount) { */ private void processPageGrouping(GroupingAggregator aggregator, Page inputPage, int groupCount) { var groupSliceSize = 1; + var allValuesNull = IntStream.range(0, inputPage.getBlockCount()) + .mapToObj(inputPage::getBlock) + .anyMatch(Block::areAllValuesNull); // Add data to chunks of groups for (int currentGroupOffset = 0; currentGroupOffset < groupCount;) { - var seenGroupIds = new SeenGroupIds.Range(0, currentGroupOffset + groupSliceSize); + int groupSliceRemainingSize = Math.min(groupSliceSize, groupCount - currentGroupOffset); + var seenGroupIds = new SeenGroupIds.Range(0, allValuesNull ? 0 : currentGroupOffset + groupSliceRemainingSize); var addInput = aggregator.prepareProcessPage(seenGroupIds, inputPage); var positionCount = inputPage.getPositionCount(); var dataSliceSize = 1; // Divide data in chunks for (int currentDataOffset = 0; currentDataOffset < positionCount;) { - try (var groups = makeGroupsVector(currentGroupOffset, currentGroupOffset + groupSliceSize, dataSliceSize)) { + int dataSliceRemainingSize = Math.min(dataSliceSize, positionCount - currentDataOffset); + try ( + var groups = makeGroupsVector(currentGroupOffset, currentGroupOffset + groupSliceRemainingSize, dataSliceRemainingSize) + ) { addInput.add(currentDataOffset, groups); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index 20c583d3ac898..0c4bd6fe38b6a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.core.session.Configuration; @@ -49,6 +50,8 @@ import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; +import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; @@ -69,6 +72,7 @@ import org.elasticsearch.xpack.esql.planner.Layout; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.versionfield.Version; +import org.hamcrest.Matcher; import org.junit.After; import org.junit.AfterClass; @@ -95,6 +99,8 @@ import java.util.TreeSet; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; import static java.util.Map.entry; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; @@ -106,6 +112,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.nullValue; /** * Base class for function tests. @@ -191,6 +198,318 @@ protected static Iterable parameterSuppliersFromTypedData(List + * Note: This won't add more than a single null to any existing test case, + * just to keep the number of test cases from exploding totally. + *

+ * + * @param entirelyNullPreservesType should a test case that only contains parameters + * with the {@code null} type keep it's expected type? + * This is mostly going to be {@code true} + * except for functions that base their type entirely + * on input types like {@link Greatest} or {@link Coalesce}. + */ + protected static List anyNullIsNull(boolean entirelyNullPreservesType, List testCaseSuppliers) { + return anyNullIsNull( + testCaseSuppliers, + (nullPosition, nullValueDataType, original) -> entirelyNullPreservesType == false + && nullValueDataType == DataType.NULL + && original.getData().size() == 1 ? DataType.NULL : original.expectedType(), + (nullPosition, nullData, original) -> original + ); + } + + public interface ExpectedType { + DataType expectedType(int nullPosition, DataType nullValueDataType, TestCaseSupplier.TestCase original); + } + + public interface ExpectedEvaluatorToString { + Matcher evaluatorToString(int nullPosition, TestCaseSupplier.TypedData nullData, Matcher original); + } + + protected static List anyNullIsNull( + List testCaseSuppliers, + ExpectedType expectedType, + ExpectedEvaluatorToString evaluatorToString + ) { + typesRequired(testCaseSuppliers); + List suppliers = new ArrayList<>(testCaseSuppliers.size()); + suppliers.addAll(testCaseSuppliers); + + /* + * For each original test case, add as many copies as there were + * arguments, replacing one of the arguments with null and keeping + * the others. + * + * Also, if this was the first time we saw the signature we copy it + * *again*, replacing the argument with null, but annotating the + * argument's type as `null` explicitly. + */ + Set> uniqueSignatures = new HashSet<>(); + for (TestCaseSupplier original : testCaseSuppliers) { + boolean firstTimeSeenSignature = uniqueSignatures.add(original.types()); + for (int nullPosition = 0; nullPosition < original.types().size(); nullPosition++) { + int finalNullPosition = nullPosition; + suppliers.add(new TestCaseSupplier(original.name() + " null in " + nullPosition, original.types(), () -> { + TestCaseSupplier.TestCase oc = original.get(); + List data = IntStream.range(0, oc.getData().size()).mapToObj(i -> { + TestCaseSupplier.TypedData od = oc.getData().get(i); + if (i != finalNullPosition) { + return od; + } + return od.withData(od.isMultiRow() ? Collections.singletonList(null) : null); + }).toList(); + TestCaseSupplier.TypedData nulledData = oc.getData().get(finalNullPosition); + return new TestCaseSupplier.TestCase( + data, + evaluatorToString.evaluatorToString(finalNullPosition, nulledData, oc.evaluatorToString()), + expectedType.expectedType(finalNullPosition, nulledData.type(), oc), + nullValue(), + null, + oc.getExpectedTypeError(), + null, + null + ); + })); + + if (firstTimeSeenSignature) { + List typesWithNull = IntStream.range(0, original.types().size()) + .mapToObj(i -> i == finalNullPosition ? DataType.NULL : original.types().get(i)) + .toList(); + boolean newSignature = uniqueSignatures.add(typesWithNull); + if (newSignature) { + suppliers.add(new TestCaseSupplier(typesWithNull, () -> { + TestCaseSupplier.TestCase oc = original.get(); + List data = IntStream.range(0, oc.getData().size()) + .mapToObj( + i -> i == finalNullPosition + ? (oc.getData().get(i).isMultiRow() + ? TestCaseSupplier.TypedData.MULTI_ROW_NULL + : TestCaseSupplier.TypedData.NULL) + : oc.getData().get(i) + ) + .toList(); + return new TestCaseSupplier.TestCase( + data, + equalTo("LiteralsEvaluator[lit=null]"), + expectedType.expectedType(finalNullPosition, DataType.NULL, oc), + nullValue(), + null, + oc.getExpectedTypeError(), + null, + null + ); + })); + } + } + } + } + + return suppliers; + } + + @FunctionalInterface + protected interface PositionalErrorMessageSupplier { + /** + * This interface defines functions to supply error messages for incorrect types in specific positions. Functions which have + * the same type requirements for all positions can simplify this with a lambda returning a string constant. + * + * @param validForPosition - the set of {@link DataType}s that the test infrastructure believes to be allowable in the + * given position. + * @param position - the zero-index position in the list of parameters the function has detected the bad argument to be. + * @return The string describing the acceptable parameters for that position. Note that this function should not return + * the full error string; that will be constructed by the test. Just return the type string for that position. + */ + String apply(Set validForPosition, int position); + } + + /** + * Adds test cases containing unsupported parameter types that assert + * that they throw type errors. + */ + protected static List errorsForCasesWithoutExamples( + List testCaseSuppliers, + PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + return errorsForCasesWithoutExamples(testCaseSuppliers, (i, v, t) -> typeErrorMessage(i, v, t, positionalErrorMessageSupplier)); + } + + /** + * Build the expected error message for an invalid type signature. + */ + protected static String typeErrorMessage( + boolean includeOrdinal, + List> validPerPosition, + List types, + PositionalErrorMessageSupplier expectedTypeSupplier + ) { + int badArgPosition = -1; + for (int i = 0; i < types.size(); i++) { + if (validPerPosition.get(i).contains(types.get(i)) == false) { + badArgPosition = i; + break; + } + } + if (badArgPosition == -1) { + throw new IllegalStateException( + "Can't generate error message for these types, you probably need a custom error message function" + ); + } + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " " : ""; + String expectedTypeString = expectedTypeSupplier.apply(validPerPosition.get(badArgPosition), badArgPosition); + String name = types.get(badArgPosition).typeName(); + return ordinal + "argument of [] must be [" + expectedTypeString + "], found value [" + name + "] type [" + name + "]"; + } + + @FunctionalInterface + protected interface TypeErrorMessageSupplier { + String apply(boolean includeOrdinal, List> validPerPosition, List types); + } + + protected static List errorsForCasesWithoutExamples( + List testCaseSuppliers, + TypeErrorMessageSupplier typeErrorMessageSupplier + ) { + typesRequired(testCaseSuppliers); + List suppliers = new ArrayList<>(testCaseSuppliers.size()); + suppliers.addAll(testCaseSuppliers); + + Set> valid = testCaseSuppliers.stream().map(TestCaseSupplier::types).collect(Collectors.toSet()); + List> validPerPosition = validPerPosition(valid); + + testCaseSuppliers.stream() + .map(s -> s.types().size()) + .collect(Collectors.toSet()) + .stream() + .flatMap(count -> allPermutations(count)) + .filter(types -> valid.contains(types) == false) + /* + * Skip any cases with more than one null. Our tests don't generate + * the full combinatorial explosions of all nulls - just a single null. + * Hopefully , cases will function the same as , + * cases. + */.filter(types -> types.stream().filter(t -> t == DataType.NULL).count() <= 1) + .map(types -> typeErrorSupplier(validPerPosition.size() != 1, validPerPosition, types, typeErrorMessageSupplier)) + .forEach(suppliers::add); + return suppliers; + } + + private static List append(List orig, DataType extra) { + List longer = new ArrayList<>(orig.size() + 1); + longer.addAll(orig); + longer.add(extra); + return longer; + } + + protected static Stream representable() { + return DataType.types().stream().filter(DataType::isRepresentable); + } + + protected static TestCaseSupplier typeErrorSupplier( + boolean includeOrdinal, + List> validPerPosition, + List types, + PositionalErrorMessageSupplier errorMessageSupplier + ) { + return typeErrorSupplier(includeOrdinal, validPerPosition, types, (o, v, t) -> typeErrorMessage(o, v, t, errorMessageSupplier)); + } + + /** + * Build a test case that asserts that the combination of parameter types is an error. + */ + protected static TestCaseSupplier typeErrorSupplier( + boolean includeOrdinal, + List> validPerPosition, + List types, + TypeErrorMessageSupplier errorMessageSupplier + ) { + return new TestCaseSupplier( + "type error for " + TestCaseSupplier.nameFromTypes(types), + types, + () -> TestCaseSupplier.TestCase.typeError( + types.stream().map(type -> new TestCaseSupplier.TypedData(randomLiteral(type).value(), type, type.typeName())).toList(), + errorMessageSupplier.apply(includeOrdinal, validPerPosition, types) + ) + ); + } + + private static List> validPerPosition(Set> valid) { + int max = valid.stream().mapToInt(List::size).max().getAsInt(); + List> result = new ArrayList<>(max); + for (int i = 0; i < max; i++) { + result.add(new HashSet<>()); + } + for (List signature : valid) { + for (int i = 0; i < signature.size(); i++) { + result.get(i).add(signature.get(i)); + } + } + return result; + } + + protected static Stream> allPermutations(int argumentCount) { + if (argumentCount == 0) { + return Stream.of(List.of()); + } + if (argumentCount > 3) { + throw new IllegalArgumentException("would generate too many combinations"); + } + Stream> stream = validFunctionParameters().map(List::of); + for (int i = 1; i < argumentCount; i++) { + stream = stream.flatMap(types -> validFunctionParameters().map(t -> append(types, t))); + } + return stream; + } + + /** + * The types that are valid in function parameters. This is used by the + * function tests to enumerate all possible parameters to test error messages + * for invalid combinations. + */ + public static Stream validFunctionParameters() { + return Arrays.stream(DataType.values()).filter(t -> { + if (t == DataType.UNSUPPORTED) { + // By definition, functions never support UNSUPPORTED + return false; + } + if (t == DataType.DOC_DATA_TYPE || t == DataType.PARTIAL_AGG) { + /* + * Doc and partial_agg are special and functions aren't + * defined to take these. They'll use them implicitly if needed. + */ + return false; + } + if (t == DataType.OBJECT || t == DataType.NESTED) { + // Object and nested fields aren't supported by any functions yet + return false; + } + if (t == DataType.SOURCE || t == DataType.TSID_DATA_TYPE) { + // No functions take source or tsid fields yet. We'll make some eventually and remove this. + return false; + } + if (t == DataType.DATE_PERIOD || t == DataType.TIME_DURATION) { + // We don't test that functions don't take date_period or time_duration. We should. + return false; + } + if (t.isCounter()) { + /* + * For now, we're assuming no functions take counters + * as parameters. That's not true - some do. But we'll + * need to update the tests to handle that. + */ + return false; + } + if (t.widenSmallNumeric() != t) { + // Small numeric types are widened long before they arrive at functions. + return false; + } + + return true; + }).sorted(); + } + /** * Build an {@link Attribute} that loads a field. */ @@ -997,6 +1316,17 @@ protected static DataType[] strings() { return DataType.types().stream().filter(DataType::isString).toArray(DataType[]::new); } + /** + * Validate that we know the types for all the test cases already created + * @param suppliers - list of suppliers before adding in the illegal type combinations + */ + protected static void typesRequired(List suppliers) { + String bad = suppliers.stream().filter(s -> s.types() == null).map(s -> s.name()).collect(Collectors.joining("\n")); + if (bad.equals("") == false) { + throw new IllegalArgumentException("types required but not found for these tests:\n" + bad); + } + } + /** * Returns true if the current test case is for an aggregation function. *

diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java index 1caea78e79ad5..f4123af8abd0a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java @@ -22,21 +22,15 @@ import org.elasticsearch.indices.CrankyCircuitBreakerService; import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; -import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunctionTestCase; -import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.optimizer.FoldNull; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.hamcrest.Matcher; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Locale; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -44,7 +38,6 @@ import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; -import java.util.stream.Stream; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; import static org.hamcrest.Matchers.either; @@ -372,152 +365,6 @@ public final void testFold() { } } - /** - * Adds cases with {@code null} and asserts that the result is {@code null}. - *

- * Note: This won't add more than a single null to any existing test case, - * just to keep the number of test cases from exploding totally. - *

- * - * @param entirelyNullPreservesType should a test case that only contains parameters - * with the {@code null} type keep it's expected type? - * This is mostly going to be {@code true} - * except for functions that base their type entirely - * on input types like {@link Greatest} or {@link Coalesce}. - */ - protected static List anyNullIsNull(boolean entirelyNullPreservesType, List testCaseSuppliers) { - return anyNullIsNull( - testCaseSuppliers, - (nullPosition, nullValueDataType, original) -> entirelyNullPreservesType == false - && nullValueDataType == DataType.NULL - && original.getData().size() == 1 ? DataType.NULL : original.expectedType(), - (nullPosition, nullData, original) -> original - ); - } - - public interface ExpectedType { - DataType expectedType(int nullPosition, DataType nullValueDataType, TestCaseSupplier.TestCase original); - } - - public interface ExpectedEvaluatorToString { - Matcher evaluatorToString(int nullPosition, TestCaseSupplier.TypedData nullData, Matcher original); - } - - protected static List anyNullIsNull( - List testCaseSuppliers, - ExpectedType expectedType, - ExpectedEvaluatorToString evaluatorToString - ) { - typesRequired(testCaseSuppliers); - List suppliers = new ArrayList<>(testCaseSuppliers.size()); - suppliers.addAll(testCaseSuppliers); - - /* - * For each original test case, add as many copies as there were - * arguments, replacing one of the arguments with null and keeping - * the others. - * - * Also, if this was the first time we saw the signature we copy it - * *again*, replacing the argument with null, but annotating the - * argument's type as `null` explicitly. - */ - Set> uniqueSignatures = new HashSet<>(); - for (TestCaseSupplier original : testCaseSuppliers) { - boolean firstTimeSeenSignature = uniqueSignatures.add(original.types()); - for (int nullPosition = 0; nullPosition < original.types().size(); nullPosition++) { - int finalNullPosition = nullPosition; - suppliers.add(new TestCaseSupplier(original.name() + " null in " + nullPosition, original.types(), () -> { - TestCaseSupplier.TestCase oc = original.get(); - List data = IntStream.range(0, oc.getData().size()).mapToObj(i -> { - TestCaseSupplier.TypedData od = oc.getData().get(i); - return i == finalNullPosition ? od.withData(null) : od; - }).toList(); - TestCaseSupplier.TypedData nulledData = oc.getData().get(finalNullPosition); - return new TestCaseSupplier.TestCase( - data, - evaluatorToString.evaluatorToString(finalNullPosition, nulledData, oc.evaluatorToString()), - expectedType.expectedType(finalNullPosition, nulledData.type(), oc), - nullValue(), - null, - oc.getExpectedTypeError(), - null, - null - ); - })); - - if (firstTimeSeenSignature) { - List typesWithNull = IntStream.range(0, original.types().size()) - .mapToObj(i -> i == finalNullPosition ? DataType.NULL : original.types().get(i)) - .toList(); - boolean newSignature = uniqueSignatures.add(typesWithNull); - if (newSignature) { - suppliers.add(new TestCaseSupplier(typesWithNull, () -> { - TestCaseSupplier.TestCase oc = original.get(); - List data = IntStream.range(0, oc.getData().size()) - .mapToObj(i -> i == finalNullPosition ? TestCaseSupplier.TypedData.NULL : oc.getData().get(i)) - .toList(); - return new TestCaseSupplier.TestCase( - data, - equalTo("LiteralsEvaluator[lit=null]"), - expectedType.expectedType(finalNullPosition, DataType.NULL, oc), - nullValue(), - null, - oc.getExpectedTypeError(), - null, - null - ); - })); - } - } - } - } - - return suppliers; - - } - - /** - * Adds test cases containing unsupported parameter types that assert - * that they throw type errors. - */ - protected static List errorsForCasesWithoutExamples( - List testCaseSuppliers, - PositionalErrorMessageSupplier positionalErrorMessageSupplier - ) { - return errorsForCasesWithoutExamples( - testCaseSuppliers, - (i, v, t) -> AbstractScalarFunctionTestCase.typeErrorMessage(i, v, t, positionalErrorMessageSupplier) - ); - } - - protected static List errorsForCasesWithoutExamples( - List testCaseSuppliers, - TypeErrorMessageSupplier typeErrorMessageSupplier - ) { - typesRequired(testCaseSuppliers); - List suppliers = new ArrayList<>(testCaseSuppliers.size()); - suppliers.addAll(testCaseSuppliers); - - Set> valid = testCaseSuppliers.stream().map(TestCaseSupplier::types).collect(Collectors.toSet()); - List> validPerPosition = validPerPosition(valid); - - testCaseSuppliers.stream() - .map(s -> s.types().size()) - .collect(Collectors.toSet()) - .stream() - .flatMap(count -> allPermutations(count)) - .filter(types -> valid.contains(types) == false) - /* - * Skip any cases with more than one null. Our tests don't generate - * the full combinatorial explosions of all nulls - just a single null. - * Hopefully , cases will function the same as , - * cases. - */.filter(types -> types.stream().filter(t -> t == DataType.NULL).count() <= 1) - .map(types -> typeErrorSupplier(validPerPosition.size() != 1, validPerPosition, types, typeErrorMessageSupplier)) - .forEach(suppliers::add); - return suppliers; - } - public static String errorMessageStringForBinaryOperators( boolean includeOrdinal, List> validPerPosition, @@ -572,178 +419,4 @@ protected static List failureForCasesWithoutExamples(List suppliers) { - String bad = suppliers.stream().filter(s -> s.types() == null).map(s -> s.name()).collect(Collectors.joining("\n")); - if (bad.equals("") == false) { - throw new IllegalArgumentException("types required but not found for these tests:\n" + bad); - } - } - - private static List> validPerPosition(Set> valid) { - int max = valid.stream().mapToInt(List::size).max().getAsInt(); - List> result = new ArrayList<>(max); - for (int i = 0; i < max; i++) { - result.add(new HashSet<>()); - } - for (List signature : valid) { - for (int i = 0; i < signature.size(); i++) { - result.get(i).add(signature.get(i)); - } - } - return result; - } - - private static Stream> allPermutations(int argumentCount) { - if (argumentCount == 0) { - return Stream.of(List.of()); - } - if (argumentCount > 3) { - throw new IllegalArgumentException("would generate too many combinations"); - } - Stream> stream = validFunctionParameters().map(List::of); - for (int i = 1; i < argumentCount; i++) { - stream = stream.flatMap(types -> validFunctionParameters().map(t -> append(types, t))); - } - return stream; - } - - private static List append(List orig, DataType extra) { - List longer = new ArrayList<>(orig.size() + 1); - longer.addAll(orig); - longer.add(extra); - return longer; - } - - @FunctionalInterface - protected interface TypeErrorMessageSupplier { - String apply(boolean includeOrdinal, List> validPerPosition, List types); - } - - @FunctionalInterface - protected interface PositionalErrorMessageSupplier { - /** - * This interface defines functions to supply error messages for incorrect types in specific positions. Functions which have - * the same type requirements for all positions can simplify this with a lambda returning a string constant. - * - * @param validForPosition - the set of {@link DataType}s that the test infrastructure believes to be allowable in the - * given position. - * @param position - the zero-index position in the list of parameters the function has detected the bad argument to be. - * @return The string describing the acceptable parameters for that position. Note that this function should not return - * the full error string; that will be constructed by the test. Just return the type string for that position. - */ - String apply(Set validForPosition, int position); - } - - protected static TestCaseSupplier typeErrorSupplier( - boolean includeOrdinal, - List> validPerPosition, - List types, - PositionalErrorMessageSupplier errorMessageSupplier - ) { - return typeErrorSupplier( - includeOrdinal, - validPerPosition, - types, - (o, v, t) -> AbstractScalarFunctionTestCase.typeErrorMessage(o, v, t, errorMessageSupplier) - ); - } - - /** - * Build a test case that asserts that the combination of parameter types is an error. - */ - protected static TestCaseSupplier typeErrorSupplier( - boolean includeOrdinal, - List> validPerPosition, - List types, - TypeErrorMessageSupplier errorMessageSupplier - ) { - return new TestCaseSupplier( - "type error for " + TestCaseSupplier.nameFromTypes(types), - types, - () -> TestCaseSupplier.TestCase.typeError( - types.stream().map(type -> new TestCaseSupplier.TypedData(randomLiteral(type).value(), type, type.typeName())).toList(), - errorMessageSupplier.apply(includeOrdinal, validPerPosition, types) - ) - ); - } - - /** - * Build the expected error message for an invalid type signature. - */ - protected static String typeErrorMessage( - boolean includeOrdinal, - List> validPerPosition, - List types, - PositionalErrorMessageSupplier expectedTypeSupplier - ) { - int badArgPosition = -1; - for (int i = 0; i < types.size(); i++) { - if (validPerPosition.get(i).contains(types.get(i)) == false) { - badArgPosition = i; - break; - } - } - if (badArgPosition == -1) { - throw new IllegalStateException( - "Can't generate error message for these types, you probably need a custom error message function" - ); - } - String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " " : ""; - String expectedTypeString = expectedTypeSupplier.apply(validPerPosition.get(badArgPosition), badArgPosition); - String name = types.get(badArgPosition).typeName(); - return ordinal + "argument of [] must be [" + expectedTypeString + "], found value [" + name + "] type [" + name + "]"; - } - - /** - * The types that are valid in function parameters. This is used by the - * function tests to enumerate all possible parameters to test error messages - * for invalid combinations. - */ - public static Stream validFunctionParameters() { - return Arrays.stream(DataType.values()).filter(t -> { - if (t == DataType.UNSUPPORTED) { - // By definition, functions never support UNSUPPORTED - return false; - } - if (t == DataType.DOC_DATA_TYPE || t == DataType.PARTIAL_AGG) { - /* - * Doc and partial_agg are special and functions aren't - * defined to take these. They'll use them implicitly if needed. - */ - return false; - } - if (t == DataType.OBJECT || t == DataType.NESTED) { - // Object and nested fields aren't supported by any functions yet - return false; - } - if (t == DataType.SOURCE || t == DataType.TSID_DATA_TYPE) { - // No functions take source or tsid fields yet. We'll make some eventually and remove this. - return false; - } - if (t == DataType.DATE_PERIOD || t == DataType.TIME_DURATION) { - // We don't test that functions don't take date_period or time_duration. We should. - return false; - } - if (t.isCounter()) { - /* - * For now, we're assuming no functions take counters - * as parameters. That's not true - some do. But we'll - * need to update the tests to handle that. - */ - return false; - } - if (t.widenSmallNumeric() != t) { - // Small numeric types are widened long before they arrive at functions. - return false; - } - - return true; - }).sorted(); - } - } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 3585e58bf97ab..6652cca0c4527 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -30,6 +30,7 @@ import java.time.Period; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.function.BiFunction; import java.util.function.BinaryOperator; @@ -1455,6 +1456,7 @@ public TypedData get() { */ public static class TypedData { public static final TypedData NULL = new TypedData(null, DataType.NULL, ""); + public static final TypedData MULTI_ROW_NULL = TypedData.multiRow(Collections.singletonList(null), DataType.NULL, ""); private final Object data; private final DataType type; @@ -1583,7 +1585,7 @@ public Literal asLiteral() { throw new IllegalStateException("Multirow values require exactly 1 element to be a literal, got " + values.size()); } - return new Literal(Source.synthetic(name), values, type); + return new Literal(Source.synthetic(name), values.get(0), type); } return new Literal(Source.synthetic(name), data, type); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java index f456bd409059a..80737dac1aa58 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -53,7 +53,7 @@ public static Iterable parameters() { ) ); - return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers, true, (v, p) -> "numeric except unsigned_long or counter types"); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java index 1d489e0146ad3..52e908a51dd1e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java @@ -49,73 +49,6 @@ public static Iterable parameters() { suppliers.addAll( List.of( - // Surrogates - new TestCaseSupplier( - List.of(DataType.INTEGER), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), - "Max[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(200) - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), - "Max[field=Attribute[channel=0]]", - DataType.LONG, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), - "Max[field=Attribute[channel=0]]", - DataType.DOUBLE, - equalTo(200.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), - "Max[field=Attribute[channel=0]]", - DataType.DATETIME, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.BOOLEAN), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(true, false, false, true), DataType.BOOLEAN, "field")), - "Max[field=Attribute[channel=0]]", - DataType.BOOLEAN, - equalTo(true) - ) - ), - new TestCaseSupplier( - List.of(DataType.IP), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow( - List.of( - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("::1"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("::"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("ffff::"))) - ), - DataType.IP, - "field" - ) - ), - "Max[field=Attribute[channel=0]]", - DataType.IP, - equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("ffff::")))) - ) - ), - // Folding new TestCaseSupplier( List.of(DataType.INTEGER), @@ -180,7 +113,11 @@ public static Iterable parameters() { ) ); - return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + return parameterSuppliersFromTypedDataWithDefaultChecks( + suppliers, + false, + (v, p) -> "boolean, datetime, ip or numeric except unsigned_long or counter types" + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java index b5fb5b2c1c414..9514c817df497 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java @@ -49,73 +49,6 @@ public static Iterable parameters() { suppliers.addAll( List.of( - // Surrogates - new TestCaseSupplier( - List.of(DataType.INTEGER), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), - "Min[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(-2) - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), - "Min[field=Attribute[channel=0]]", - DataType.LONG, - equalTo(-2L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), - "Min[field=Attribute[channel=0]]", - DataType.DOUBLE, - equalTo(-2.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), - "Min[field=Attribute[channel=0]]", - DataType.DATETIME, - equalTo(0L) - ) - ), - new TestCaseSupplier( - List.of(DataType.BOOLEAN), - () -> new TestCaseSupplier.TestCase( - List.of(TestCaseSupplier.TypedData.multiRow(List.of(true, false, false, true), DataType.BOOLEAN, "field")), - "Min[field=Attribute[channel=0]]", - DataType.BOOLEAN, - equalTo(false) - ) - ), - new TestCaseSupplier( - List.of(DataType.IP), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow( - List.of( - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("::1"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("::"))), - new BytesRef(InetAddressPoint.encode(InetAddresses.forString("ffff::"))) - ), - DataType.IP, - "field" - ) - ), - "Min[field=Attribute[channel=0]]", - DataType.IP, - equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("::")))) - ) - ), - // Folding new TestCaseSupplier( List.of(DataType.INTEGER), @@ -180,7 +113,11 @@ public static Iterable parameters() { ) ); - return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + return parameterSuppliersFromTypedDataWithDefaultChecks( + suppliers, + false, + (v, p) -> "boolean, datetime, ip or numeric except unsigned_long or counter types" + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java index 23b70b94d0d7f..55320543d0ec3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java @@ -53,7 +53,11 @@ public static Iterable parameters() { MultiRowTestCaseSupplier.stringCases(1, 20, DataType.TEXT) ).flatMap(List::stream).map(ValuesTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); - return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + return parameterSuppliersFromTypedDataWithDefaultChecks( + suppliers, + false, + (v, p) -> "any type except unsigned_long and spatial types" + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java index 2ba091437f237..2c2ffc97f268c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java @@ -52,11 +52,11 @@ public static Iterable parameters() { List.of( // Folding new TestCaseSupplier( - List.of(DataType.INTEGER), + List.of(DataType.INTEGER, DataType.INTEGER), () -> new TestCaseSupplier.TestCase( List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5), DataType.INTEGER, "field"), - TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "field") + TestCaseSupplier.TypedData.multiRow(List.of(5), DataType.INTEGER, "number"), + TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "weight") ), "WeightedAvg[number=Attribute[channel=0],weight=Attribute[channel=1]]", DataType.DOUBLE, @@ -64,11 +64,11 @@ public static Iterable parameters() { ) ), new TestCaseSupplier( - List.of(DataType.LONG), + List.of(DataType.LONG, DataType.INTEGER), () -> new TestCaseSupplier.TestCase( List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L), DataType.LONG, "field"), - TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "field") + TestCaseSupplier.TypedData.multiRow(List.of(5L), DataType.LONG, "number"), + TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "weight") ), "WeightedAvg[number=Attribute[channel=0],weight=Attribute[channel=1]]", DataType.DOUBLE, @@ -76,11 +76,11 @@ public static Iterable parameters() { ) ), new TestCaseSupplier( - List.of(DataType.DOUBLE), + List.of(DataType.DOUBLE, DataType.INTEGER), () -> new TestCaseSupplier.TestCase( List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5.), DataType.DOUBLE, "field"), - TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "field") + TestCaseSupplier.TypedData.multiRow(List.of(5.), DataType.DOUBLE, "number"), + TestCaseSupplier.TypedData.multiRow(List.of(100), DataType.INTEGER, "weight") ), "WeightedAvg[number=Attribute[channel=0],weight=Attribute[channel=1]]", DataType.DOUBLE,