Skip to content

Commit

Permalink
Add test cases for nulls and wrong types to aggregation tests (elasti…
Browse files Browse the repository at this point in the history
…c#111482)

- Migrated the anyNullIsNull and wrong types cases to `AbstractFunctionTestCase`
- Minor fixes on anyNullIsNull to work with multi-row values. Just some conditions to return a List of null instead of a null. Everything else in these functions was mostly untouched
- Implemented it in some aggregations
- Fixed some errors around the aggregation tests code

Not all aggregations were migrated. Many of them have edge cases that don't work with some of those things.
For example, if `WEIGHTED_AVG(value, weight)` has a literal on the value, it ignores the weight, which makes anyNullIsNull fail as it expects a null return.
Such cases can be handled later.

Closes elastic#109917
  • Loading branch information
ivancea authored Aug 1, 2024
1 parent dfbedb2 commit c5da257
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 510 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public Max replaceChildren(List<Expression> newChildren) {
@Override
protected TypeResolution resolveType() {
return TypeResolutions.isType(
this,
field(),
e -> e == DataType.BOOLEAN || e == DataType.DATETIME || e == DataType.IP || (e.isNumeric() && e != DataType.UNSIGNED_LONG),
sourceText(),
DEFAULT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public Min replaceChildren(List<Expression> newChildren) {
@Override
protected TypeResolution resolveType() {
return TypeResolutions.isType(
this,
field(),
e -> e == DataType.BOOLEAN || e == DataType.DATETIME || e == DataType.IP || (e.isNumeric() && e != DataType.UNSIGNED_LONG),
sourceText(),
DEFAULT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
Expand All @@ -30,6 +30,7 @@
import java.util.List;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;

public class Values extends AggregateFunction implements ToAggregator {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Values", Values::new);
Expand Down Expand Up @@ -84,7 +85,13 @@ public DataType dataType() {

@Override
protected TypeResolution resolveType() {
return EsqlTypeResolutions.isNotSpatial(field(), sourceText(), DEFAULT);
return TypeResolutions.isType(
field(),
dt -> DataType.isSpatial(dt) == false && dt != UNSIGNED_LONG,
sourceText(),
DEFAULT,
"any type except unsigned_long and spatial types"
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1834,13 +1834,13 @@ public void testUnsupportedTypesInStats() {
line 2:20: argument of [count_distinct(x)] must be [any exact type except unsigned_long, _source, or counter types],\
found value [x] type [unsigned_long]
line 2:39: argument of [max(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\
found value [max(x)] type [unsigned_long]
found value [x] type [unsigned_long]
line 2:47: argument of [median(x)] must be [numeric except unsigned_long or counter types],\
found value [x] type [unsigned_long]
line 2:58: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\
found value [x] type [unsigned_long]
line 2:88: argument of [min(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\
found value [min(x)] type [unsigned_long]
found value [x] type [unsigned_long]
line 2:96: first argument of [percentile(x, 10)] must be [numeric except unsigned_long],\
found value [x] type [unsigned_long]
line 2:115: argument of [sum(x)] must be [numeric except unsigned_long or counter types],\
Expand All @@ -1854,13 +1854,13 @@ public void testUnsupportedTypesInStats() {
line 2:10: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\
found value [x] type [version]
line 2:18: argument of [max(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\
found value [max(x)] type [version]
found value [x] type [version]
line 2:26: argument of [median(x)] must be [numeric except unsigned_long or counter types],\
found value [x] type [version]
line 2:37: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\
found value [x] type [version]
line 2:67: argument of [min(x)] must be [boolean, datetime, ip or numeric except unsigned_long or counter types],\
found value [min(x)] type [version]
found value [x] type [version]
line 2:75: first argument of [percentile(x, 10)] must be [numeric except unsigned_long], found value [x] type [version]
line 2:94: argument of [sum(x)] must be [numeric except unsigned_long or counter types], found value [x] type [version]""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ public void testAggregateOnCounter() {
equalTo(
"1:20: argument of [min(network.bytes_in)] must be"
+ " [boolean, datetime, ip or numeric except unsigned_long or counter types],"
+ " found value [min(network.bytes_in)] type [counter_long]"
+ " found value [network.bytes_in] type [counter_long]"
)
);

Expand All @@ -503,7 +503,7 @@ public void testAggregateOnCounter() {
equalTo(
"1:20: argument of [max(network.bytes_in)] must be"
+ " [boolean, datetime, ip or numeric except unsigned_long or counter types],"
+ " found value [max(network.bytes_in)] type [counter_long]"
+ " found value [network.bytes_in] type [counter_long]"
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
* Use if possible, as this method may get updated with new checks in the future.
* </p>
*/
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecks(
List<TestCaseSupplier> suppliers,
boolean entirelyNullPreservesType,
PositionalErrorMessageSupplier positionalErrorMessageSupplier
) {
return parameterSuppliersFromTypedData(
errorsForCasesWithoutExamples(
withNoRowsExpectingNull(anyNullIsNull(entirelyNullPreservesType, randomizeBytesRefsOffset(suppliers))),
positionalErrorMessageSupplier
)
);
}

// TODO: Remove and migrate everything to the method with all the parameters
/**
* @deprecated Use {@link #parameterSuppliersFromTypedDataWithDefaultChecks(List, boolean, PositionalErrorMessageSupplier)} instead.
* This method doesn't add all the default checks.
*/
@Deprecated
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecks(List<TestCaseSupplier> suppliers) {
return parameterSuppliersFromTypedData(withNoRowsExpectingNull(randomizeBytesRefsOffset(suppliers)));
}
Expand Down Expand Up @@ -119,24 +138,9 @@ public void testFold() {
Expression expression = buildLiteralExpression(testCase);

resolveExpression(expression, aggregatorFunctionSupplier -> {
// An aggregation cannot be folded
}, evaluableExpression -> {
assertTrue(evaluableExpression.foldable());
if (testCase.foldingExceptionClass() == null) {
Object result = evaluableExpression.fold();
// Decode unsigned longs into BigIntegers
if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
assertThat(result, testCase.getMatcher());
if (testCase.getExpectedWarnings() != null) {
assertWarnings(testCase.getExpectedWarnings());
}
} else {
Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold);
assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage()));
}
});
// An aggregation cannot be folded.
// It's not an error either as not all aggregations are foldable.
}, this::evaluate);
}

private void aggregateSingleMode(Expression expression) {
Expand Down Expand Up @@ -263,13 +267,19 @@ private void aggregateWithIntermediates(Expression expression) {
}

private void evaluate(Expression evaluableExpression) {
Object result;
try (var evaluator = evaluator(evaluableExpression).get(driverContext())) {
try (Block block = evaluator.eval(row(testCase.getDataValues()))) {
result = toJavaObjectUnsignedLongAware(block, 0);
}
assertTrue(evaluableExpression.foldable());

if (testCase.foldingExceptionClass() != null) {
Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold);
assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage()));
return;
}

Object result = evaluableExpression.fold();
// Decode unsigned longs into BigIntegers
if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
assertThat(result, not(equalTo(Double.NaN)));
assert testCase.getMatcher().matches(Double.POSITIVE_INFINITY) == false;
assertThat(result, not(equalTo(Double.POSITIVE_INFINITY)));
Expand Down Expand Up @@ -435,16 +445,23 @@ private IntBlock makeGroupsVector(int groupStart, int groupEnd, int rowCount) {
*/
private void processPageGrouping(GroupingAggregator aggregator, Page inputPage, int groupCount) {
var groupSliceSize = 1;
var allValuesNull = IntStream.range(0, inputPage.getBlockCount())
.<Block>mapToObj(inputPage::getBlock)
.anyMatch(Block::areAllValuesNull);
// Add data to chunks of groups
for (int currentGroupOffset = 0; currentGroupOffset < groupCount;) {
var seenGroupIds = new SeenGroupIds.Range(0, currentGroupOffset + groupSliceSize);
int groupSliceRemainingSize = Math.min(groupSliceSize, groupCount - currentGroupOffset);
var seenGroupIds = new SeenGroupIds.Range(0, allValuesNull ? 0 : currentGroupOffset + groupSliceRemainingSize);
var addInput = aggregator.prepareProcessPage(seenGroupIds, inputPage);

var positionCount = inputPage.getPositionCount();
var dataSliceSize = 1;
// Divide data in chunks
for (int currentDataOffset = 0; currentDataOffset < positionCount;) {
try (var groups = makeGroupsVector(currentGroupOffset, currentGroupOffset + groupSliceSize, dataSliceSize)) {
int dataSliceRemainingSize = Math.min(dataSliceSize, positionCount - currentDataOffset);
try (
var groups = makeGroupsVector(currentGroupOffset, currentGroupOffset + groupSliceRemainingSize, dataSliceRemainingSize)
) {
addInput.add(currentDataOffset, groups);
}

Expand Down
Loading

0 comments on commit c5da257

Please sign in to comment.