diff --git a/.gitattributes b/.gitattributes index 6a8de5462ec3f..04881c92ede00 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,6 +4,8 @@ CHANGELOG.asciidoc merge=union # Windows build-tools-internal/src/test/resources/org/elasticsearch/gradle/internal/release/*.asciidoc text eol=lf +x-pack/plugin/esql/compute/src/main/generated/** linguist-generated=true +x-pack/plugin/esql/compute/src/main/generated-src/** linguist-generated=true x-pack/plugin/esql/src/main/antlr/*.tokens linguist-generated=true x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/*.interp linguist-generated=true x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer*.java linguist-generated=true diff --git a/docs/changelog/109386.yaml b/docs/changelog/109386.yaml new file mode 100644 index 0000000000000..984ee96dde063 --- /dev/null +++ b/docs/changelog/109386.yaml @@ -0,0 +1,6 @@ +pr: 109386 +summary: "ESQL: `top_list` aggregation" +area: ES|QL +type: feature +issues: + - 109213 diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java index 7e92fc5c2734e..0216ea07e5c7c 100644 --- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java @@ -12,6 +12,10 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +/** + * Annotates a class that implements an aggregation function with grouping. + * See {@link Aggregator} for more information. + */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.SOURCE) public @interface GroupingAggregator { diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 635a53d1ac98a..bc206ee1d78d6 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -36,10 +36,11 @@ spotless { } } -def prop(Type, type, TYPE, BYTES, Array, Hash) { +def prop(Type, type, Wrapper, TYPE, BYTES, Array, Hash) { return [ "Type" : Type, "type" : type, + "Wrapper": Wrapper, "TYPE" : TYPE, "BYTES" : BYTES, "Array" : Array, @@ -55,12 +56,13 @@ def prop(Type, type, TYPE, BYTES, Array, Hash) { } tasks.named('stringTemplates').configure { - var intProperties = prop("Int", "int", "INT", "Integer.BYTES", "IntArray", "LongHash") - var floatProperties = prop("Float", "float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash") - var longProperties = prop("Long", "long", "LONG", "Long.BYTES", "LongArray", "LongHash") - var doubleProperties = prop("Double", "double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash") - var bytesRefProperties = prop("BytesRef", "BytesRef", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash") - var booleanProperties = prop("Boolean", "boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "") + var intProperties = prop("Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash") + var floatProperties = prop("Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash") + var longProperties = prop("Long", "long", "Long", "LONG", "Long.BYTES", "LongArray", "LongHash") + var doubleProperties = prop("Double", "double", "Double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash") + var bytesRefProperties = prop("BytesRef", "BytesRef", "", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash") + var booleanProperties = prop("Boolean", "boolean", "Boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "") + // primitive vectors File vectorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st") template { @@ -500,6 +502,24 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java" } + + File topListAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st") + template { + it.properties = intProperties + it.inputFile = topListAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/TopListIntAggregator.java" + } + template { + it.properties = longProperties + it.inputFile = topListAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/TopListLongAggregator.java" + } + template { + it.properties = doubleProperties + it.inputFile = topListAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java" + } + File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st") template { it.properties = intProperties @@ -635,4 +655,21 @@ tasks.named('stringTemplates').configure { it.inputFile = resultBuilderInputFile it.outputFile = "org/elasticsearch/compute/operator/topn/ResultBuilderForFloat.java" } + + File bucketedSortInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st") + template { + it.properties = intProperties + it.inputFile = bucketedSortInputFile + it.outputFile = "org/elasticsearch/compute/data/sort/IntBucketedSort.java" + } + template { + it.properties = longProperties + it.inputFile = bucketedSortInputFile + it.outputFile = "org/elasticsearch/compute/data/sort/LongBucketedSort.java" + } + template { + it.properties = doubleProperties + it.inputFile = bucketedSortInputFile + it.outputFile = "org/elasticsearch/compute/data/sort/DoubleBucketedSort.java" + } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java new file mode 100644 index 0000000000000..941722b4424d3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.DoubleBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Aggregates the top N field values for double. + */ +@Aggregator({ @IntermediateState(name = "topList", type = "DOUBLE_BLOCK") }) +@GroupingAggregator +class TopListDoubleAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, double v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, DoubleBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getDouble(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, double v) { + state.add(groupId, v); + } + + public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getDouble(i)); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory(), selected); + } + + public static class GroupingState implements Releasable { + private final DoubleBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new DoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, double value) { + sort.collect(value, groupId); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements Releasable { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(double value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java new file mode 100644 index 0000000000000..dafbf1c2a3051 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.IntBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Aggregates the top N field values for int. + */ +@Aggregator({ @IntermediateState(name = "topList", type = "INT_BLOCK") }) +@GroupingAggregator +class TopListIntAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, int v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, IntBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getInt(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, int v) { + state.add(groupId, v); + } + + public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getInt(i)); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory(), selected); + } + + public static class GroupingState implements Releasable { + private final IntBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new IntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, int value) { + sort.collect(value, groupId); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements Releasable { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(int value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java new file mode 100644 index 0000000000000..c0e7122a4be0b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.sort.LongBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Aggregates the top N field values for long. + */ +@Aggregator({ @IntermediateState(name = "topList", type = "LONG_BLOCK") }) +@GroupingAggregator +class TopListLongAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, long v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, LongBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getLong(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, long v) { + state.add(groupId, v); + } + + public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getLong(i)); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory(), selected); + } + + public static class GroupingState implements Releasable { + private final LongBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new LongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, long value) { + sort.collect(value, groupId); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements Releasable { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(long value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java new file mode 100644 index 0000000000000..63318a2189908 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.Arrays; +import java.util.stream.IntStream; + +/** + * Aggregates the top N double values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class DoubleBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ * + */ + private DoubleArray values; + + public DoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newDoubleArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(double value, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex))) { + values.set(rootIndex, value); + downHeap(rootIndex, 0); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(requiredSize); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, DoubleBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + return size > 0; + })) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + + // Used to sort the values in the bucket. + var bucketValues = new double[bucketSize]; + + try (var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendDouble(values.get(bounds.v1())); + continue; + } + + for (int i = 0; i < size; i++) { + bucketValues[i] = values.get(bounds.v1() + i); + } + + // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting + Arrays.sort(bucketValues, 0, (int) size); + + builder.beginPositionEntry(); + if (order == SortOrder.ASC) { + for (int i = 0; i < size; i++) { + builder.appendDouble(bucketValues[i]); + } + } else { + for (int i = (int) size - 1; i >= 0; i--) { + builder.appendDouble(bucketValues[i]); + } + } + builder.endPositionEntry(); + } + return builder.build(); + } + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(double lhs, double rhs) { + return getOrder().reverseMul() * Double.compare(lhs, rhs) < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. + */ + private void grow(long minSize) { + long oldMax = values.size(); + values = bigArrays.grow(values, minSize); + // Set the next gather offsets for all newly allocated buckets. + setNextGatherOffsets(oldMax - (oldMax % getBucketSize())); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void setNextGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex) { + int maxParent = bucketSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + */ + private void downHeap(long rootIndex, int parent) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < bucketSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java new file mode 100644 index 0000000000000..04a635d75fe52 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.Arrays; +import java.util.stream.IntStream; + +/** + * Aggregates the top N int values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class IntBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ * + */ + private IntArray values; + + public IntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newIntArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(int value, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex))) { + values.set(rootIndex, value); + downHeap(rootIndex, 0); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(requiredSize); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, IntBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + return size > 0; + })) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + + // Used to sort the values in the bucket. + var bucketValues = new int[bucketSize]; + + try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendInt(values.get(bounds.v1())); + continue; + } + + for (int i = 0; i < size; i++) { + bucketValues[i] = values.get(bounds.v1() + i); + } + + // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting + Arrays.sort(bucketValues, 0, (int) size); + + builder.beginPositionEntry(); + if (order == SortOrder.ASC) { + for (int i = 0; i < size; i++) { + builder.appendInt(bucketValues[i]); + } + } else { + for (int i = (int) size - 1; i >= 0; i--) { + builder.appendInt(bucketValues[i]); + } + } + builder.endPositionEntry(); + } + return builder.build(); + } + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(int lhs, int rhs) { + return getOrder().reverseMul() * Integer.compare(lhs, rhs) < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. + */ + private void grow(long minSize) { + long oldMax = values.size(); + values = bigArrays.grow(values, minSize); + // Set the next gather offsets for all newly allocated buckets. + setNextGatherOffsets(oldMax - (oldMax % getBucketSize())); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void setNextGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex) { + int maxParent = bucketSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + */ + private void downHeap(long rootIndex, int parent) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < bucketSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java new file mode 100644 index 0000000000000..e08c25256944b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.Arrays; +import java.util.stream.IntStream; + +/** + * Aggregates the top N long values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class LongBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ * + */ + private LongArray values; + + public LongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newLongArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(long value, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex))) { + values.set(rootIndex, value); + downHeap(rootIndex, 0); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(requiredSize); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, LongBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + return size > 0; + })) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + + // Used to sort the values in the bucket. + var bucketValues = new long[bucketSize]; + + try (var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendLong(values.get(bounds.v1())); + continue; + } + + for (int i = 0; i < size; i++) { + bucketValues[i] = values.get(bounds.v1() + i); + } + + // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting + Arrays.sort(bucketValues, 0, (int) size); + + builder.beginPositionEntry(); + if (order == SortOrder.ASC) { + for (int i = 0; i < size; i++) { + builder.appendLong(bucketValues[i]); + } + } else { + for (int i = (int) size - 1; i >= 0; i--) { + builder.appendLong(bucketValues[i]); + } + } + builder.endPositionEntry(); + } + return builder.build(); + } + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(long lhs, long rhs) { + return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. + */ + private void grow(long minSize) { + long oldMax = values.size(); + values = bigArrays.grow(values, minSize); + // Set the next gather offsets for all newly allocated buckets. + setNextGatherOffsets(oldMax - (oldMax % getBucketSize())); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void setNextGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex) { + int maxParent = bucketSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + */ + private void downHeap(long rootIndex, int parent) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < bucketSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..d52d25941780c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java @@ -0,0 +1,126 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopListDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.DOUBLE) ); + + private final DriverContext driverContext; + + private final TopListDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListDoubleAggregatorFunction(DriverContext driverContext, List channels, + TopListDoubleAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopListDoubleAggregatorFunction(driverContext, channels, TopListDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(DoubleVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + TopListDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawBlock(DoubleBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + TopListDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + DoubleBlock topList = (DoubleBlock) topListUncast; + assert topList.getPositionCount() == 1; + TopListDoubleAggregator.combineIntermediate(state, topList); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..48df091d339b6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,45 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopListDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListDoubleAggregatorFunctionSupplier(List channels, int limit, + boolean ascending) { + this.channels = channels; + this.limit = limit; + this.ascending = ascending; + } + + @Override + public TopListDoubleAggregatorFunction aggregator(DriverContext driverContext) { + return TopListDoubleAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopListDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return TopListDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_list of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..0e3b98bb0f7e5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java @@ -0,0 +1,202 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopListDoubleAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.DOUBLE) ); + + private final TopListDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopListDoubleGroupingAggregatorFunction(List channels, + TopListDoubleAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopListDoubleGroupingAggregatorFunction(channels, TopListDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock valuesBlock = page.getBlock(channels.get(0)); + DoubleVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + DoubleBlock topList = (DoubleBlock) topListUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListDoubleAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + TopListDoubleAggregator.GroupingState inState = ((TopListDoubleGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + TopListDoubleAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java new file mode 100644 index 0000000000000..e885b285c4a51 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java @@ -0,0 +1,126 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopListIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.INT) ); + + private final DriverContext driverContext; + + private final TopListIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListIntAggregatorFunction(DriverContext driverContext, List channels, + TopListIntAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopListIntAggregatorFunction(driverContext, channels, TopListIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + TopListIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + TopListIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + IntBlock topList = (IntBlock) topListUncast; + assert topList.getPositionCount() == 1; + TopListIntAggregator.combineIntermediate(state, topList); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopListIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..d8bf91ba85541 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java @@ -0,0 +1,45 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopListIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListIntAggregatorFunctionSupplier(List channels, int limit, + boolean ascending) { + this.channels = channels; + this.limit = limit; + this.ascending = ascending; + } + + @Override + public TopListIntAggregatorFunction aggregator(DriverContext driverContext) { + return TopListIntAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopListIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return TopListIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_list of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..820ebb95e530c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java @@ -0,0 +1,200 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopListIntAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.INT) ); + + private final TopListIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopListIntGroupingAggregatorFunction(List channels, + TopListIntAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopListIntGroupingAggregatorFunction(channels, TopListIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valuesBlock = page.getBlock(channels.get(0)); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + IntBlock topList = (IntBlock) topListUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListIntAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + TopListIntAggregator.GroupingState inState = ((TopListIntGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + TopListIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = TopListIntAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java new file mode 100644 index 0000000000000..1a09a1a860e2f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java @@ -0,0 +1,126 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopListLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final TopListLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListLongAggregatorFunction(DriverContext driverContext, List channels, + TopListLongAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopListLongAggregatorFunction(driverContext, channels, TopListLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(LongVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + TopListLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + TopListLongAggregator.combine(state, block.getLong(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + LongBlock topList = (LongBlock) topListUncast; + assert topList.getPositionCount() == 1; + TopListLongAggregator.combineIntermediate(state, topList); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopListLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..617895fbff1a3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java @@ -0,0 +1,45 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopListLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopListLongAggregatorFunctionSupplier(List channels, int limit, + boolean ascending) { + this.channels = channels; + this.limit = limit; + this.ascending = ascending; + } + + @Override + public TopListLongAggregatorFunction aggregator(DriverContext driverContext) { + return TopListLongAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopListLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return TopListLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_list of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..cadb48b7d29d4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java @@ -0,0 +1,202 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopListLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopListLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("topList", ElementType.LONG) ); + + private final TopListLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopListLongGroupingAggregatorFunction(List channels, + TopListLongAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopListLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopListLongGroupingAggregatorFunction(channels, TopListLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopListLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topListUncast = page.getBlock(channels.get(0)); + if (topListUncast.areAllValuesNull()) { + return; + } + LongBlock topList = (LongBlock) topListUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + TopListLongAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + TopListLongAggregator.GroupingState inState = ((TopListLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + TopListLongAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = TopListLongAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/module-info.java b/x-pack/plugin/esql/compute/src/main/java/module-info.java index 3772d6c83f5aa..dc8cda0fbe3c8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/module-info.java +++ b/x-pack/plugin/esql/compute/src/main/java/module-info.java @@ -30,4 +30,5 @@ exports org.elasticsearch.compute.operator.topn; exports org.elasticsearch.compute.operator.mvdedupe; exports org.elasticsearch.compute.aggregation.table; + exports org.elasticsearch.compute.data.sort; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st new file mode 100644 index 0000000000000..810311154503e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +$if(!long)$ +import org.elasticsearch.compute.data.$Type$Block; +$endif$ +import org.elasticsearch.compute.data.IntVector; +$if(long)$ +import org.elasticsearch.compute.data.$Type$Block; +$endif$ +import org.elasticsearch.compute.data.sort.$Type$BucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Aggregates the top N field values for $type$. + */ +@Aggregator({ @IntermediateState(name = "topList", type = "$TYPE$_BLOCK") }) +@GroupingAggregator +class TopList$Type$Aggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, $type$ v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, $Type$Block values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.get$Type$(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, $type$ v) { + state.add(groupId, v); + } + + public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.get$Type$(i)); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory(), selected); + } + + public static class GroupingState implements Releasable { + private final $Type$BucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new $Type$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, $type$ value) { + sort.collect(value, groupId); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements Releasable { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add($type$ value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st new file mode 100644 index 0000000000000..6587743e34b6f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st @@ -0,0 +1,350 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.$Type$Array; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.Arrays; +import java.util.stream.IntStream; + +/** + * Aggregates the top N $type$ values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class $Type$BucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private $Type$Array values; + + public $Type$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.new$Type$Array(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect($type$ value, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex))) { + values.set(rootIndex, value); + downHeap(rootIndex, 0); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(requiredSize); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, $Type$BucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + return size > 0; + })) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + + // Used to sort the values in the bucket. + var bucketValues = new $type$[bucketSize]; + + try (var builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + continue; + } + + if (size == 1) { + builder.append$Type$(values.get(bounds.v1())); + continue; + } + + for (int i = 0; i < size; i++) { + bucketValues[i] = values.get(bounds.v1() + i); + } + + // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting + Arrays.sort(bucketValues, 0, (int) size); + + builder.beginPositionEntry(); + if (order == SortOrder.ASC) { + for (int i = 0; i < size; i++) { + builder.append$Type$(bucketValues[i]); + } + } else { + for (int i = (int) size - 1; i >= 0; i--) { + builder.append$Type$(bucketValues[i]); + } + } + builder.endPositionEntry(); + } + return builder.build(); + } + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { +$if(int)$ + return values.get(rootIndex); +$else$ + return (int) values.get(rootIndex); +$endif$ + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan($type$ lhs, $type$ rhs) { + return getOrder().reverseMul() * $Wrapper$.compare(lhs, rhs) < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. + */ + private void grow(long minSize) { + long oldMax = values.size(); + values = bigArrays.grow(values, minSize); + // Set the next gather offsets for all newly allocated buckets. + setNextGatherOffsets(oldMax - (oldMax % getBucketSize())); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void setNextGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex) { + int maxParent = bucketSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + */ + private void downHeap(long rootIndex, int parent) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < bucketSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java new file mode 100644 index 0000000000000..f708038776032 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; + +public class TopListDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 100; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceDoubleBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToDouble(l -> randomDouble())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new TopListDoubleAggregatorFunctionSupplier(inputChannels, LIMIT, true); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "top_list of doubles"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Object[] values = input.stream().flatMapToDouble(b -> allDoubles(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new); + assertThat((List) BlockUtils.toJavaObject(result, 0), contains(values)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..443604efd5c15 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; + +public class TopListIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 100; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceIntBlockSourceOperator(blockFactory, IntStream.range(0, size).map(l -> randomInt())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new TopListIntAggregatorFunctionSupplier(inputChannels, LIMIT, true); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "top_list of ints"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Object[] values = input.stream().flatMapToInt(b -> allInts(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new); + assertThat((List) BlockUtils.toJavaObject(result, 0), contains(values)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java new file mode 100644 index 0000000000000..4a6f101e573b8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.contains; + +public class TopListLongAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 100; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new TopListLongAggregatorFunctionSupplier(inputChannels, LIMIT, true); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "top_list of longs"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Object[] values = input.stream().flatMapToLong(b -> allLongs(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new); + assertThat((List) BlockUtils.toJavaObject(result, 0), contains(values)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java new file mode 100644 index 0000000000000..9e1bc145ad4ca --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java @@ -0,0 +1,368 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.TestBlockFactory; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class BucketedSortTestCase extends ESTestCase { + /** + * Build a {@link T} to test. Sorts built by this method shouldn't need scores. + */ + protected abstract T build(SortOrder sortOrder, int bucketSize); + + /** + * Build the expected correctly typed value for a value. + */ + protected abstract Object expectedValue(double v); + + /** + * A random value for testing, with the appropriate precision for the type we're testing. + */ + protected abstract double randomValue(); + + /** + * Collect a value into the sort. + * @param value value to collect, always sent as double just to have + * a number to test. Subclasses should cast to their favorite types + */ + protected abstract void collect(T sort, double value, int bucket); + + protected abstract void merge(T sort, int groupId, T other, int otherGroupId); + + protected abstract Block toBlock(T sort, BlockFactory blockFactory, IntVector selected); + + protected abstract void assertBlockTypeAndValues(Block block, Object... values); + + public final void testNeverCalled() { + SortOrder order = randomFrom(SortOrder.values()); + try (T sort = build(order, 1)) { + assertBlock(sort, randomNonNegativeInt()); + } + } + + public final void testSingleDoc() { + try (T sort = build(randomFrom(SortOrder.values()), 1)) { + collect(sort, 1, 0); + + assertBlock(sort, 0, expectedValue(1)); + } + } + + public final void testNonCompetitive() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, 2, 0); + collect(sort, 1, 0); + + assertBlock(sort, 0, expectedValue(2)); + } + } + + public final void testCompetitive() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + + assertBlock(sort, 0, expectedValue(2)); + } + } + + public final void testNegativeValue() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, -1, 0); + assertBlock(sort, 0, expectedValue(-1)); + } + } + + public final void testSomeBuckets() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, 2, 0); + collect(sort, 2, 1); + collect(sort, 2, 2); + collect(sort, 3, 0); + + assertBlock(sort, 0, expectedValue(3)); + assertBlock(sort, 1, expectedValue(2)); + assertBlock(sort, 2, expectedValue(2)); + assertBlock(sort, 3); + } + } + + public final void testBucketGaps() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, 2, 0); + collect(sort, 2, 2); + + assertBlock(sort, 0, expectedValue(2)); + assertBlock(sort, 1); + assertBlock(sort, 2, expectedValue(2)); + assertBlock(sort, 3); + } + } + + public final void testBucketsOutOfOrder() { + try (T sort = build(SortOrder.DESC, 1)) { + collect(sort, 2, 1); + collect(sort, 2, 0); + + assertBlock(sort, 0, expectedValue(2.0)); + assertBlock(sort, 1, expectedValue(2.0)); + assertBlock(sort, 2); + } + } + + public final void testManyBuckets() { + // Collect the buckets in random order + int[] buckets = new int[10000]; + for (int b = 0; b < buckets.length; b++) { + buckets[b] = b; + } + Collections.shuffle(Arrays.asList(buckets), random()); + + double[] maxes = new double[buckets.length]; + + try (T sort = build(SortOrder.DESC, 1)) { + for (int b : buckets) { + maxes[b] = 2; + collect(sort, 2, b); + if (randomBoolean()) { + maxes[b] = 3; + collect(sort, 3, b); + } + if (randomBoolean()) { + collect(sort, -1, b); + } + } + for (int b = 0; b < buckets.length; b++) { + assertBlock(sort, b, expectedValue(maxes[b])); + } + assertBlock(sort, buckets.length); + } + } + + public final void testTwoHitsDesc() { + try (T sort = build(SortOrder.DESC, 2)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + collect(sort, 3, 0); + + assertBlock(sort, 0, expectedValue(3), expectedValue(2)); + } + } + + public final void testTwoHitsAsc() { + try (T sort = build(SortOrder.ASC, 2)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + collect(sort, 3, 0); + + assertBlock(sort, 0, expectedValue(1), expectedValue(2)); + } + } + + public final void testTwoHitsTwoBucket() { + try (T sort = build(SortOrder.DESC, 2)) { + collect(sort, 1, 0); + collect(sort, 1, 1); + collect(sort, 2, 0); + collect(sort, 2, 1); + collect(sort, 3, 0); + collect(sort, 3, 1); + collect(sort, 4, 1); + + assertBlock(sort, 0, expectedValue(3), expectedValue(2)); + assertBlock(sort, 1, expectedValue(4), expectedValue(3)); + } + } + + public final void testManyBucketsManyHits() { + // Set the values in random order + double[] values = new double[10000]; + for (int v = 0; v < values.length; v++) { + values[v] = randomValue(); + } + Collections.shuffle(Arrays.asList(values), random()); + + int buckets = between(2, 100); + int bucketSize = between(2, 100); + try (T sort = build(SortOrder.DESC, bucketSize)) { + BitArray[] bucketUsed = new BitArray[buckets]; + Arrays.setAll(bucketUsed, i -> new BitArray(values.length, bigArrays())); + for (int doc = 0; doc < values.length; doc++) { + for (int bucket = 0; bucket < buckets; bucket++) { + if (randomBoolean()) { + bucketUsed[bucket].set(doc); + collect(sort, values[doc], bucket); + } + } + } + for (int bucket = 0; bucket < buckets; bucket++) { + List bucketValues = new ArrayList<>(values.length); + for (int doc = 0; doc < values.length; doc++) { + if (bucketUsed[bucket].get(doc)) { + bucketValues.add(values[doc]); + } + } + bucketUsed[bucket].close(); + assertBlock( + sort, + bucket, + bucketValues.stream().sorted((lhs, rhs) -> rhs.compareTo(lhs)).limit(bucketSize).map(this::expectedValue).toArray() + ); + } + assertBlock(sort, buckets); + } + } + + public final void testMergeHeapToHeap() { + try (T sort = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + collect(sort, 3, 0); + + try (T other = build(SortOrder.ASC, 3)) { + collect(other, 1, 0); + collect(other, 2, 0); + collect(other, 3, 0); + + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2)); + } + } + + public final void testMergeNoHeapToNoHeap() { + try (T sort = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + + try (T other = build(SortOrder.ASC, 3)) { + collect(other, 1, 0); + collect(other, 2, 0); + + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2)); + } + } + + public final void testMergeHeapToNoHeap() { + try (T sort = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + + try (T other = build(SortOrder.ASC, 3)) { + collect(other, 1, 0); + collect(other, 2, 0); + collect(other, 3, 0); + + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2)); + } + } + + public final void testMergeNoHeapToHeap() { + try (T sort = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + collect(sort, 3, 0); + + try (T other = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2)); + } + } + + public final void testMergeHeapToEmpty() { + try (T sort = build(SortOrder.ASC, 3)) { + try (T other = build(SortOrder.ASC, 3)) { + collect(other, 1, 0); + collect(other, 2, 0); + collect(other, 3, 0); + + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3)); + } + } + + public final void testMergeEmptyToHeap() { + try (T sort = build(SortOrder.ASC, 3)) { + collect(sort, 1, 0); + collect(sort, 2, 0); + collect(sort, 3, 0); + + try (T other = build(SortOrder.ASC, 3)) { + merge(sort, 0, other, 0); + } + + assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3)); + } + } + + public final void testMergeEmptyToEmpty() { + try (T sort = build(SortOrder.ASC, 3)) { + try (T other = build(SortOrder.ASC, 3)) { + merge(sort, 0, other, randomNonNegativeInt()); + } + + assertBlock(sort, 0); + } + } + + private void assertBlock(T sort, int groupId, Object... values) { + var blockFactory = TestBlockFactory.getNonBreakingInstance(); + + try (var intVector = blockFactory.newConstantIntVector(groupId, 1)) { + var block = toBlock(sort, blockFactory, intVector); + + assertThat(block.getPositionCount(), equalTo(1)); + assertThat(block.getTotalValueCount(), equalTo(values.length)); + + if (values.length == 0) { + assertThat(block.elementType(), equalTo(ElementType.NULL)); + assertThat(block.isNull(0), equalTo(true)); + } else { + assertBlockTypeAndValues(block, values); + } + } + } + + protected final BigArrays bigArrays() { + return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java new file mode 100644 index 0000000000000..43b5caa092b9a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.search.sort.SortOrder; + +import static org.hamcrest.Matchers.equalTo; + +public class DoubleBucketedSortTests extends BucketedSortTestCase { + @Override + protected DoubleBucketedSort build(SortOrder sortOrder, int bucketSize) { + return new DoubleBucketedSort(bigArrays(), sortOrder, bucketSize); + } + + @Override + protected Object expectedValue(double v) { + return v; + } + + @Override + protected double randomValue() { + return randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); + } + + @Override + protected void collect(DoubleBucketedSort sort, double value, int bucket) { + sort.collect(value, bucket); + } + + @Override + protected void merge(DoubleBucketedSort sort, int groupId, DoubleBucketedSort other, int otherGroupId) { + sort.merge(groupId, other, otherGroupId); + } + + @Override + protected Block toBlock(DoubleBucketedSort sort, BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + protected void assertBlockTypeAndValues(Block block, Object... values) { + assertThat(block.elementType(), equalTo(ElementType.DOUBLE)); + var typedBlock = (DoubleBlock) block; + for (int i = 0; i < values.length; i++) { + assertThat(typedBlock.getDouble(i), equalTo(values[i])); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java new file mode 100644 index 0000000000000..70d0a79ea7473 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.search.sort.SortOrder; + +import static org.hamcrest.Matchers.equalTo; + +public class IntBucketedSortTests extends BucketedSortTestCase { + @Override + protected IntBucketedSort build(SortOrder sortOrder, int bucketSize) { + return new IntBucketedSort(bigArrays(), sortOrder, bucketSize); + } + + @Override + protected Object expectedValue(double v) { + return (int) v; + } + + @Override + protected double randomValue() { + return randomInt(); + } + + @Override + protected void collect(IntBucketedSort sort, double value, int bucket) { + sort.collect((int) value, bucket); + } + + @Override + protected void merge(IntBucketedSort sort, int groupId, IntBucketedSort other, int otherGroupId) { + sort.merge(groupId, other, otherGroupId); + } + + @Override + protected Block toBlock(IntBucketedSort sort, BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + protected void assertBlockTypeAndValues(Block block, Object... values) { + assertThat(block.elementType(), equalTo(ElementType.INT)); + var typedBlock = (IntBlock) block; + for (int i = 0; i < values.length; i++) { + assertThat(typedBlock.getInt(i), equalTo(values[i])); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java new file mode 100644 index 0000000000000..bceed3b1d95b5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.search.sort.SortOrder; + +import static org.hamcrest.Matchers.equalTo; + +public class LongBucketedSortTests extends BucketedSortTestCase { + @Override + protected LongBucketedSort build(SortOrder sortOrder, int bucketSize) { + return new LongBucketedSort(bigArrays(), sortOrder, bucketSize); + } + + @Override + protected Object expectedValue(double v) { + return (long) v; + } + + @Override + protected double randomValue() { + // 2L^50 fits in the mantisa of a double which the test sort of needs. + return randomLongBetween(-2L ^ 50, 2L ^ 50); + } + + @Override + protected void collect(LongBucketedSort sort, double value, int bucket) { + sort.collect((long) value, bucket); + } + + @Override + protected void merge(LongBucketedSort sort, int groupId, LongBucketedSort other, int otherGroupId) { + sort.merge(groupId, other, otherGroupId); + } + + @Override + protected Block toBlock(LongBucketedSort sort, BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + protected void assertBlockTypeAndValues(Block block, Object... values) { + assertThat(block.elementType(), equalTo(ElementType.LONG)); + var typedBlock = (LongBlock) block; + for (int i = 0; i < values.length; i++) { + assertThat(typedBlock.getLong(i), equalTo(values[i])); + } + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 2cdd5c1dfd931..0fb35b4253d6d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -38,10 +38,10 @@ double e() "double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)" "double log10(number:double|integer|long|unsigned_long)" "keyword|text ltrim(string:keyword|text)" -"double|integer|long max(number:double|integer|long)" +"double|integer|long|date max(number:double|integer|long|date)" "double|integer|long median(number:double|integer|long)" "double|integer|long median_absolute_deviation(number:double|integer|long)" -"double|integer|long min(number:double|integer|long)" +"double|integer|long|date min(number:double|integer|long|date)" "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)" "double mv_avg(number:double|integer|long|unsigned_long)" "keyword mv_concat(string:text|keyword, delim:text|keyword)" @@ -109,6 +109,7 @@ double tau() "keyword|text to_upper(str:keyword|text)" "version to_ver(field:keyword|text|version)" "version to_version(field:keyword|text|version)" +"double|integer|long|date top_list(field:double|integer|long|date, limit:integer, order:keyword)" "keyword|text trim(string:keyword|text)" "boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)" ; @@ -155,10 +156,10 @@ locate |[string, substring, start] |["keyword|text", "keyword|te log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."] log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`. -max |number |"double|integer|long" |[""] +max |number |"double|integer|long|date" |[""] median |number |"double|integer|long" |[""] median_absolut|number |"double|integer|long" |[""] -min |number |"double|integer|long" |[""] +min |number |"double|integer|long|date" |[""] mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""] mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression. mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.] @@ -226,6 +227,7 @@ to_unsigned_lo|field |"boolean|date|keyword|text|d to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`. to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. +top_list |[field, limit, order] |["double|integer|long|date", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.] trim |string |"keyword|text" |String expression. If `null`, the function returns `null`. values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""] ; @@ -344,6 +346,7 @@ to_unsigned_lo|Converts an input value to an unsigned long value. If the input p to_upper |Returns a new string representing the input string converted to upper case. to_ver |Converts an input string to a version value. to_version |Converts an input string to a version value. +top_list |Collects the top values for a field. Includes repeated values. trim |Removes leading and trailing whitespaces from a string. values |Collect values for a field. ; @@ -392,10 +395,10 @@ locate |integer log |double |[true, false] |false |false log10 |double |false |false |false ltrim |"keyword|text" |false |false |false -max |"double|integer|long" |false |false |true +max |"double|integer|long|date" |false |false |true median |"double|integer|long" |false |false |true median_absolut|"double|integer|long" |false |false |true -min |"double|integer|long" |false |false |true +min |"double|integer|long|date" |false |false |true mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false mv_avg |double |false |false |false mv_concat |keyword |[false, false] |false |false @@ -463,6 +466,7 @@ to_unsigned_lo|unsigned_long to_upper |"keyword|text" |false |false |false to_ver |version |false |false |false to_version |version |false |false |false +top_list |"double|integer|long|date" |[false, false, false] |false |true trim |"keyword|text" |false |false |false values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true ; @@ -483,5 +487,5 @@ countFunctions#[skip:-8.14.99, reason:BIN added] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long -109 | 109 | 109 +110 | 110 | 110 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec new file mode 100644 index 0000000000000..c24f6a7e70954 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec @@ -0,0 +1,156 @@ +topList +required_capability: agg_top_list +// tag::top-list[] +FROM employees +| STATS top_salaries = TOP_LIST(salary, 3, "desc"), top_salary = MAX(salary) +// end::top-list[] +; + +// tag::top-list-result[] +top_salaries:integer | top_salary:integer +[74999, 74970, 74572] | 74999 +// end::top-list-result[] +; + +topListAllTypesAsc +required_capability: agg_top_list +FROM employees +| STATS + date = TOP_LIST(hire_date, 2, "asc"), + double = TOP_LIST(salary_change, 2, "asc"), + integer = TOP_LIST(salary, 2, "asc"), + long = TOP_LIST(salary_change.long, 2, "asc") +; + +date:date | double:double | integer:integer | long:long +[1985-02-18T00:00:00.000Z,1985-02-24T00:00:00.000Z] | [-9.81,-9.28] | [25324,25945] | [-9,-9] +; + +topListAllTypesDesc +required_capability: agg_top_list +FROM employees +| STATS + date = TOP_LIST(hire_date, 2, "desc"), + double = TOP_LIST(salary_change, 2, "desc"), + integer = TOP_LIST(salary, 2, "desc"), + long = TOP_LIST(salary_change.long, 2, "desc") +; + +date:date | double:double | integer:integer | long:long +[1999-04-30T00:00:00.000Z,1997-05-19T00:00:00.000Z] | [14.74,14.68] | [74999,74970] | [14,14] +; + +topListAllTypesRow +required_capability: agg_top_list +ROW + constant_date=TO_DATETIME("1985-02-18T00:00:00.000Z"), + constant_double=-9.81, + constant_integer=25324, + constant_long=TO_LONG(-9) +| STATS + date = TOP_LIST(constant_date, 2, "asc"), + double = TOP_LIST(constant_double, 2, "asc"), + integer = TOP_LIST(constant_integer, 2, "asc"), + long = TOP_LIST(constant_long, 2, "asc") +| keep date, double, integer, long +; + +date:date | double:double | integer:integer | long:long +1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9 +; + +topListSomeBuckets +required_capability: agg_top_list +FROM employees +| STATS top_salary = TOP_LIST(salary, 2, "desc") by still_hired +| sort still_hired asc +; + +top_salary:integer | still_hired:boolean +[74999,74970] | false +[74572,73578] | true +; + +topListManyBuckets +required_capability: agg_top_list +FROM employees +| STATS top_salary = TOP_LIST(salary, 2, "desc") by x=emp_no, y=emp_no+1 +| sort x asc +| limit 3 +; + +top_salary:integer | x:integer | y:integer +57305 | 10001 | 10002 +56371 | 10002 | 10003 +61805 | 10003 | 10004 +; + +topListMultipleStats +required_capability: agg_top_list +FROM employees +| STATS top_salary = TOP_LIST(salary, 1, "desc") by emp_no +| STATS top_salary = TOP_LIST(top_salary, 3, "asc") +; + +top_salary:integer +[25324,25945,25976] +; + +topListAllTypesMin +required_capability: agg_top_list +FROM employees +| STATS + date = TOP_LIST(hire_date, 1, "asc"), + double = TOP_LIST(salary_change, 1, "asc"), + integer = TOP_LIST(salary, 1, "asc"), + long = TOP_LIST(salary_change.long, 1, "asc") +; + +date:date | double:double | integer:integer | long:long +1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9 +; + +topListAllTypesMax +required_capability: agg_top_list +FROM employees +| STATS + date = TOP_LIST(hire_date, 1, "desc"), + double = TOP_LIST(salary_change, 1, "desc"), + integer = TOP_LIST(salary, 1, "desc"), + long = TOP_LIST(salary_change.long, 1, "desc") +; + +date:date | double:double | integer:integer | long:long +1999-04-30T00:00:00.000Z | 14.74 | 74999 | 14 +; + +topListAscDesc +required_capability: agg_top_list +FROM employees +| STATS top_asc = TOP_LIST(salary, 3, "asc"), top_desc = TOP_LIST(salary, 3, "desc") +; + +top_asc:integer | top_desc:integer +[25324, 25945, 25976] | [74999, 74970, 74572] +; + +topListEmpty +required_capability: agg_top_list +FROM employees +| WHERE salary < 0 +| STATS top = TOP_LIST(salary, 3, "asc") +; + +top:integer +null +; + +topListDuplicates +required_capability: agg_top_list +FROM employees +| STATS integer = TOP_LIST(languages, 2, "desc") +; + +integer:integer +[5, 5] +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 3eef9f7356b39..e65f574422dd5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -42,6 +42,11 @@ public class EsqlCapabilities { */ private static final String FN_SUBSTRING_EMPTY_NULL = "fn_substring_empty_null"; + /** + * Support for aggregation function {@code TOP_LIST}. + */ + private static final String AGG_TOP_LIST = "agg_top_list"; + /** * Optimization for ST_CENTROID changed some results in cartesian data. #108713 */ @@ -84,6 +89,7 @@ private static Set capabilities() { caps.add(FN_CBRT); caps.add(FN_IP_PREFIX); caps.add(FN_SUBSTRING_EMPTY_NULL); + caps.add(AGG_TOP_LIST); caps.add(ST_CENTROID_AGG_OPTIMIZED); caps.add(METADATA_IGNORED_FIELD); caps.add(FN_MV_APPEND); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 8fd6ebe8d7d69..7034f23be1662 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; @@ -192,6 +193,7 @@ private FunctionDefinition[][] functions() { def(Min.class, Min::new, "min"), def(Percentile.class, Percentile::new, "percentile"), def(Sum.class, Sum::new, "sum"), + def(TopList.class, TopList::new, "top_list"), def(Values.class, Values::new, "values") }, // math new FunctionDefinition[] { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java index 3f6632f66bcee..1c1139c197ac0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java @@ -24,8 +24,12 @@ public class Max extends NumericAggregate implements SurrogateExpression { - @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true) - public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { + @FunctionInfo( + returnType = { "double", "integer", "long", "date" }, + description = "The maximum value of a numeric field.", + isAggregation = true + ) + public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { super(source, field); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java index 16821752bc7b8..ecfc2200a3643 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java @@ -24,8 +24,12 @@ public class Min extends NumericAggregate implements SurrogateExpression { - @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true) - public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { + @FunctionInfo( + returnType = { "double", "integer", "long", "date" }, + description = "The minimum value of a numeric field.", + isAggregation = true + ) + public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { super(source, field); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java index b003b981c0709..390cd0d68018e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java @@ -19,6 +19,28 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +/** + * Aggregate function that receives a numeric, signed field, and returns a single double value. + *

+ * Implement the supplier methods to return the correct {@link AggregatorFunctionSupplier}. + *

+ *

+ * Some methods can be optionally overridden to support different variations: + *

+ *
    + *
  • + * {@link #supportsDates}: override to also support dates. Defaults to false. + *
  • + *
  • + * {@link #resolveType}: override to support different parameters. + * Call {@code super.resolveType()} to add extra checks. + *
  • + *
  • + * {@link #dataType}: override to return a different datatype. + * You can return {@code field().dataType()} to propagate the parameter type. + *
  • + *
+ */ public abstract class NumericAggregate extends AggregateFunction implements ToAggregator { NumericAggregate(Source source, Expression field, List parameters) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java new file mode 100644 index 0000000000000..79893b1c7de07 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopListDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopListIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopListLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +public class TopList extends AggregateFunction implements ToAggregator, SurrogateExpression { + private static final String ORDER_ASC = "ASC"; + private static final String ORDER_DESC = "DESC"; + + @FunctionInfo( + returnType = { "double", "integer", "long", "date" }, + description = "Collects the top values for a field. Includes repeated values.", + isAggregation = true, + examples = @Example(file = "stats_top_list", tag = "top-list") + ) + public TopList( + Source source, + @Param( + name = "field", + type = { "double", "integer", "long", "date" }, + description = "The field to collect the top values for." + ) Expression field, + @Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit, + @Param( + name = "order", + type = { "keyword" }, + description = "The order to calculate the top values. Either `asc` or `desc`." + ) Expression order + ) { + super(source, field, Arrays.asList(limit, order)); + } + + public static TopList readFrom(PlanStreamInput in) throws IOException { + return new TopList(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression()); + } + + public void writeTo(PlanStreamOutput out) throws IOException { + source().writeTo(out); + List fields = children(); + assert fields.size() == 3; + out.writeExpression(fields.get(0)); + out.writeExpression(fields.get(1)); + out.writeExpression(fields.get(2)); + } + + private Expression limitField() { + return parameters().get(0); + } + + private Expression orderField() { + return parameters().get(1); + } + + private int limitValue() { + return (int) limitField().fold(); + } + + private String orderRawValue() { + return BytesRefs.toString(orderField().fold()); + } + + private boolean orderValue() { + return orderRawValue().equalsIgnoreCase(ORDER_ASC); + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + var typeResolution = isType( + field(), + dt -> dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG, + sourceText(), + FIRST, + "numeric except unsigned_long or counter types" + ).and(isFoldable(limitField(), sourceText(), SECOND)) + .and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer")) + .and(isFoldable(orderField(), sourceText(), THIRD)) + .and(isString(orderField(), sourceText(), THIRD)); + + if (typeResolution.unresolved()) { + return typeResolution; + } + + var limit = limitValue(); + var order = orderRawValue(); + + if (limit <= 0) { + return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit)); + } + + if (order.equalsIgnoreCase(ORDER_ASC) == false && order.equalsIgnoreCase(ORDER_DESC) == false) { + return new TypeResolution( + format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order) + ); + } + + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public DataType dataType() { + return field().dataType(); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, TopList::new, children().get(0), children().get(1), children().get(2)); + } + + @Override + public TopList replaceChildren(List newChildren) { + return new TopList(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + } + + @Override + public AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.LONG || type == DataType.DATETIME) { + return new TopListLongAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue()); + } + if (type == DataType.INTEGER) { + return new TopListIntAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue()); + } + if (type == DataType.DOUBLE) { + return new TopListDoubleAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue()); + } + throw EsqlIllegalArgumentException.illegalDataType(type); + } + + @Override + public Expression surrogate() { + var s = source(); + + if (limitValue() == 1) { + if (orderValue()) { + return new Min(s, field()); + } else { + return new Max(s, field()); + } + } + + return null; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java new file mode 100644 index 0000000000000..a99c7a8b7ac8d --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * Functions that aggregate values, with or without grouping within buckets. + * Used in `STATS` and similar commands. + * + *

Guide to adding new aggregate function

+ *
    + *
  1. + * Aggregation functions are more complex than scalar functions, so it's a good idea to discuss + * the new function with the ESQL team before starting to implement it. + *

    + * You may also discuss its implementation, as aggregations may require special performance considerations. + *

    + *
  2. + *
  3. + * To learn the basics about making functions, check {@link org.elasticsearch.xpack.esql.expression.function.scalar}. + *

    + * It has the guide to making a simple function, which should be a good base to start doing aggregations. + *

    + *
  4. + *
  5. + * Pick one of the csv-spec files in {@code x-pack/plugin/esql/qa/testFixtures/src/main/resources/} + * and add a test for the function you want to write. These files are roughly themed but there + * isn't a strong guiding principle in the organization. + *
  6. + *
  7. + * Rerun the {@code CsvTests} and watch your new test fail. + *
  8. + *
  9. + * Find an aggregate function in this package similar to the one you are working on and copy it to build + * yours. + * Your function might extend from the available abstract classes. Check the javadoc of each before using them: + *
      + *
    • + * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction}: The base class for aggregates + *
    • + *
    • + * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate}: Aggregation for numeric values + *
    • + *
    • + * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction}: + * Aggregation for spatial values + *
    • + *
    + *
  10. + *
  11. + * Fill the required methods in your new function. Check their JavaDoc for more information. + * Here are some of the important ones: + *
      + *
    • + * Constructor: Review the constructor annotations, and make sure to add the correct types and descriptions. + *
        + *
      • {@link org.elasticsearch.xpack.esql.expression.function.FunctionInfo}, for the constructor itself
      • + *
      • {@link org.elasticsearch.xpack.esql.expression.function.Param}, for the function parameters
      • + *
      + *
    • + *
    • + * {@code resolveType}: Check the metadata of your function parameters. + * This may include types, whether they are foldable or not, or their possible values. + *
    • + *
    • + * {@code dataType}: This will return the datatype of your function. + * May be based on its current parameters. + *
    • + *
    + * + * Finally, you may want to implement some interfaces. + * Check their JavaDocs to see if they are suitable for your function: + *
      + *
    • + * {@link org.elasticsearch.xpack.esql.planner.ToAggregator}: (More information about aggregators below) + *
    • + *
    • + * {@link org.elasticsearch.xpack.esql.expression.SurrogateExpression} + *
    • + *
    + *
  12. + *
  13. + * To introduce your aggregation to the engine: + *
      + *
    • + * Add it to {@code org.elasticsearch.xpack.esql.planner.AggregateMapper}. + * Check all usages of other aggregations there, and replicate the logic. + *
    • + *
    • + * Add it to {@link org.elasticsearch.xpack.esql.io.stream.PlanNamedTypes}. + * Consider adding a {@code writeTo} method and a constructor/{@code readFrom} method inside your function, + * to keep all the logic in one place. + *

      + * You can find examples of other aggregations using this method, + * like {@link org.elasticsearch.xpack.esql.expression.function.aggregate.TopList#writeTo(PlanStreamOutput)} + *

      + *
    • + *
    • + * Do the same with {@link org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry}. + *
    • + *
    + *
  14. + *
+ * + *

Creating aggregators for your function

+ *

+ * Aggregators contain the core logic of your aggregation. That is, how to combine values, what to store, how to process data, etc. + *

+ *
    + *
  1. + * Copy an existing aggregator to use as a base. You'll usually make one per type. Check other classes to see the naming pattern. + * You can find them in {@link org.elasticsearch.compute.aggregation}. + *

    + * Note that some aggregators are autogenerated, so they live in different directories. + * The base is {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/} + *

    + *
  2. + *
  3. + * Make a test for your aggregator. + * You can copy an existing one from {@code x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/}. + *

    + * Tests extending from {@code org.elasticsearch.compute.aggregation.AggregatorFunctionTestCase} + * will already include most required cases. You should only need to fill the required abstract methods. + *

    + *
  4. + *
  5. + * Check the Javadoc of the {@link org.elasticsearch.compute.ann.Aggregator} + * and {@link org.elasticsearch.compute.ann.GroupingAggregator} annotations. + * Add/Modify them on your aggregator. + *
  6. + *
  7. + * The {@link org.elasticsearch.compute.ann.Aggregator} JavaDoc explains the static methods you should add. + *
  8. + *
  9. + * After implementing the required methods (Even if they have a dummy implementation), + * run the CsvTests to generate some extra required classes. + *

    + * One of them will be the {@code AggregatorFunctionSupplier} for your aggregator. + * Find it by its name ({@code AggregatorFunctionSupplier}), + * and return it in the {@code toSupplier} method in your function, under the correct type condition. + *

    + *
  10. + *
  11. + * Now, complete the implementation of the aggregator, until the tests pass! + *
  12. + *
+ * + *

StringTemplates

+ *

+ * Making an aggregator per type may be repetitive. To avoid code duplication, we use StringTemplates: + *

+ *
    + *
  1. + * Create a new StringTemplate file. + * Use another as a reference, like + * {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st}. + *
  2. + *
  3. + * Add the template scripts to {@code x-pack/plugin/esql/compute/build.gradle}. + *

    + * You can also see there which variables you can use, and which types are currently supported. + *

    + *
  4. + *
  5. + * After completing your template, run the generation with {@code ./gradlew :x-pack:plugin:esql:compute:compileJava}. + *

    + * You may need to tweak some import orders per type so they don't raise warnings. + *

    + *
  6. + *
+ */ +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java index be5e105c3398e..831d105a89076 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; @@ -298,6 +299,7 @@ public static List namedTypeEntries() { of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile), of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), + of(AggregateFunction.class, TopList.class, (out, prefix) -> prefix.writeTo(out), TopList::readFrom), of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction) ); List entries = new ArrayList<>(declared); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java index 863476ba55686..0d45ce10b1966 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java @@ -29,6 +29,7 @@ * functions, designed to run over a {@link org.elasticsearch.compute.data.Block} *
  • {@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query
  • *
  • {@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions
  • + *
  • {@link org.elasticsearch.xpack.esql.expression.function.aggregate} - Guide to writing aggregation functions
  • *
  • {@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing
  • *
  • {@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations
  • *
  • {@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations
  • diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 68e6ea4d6cadb..83fdd5dc0c5d2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import java.lang.invoke.MethodHandle; @@ -61,7 +62,8 @@ final class AggregateMapper { Percentile.class, SpatialCentroid.class, Sum.class, - Values.class + Values.class, + TopList.class ); /** Record of agg Class, type, and grouping (or non-grouping). */ @@ -143,6 +145,8 @@ private static Stream, Tuple>> typeAndNames(Class } else if (Values.class.isAssignableFrom(clazz)) { // TODO can't we figure this out from the function itself? types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); + } else if (TopList.class.isAssignableFrom(clazz)) { + types = List.of("Int", "Long", "Double"); } else { assert clazz == CountDistinct.class : "Expected CountDistinct, got: " + clazz; types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();