Skip to content

Commit

Permalink
Implement native int aggregations (ESQL-701)
Browse files Browse the repository at this point in the history
This implements "native" `int` flavored aggregations for all existing
aggregations. A few of them are just "promoted" to another type before
being passed along to the aggregation infrastructure for another type.
But that seems fine. Here are the types:
```
`avg(int)`    -> `sum(long) / count` -> `double`
`count(int)`  -> `long`
`max(int)`    -> `int`
`median(int)` -> `median(double)`    -> `double`
`median_absolute_deviation(int)`     -> `mad(double)` -> double
`min(int)`    -> `int`
`sum(int)`    -> `sum(long)`         -> `long`
```

This also removes the "funny" cast in the CSV tests which promotes
`int`s to `long`s because it was confusing and got in the way of testing
the `int` flavored versions of these methods.
  • Loading branch information
nik9000 authored Feb 2, 2023
1 parent 80f5fdc commit d929acd
Show file tree
Hide file tree
Showing 65 changed files with 3,010 additions and 230 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ testfixtures_shared/

# Generated
checkstyle_ide.xml
x-pack/plugin/esql/gen/
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.DOUBLE_ARRAY_VECTOR;
import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK;
import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.INT_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_ARRAY_VECTOR;
import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
import static org.elasticsearch.compute.gen.Types.PAGE;
Expand Down Expand Up @@ -91,30 +90,39 @@ private TypeName choseStateType() {
return ClassName.get("org.elasticsearch.compute.aggregation", firstUpper(initReturn.toString()) + "State");
}

private String primitiveType() {
String initReturn = declarationType.toString().toLowerCase(Locale.ROOT);
if (initReturn.contains("double")) {
return "double";
} else if (initReturn.contains("long")) {
return "long";
} else {
throw new IllegalArgumentException("unknown primitive type for " + initReturn);
static String primitiveType(ExecutableElement init, ExecutableElement combine) {
if (combine != null) {
// If there's an explicit combine function it's final parameter is the type of the value.
return combine.getParameters().get(combine.getParameters().size() - 1).asType().toString();
}
String initReturn = init.getReturnType().toString();
switch (initReturn) {
case "double":
return "double";
case "long":
return "long";
case "int":
return "int";
default:
throw new IllegalArgumentException("unknown primitive type for " + initReturn);
}
}

private ClassName valueBlockType() {
return switch (primitiveType()) {
static ClassName valueBlockType(ExecutableElement init, ExecutableElement combine) {
return switch (primitiveType(init, combine)) {
case "double" -> DOUBLE_BLOCK;
case "long" -> LONG_BLOCK;
default -> throw new IllegalArgumentException("unknown block type for " + primitiveType());
case "int" -> INT_BLOCK;
default -> throw new IllegalArgumentException("unknown block type for " + primitiveType(init, combine));
};
}

private ClassName valueVectorType() {
return switch (primitiveType()) {
static ClassName valueVectorType(ExecutableElement init, ExecutableElement combine) {
return switch (primitiveType(init, combine)) {
case "double" -> DOUBLE_VECTOR;
case "long" -> LONG_VECTOR;
default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType());
case "int" -> INT_VECTOR;
default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType(init, combine));
};
}

Expand Down Expand Up @@ -187,23 +195,16 @@ private MethodSpec addRawInput() {
builder.addStatement("assert channel >= 0");
builder.addStatement("$T type = page.getBlock(channel).elementType()", ELEMENT_TYPE);
builder.beginControlFlow("if (type == $T.NULL)", ELEMENT_TYPE).addStatement("return").endControlFlow();
if (primitiveType().equals("double")) {
builder.addStatement("$T block = page.getBlock(channel)", valueBlockType());
} else { // long
builder.addStatement("$T block", valueBlockType());
builder.beginControlFlow("if (type == $T.INT)", ELEMENT_TYPE) // explicit cast, for now
.addStatement("block = page.<$T>getBlock(channel).asLongBlock()", INT_BLOCK);
builder.nextControlFlow("else").addStatement("block = page.getBlock(channel)").endControlFlow();
}
builder.addStatement("$T vector = block.asVector()", valueVectorType());
builder.addStatement("$T block = page.getBlock(channel)", valueBlockType(init, combine));
builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine));
builder.beginControlFlow("if (vector != null)").addStatement("addRawVector(vector)");
builder.nextControlFlow("else").addStatement("addRawBlock(block)").endControlFlow();
return builder.build();
}

private MethodSpec addRawVector() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector");
builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(), "vector");
builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(init, combine), "vector");
builder.beginControlFlow("for (int i = 0; i < vector.getPositionCount(); i++)");
{
combineRawInput(builder, "vector");
Expand All @@ -217,7 +218,7 @@ private MethodSpec addRawVector() {

private MethodSpec addRawBlock() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock");
builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(), "block");
builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(init, combine), "block");
builder.beginControlFlow("for (int i = 0; i < block.getTotalValueCount(); i++)");
{
builder.beginControlFlow("if (block.isNull(i) == false)");
Expand Down Expand Up @@ -296,6 +297,8 @@ private void combineStates(MethodSpec.Builder builder) {

private String primitiveStateMethod() {
switch (stateType.toString()) {
case "org.elasticsearch.compute.aggregation.IntState":
return "intValue";
case "org.elasticsearch.compute.aggregation.LongState":
return "longValue";
case "org.elasticsearch.compute.aggregation.DoubleState":
Expand Down Expand Up @@ -339,11 +342,14 @@ private MethodSpec evaluateFinal() {

private void primitiveStateToResult(MethodSpec.Builder builder) {
switch (stateType.toString()) {
case "org.elasticsearch.compute.aggregation.IntState":
builder.addStatement("return $T.newConstantBlockWith(state.intValue(), 1)", INT_BLOCK);
return;
case "org.elasticsearch.compute.aggregation.LongState":
builder.addStatement("return new $T(new long[] { state.longValue() }, 1).asBlock()", LONG_ARRAY_VECTOR);
builder.addStatement("return $T.newConstantBlockWith(state.longValue(), 1)", LONG_BLOCK);
return;
case "org.elasticsearch.compute.aggregation.DoubleState":
builder.addStatement("return new $T(new double[] { state.doubleValue() }, 1).asBlock()", DOUBLE_ARRAY_VECTOR);
builder.addStatement("return $T.newConstantBlockWith(state.doubleValue(), 1)", DOUBLE_BLOCK);
return;
default:
throw new IllegalArgumentException("don't know how to convert state to result: " + stateType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import javax.lang.model.element.TypeElement;
import javax.lang.model.util.Elements;

import static org.elasticsearch.compute.gen.AggregatorImplementer.valueBlockType;
import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType;
import static org.elasticsearch.compute.gen.Methods.findMethod;
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK;
import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR;
import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
Expand Down Expand Up @@ -88,33 +88,6 @@ private TypeName choseStateType() {
return ClassName.get("org.elasticsearch.compute.aggregation", head + tail + "ArrayState");
}

private String primitiveType() {
String initReturn = declarationType.toString().toLowerCase(Locale.ROOT);
if (initReturn.contains("double")) {
return "double";
} else if (initReturn.contains("long")) {
return "long";
} else {
throw new IllegalArgumentException("unknown primitive type for " + initReturn);
}
}

private ClassName valueBlockType() {
return switch (primitiveType()) {
case "double" -> DOUBLE_BLOCK;
case "long" -> LONG_BLOCK;
default -> throw new IllegalArgumentException("unknown block type for " + primitiveType());
};
}

private ClassName valueVectorType() {
return switch (primitiveType()) {
case "double" -> DOUBLE_VECTOR;
case "long" -> LONG_VECTOR;
default -> throw new IllegalArgumentException("unknown vector type for " + primitiveType());
};
}

public JavaFile sourceFile() {
JavaFile.Builder builder = JavaFile.builder(implementation.packageName(), type());
builder.addFileComment("""
Expand Down Expand Up @@ -179,8 +152,8 @@ private MethodSpec addRawInputVector() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput");
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
builder.addParameter(LONG_VECTOR, "groups").addParameter(PAGE, "page");
builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType());
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType());
builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType(init, combine));
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
builder.beginControlFlow("if (valuesVector != null)");
{
builder.addStatement("int positions = groups.getPositionCount()");
Expand All @@ -203,7 +176,7 @@ private MethodSpec addRawInputVector() {
private MethodSpec addRawInputWithBlockValues() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInputWithBlockValues");
builder.addModifiers(Modifier.PRIVATE);
builder.addParameter(LONG_VECTOR, "groups").addParameter(valueBlockType(), "valuesBlock");
builder.addParameter(LONG_VECTOR, "groups").addParameter(valueBlockType(init, combine), "valuesBlock");
builder.addStatement("int positions = groups.getPositionCount()");
builder.beginControlFlow("for (int position = 0; position < groups.getPositionCount(); position++)");
{
Expand All @@ -227,8 +200,8 @@ private MethodSpec addRawInputBlock() {
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
builder.addParameter(LONG_BLOCK, "groups").addParameter(PAGE, "page");
builder.addStatement("assert channel >= 0");
builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType());
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType());
builder.addStatement("$T valuesBlock = page.getBlock(channel)", valueBlockType(init, combine));
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
builder.addStatement("int positions = groups.getPositionCount()");
builder.beginControlFlow("if (valuesVector != null)");
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ public class Types {
static final ClassName AGGREGATOR_STATE_VECTOR = ClassName.get(DATA_PACKAGE, "AggregatorStateVector");
static final ClassName AGGREGATOR_STATE_VECTOR_BUILDER = ClassName.get(DATA_PACKAGE, "AggregatorStateVector", "Builder");

static final ClassName INT_VECTOR = ClassName.get(DATA_PACKAGE, "IntVector");
static final ClassName LONG_VECTOR = ClassName.get(DATA_PACKAGE, "LongVector");
static final ClassName LONG_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "LongArrayVector");
static final ClassName DOUBLE_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleVector");
static final ClassName DOUBLE_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleArrayVector");

static final ClassName AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunction");
static final ClassName GROUPING_AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "GroupingAggregatorFunction");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ public int getInt(int valueIndex) {
return block.getInt(mapPosition(valueIndex));
}

@Override
public LongBlock asLongBlock() {
return new FilterLongBlock(block.asLongBlock(), positions);
}

@Override
public ElementType elementType() {
return ElementType.INT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ public ElementType elementType() {
return ElementType.INT;
}

@Override
public LongBlock asLongBlock() { // copy rather than view, for now
final int positions = getPositionCount();
long[] longValues = new long[positions];
for (int i = 0; i < positions; i++) {
longValues[i] = values[i];
}
return new LongArrayBlock(longValues, getPositionCount(), firstValueIndexes, nullsMask);
}

@Override
public boolean equals(Object obj) {
if (obj instanceof IntBlock that) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public sealed interface IntBlock extends Block permits FilterIntBlock,IntArrayBl
@Override
IntBlock filter(int... positions);

LongBlock asLongBlock();

/**
* Compares the given object with this block for equality. Returns {@code true} if and only if the
* given object is a IntBlock, and both blocks are {@link #equals(IntBlock, IntBlock) equal}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,6 @@ public ElementType elementType() {
return vector.elementType();
}

public LongBlock asLongBlock() { // copy rather than view, for now
final int positions = getPositionCount();
long[] longValues = new long[positions];
for (int i = 0; i < positions; i++) {
longValues[i] = vector.getInt(i);
}
return new LongArrayVector(longValues, getPositionCount()).asBlock();
}

@Override
public IntBlock getRow(int position) {
return filter(position);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License
// 2.0; you may not use this file except in compliance with the Elastic License
// 2.0.
package org.elasticsearch.compute.aggregation;

import java.lang.Override;
import java.lang.String;
import java.lang.StringBuilder;
import org.elasticsearch.compute.data.AggregatorStateVector;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;

/**
* {@link AggregatorFunction} implementation for {@link AvgIntAggregator}.
* This class is generated. Do not edit it.
*/
public final class AvgIntAggregatorFunction implements AggregatorFunction {
private final AvgLongAggregator.AvgState state;

private final int channel;

public AvgIntAggregatorFunction(int channel, AvgLongAggregator.AvgState state) {
this.channel = channel;
this.state = state;
}

public static AvgIntAggregatorFunction create(int channel) {
return new AvgIntAggregatorFunction(channel, AvgIntAggregator.initSingle());
}

@Override
public void addRawInput(Page page) {
assert channel >= 0;
ElementType type = page.getBlock(channel).elementType();
if (type == ElementType.NULL) {
return;
}
IntBlock block = page.getBlock(channel);
IntVector vector = block.asVector();
if (vector != null) {
addRawVector(vector);
} else {
addRawBlock(block);
}
}

private void addRawVector(IntVector vector) {
for (int i = 0; i < vector.getPositionCount(); i++) {
AvgIntAggregator.combine(state, vector.getInt(i));
}
AvgIntAggregator.combineValueCount(state, vector.getPositionCount());
}

private void addRawBlock(IntBlock block) {
for (int i = 0; i < block.getTotalValueCount(); i++) {
if (block.isNull(i) == false) {
AvgIntAggregator.combine(state, block.getInt(i));
}
}
AvgIntAggregator.combineValueCount(state, block.validPositionCount());
}

@Override
public void addIntermediateInput(Block block) {
assert channel == -1;
Vector vector = block.asVector();
if (vector == null || vector instanceof AggregatorStateVector == false) {
throw new RuntimeException("expected AggregatorStateBlock, got:" + block);
}
@SuppressWarnings("unchecked") AggregatorStateVector<AvgLongAggregator.AvgState> blobVector = (AggregatorStateVector<AvgLongAggregator.AvgState>) vector;
AvgLongAggregator.AvgState tmpState = new AvgLongAggregator.AvgState();
for (int i = 0; i < block.getPositionCount(); i++) {
blobVector.get(i, tmpState);
AvgIntAggregator.combineStates(state, tmpState);
}
}

@Override
public Block evaluateIntermediate() {
AggregatorStateVector.Builder<AggregatorStateVector<AvgLongAggregator.AvgState>, AvgLongAggregator.AvgState> builder =
AggregatorStateVector.builderOfAggregatorState(AvgLongAggregator.AvgState.class, state.getEstimatedSize());
builder.add(state);
return builder.build().asBlock();
}

@Override
public Block evaluateFinal() {
return AvgIntAggregator.evaluateFinal(state);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(getClass().getSimpleName()).append("[");
sb.append("channel=").append(channel);
sb.append("]");
return sb.toString();
}
}
Loading

0 comments on commit d929acd

Please sign in to comment.