diff --git a/.gitignore b/.gitignore index fd5449b9fc3b6..2151e666ea209 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ testfixtures_shared/ # Generated checkstyle_ide.xml +x-pack/plugin/esql/gen/ diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 1d268bb6f61b9..af08db3bdfd23 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -30,12 +30,11 @@ import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR; import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER; import static org.elasticsearch.compute.gen.Types.BLOCK; -import static org.elasticsearch.compute.gen.Types.DOUBLE_ARRAY_VECTOR; import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK; import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR; import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE; import static org.elasticsearch.compute.gen.Types.INT_BLOCK; -import static org.elasticsearch.compute.gen.Types.LONG_ARRAY_VECTOR; +import static org.elasticsearch.compute.gen.Types.INT_VECTOR; import static org.elasticsearch.compute.gen.Types.LONG_BLOCK; import static org.elasticsearch.compute.gen.Types.LONG_VECTOR; import static org.elasticsearch.compute.gen.Types.PAGE; @@ -91,30 +90,39 @@ private TypeName choseStateType() { return ClassName.get("org.elasticsearch.compute.aggregation", firstUpper(initReturn.toString()) + "State"); } - private String primitiveType() { - String initReturn = declarationType.toString().toLowerCase(Locale.ROOT); - if (initReturn.contains("double")) { - return "double"; - } else if (initReturn.contains("long")) { - return "long"; - } else { - throw new IllegalArgumentException("unknown primitive type for " + initReturn); + static String primitiveType(ExecutableElement init, ExecutableElement combine) { + if (combine != null) { + // If there's an explicit combine function it's final parameter is the type of the value. + return combine.getParameters().get(combine.getParameters().size() - 1).asType().toString(); + } + String initReturn = init.getReturnType().toString(); + switch (initReturn) { + case "double": + return "double"; + case "long": + return "long"; + case "int": + return "int"; + default: + throw new IllegalArgumentException("unknown primitive type for " + initReturn); } } - private ClassName valueBlockType() { - return switch (primitiveType()) { + static ClassName valueBlockType(ExecutableElement init, ExecutableElement combine) { + return switch (primitiveType(init, combine)) { case "double" -> DOUBLE_BLOCK; case "long" -> LONG_BLOCK; - default -> throw new IllegalArgumentException("unknown block type for " + primitiveType()); + case "int" -> INT_BLOCK; + default -> throw new IllegalArgumentException("unknown block type for " + primitiveType(init, combine)); }; } - private ClassName valueVectorType() { - return switch (primitiveType()) { + static ClassName valueVectorType(ExecutableElement init, ExecutableElement combine) { + return switch (primitiveType(init, combine)) { case "double" -> DOUBLE_VECTOR; case "long" -> LONG_VECTOR; - default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType()); + case "int" -> INT_VECTOR; + default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType(init, combine)); }; } @@ -187,15 +195,8 @@ private MethodSpec addRawInput() { builder.addStatement("assert channel >= 0"); builder.addStatement("$T type = page.getBlock(channel).elementType()", ELEMENT_TYPE); builder.beginControlFlow("if (type == $T.NULL)", ELEMENT_TYPE).addStatement("return").endControlFlow(); - if (primitiveType().equals("double")) { - builder.addStatement("$T block = page.getBlock(channel)", valueBlockType()); - } else { // long - builder.addStatement("$T block", valueBlockType()); - builder.beginControlFlow("if (type == $T.INT)", ELEMENT_TYPE) // explicit cast, for now - .addStatement("block = page.<$T>getBlock(channel).asLongBlock()", INT_BLOCK); - builder.nextControlFlow("else").addStatement("block = page.getBlock(channel)").endControlFlow(); - } - builder.addStatement("$T vector = block.asVector()", valueVectorType()); + builder.addStatement("$T block = page.getBlock(channel)", valueBlockType(init, combine)); + builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine)); builder.beginControlFlow("if (vector != null)").addStatement("addRawVector(vector)"); builder.nextControlFlow("else").addStatement("addRawBlock(block)").endControlFlow(); return builder.build(); @@ -203,7 +204,7 @@ private MethodSpec addRawInput() { private MethodSpec addRawVector() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector"); - builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(), "vector"); + builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(init, combine), "vector"); builder.beginControlFlow("for (int i = 0; i < vector.getPositionCount(); i++)"); { combineRawInput(builder, "vector"); @@ -217,7 +218,7 @@ private MethodSpec addRawVector() { private MethodSpec addRawBlock() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock"); - builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(), "block"); + builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(init, combine), "block"); builder.beginControlFlow("for (int i = 0; i < block.getTotalValueCount(); i++)"); { builder.beginControlFlow("if (block.isNull(i) == false)"); @@ -296,6 +297,8 @@ private void combineStates(MethodSpec.Builder builder) { private String primitiveStateMethod() { switch (stateType.toString()) { + case "org.elasticsearch.compute.aggregation.IntState": + return "intValue"; case "org.elasticsearch.compute.aggregation.LongState": return "longValue"; case "org.elasticsearch.compute.aggregation.DoubleState": @@ -339,11 +342,14 @@ private MethodSpec evaluateFinal() { private void primitiveStateToResult(MethodSpec.Builder builder) { switch (stateType.toString()) { + case "org.elasticsearch.compute.aggregation.IntState": + builder.addStatement("return $T.newConstantBlockWith(state.intValue(), 1)", INT_BLOCK); + return; case "org.elasticsearch.compute.aggregation.LongState": - builder.addStatement("return new $T(new long[] { state.longValue() }, 1).asBlock()", LONG_ARRAY_VECTOR); + builder.addStatement("return $T.newConstantBlockWith(state.longValue(), 1)", LONG_BLOCK); return; case "org.elasticsearch.compute.aggregation.DoubleState": - builder.addStatement("return new $T(new double[] { state.doubleValue() }, 1).asBlock()", DOUBLE_ARRAY_VECTOR); + builder.addStatement("return $T.newConstantBlockWith(state.doubleValue(), 1)", DOUBLE_BLOCK); return; default: throw new IllegalArgumentException("don't know how to convert state to result: " + stateType); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 7cb29a3cb347d..41952d5ee6c56 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -24,14 +24,14 @@ import javax.lang.model.element.TypeElement; import javax.lang.model.util.Elements; +import static org.elasticsearch.compute.gen.AggregatorImplementer.valueBlockType; +import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType; import static org.elasticsearch.compute.gen.Methods.findMethod; import static org.elasticsearch.compute.gen.Methods.findRequiredMethod; import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR; import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER; import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; import static org.elasticsearch.compute.gen.Types.BLOCK; -import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK; -import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR; import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION; import static org.elasticsearch.compute.gen.Types.LONG_BLOCK; import static org.elasticsearch.compute.gen.Types.LONG_VECTOR; @@ -88,33 +88,6 @@ private TypeName choseStateType() { return ClassName.get("org.elasticsearch.compute.aggregation", head + tail + "ArrayState"); } - private String primitiveType() { - String initReturn = declarationType.toString().toLowerCase(Locale.ROOT); - if (initReturn.contains("double")) { - return "double"; - } else if (initReturn.contains("long")) { - return "long"; - } else { - throw new IllegalArgumentException("unknown primitive type for " + initReturn); - } - } - - private ClassName valueBlockType() { - return switch (primitiveType()) { - case "double" -> DOUBLE_BLOCK; - case "long" -> LONG_BLOCK; - default -> throw new IllegalArgumentException("unknown block type for " + primitiveType()); - }; - } - - private ClassName valueVectorType() { - return switch (primitiveType()) { - case "double" -> DOUBLE_VECTOR; - case "long" -> LONG_VECTOR; - default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType()); - }; - } - public JavaFile sourceFile() { JavaFile.Builder builder = JavaFile.builder(implementation.packageName(), type()); builder.addFileComment(""" @@ -179,8 +152,8 @@ private MethodSpec addRawInputVector() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); builder.addParameter(LONG_VECTOR, "groups").addParameter(PAGE, "page"); - builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType()); - builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType()); + builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType(init, combine)); + builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine)); builder.beginControlFlow("if (valuesVector != null)"); { builder.addStatement("int positions = groups.getPositionCount()"); @@ -203,7 +176,7 @@ private MethodSpec addRawInputVector() { private MethodSpec addRawInputWithBlockValues() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInputWithBlockValues"); builder.addModifiers(Modifier.PRIVATE); - builder.addParameter(LONG_VECTOR, "groups").addParameter(valueBlockType(), "valuesBlock"); + builder.addParameter(LONG_VECTOR, "groups").addParameter(valueBlockType(init, combine), "valuesBlock"); builder.addStatement("int positions = groups.getPositionCount()"); builder.beginControlFlow("for (int position = 0; position < groups.getPositionCount(); position++)"); { @@ -227,8 +200,8 @@ private MethodSpec addRawInputBlock() { builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); builder.addParameter(LONG_BLOCK, "groups").addParameter(PAGE, "page"); builder.addStatement("assert channel >= 0"); - builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType()); - builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType()); + builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType(init, combine)); + builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine)); builder.addStatement("int positions = groups.getPositionCount()"); builder.beginControlFlow("if (valuesVector != null)"); { diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java index b611051687c90..55bc08f915ab6 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java @@ -32,10 +32,9 @@ public class Types { static final ClassName AGGREGATOR_STATE_VECTOR = ClassName.get(DATA_PACKAGE, "AggregatorStateVector"); static final ClassName AGGREGATOR_STATE_VECTOR_BUILDER = ClassName.get(DATA_PACKAGE, "AggregatorStateVector", "Builder"); + static final ClassName INT_VECTOR = ClassName.get(DATA_PACKAGE, "IntVector"); static final ClassName LONG_VECTOR = ClassName.get(DATA_PACKAGE, "LongVector"); - static final ClassName LONG_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "LongArrayVector"); static final ClassName DOUBLE_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleVector"); - static final ClassName DOUBLE_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleArrayVector"); static final ClassName AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunction"); static final ClassName GROUPING_AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "GroupingAggregatorFunction"); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FilterIntBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FilterIntBlock.java index 72456e046fa79..60d9ec70a329f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FilterIntBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FilterIntBlock.java @@ -30,11 +30,6 @@ public int getInt(int valueIndex) { return block.getInt(mapPosition(valueIndex)); } - @Override - public LongBlock asLongBlock() { - return new FilterLongBlock(block.asLongBlock(), positions); - } - @Override public ElementType elementType() { return ElementType.INT; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java index 3301eaf4ec72d..73d5ca9c26710 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java @@ -50,16 +50,6 @@ public ElementType elementType() { return ElementType.INT; } - @Override - public LongBlock asLongBlock() { // copy rather than view, for now - final int positions = getPositionCount(); - long[] longValues = new long[positions]; - for (int i = 0; i < positions; i++) { - longValues[i] = values[i]; - } - return new LongArrayBlock(longValues, getPositionCount(), firstValueIndexes, nullsMask); - } - @Override public boolean equals(Object obj) { if (obj instanceof IntBlock that) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java index 24ea23d9e35a7..4e34da7d1e46c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java @@ -33,8 +33,6 @@ public sealed interface IntBlock extends Block permits FilterIntBlock,IntArrayBl @Override IntBlock filter(int... positions); - LongBlock asLongBlock(); - /** * Compares the given object with this block for equality. Returns {@code true} if and only if the * given object is a IntBlock, and both blocks are {@link #equals(IntBlock, IntBlock) equal}. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java index 9d4033b7a84ab..4c9d5e883705c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java @@ -40,15 +40,6 @@ public ElementType elementType() { return vector.elementType(); } - public LongBlock asLongBlock() { // copy rather than view, for now - final int positions = getPositionCount(); - long[] longValues = new long[positions]; - for (int i = 0; i < positions; i++) { - longValues[i] = vector.getInt(i); - } - return new LongArrayVector(longValues, getPositionCount()).asBlock(); - } - @Override public IntBlock getRow(int position) { return filter(position); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunction.java new file mode 100644 index 0000000000000..cb57376797ca2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunction.java @@ -0,0 +1,104 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link AvgIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class AvgIntAggregatorFunction implements AggregatorFunction { + private final AvgLongAggregator.AvgState state; + + private final int channel; + + public AvgIntAggregatorFunction(int channel, AvgLongAggregator.AvgState state) { + this.channel = channel; + this.state = state; + } + + public static AvgIntAggregatorFunction create(int channel) { + return new AvgIntAggregatorFunction(channel, AvgIntAggregator.initSingle()); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + AvgIntAggregator.combine(state, vector.getInt(i)); + } + AvgIntAggregator.combineValueCount(state, vector.getPositionCount()); + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + AvgIntAggregator.combine(state, block.getInt(i)); + } + } + AvgIntAggregator.combineValueCount(state, block.validPositionCount()); + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + AvgLongAggregator.AvgState tmpState = new AvgLongAggregator.AvgState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + AvgIntAggregator.combineStates(state, tmpState); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, AvgLongAggregator.AvgState> builder = + AggregatorStateVector.builderOfAggregatorState(AvgLongAggregator.AvgState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return AvgIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..a0ebd4cd10833 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunction.java @@ -0,0 +1,147 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link AvgIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class AvgIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final AvgLongAggregator.GroupingAvgState state; + + private final int channel; + + public AvgIntGroupingAggregatorFunction(int channel, AvgLongAggregator.GroupingAvgState state) { + this.channel = channel; + this.state = state; + } + + public static AvgIntGroupingAggregatorFunction create(BigArrays bigArrays, int channel) { + return new AvgIntGroupingAggregatorFunction(channel, AvgIntAggregator.initGrouping(bigArrays)); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + AvgIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + AvgIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + AvgIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + AvgIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + AvgLongAggregator.GroupingAvgState inState = AvgIntAggregator.initGrouping(bigArrays); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + AvgIntAggregator.combineStates(state, groupId, inState, position); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + AvgLongAggregator.GroupingAvgState inState = ((AvgIntGroupingAggregatorFunction) input).state; + AvgIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, AvgLongAggregator.GroupingAvgState> builder = + AggregatorStateVector.builderOfAggregatorState(AvgLongAggregator.GroupingAvgState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return AvgIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunction.java index 001dfb214aa9e..e3958dd8525fa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/AvgLongAggregatorFunction.java @@ -10,7 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -41,12 +40,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleAggregatorFunction.java index eed92a01032fb..a58cb38cd260e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleAggregatorFunction.java @@ -9,7 +9,6 @@ import java.lang.StringBuilder; import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleArrayVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; @@ -89,7 +88,7 @@ public Block evaluateIntermediate() { @Override public Block evaluateFinal() { - return new DoubleArrayVector(new double[] { state.doubleValue() }, 1).asBlock(); + return DoubleBlock.newConstantBlockWith(state.doubleValue(), 1); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunction.java new file mode 100644 index 0000000000000..57307138a1022 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunction.java @@ -0,0 +1,102 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link MaxIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MaxIntAggregatorFunction implements AggregatorFunction { + private final IntState state; + + private final int channel; + + public MaxIntAggregatorFunction(int channel, IntState state) { + this.channel = channel; + this.state = state; + } + + public static MaxIntAggregatorFunction create(int channel) { + return new MaxIntAggregatorFunction(channel, new IntState(MaxIntAggregator.init())); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + state.intValue(MaxIntAggregator.combine(state.intValue(), vector.getInt(i))); + } + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + state.intValue(MaxIntAggregator.combine(state.intValue(), block.getInt(i))); + } + } + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + IntState tmpState = new IntState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + state.intValue(MaxIntAggregator.combine(state.intValue(), tmpState.intValue())); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, IntState> builder = + AggregatorStateVector.builderOfAggregatorState(IntState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return IntBlock.newConstantBlockWith(state.intValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..8deca7b86f6fb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java @@ -0,0 +1,147 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MaxIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MaxIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final IntArrayState state; + + private final int channel; + + public MaxIntGroupingAggregatorFunction(int channel, IntArrayState state) { + this.channel = channel; + this.state = state; + } + + public static MaxIntGroupingAggregatorFunction create(BigArrays bigArrays, int channel) { + return new MaxIntGroupingAggregatorFunction(channel, new IntArrayState(bigArrays, MaxIntAggregator.init())); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + IntArrayState inState = new IntArrayState(bigArrays, MaxIntAggregator.init()); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + IntArrayState inState = ((MaxIntGroupingAggregatorFunction) input).state; + state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, IntArrayState> builder = + AggregatorStateVector.builderOfAggregatorState(IntArrayState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return state.toValuesBlock(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongAggregatorFunction.java index a5164f79aa7a0..a8961c1f06295 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongAggregatorFunction.java @@ -10,8 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -42,12 +40,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); @@ -95,7 +88,7 @@ public Block evaluateIntermediate() { @Override public Block evaluateFinal() { - return new LongArrayVector(new long[] { state.longValue() }, 1).asBlock(); + return LongBlock.newConstantBlockWith(state.longValue(), 1); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunction.java new file mode 100644 index 0000000000000..f597393b86b3e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunction.java @@ -0,0 +1,103 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link MedianAbsoluteDeviationIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MedianAbsoluteDeviationIntAggregatorFunction implements AggregatorFunction { + private final QuantileStates.SingleState state; + + private final int channel; + + public MedianAbsoluteDeviationIntAggregatorFunction(int channel, + QuantileStates.SingleState state) { + this.channel = channel; + this.state = state; + } + + public static MedianAbsoluteDeviationIntAggregatorFunction create(int channel) { + return new MedianAbsoluteDeviationIntAggregatorFunction(channel, MedianAbsoluteDeviationIntAggregator.initSingle()); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + MedianAbsoluteDeviationIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + MedianAbsoluteDeviationIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + QuantileStates.SingleState tmpState = new QuantileStates.SingleState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + MedianAbsoluteDeviationIntAggregator.combineStates(state, tmpState); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, QuantileStates.SingleState> builder = + AggregatorStateVector.builderOfAggregatorState(QuantileStates.SingleState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return MedianAbsoluteDeviationIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..051bccb5a191a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java @@ -0,0 +1,149 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MedianAbsoluteDeviationIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final QuantileStates.GroupingState state; + + private final int channel; + + public MedianAbsoluteDeviationIntGroupingAggregatorFunction(int channel, + QuantileStates.GroupingState state) { + this.channel = channel; + this.state = state; + } + + public static MedianAbsoluteDeviationIntGroupingAggregatorFunction create(BigArrays bigArrays, + int channel) { + return new MedianAbsoluteDeviationIntGroupingAggregatorFunction(channel, MedianAbsoluteDeviationIntAggregator.initGrouping(bigArrays)); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + QuantileStates.GroupingState inState = MedianAbsoluteDeviationIntAggregator.initGrouping(bigArrays); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + MedianAbsoluteDeviationIntAggregator.combineStates(state, groupId, inState, position); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationIntGroupingAggregatorFunction) input).state; + MedianAbsoluteDeviationIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, QuantileStates.GroupingState> builder = + AggregatorStateVector.builderOfAggregatorState(QuantileStates.GroupingState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return MedianAbsoluteDeviationIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregatorFunction.java index d9e10effd24d1..dc587f0f35707 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregatorFunction.java @@ -10,7 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -42,12 +41,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunction.java new file mode 100644 index 0000000000000..1736202ca0969 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunction.java @@ -0,0 +1,102 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link MedianIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MedianIntAggregatorFunction implements AggregatorFunction { + private final QuantileStates.SingleState state; + + private final int channel; + + public MedianIntAggregatorFunction(int channel, QuantileStates.SingleState state) { + this.channel = channel; + this.state = state; + } + + public static MedianIntAggregatorFunction create(int channel) { + return new MedianIntAggregatorFunction(channel, MedianIntAggregator.initSingle()); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + MedianIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + MedianIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + QuantileStates.SingleState tmpState = new QuantileStates.SingleState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + MedianIntAggregator.combineStates(state, tmpState); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, QuantileStates.SingleState> builder = + AggregatorStateVector.builderOfAggregatorState(QuantileStates.SingleState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return MedianIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..658be2cd8e2cc --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunction.java @@ -0,0 +1,147 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MedianIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MedianIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final QuantileStates.GroupingState state; + + private final int channel; + + public MedianIntGroupingAggregatorFunction(int channel, QuantileStates.GroupingState state) { + this.channel = channel; + this.state = state; + } + + public static MedianIntGroupingAggregatorFunction create(BigArrays bigArrays, int channel) { + return new MedianIntGroupingAggregatorFunction(channel, MedianIntAggregator.initGrouping(bigArrays)); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + MedianIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + MedianIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + MedianIntAggregator.combine(state, groupId, valuesVector.getInt(position)); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + MedianIntAggregator.combine(state, groupId, valuesBlock.getInt(position)); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + QuantileStates.GroupingState inState = MedianIntAggregator.initGrouping(bigArrays); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + MedianIntAggregator.combineStates(state, groupId, inState, position); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + QuantileStates.GroupingState inState = ((MedianIntGroupingAggregatorFunction) input).state; + MedianIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, QuantileStates.GroupingState> builder = + AggregatorStateVector.builderOfAggregatorState(QuantileStates.GroupingState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return MedianIntAggregator.evaluateFinal(state); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianLongAggregatorFunction.java index 332be4fa54c0c..27705137d7f31 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianLongAggregatorFunction.java @@ -10,7 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -41,12 +40,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleAggregatorFunction.java index aca15a08ab467..8704cf8c72494 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleAggregatorFunction.java @@ -9,7 +9,6 @@ import java.lang.StringBuilder; import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleArrayVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; @@ -89,7 +88,7 @@ public Block evaluateIntermediate() { @Override public Block evaluateFinal() { - return new DoubleArrayVector(new double[] { state.doubleValue() }, 1).asBlock(); + return DoubleBlock.newConstantBlockWith(state.doubleValue(), 1); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntAggregatorFunction.java new file mode 100644 index 0000000000000..af285f97dfcb2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntAggregatorFunction.java @@ -0,0 +1,102 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link MinIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MinIntAggregatorFunction implements AggregatorFunction { + private final IntState state; + + private final int channel; + + public MinIntAggregatorFunction(int channel, IntState state) { + this.channel = channel; + this.state = state; + } + + public static MinIntAggregatorFunction create(int channel) { + return new MinIntAggregatorFunction(channel, new IntState(MinIntAggregator.init())); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + state.intValue(MinIntAggregator.combine(state.intValue(), vector.getInt(i))); + } + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + state.intValue(MinIntAggregator.combine(state.intValue(), block.getInt(i))); + } + } + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + IntState tmpState = new IntState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + state.intValue(MinIntAggregator.combine(state.intValue(), tmpState.intValue())); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, IntState> builder = + AggregatorStateVector.builderOfAggregatorState(IntState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return IntBlock.newConstantBlockWith(state.intValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..84c4ba608bbdb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java @@ -0,0 +1,147 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MinIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class MinIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final IntArrayState state; + + private final int channel; + + public MinIntGroupingAggregatorFunction(int channel, IntArrayState state) { + this.channel = channel; + this.state = state; + } + + public static MinIntGroupingAggregatorFunction create(BigArrays bigArrays, int channel) { + return new MinIntGroupingAggregatorFunction(channel, new IntArrayState(bigArrays, MinIntAggregator.init())); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + IntArrayState inState = new IntArrayState(bigArrays, MinIntAggregator.init()); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + IntArrayState inState = ((MinIntGroupingAggregatorFunction) input).state; + state.set(MinIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, IntArrayState> builder = + AggregatorStateVector.builderOfAggregatorState(IntArrayState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return state.toValuesBlock(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongAggregatorFunction.java index 5f2f50d6e2422..3eec5ea00c3bb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongAggregatorFunction.java @@ -10,8 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -42,12 +40,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); @@ -95,7 +88,7 @@ public Block evaluateIntermediate() { @Override public Block evaluateFinal() { - return new LongArrayVector(new long[] { state.longValue() }, 1).asBlock(); + return LongBlock.newConstantBlockWith(state.longValue(), 1); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntAggregatorFunction.java new file mode 100644 index 0000000000000..e03084672dfec --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntAggregatorFunction.java @@ -0,0 +1,103 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link AggregatorFunction} implementation for {@link SumIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class SumIntAggregatorFunction implements AggregatorFunction { + private final LongState state; + + private final int channel; + + public SumIntAggregatorFunction(int channel, LongState state) { + this.channel = channel; + this.state = state; + } + + public static SumIntAggregatorFunction create(int channel) { + return new SumIntAggregatorFunction(channel, new LongState(SumIntAggregator.init())); + } + + @Override + public void addRawInput(Page page) { + assert channel >= 0; + ElementType type = page.getBlock(channel).elementType(); + if (type == ElementType.NULL) { + return; + } + IntBlock block = page.getBlock(channel); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + state.longValue(SumIntAggregator.combine(state.longValue(), vector.getInt(i))); + } + } + + private void addRawBlock(IntBlock block) { + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.isNull(i) == false) { + state.longValue(SumIntAggregator.combine(state.longValue(), block.getInt(i))); + } + } + } + + @Override + public void addIntermediateInput(Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + LongState tmpState = new LongState(); + for (int i = 0; i < block.getPositionCount(); i++) { + blobVector.get(i, tmpState); + SumIntAggregator.combineStates(state, tmpState); + } + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, LongState> builder = + AggregatorStateVector.builderOfAggregatorState(LongState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return LongBlock.newConstantBlockWith(state.longValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..669cc58c5567d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java @@ -0,0 +1,147 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.AggregatorStateVector; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.data.Vector; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SumIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class SumIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private final LongArrayState state; + + private final int channel; + + public SumIntGroupingAggregatorFunction(int channel, LongArrayState state) { + this.channel = channel; + this.state = state; + } + + public static SumIntGroupingAggregatorFunction create(BigArrays bigArrays, int channel) { + return new SumIntGroupingAggregatorFunction(channel, new LongArrayState(bigArrays, SumIntAggregator.init())); + } + + @Override + public void addRawInput(LongVector groups, Page page) { + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector != null) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(SumIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } else { + // move the cold branch out of this method to keep the optimized case vector/vector as small as possible + addRawInputWithBlockValues(groups, valuesBlock); + } + } + + private void addRawInputWithBlockValues(LongVector groups, IntBlock valuesBlock) { + int positions = groups.getPositionCount(); + for (int position = 0; position < groups.getPositionCount(); position++) { + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(SumIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + + @Override + public void addRawInput(LongBlock groups, Page page) { + assert channel >= 0; + IntBlock valuesBlock = page.getBlock(channel); + IntVector valuesVector = valuesBlock.asVector(); + int positions = groups.getPositionCount(); + if (valuesVector != null) { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position) == false) { + int groupId = Math.toIntExact(groups.getLong(position)); + state.set(SumIntAggregator.combine(state.getOrDefault(groupId), valuesVector.getInt(position)), groupId); + } + } + } else { + for (int position = 0; position < groups.getPositionCount(); position++) { + if (groups.isNull(position)) { + continue; + } + int groupId = Math.toIntExact(groups.getLong(position)); + if (valuesBlock.isNull(position)) { + state.putNull(groupId); + } else { + state.set(SumIntAggregator.combine(state.getOrDefault(groupId), valuesBlock.getInt(position)), groupId); + } + } + } + } + + @Override + public void addIntermediateInput(LongVector groupIdVector, Block block) { + assert channel == -1; + Vector vector = block.asVector(); + if (vector == null || vector instanceof AggregatorStateVector == false) { + throw new RuntimeException("expected AggregatorStateBlock, got:" + block); + } + @SuppressWarnings("unchecked") AggregatorStateVector blobVector = (AggregatorStateVector) vector; + // TODO exchange big arrays directly without funny serialization - no more copying + BigArrays bigArrays = BigArrays.NON_RECYCLING_INSTANCE; + LongArrayState inState = new LongArrayState(bigArrays, SumIntAggregator.init()); + blobVector.get(0, inState); + for (int position = 0; position < groupIdVector.getPositionCount(); position++) { + int groupId = Math.toIntExact(groupIdVector.getLong(position)); + SumIntAggregator.combineStates(state, groupId, inState, position); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + LongArrayState inState = ((SumIntGroupingAggregatorFunction) input).state; + SumIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public Block evaluateIntermediate() { + AggregatorStateVector.Builder, LongArrayState> builder = + AggregatorStateVector.builderOfAggregatorState(LongArrayState.class, state.getEstimatedSize()); + builder.add(state); + return builder.build().asBlock(); + } + + @Override + public Block evaluateFinal() { + return state.toValuesBlock(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channel=").append(channel); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index aef7a29569e27..aefa51e0593f5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -10,8 +10,6 @@ import org.elasticsearch.compute.data.AggregatorStateVector; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; @@ -42,12 +40,7 @@ public void addRawInput(Page page) { if (type == ElementType.NULL) { return; } - LongBlock block; - if (type == ElementType.INT) { - block = page.getBlock(channel).asLongBlock(); - } else { - block = page.getBlock(channel); - } + LongBlock block = page.getBlock(channel); LongVector vector = block.asVector(); if (vector != null) { addRawVector(vector); @@ -95,7 +88,7 @@ public Block evaluateIntermediate() { @Override public Block evaluateFinal() { - return new LongArrayVector(new long[] { state.longValue() }, 1).asBlock(); + return LongBlock.newConstantBlockWith(state.longValue(), 1); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregationType.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregationType.java index 97699d29fe215..07b7b0590513a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregationType.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregationType.java @@ -12,6 +12,8 @@ public enum AggregationType { agnostic, + ints, + longs, doubles diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregatorFunction.java index 803c0c26f34f4..154ba06f47af2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AggregatorFunction.java @@ -23,6 +23,7 @@ import static org.elasticsearch.compute.aggregation.AggregationName.sum; import static org.elasticsearch.compute.aggregation.AggregationType.agnostic; import static org.elasticsearch.compute.aggregation.AggregationType.doubles; +import static org.elasticsearch.compute.aggregation.AggregationType.ints; import static org.elasticsearch.compute.aggregation.AggregationType.longs; @Experimental @@ -53,6 +54,15 @@ static Factory of(AggregationName name, AggregationType type) { case count -> COUNT; default -> throw new IllegalArgumentException("unknown " + name + ", type:" + type); }; + case ints -> switch (name) { + case avg -> AVG_INTS; + case count -> COUNT; + case max -> MAX_INTS; + case median -> MEDIAN_INTS; + case median_absolute_deviation -> MEDIAN_ABSOLUTE_DEVIATION_INTS; + case min -> MIN_INTS; + case sum -> SUM_INTS; + }; case longs -> switch (name) { case avg -> AVG_LONGS; case count -> COUNT; @@ -76,14 +86,17 @@ static Factory of(AggregationName name, AggregationType type) { Factory AVG_DOUBLES = new Factory(avg, doubles, AvgDoubleAggregatorFunction::create); Factory AVG_LONGS = new Factory(avg, longs, AvgLongAggregatorFunction::create); + Factory AVG_INTS = new Factory(avg, ints, AvgIntAggregatorFunction::create); Factory COUNT = new Factory(count, agnostic, CountAggregatorFunction::create); Factory MAX_DOUBLES = new Factory(max, doubles, MaxDoubleAggregatorFunction::create); Factory MAX_LONGS = new Factory(max, longs, MaxLongAggregatorFunction::create); + Factory MAX_INTS = new Factory(max, ints, MaxIntAggregatorFunction::create); Factory MEDIAN_DOUBLES = new Factory(median, doubles, MedianDoubleAggregatorFunction::create); Factory MEDIAN_LONGS = new Factory(median, longs, MedianLongAggregatorFunction::create); + Factory MEDIAN_INTS = new Factory(median, ints, MedianIntAggregatorFunction::create); Factory MEDIAN_ABSOLUTE_DEVIATION_DOUBLES = new Factory( median_absolute_deviation, @@ -95,10 +108,17 @@ static Factory of(AggregationName name, AggregationType type) { longs, MedianAbsoluteDeviationLongAggregatorFunction::create ); + Factory MEDIAN_ABSOLUTE_DEVIATION_INTS = new Factory( + median_absolute_deviation, + ints, + MedianAbsoluteDeviationIntAggregatorFunction::create + ); Factory MIN_DOUBLES = new Factory(min, doubles, MinDoubleAggregatorFunction::create); Factory MIN_LONGS = new Factory(min, longs, MinLongAggregatorFunction::create); + Factory MIN_INTS = new Factory(min, ints, MinIntAggregatorFunction::create); Factory SUM_DOUBLES = new Factory(sum, doubles, SumDoubleAggregatorFunction::create); Factory SUM_LONGS = new Factory(sum, longs, SumLongAggregatorFunction::create); + Factory SUM_INTS = new Factory(sum, ints, SumIntAggregatorFunction::create); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AvgIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AvgIntAggregator.java new file mode 100644 index 0000000000000..b0fad89878ac8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AvgIntAggregator.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.aggregation.AvgLongAggregator.AvgState; +import org.elasticsearch.compute.aggregation.AvgLongAggregator.GroupingAvgState; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; + +@Aggregator +@GroupingAggregator +class AvgIntAggregator { + public static AvgState initSingle() { + return new AvgState(); + } + + public static void combine(AvgState current, int v) { + current.value = Math.addExact(current.value, v); + } + + public static void combineValueCount(AvgState current, int positions) { + current.count += positions; + } + + public static void combineStates(AvgState current, AvgState state) { + current.value = Math.addExact(current.value, state.value); + current.count += state.count; + } + + public static Block evaluateFinal(AvgState state) { + double result = ((double) state.value) / state.count; + return DoubleBlock.newConstantBlockWith(result, 1); + } + + public static GroupingAvgState initGrouping(BigArrays bigArrays) { + return new GroupingAvgState(bigArrays); + } + + public static void combine(GroupingAvgState current, int groupId, int v) { + current.add(v, groupId, 1); + } + + public static void combineStates(GroupingAvgState current, int currentGroupId, GroupingAvgState state, int statePosition) { + current.add(state.values.get(statePosition), currentGroupId, state.counts.get(statePosition)); + } + + public static Block evaluateFinal(GroupingAvgState state) { + int positions = state.largestGroupId + 1; + DoubleBlock.Builder builder = DoubleBlock.newBlockBuilder(positions); + for (int i = 0; i < positions; i++) { + final long count = state.counts.get(i); + if (count > 0) { + builder.appendDouble((double) state.values.get(i) / count); + } else { + assert state.values.get(i) == 0; + builder.appendNull(); + } + } + return builder.build(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DoubleArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DoubleArrayState.java index 58ff32fe52729..a229ee92617fc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DoubleArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DoubleArrayState.java @@ -12,8 +12,8 @@ import org.elasticsearch.common.util.DoubleArray; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleArrayVector; import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.core.Releasables; import java.lang.invoke.MethodHandles; @@ -85,11 +85,11 @@ boolean hasValue(int index) { Block toValuesBlock() { final int positions = largestIndex + 1; if (nonNulls == null) { - final double[] vs = new double[positions]; + DoubleVector.Builder builder = DoubleVector.newVectorBuilder(positions); for (int i = 0; i < positions; i++) { - vs[i] = values.get(i); + builder.appendDouble(values.get(i)); } - return new DoubleArrayVector(vs, positions).asBlock(); + return builder.build().asBlock(); } else { final DoubleBlock.Builder builder = DoubleBlock.newBlockBuilder(positions); for (int i = 0; i < positions; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index 539ad323862c3..4493bf908756c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -27,6 +27,7 @@ import static org.elasticsearch.compute.aggregation.AggregationName.sum; import static org.elasticsearch.compute.aggregation.AggregationType.agnostic; import static org.elasticsearch.compute.aggregation.AggregationType.doubles; +import static org.elasticsearch.compute.aggregation.AggregationType.ints; import static org.elasticsearch.compute.aggregation.AggregationType.longs; @Experimental @@ -70,6 +71,15 @@ static Factory of(AggregationName name, AggregationType type) { case count -> COUNT; default -> throw new IllegalArgumentException("unknown " + name + ", type:" + type); }; + case ints -> switch (name) { + case avg -> AVG_INTS; + case count -> COUNT; + case max -> MAX_INTS; + case median -> MEDIAN_INTS; + case median_absolute_deviation -> MEDIAN_ABSOLUTE_DEVIATION_INTS; + case min -> MIN_INTS; + case sum -> SUM_INTS; + }; case longs -> switch (name) { case avg -> AVG_LONGS; case count -> COUNT; @@ -93,30 +103,39 @@ static Factory of(AggregationName name, AggregationType type) { Factory AVG_DOUBLES = new Factory(avg, doubles, AvgDoubleGroupingAggregatorFunction::create); Factory AVG_LONGS = new Factory(avg, longs, AvgLongGroupingAggregatorFunction::create); + Factory AVG_INTS = new Factory(avg, ints, AvgIntGroupingAggregatorFunction::create); Factory COUNT = new Factory(count, agnostic, CountGroupingAggregatorFunction::create); Factory MIN_DOUBLES = new Factory(min, doubles, MinDoubleGroupingAggregatorFunction::create); Factory MIN_LONGS = new Factory(min, longs, MinLongGroupingAggregatorFunction::create); + Factory MIN_INTS = new Factory(min, ints, MinIntGroupingAggregatorFunction::create); Factory MAX_DOUBLES = new Factory(max, doubles, MaxDoubleGroupingAggregatorFunction::create); Factory MAX_LONGS = new Factory(max, longs, MaxLongGroupingAggregatorFunction::create); + Factory MAX_INTS = new Factory(max, ints, MaxIntGroupingAggregatorFunction::create); Factory MEDIAN_DOUBLES = new Factory(median, doubles, MedianDoubleGroupingAggregatorFunction::create); Factory MEDIAN_LONGS = new Factory(median, longs, MedianLongGroupingAggregatorFunction::create); + Factory MEDIAN_INTS = new Factory(median, ints, MedianIntGroupingAggregatorFunction::create); Factory MEDIAN_ABSOLUTE_DEVIATION_DOUBLES = new Factory( median_absolute_deviation, doubles, MedianAbsoluteDeviationDoubleGroupingAggregatorFunction::create ); - Factory MEDIAN_ABSOLUTE_DEVIATION_LONGS = new Factory( median_absolute_deviation, longs, MedianAbsoluteDeviationLongGroupingAggregatorFunction::create ); + Factory MEDIAN_ABSOLUTE_DEVIATION_INTS = new Factory( + median_absolute_deviation, + ints, + MedianAbsoluteDeviationIntGroupingAggregatorFunction::create + ); Factory SUM_DOUBLES = new Factory(sum, doubles, SumDoubleGroupingAggregatorFunction::create); Factory SUM_LONGS = new Factory(sum, longs, SumLongGroupingAggregatorFunction::create); + Factory SUM_INTS = new Factory(sum, ints, SumIntGroupingAggregatorFunction::create); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntArrayState.java new file mode 100644 index 0000000000000..35ed1ee63f3dd --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntArrayState.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.ann.Experimental; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasables; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.util.Objects; + +@Experimental +final class IntArrayState implements AggregatorState { + + private final BigArrays bigArrays; + + private final int initialDefaultValue; + + private IntArray values; + // total number of groups; <= values.length + int largestIndex; + + private BitArray nonNulls; + + private final IntArrayStateSerializer serializer; + + IntArrayState(BigArrays bigArrays, int initialDefaultValue) { + this.bigArrays = bigArrays; + this.values = bigArrays.newIntArray(1, false); + this.values.set(0, initialDefaultValue); + this.initialDefaultValue = initialDefaultValue; + this.serializer = new IntArrayStateSerializer(); + } + + int get(int index) { + // TODO bounds check + return values.get(index); + } + + void increment(int value, int index) { + ensureCapacity(index); + values.increment(index, value); + if (nonNulls != null) { + nonNulls.set(index); + } + } + + void set(int value, int index) { + ensureCapacity(index); + values.set(index, value); + if (nonNulls != null) { + nonNulls.set(index); + } + } + + void putNull(int index) { + ensureCapacity(index); + if (nonNulls == null) { + nonNulls = new BitArray(index + 1, bigArrays); + for (int i = 0; i < index; i++) { + nonNulls.set(i); // TODO: bulk API + } + } else { + nonNulls.ensureCapacity(index); + } + } + + boolean hasValue(int index) { + return nonNulls == null || nonNulls.get(index); + } + + Block toValuesBlock() { + final int positions = largestIndex + 1; + if (nonNulls == null) { + IntVector.Builder builder = IntVector.newVectorBuilder(positions); + for (int i = 0; i < positions; i++) { + builder.appendInt(values.get(i)); + } + return builder.build().asBlock(); + } else { + final IntBlock.Builder builder = IntBlock.newBlockBuilder(positions); + for (int i = 0; i < positions; i++) { + if (hasValue(i)) { + builder.appendInt(values.get(i)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + int getOrDefault(int index) { + return index <= largestIndex ? values.get(index) : initialDefaultValue; + } + + private void ensureCapacity(int position) { + if (position > largestIndex) { + largestIndex = position; + } + if (position >= values.size()) { + long prevSize = values.size(); + values = bigArrays.grow(values, position + 1); + values.fill(prevSize, values.size(), initialDefaultValue); + } + } + + @Override + public long getEstimatedSize() { + final long positions = largestIndex + 1L; + return Long.BYTES + (positions * Long.BYTES) + estimateSerializeSize(nonNulls); + } + + @Override + public void close() { + Releasables.close(values, nonNulls); + } + + @Override + public AggregatorStateSerializer serializer() { + return serializer; + } + + private static final VarHandle intHandle = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.BIG_ENDIAN); + private static final VarHandle longHandle = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); + + static int estimateSerializeSize(BitArray bits) { + if (bits == null) { + return Long.BYTES; + } else { + return Long.BYTES + Math.toIntExact(bits.getBits().size() * Long.BYTES); + } + } + + static int serializeBitArray(BitArray bits, byte[] ba, int offset) { + if (bits == null) { + intHandle.set(ba, offset, 0); + return Integer.BYTES; + } + final LongArray array = bits.getBits(); + intHandle.set(ba, offset, array.size()); + offset += Long.BYTES; + for (long i = 0; i < array.size(); i++) { + longHandle.set(ba, offset, array.get(i)); + } + return Integer.BYTES + Math.toIntExact(array.size() * Long.BYTES); + } + + static BitArray deseralizeBitArray(BigArrays bigArrays, byte[] ba, int offset) { + long size = (long) intHandle.get(ba, offset); + if (size == 0) { + return null; + } else { + offset += Integer.BYTES; + final LongArray array = bigArrays.newLongArray(size); + for (long i = 0; i < size; i++) { + array.set(i, (long) longHandle.get(ba, offset)); + } + return new BitArray(bigArrays, array); + } + } + + static class IntArrayStateSerializer implements AggregatorStateSerializer { + + static final int BYTES_SIZE = Integer.BYTES; + + @Override + public int size() { + return BYTES_SIZE; + } + + @Override + public int serialize(IntArrayState state, byte[] ba, int offset) { + int positions = state.largestIndex + 1; + intHandle.set(ba, offset, positions); + offset += Integer.BYTES; + for (int i = 0; i < positions; i++) { + intHandle.set(ba, offset, state.values.get(i)); + offset += BYTES_SIZE; + } + final int valuesBytes = Integer.BYTES + (BYTES_SIZE * positions) + Long.BYTES; + return valuesBytes + serializeBitArray(state.nonNulls, ba, offset); + } + + @Override + public void deserialize(IntArrayState state, byte[] ba, int offset) { + Objects.requireNonNull(state); + int positions = (int) intHandle.get(ba, offset); + offset += Integer.BYTES; + for (int i = 0; i < positions; i++) { + state.set((int) intHandle.get(ba, offset), i); + offset += BYTES_SIZE; + } + state.largestIndex = positions - 1; + state.nonNulls = deseralizeBitArray(state.bigArrays, ba, offset); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntState.java new file mode 100644 index 0000000000000..b77b4f1f24c8b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/IntState.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Experimental; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.util.Objects; + +@Experimental +final class IntState implements AggregatorState { + private int intValue; + + private final LongStateSerializer serializer; + + IntState() { + this(0); + } + + IntState(int value) { + this.intValue = value; + this.serializer = new LongStateSerializer(); + } + + int intValue() { + return intValue; + } + + void intValue(int value) { + this.intValue = value; + } + + @Override + public long getEstimatedSize() { + return Integer.BYTES; + } + + @Override + public void close() {} + + @Override + public AggregatorStateSerializer serializer() { + return serializer; + } + + static class LongStateSerializer implements AggregatorStateSerializer { + + static final int BYTES_SIZE = Integer.BYTES; + + @Override + public int size() { + return BYTES_SIZE; + } + + private static final VarHandle intHandle = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.BIG_ENDIAN); + + @Override + public int serialize(IntState state, byte[] ba, int offset) { + intHandle.set(ba, offset, state.intValue); + return BYTES_SIZE; // number of bytes written + } + + // sets the long value in the given state. + @Override + public void deserialize(IntState state, byte[] ba, int offset) { + Objects.requireNonNull(state); + state.intValue = (int) intHandle.get(ba, offset); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LongArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LongArrayState.java index 83ca6cda715f1..ef7294e284d9c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LongArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LongArrayState.java @@ -12,8 +12,8 @@ import org.elasticsearch.common.util.LongArray; import org.elasticsearch.compute.ann.Experimental; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.LongArrayVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.core.Releasables; import java.lang.invoke.MethodHandles; @@ -84,11 +84,11 @@ boolean hasValue(int index) { Block toValuesBlock() { final int positions = largestIndex + 1; if (nonNulls == null) { - final long[] vs = new long[positions]; + LongVector.Builder builder = LongVector.newVectorBuilder(positions); for (int i = 0; i < positions; i++) { - vs[i] = values.get(i); + builder.appendLong(values.get(i)); } - return new LongArrayVector(vs, positions).asBlock(); + return builder.build().asBlock(); } else { final LongBlock.Builder builder = LongBlock.newBlockBuilder(positions); for (int i = 0; i < positions; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIntAggregator.java new file mode 100644 index 0000000000000..88420e14df35c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIntAggregator.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; + +@Aggregator +@GroupingAggregator +class MaxIntAggregator { + public static int init() { + return Integer.MIN_VALUE; + } + + public static int combine(int current, int v) { + return Math.max(current, v); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java new file mode 100644 index 0000000000000..a745683c52aa0 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.data.Block; + +@Aggregator +@GroupingAggregator +class MedianAbsoluteDeviationIntAggregator { + public static QuantileStates.SingleState initSingle() { + return new QuantileStates.SingleState(); + } + + public static void combine(QuantileStates.SingleState current, int v) { + current.add(v); + } + + public static void combineStates(QuantileStates.SingleState current, QuantileStates.SingleState state) { + current.add(state); + } + + public static Block evaluateFinal(QuantileStates.SingleState state) { + return state.evaluateMedianAbsoluteDeviation(); + } + + public static QuantileStates.GroupingState initGrouping(BigArrays bigArrays) { + return new QuantileStates.GroupingState(bigArrays); + } + + public static void combine(QuantileStates.GroupingState state, int groupId, int v) { + state.add(groupId, v); + } + + public static void combineStates( + QuantileStates.GroupingState current, + int currentGroupId, + QuantileStates.GroupingState state, + int statePosition + ) { + current.add(currentGroupId, state.get(statePosition)); + } + + public static Block evaluateFinal(QuantileStates.GroupingState state) { + return state.evaluateMedianAbsoluteDeviation(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianIntAggregator.java new file mode 100644 index 0000000000000..3a55c2db4bc32 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianIntAggregator.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.data.Block; + +@Aggregator +@GroupingAggregator +class MedianIntAggregator { + public static QuantileStates.SingleState initSingle() { + return new QuantileStates.SingleState(); + } + + public static void combine(QuantileStates.SingleState current, int v) { + current.add(v); + } + + public static void combineStates(QuantileStates.SingleState current, QuantileStates.SingleState state) { + current.add(state); + } + + public static Block evaluateFinal(QuantileStates.SingleState state) { + return state.evaluateMedian(); + } + + public static QuantileStates.GroupingState initGrouping(BigArrays bigArrays) { + return new QuantileStates.GroupingState(bigArrays); + } + + public static void combine(QuantileStates.GroupingState state, int groupId, int v) { + state.add(groupId, v); + } + + public static void combineStates( + QuantileStates.GroupingState current, + int currentGroupId, + QuantileStates.GroupingState state, + int statePosition + ) { + current.add(currentGroupId, state.get(statePosition)); + } + + public static Block evaluateFinal(QuantileStates.GroupingState state) { + return state.evaluateMedian(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIntAggregator.java new file mode 100644 index 0000000000000..4215c7a9439b7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIntAggregator.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; + +@Aggregator +@GroupingAggregator +class MinIntAggregator { + public static int init() { + return Integer.MAX_VALUE; + } + + public static int combine(int current, int v) { + return Math.min(current, v); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumIntAggregator.java new file mode 100644 index 0000000000000..e32ae49c73df6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumIntAggregator.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; + +@Aggregator +@GroupingAggregator +class SumIntAggregator { + public static long init() { + return 0; + } + + public static long combine(long current, int v) { + return Math.addExact(current, v); + } + + public static void combineStates(LongState current, LongState state) { + current.longValue(Math.addExact(current.longValue(), state.longValue())); + } + + public static void combineStates(LongArrayState current, int groupId, LongArrayState state, int position) { + current.set(Math.addExact(current.getOrDefault(groupId), state.get(position)), groupId); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st index be02fd7c1db4a..701eb93d3c49b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st @@ -70,18 +70,6 @@ $endif$ return ElementType.$TYPE$; } -$if(int)$ - @Override - public LongBlock asLongBlock() { // copy rather than view, for now - final int positions = getPositionCount(); - long[] longValues = new long[positions]; - for (int i = 0; i < positions; i++) { - longValues[i] = values[i]; - } - return new LongArrayBlock(longValues, getPositionCount(), firstValueIndexes, nullsMask); - } -$endif$ - @Override public boolean equals(Object obj) { if (obj instanceof $Type$Block that) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st index ad0ee8be89e50..e8fa4890c1cb0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st @@ -45,10 +45,6 @@ $endif$ @Override $Type$Block filter(int... positions); -$if(int)$ - LongBlock asLongBlock(); -$endif$ - /** * Compares the given object with this block for equality. Returns {@code true} if and only if the * given object is a $Type$Block, and both blocks are {@link #equals($Type$Block, $Type$Block) equal}. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-FilterBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-FilterBlock.java.st index 8c80c0c803a63..844cddd31555d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-FilterBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-FilterBlock.java.st @@ -39,13 +39,6 @@ $else$ $endif$ } -$if(int)$ - @Override - public LongBlock asLongBlock() { - return new FilterLongBlock(block.asLongBlock(), positions); - } -$endif$ - @Override public ElementType elementType() { return ElementType.$TYPE$; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st index 4198825023e12..f86ee4296379b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st @@ -49,17 +49,6 @@ $endif$ return vector.elementType(); } -$if(int)$ - public LongBlock asLongBlock() { // copy rather than view, for now - final int positions = getPositionCount(); - long[] longValues = new long[positions]; - for (int i = 0; i < positions; i++) { - longValues[i] = vector.getInt(i); - } - return new LongArrayVector(longValues, getPositionCount()).asBlock(); - } -$endif$ - @Override public $Type$Block getRow(int position) { return filter(position); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..7a89a9c78a371 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntAggregatorFunctionTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class AvgIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(int size) { + int max = between(1, (int) Math.min(Integer.MAX_VALUE, Long.MAX_VALUE / size)); + return new SequenceIntBlockSourceOperator(LongStream.range(0, size).mapToInt(l -> between(-max, max))); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.AVG_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "avg of ints"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + long sum = input.stream() + .flatMapToLong( + b -> IntStream.range(0, b.getTotalValueCount()) + .filter(p -> false == b.isNull(p)) + .mapToLong(p -> (long) ((IntBlock) b).getInt(p)) + ) + .sum(); + long count = input.stream().flatMapToInt(b -> IntStream.range(0, b.getPositionCount()).filter(p -> false == b.isNull(p))).count(); + assertThat(((DoubleBlock) result).getDouble(0), equalTo(((double) sum) / count)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..2476f315c9da1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AvgIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class AvgIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.AVG_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "avg of ints"; + } + + @Override + protected SourceOperator simpleInput(int size) { + int max = between(1, (int) Math.min(Integer.MAX_VALUE, Long.MAX_VALUE / size)); + return new LongIntBlockSourceOperator( + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), between(-max, max))) + ); + } + + @Override + public void assertSimpleGroup(List input, Block result, int position, long group) { + long[] sum = new long[] { 0 }; + long[] count = new long[] { 0 }; + forEachGroupAndValue(input, (groups, groupOffset, values, valueOffset) -> { + if (groups.getLong(groupOffset) == group) { + sum[0] = Math.addExact(sum[0], ((IntBlock) values).getInt(valueOffset)); + count[0]++; + } + }); + assertThat(((DoubleBlock) result).getDouble(position), equalTo(((double) sum[0]) / count[0])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..584adaea3e892 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntAggregatorFunctionTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MaxIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(int size) { + return new SequenceIntBlockSourceOperator(IntStream.range(0, size).map(l -> randomInt())); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.MAX_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "max of ints"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + int max = input.stream() + .flatMapToInt( + b -> IntStream.range(0, b.getTotalValueCount()).filter(p -> false == b.isNull(p)).map(p -> ((IntBlock) b).getInt(p)) + ) + .max() + .getAsInt(); + assertThat(((IntBlock) result).getInt(0), equalTo(max)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..31a86af126a87 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MaxIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.MAX_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "max of ints"; + } + + @Override + protected SourceOperator simpleInput(int size) { + return new LongIntBlockSourceOperator(LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomInt()))); + } + + @Override + public void assertSimpleGroup(List input, Block result, int position, long group) { + int[] max = new int[] { Integer.MIN_VALUE }; + forEachGroupAndValue(input, (groups, groupOffset, values, valueOffset) -> { + if (groups.getLong(groupOffset) == group) { + max[0] = Math.max(max[0], ((IntBlock) values).getInt(valueOffset)); + } + }); + assertThat(((IntBlock) result).getInt(position), equalTo(max[0])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..8baf738df4e9e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregatorFunctionTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class MedianAbsoluteDeviationIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + + @Override + protected SourceOperator simpleInput(int end) { + List values = Arrays.asList(12, 125, 20, 20, 43, 60, 90); + Randomness.shuffle(values); + return new SequenceIntBlockSourceOperator(values); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.MEDIAN_ABSOLUTE_DEVIATION_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "median_absolute_deviation of ints"; + } + + @Override + protected void assertSimpleOutput(List input, Block result) { + assertThat(((DoubleBlock) result).getDouble(0), equalTo(23.0)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..116848b3739f1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class MedianAbsoluteDeviationIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + + @Override + protected SourceOperator simpleInput(int end) { + int[][] samples = new int[][] { + { 12, 125, 20, 20, 43, 60, 90 }, + { 1, 15, 20, 30, 40, 75, 1000 }, + { 2, 175, 20, 25 }, + { 5, 30, 30, 30, 43 }, + { 7, 15, 30 } }; + List> values = new ArrayList<>(); + for (int i = 0; i < samples.length; i++) { + List list = Arrays.stream(samples[i]).boxed().collect(Collectors.toList()); + Randomness.shuffle(list); + for (int v : list) { + values.add(Tuple.tuple((long) i, v)); + } + } + return new LongIntBlockSourceOperator(values); + } + + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.MEDIAN_ABSOLUTE_DEVIATION_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "median_absolute_deviation of ints"; + } + + @Override + protected void assertSimpleGroup(List input, Block result, int position, long group) { + int bucket = Math.toIntExact(group); + double[] expectedValues = new double[] { 23.0, 15, 11.5, 0.0, 8.0 }; + assertThat(bucket, allOf(greaterThanOrEqualTo(0), lessThanOrEqualTo(4))); + assertThat(((DoubleBlock) result).getDouble(position), equalTo(expectedValues[bucket])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..f3539ba5c8009 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntAggregatorFunctionTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class MedianIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + + @Override + protected SourceOperator simpleInput(int end) { + List values = Arrays.asList(12, 20, 20, 43, 60, 90, 125); + Randomness.shuffle(values); + return new SequenceIntBlockSourceOperator(values); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.MEDIAN_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "median of ints"; + } + + @Override + protected void assertSimpleOutput(List input, Block result) { + assertThat(((DoubleBlock) result).getDouble(0), equalTo(43.0)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..73c7f62257b6b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MedianIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class MedianIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + + @Override + protected SourceOperator simpleInput(int end) { + int[][] samples = new int[][] { + { 12, 20, 20, 43, 60, 90, 125 }, + { 1, 15, 20, 30, 40, 75, 1000 }, + { 2, 20, 25, 175 }, + { 5, 30, 30, 30, 43 }, + { 7, 15, 30 } }; + List> values = new ArrayList<>(); + for (int i = 0; i < samples.length; i++) { + for (int v : samples[i]) { + values.add(Tuple.tuple((long) i, v)); + } + } + Randomness.shuffle(values); + return new LongIntBlockSourceOperator(values); + } + + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.MEDIAN_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "median of ints"; + } + + @Override + protected void assertSimpleGroup(List input, Block result, int position, long group) { + int bucket = Math.toIntExact(group); + double[] expectedValues = new double[] { 43.0, 30, 22.5, 30, 15 }; + assertThat(bucket, allOf(greaterThanOrEqualTo(0), lessThanOrEqualTo(4))); + assertThat(((DoubleBlock) result).getDouble(position), equalTo(expectedValues[bucket])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..466e5094f9a4d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntAggregatorFunctionTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MinIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(int size) { + return new SequenceIntBlockSourceOperator(IntStream.range(0, size).map(l -> randomInt())); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.MIN_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "min of ints"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + int max = input.stream() + .flatMapToInt( + b -> IntStream.range(0, b.getTotalValueCount()).filter(p -> false == b.isNull(p)).map(p -> ((IntBlock) b).getInt(p)) + ) + .min() + .getAsInt(); + assertThat(((IntBlock) result).getInt(0), equalTo(max)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..44bd590d15de2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MinIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.MIN_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "min of ints"; + } + + @Override + protected SourceOperator simpleInput(int size) { + return new LongIntBlockSourceOperator(LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomInt()))); + } + + @Override + public void assertSimpleGroup(List input, Block result, int position, long group) { + int[] min = new int[] { Integer.MAX_VALUE }; + forEachGroupAndValue(input, (groups, groupOffset, values, valueOffset) -> { + if (groups.getLong(groupOffset) == group) { + min[0] = Math.min(min[0], ((IntBlock) values).getInt(valueOffset)); + } + }); + assertThat(((IntBlock) result).getInt(position), equalTo(min[0])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..957abb5919054 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntAggregatorFunctionTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleArrayVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class SumIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(int size) { + int max = between(1, (int) Math.min(Integer.MAX_VALUE, Long.MAX_VALUE / size)); + return new SequenceIntBlockSourceOperator(LongStream.range(0, size).mapToInt(l -> between(-max, max))); + } + + @Override + protected AggregatorFunction.Factory aggregatorFunction() { + return AggregatorFunction.SUM_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sum of ints"; + } + + @Override + protected void assertSimpleOutput(List input, Block result) { + long sum = input.stream() + .flatMapToLong( + b -> IntStream.range(0, b.getTotalValueCount()) + .filter(p -> false == b.isNull(p)) + .mapToLong(p -> (long) ((IntBlock) b).getInt(p)) + ) + .sum(); + assertThat(((LongBlock) result).getLong(0), equalTo(sum)); + } + + public void testRejectsDouble() { + try ( + Driver d = new Driver( + new CannedSourceOperator(Iterators.single(new Page(new DoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), + List.of(simple(nonBreakingBigArrays()).get()), + new PageConsumerOperator(page -> fail("shouldn't have made it this far")), + () -> {} + ) + ) { + expectThrows(Exception.class, d::run); // ### find a more specific exception type + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..116238db3ccdb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunctionTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class SumIntGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected GroupingAggregatorFunction.Factory aggregatorFunction() { + return GroupingAggregatorFunction.SUM_INTS; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sum of ints"; + } + + @Override + protected SourceOperator simpleInput(int size) { + int max = between(1, (int) Math.min(Integer.MAX_VALUE, Long.MAX_VALUE / size)); + return new LongIntBlockSourceOperator( + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), between(-max, max))) + ); + } + + @Override + protected void assertSimpleGroup(List input, Block result, int position, long group) { + long[] sum = new long[] { 0 }; + forEachGroupAndValue(input, (groups, groupOffset, values, valueOffset) -> { + if (groups.getLong(groupOffset) == group) { + sum[0] = Math.addExact(sum[0], (long) ((IntBlock) values).getInt(valueOffset)); + } + }); + assertThat(((LongBlock) result).getLong(position), equalTo(sum[0])); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LongIntBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LongIntBlockSourceOperator.java new file mode 100644 index 0000000000000..85ed36656675a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LongIntBlockSourceOperator.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.Stream; + +/** + * A source operator whose output is the given tuple values. This operator produces pages + * with two Blocks. The returned pages preserve the order of values as given in the in initial list. + */ +public class LongIntBlockSourceOperator extends AbstractBlockSourceOperator { + + private static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; + + private final List> values; + + public LongIntBlockSourceOperator(Stream> values) { + this(values, DEFAULT_MAX_PAGE_POSITIONS); + } + + public LongIntBlockSourceOperator(Stream> values, int maxPagePositions) { + super(maxPagePositions); + this.values = values.toList(); + } + + public LongIntBlockSourceOperator(List> values) { + this(values, DEFAULT_MAX_PAGE_POSITIONS); + } + + public LongIntBlockSourceOperator(List> values, int maxPagePositions) { + super(maxPagePositions); + this.values = values; + } + + @Override + protected Page createPage(int positionOffset, int length) { + var blockBuilder1 = LongBlock.newBlockBuilder(length); + var blockBuilder2 = IntBlock.newBlockBuilder(length); + for (int i = 0; i < length; i++) { + Tuple item = values.get(positionOffset + i); + if (item.v1() == null) { + blockBuilder1.appendNull(); + } else { + blockBuilder1.appendLong(item.v1()); + } + if (item.v2() == null) { + blockBuilder2.appendNull(); + } else { + blockBuilder2.appendInt(item.v2()); + } + } + currentPosition += length; + return new Page(blockBuilder1.build(), blockBuilder2.build()); + } + + @Override + protected int remaining() { + return values.size() - currentPosition; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/NullInsertingSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/NullInsertingSourceOperator.java index bcd7d8aafba0d..fea688cad782b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/NullInsertingSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/NullInsertingSourceOperator.java @@ -10,6 +10,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; @@ -37,6 +38,9 @@ protected Page map(Page page) { case LONG: builders[b] = LongBlock.newBlockBuilder(page.getPositionCount()); break; + case INT: + builders[b] = IntBlock.newBlockBuilder(page.getPositionCount()); + break; case DOUBLE: builders[b] = DoubleBlock.newBlockBuilder(page.getPositionCount()); break; @@ -88,6 +92,9 @@ private void copyValue(Block from, int valueIndex, Block.Builder into) { case LONG: ((LongBlock.Builder) into).appendLong(((LongBlock) from).getLong(valueIndex)); break; + case INT: + ((IntBlock.Builder) into).appendInt(((IntBlock) from).getInt(valueIndex)); + break; case DOUBLE: ((DoubleBlock.Builder) into).appendDouble(((DoubleBlock) from).getDouble(valueIndex)); break; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceIntBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceIntBlockSourceOperator.java new file mode 100644 index 0000000000000..7a28bca9052e2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceIntBlockSourceOperator.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * A source operator whose output is the given long values. This operator produces pages + * containing a single Block. The Block contains the long values from the given list, in order. + */ +public class SequenceIntBlockSourceOperator extends AbstractBlockSourceOperator { + + static final int DEFAULT_MAX_PAGE_POSITIONS = 8 * 1024; + + private final int[] values; + + public SequenceIntBlockSourceOperator(IntStream values) { + this(values, DEFAULT_MAX_PAGE_POSITIONS); + } + + public SequenceIntBlockSourceOperator(IntStream values, int maxPagePositions) { + super(maxPagePositions); + this.values = values.toArray(); + } + + public SequenceIntBlockSourceOperator(List values) { + this(values, DEFAULT_MAX_PAGE_POSITIONS); + } + + public SequenceIntBlockSourceOperator(List values, int maxPagePositions) { + super(maxPagePositions); + this.values = values.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + protected Page createPage(int positionOffset, int length) { + IntVector.Builder builder = IntVector.newVectorBuilder(length); + for (int i = 0; i < length; i++) { + builder.appendInt(values[positionOffset + i]); + } + currentPosition += length; + return new Page(builder.build().asBlock()); + } + + protected int remaining() { + return values.length - currentPosition; + } +} diff --git a/x-pack/plugin/esql/qa/server/src/main/resources/row.csv-spec b/x-pack/plugin/esql/qa/server/src/main/resources/row.csv-spec index d1ad48d8848b7..da33c2e5015cb 100644 --- a/x-pack/plugin/esql/qa/server/src/main/resources/row.csv-spec +++ b/x-pack/plugin/esql/qa/server/src/main/resources/row.csv-spec @@ -206,16 +206,16 @@ avg:double | min(x):integer | max(x):integer | count(x):long | avg(x):double | a rowWithMultipleStatsOverNull row x=1, y=2 | eval tot = null + y + x | stats c=count(tot), a=avg(tot), mi=min(tot), ma=max(tot), s=sum(tot); -c:long | a:double | mi:integer | ma:integer | s:long - 0 | NaN | 9223372036854775807 | -9223372036854775808 | 0 +c:long | a:double | mi:integer | ma:integer | s:long + 0 | NaN | 2147483647 | -2147483648 | 0 ; min row l=1, d=1.0, ln=1 + null, dn=1.0 + null | stats min(l), min(d), min(ln), min(dn); -min(l):integer | min(d):double | min(ln):integer | min(dn):double - 1 | 1.0 | 9223372036854775807 | Infinity +min(l):integer | min(d):double | min(ln):integer | min(dn):double + 1 | 1.0 | 2147483647 | Infinity ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 8e98f3fc74b77..a582d31ca35e1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -10,6 +10,7 @@ import org.elasticsearch.compute.aggregation.AggregationName; import org.elasticsearch.compute.aggregation.AggregationType; import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.ql.type.DataTypes; import java.util.Locale; @@ -19,7 +20,17 @@ class AggregateMapper { static AggregationType mapToType(AggregateFunction aggregateFunction) { - return aggregateFunction.field().dataType().isRational() ? AggregationType.doubles : AggregationType.longs; + if (aggregateFunction.field().dataType() == DataTypes.LONG) { + return AggregationType.longs; + } + if (aggregateFunction.field().dataType() == DataTypes.INTEGER) { + return AggregationType.ints; + } + if (aggregateFunction.field().dataType() == DataTypes.DOUBLE) { + return AggregationType.doubles; + } + // agnostic here means "only works if the aggregation doesn't care about type". + return AggregationType.agnostic; } static AggregationName mapToName(AggregateFunction aggregateFunction) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index 1f89c702f415c..918a7afd2f849 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -12,9 +12,7 @@ import org.elasticsearch.compute.aggregation.BlockHash; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; @@ -142,7 +140,7 @@ private class TestFieldExtractOperator implements Operator { @Override public void addInput(Page page) { - Block block = maybeConvertToLongBlock(extractBlockForColumn(page, columnName)); + Block block = extractBlockForColumn(page, columnName); lastPage = page.appendBlock(block); } @@ -256,22 +254,6 @@ public String describe() { } } - private Block maybeConvertToLongBlock(Block block) { - int positionCount = block.getPositionCount(); - if (block.elementType() == ElementType.INT) { - LongBlock.Builder builder = LongBlock.newBlockBuilder(positionCount); - for (int i = 0; i < positionCount; i++) { - if (block.isNull(i)) { - builder.appendNull(); - } else { - builder.appendLong(((IntBlock) block).getInt(i)); - } - } - return builder.build(); - } - return block; - } - private Block extractBlockForColumn(Page page, String columnName) { var columnIndex = -1; var i = 0; diff --git a/x-pack/plugin/esql/src/test/resources/project.csv-spec b/x-pack/plugin/esql/src/test/resources/project.csv-spec index 63e83f38a631d..7a1aa79ebb000 100644 --- a/x-pack/plugin/esql/src/test/resources/project.csv-spec +++ b/x-pack/plugin/esql/src/test/resources/project.csv-spec @@ -2,7 +2,7 @@ projectFrom from test | project languages, emp_no, first_name, last_name | limit 10; -languages:long | emp_no:long | first_name:keyword | last_name:keyword +languages:integer | emp_no:integer | first_name:keyword | last_name:keyword 2 | 10001 | Georgi | Facello 5 | 10002 | Bezalel | Simmel 4 | 10003 | Parto | Bamford @@ -18,7 +18,7 @@ languages:long | emp_no:long | first_name:keyword | last_name:keyword projectFromWithFilter from test | project languages, emp_no, first_name, last_name | eval x = emp_no + 10 | where x > 10040 and x < 10050 | limit 5; -languages:long | emp_no:long | first_name:keyword | last_name:keyword | x:integer +languages:integer | emp_no:integer | first_name:keyword | last_name:keyword | x:integer 4 | 10031 | null | Joslin | 10041 3 | 10032 | null | Reistad | 10042 1 | 10033 | null | Merlo | 10043 @@ -158,7 +158,7 @@ med:double | languages:long multiConditionalWhere from test | eval abc = 1+2 | where (abc + emp_no > 10100 or languages == 1) or (abc + emp_no < 10005 and gender == "F") | project emp_no, languages, gender, first_name, abc; -emp_no:long | languages:long | gender:keyword | first_name:keyword | abc:integer +emp_no:integer | languages:integer | gender:keyword | first_name:keyword | abc:integer 10005 | 1 | M | Kyoichi | 3 10009 | 1 | F | Sumant | 3 10013 | 1 | null | Eberhardt | 3 @@ -182,7 +182,7 @@ emp_no:long | languages:long | gender:keyword | first_name:keyword | abc:integer projectFromWithFilterPushedToES from test | project languages, emp_no, first_name, last_name, x = emp_no | where emp_no > 10030 and x < 10040 | limit 5; -languages:long | emp_no:long | first_name:keyword | last_name:keyword | x:long +languages:integer | emp_no:integer | first_name:keyword | last_name:keyword | x:integer 4 | 10031 | null | Joslin | 10031 3 | 10032 | null | Reistad | 10032 1 | 10033 | null | Merlo | 10033 @@ -223,7 +223,7 @@ emp_no:long | languages:long | first_name:keyword | last_name:keyword sortWithLimitOne from test | sort languages | limit 1; -avg_worked_seconds:long | emp_no:long | first_name:keyword | gender:keyword | height:double | languages:long | languages.long:long | last_name:keyword | salary:long | still_hired:keyword +avg_worked_seconds:long | emp_no:integer | first_name:keyword | gender:keyword | height:double | languages:integer | languages.long:long | last_name:keyword | salary:integer | still_hired:keyword 244294991 | 10005 | Kyoichi | M | 2.05 | 1 | 1 | Maliniak | 63528 | true ; @@ -252,7 +252,7 @@ height:double | languages.long:long | still_hired:keyword simpleEvalWithSortAndLimitOne from test | eval x = languages + 7 | sort x | limit 1; -avg_worked_seconds:long | emp_no:long | first_name:keyword | gender:keyword | height:double | languages:long | languages.long:long | last_name:keyword | salary:long | still_hired:keyword | x:integer +avg_worked_seconds:long | emp_no:integer | first_name:keyword | gender:keyword | height:double | languages:integer | languages.long:long | last_name:keyword | salary:integer | still_hired:keyword | x:integer 244294991 | 10005 | Kyoichi | M | 2.05 | 1 | 1 | Maliniak | 63528 | true | 8 ; @@ -273,7 +273,7 @@ avg(ratio):double simpleWhere from test | where salary > 70000 | project first_name, last_name, salary; -first_name:keyword | last_name:keyword | salary:long +first_name:keyword | last_name:keyword | salary:integer Tzvetan | Zielinski | 74572 Lillian | Haddadi | 73717 Divier | Reistad | 73851 @@ -287,7 +287,7 @@ Valter | Sullins | 73578 whereAfterProject from test | project salary | where salary > 70000; -salary:long +salary:integer 74572 73717 73851 @@ -301,7 +301,7 @@ salary:long whereWithEvalGeneratedValue from test | eval x = salary / 2 | where x > 37000; -avg_worked_seconds:long | emp_no:long | first_name:keyword | gender:keyword | height:double | languages:long | languages.long:long | last_name:keyword | salary:long | still_hired:keyword | x:integer +avg_worked_seconds:long | emp_no:integer | first_name:keyword | gender:keyword | height:double | languages:integer | languages.long:long | last_name:keyword | salary:integer | still_hired:keyword | x:integer 393084805 | 10007 | Tzvetan | F | 1.7 | 4 | 4 | Zielinski | 74572 | true | 37286 257694181 | 10029 | Otmar | M | 1.99 | null | null | Herbst | 74999 | false | 37499 371418933 | 10045 | Moss | M | 1.7 | 3 | 3 | Shanbhogue | 74970 | false | 37485 @@ -380,7 +380,7 @@ count(height):long | h1:double whereNegatedCondition from test | eval abc=1+2 | where abc + languages > 4 and languages.long != 1 | eval x=abc+languages | project x, languages, languages.long | limit 3; -x:integer | languages:long | languages.long:long +x:integer | languages:integer | languages.long:long 5 | 2 | 2 8 | 5 | 5 7 | 4 | 4 @@ -400,27 +400,25 @@ languages.long:long | last_name:keyword | languages:integer projectRename from test | project x = languages, y = languages | limit 3; -x:long | y:long +x:integer | y:integer 2 | 2 5 | 5 4 | 4 ; projectRenameEval -// TODO why are x2 and y2 ints if x and y are longs? And why are x and y longs? from test | project x = languages, y = languages | eval x2 = x + 1 | eval y2 = y + 2 | limit 3; -x:long | y:long | x2:integer | y2:integer +x:integer | y:integer | x2:integer | y2:integer 2 | 2 | 3 | 4 5 | 5 | 6 | 7 4 | 4 | 5 | 6 ; projectRenameEvalProject -// x and y should be integers but they are longs from test | project x = languages, y = languages | eval z = x + y | project x, y, z | limit 3; -x:long | y:long | z:integer +x:integer | y:integer | z:integer 2 | 2 | 4 5 | 5 | 10 4 | 4 | 8 @@ -429,7 +427,7 @@ x:long | y:long | z:integer projectOverride from test | project languages, first_name = languages | limit 3; -languages:long | first_name:long +languages:integer | first_name:integer 2 | 2 5 | 5 4 | 4 @@ -438,7 +436,7 @@ languages:long | first_name:long evalWithNull from test | eval nullsum = salary + null | sort nullsum asc, salary desc | project nullsum, salary | limit 1; -nullsum:integer | salary:long +nullsum:integer | salary:integer null | 74999 ; @@ -467,21 +465,21 @@ Bezalel projectAfterTopN from test | sort salary | limit 1 | project first_name, salary; -first_name:keyword | salary:long +first_name:keyword | salary:integer Guoxiang | 25324 ; projectAfterTopNDesc from test | sort salary desc | limit 1 | project first_name, salary; -first_name:keyword | salary:long +first_name:keyword | salary:integer Otmar | 74999 ; topNProjectEval from test | sort salary | limit 1 | project languages, salary | eval x = languages + 1; -languages:long | salary:long | x:integer +languages:integer | salary:integer | x:integer 5 | 25324 | 6 ; diff --git a/x-pack/plugin/esql/src/test/resources/stats.csv-spec b/x-pack/plugin/esql/src/test/resources/stats.csv-spec new file mode 100644 index 0000000000000..5534277e3c846 --- /dev/null +++ b/x-pack/plugin/esql/src/test/resources/stats.csv-spec @@ -0,0 +1,62 @@ +maxOfLong +from test | stats l = max(languages.long); + +l:long +5 +; + +maxOfInteger +from test | stats l = max(languages); + +l:integer +5 +; + +maxOfDouble +from test | stats h = max(height); + +h:double +2.1 +; + +avgOfLong +from test | stats l = avg(languages.long); + +l:double +3.1222222222222222 +; + +avgOfInteger +from test | stats l = avg(languages); + +l:double +3.1222222222222222 +; + +avgOfDouble +from test | stats h = avg(height); + +h:double +1.7682 +; + +sumOfLong +from test | stats l = sum(languages.long); + +l:long +281 +; + +sumOfInteger +from test | stats l = sum(languages); + +l:long +281 +; + +sumOfDouble +from test | stats h = sum(height); + +h:double +176.82 +;