diff --git a/.gitattributes b/.gitattributes
index 6a8de5462ec3f..04881c92ede00 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -4,6 +4,8 @@ CHANGELOG.asciidoc merge=union
# Windows
build-tools-internal/src/test/resources/org/elasticsearch/gradle/internal/release/*.asciidoc text eol=lf
+x-pack/plugin/esql/compute/src/main/generated/** linguist-generated=true
+x-pack/plugin/esql/compute/src/main/generated-src/** linguist-generated=true
x-pack/plugin/esql/src/main/antlr/*.tokens linguist-generated=true
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/*.interp linguist-generated=true
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer*.java linguist-generated=true
diff --git a/docs/changelog/109386.yaml b/docs/changelog/109386.yaml
new file mode 100644
index 0000000000000..984ee96dde063
--- /dev/null
+++ b/docs/changelog/109386.yaml
@@ -0,0 +1,6 @@
+pr: 109386
+summary: "ESQL: `top_list` aggregation"
+area: ES|QL
+type: feature
+issues:
+ - 109213
diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java
index 7e92fc5c2734e..0216ea07e5c7c 100644
--- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java
+++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/GroupingAggregator.java
@@ -12,6 +12,10 @@
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
+/**
+ * Annotates a class that implements an aggregation function with grouping.
+ * See {@link Aggregator} for more information.
+ */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.SOURCE)
public @interface GroupingAggregator {
diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle
index 635a53d1ac98a..bc206ee1d78d6 100644
--- a/x-pack/plugin/esql/compute/build.gradle
+++ b/x-pack/plugin/esql/compute/build.gradle
@@ -36,10 +36,11 @@ spotless {
}
}
-def prop(Type, type, TYPE, BYTES, Array, Hash) {
+def prop(Type, type, Wrapper, TYPE, BYTES, Array, Hash) {
return [
"Type" : Type,
"type" : type,
+ "Wrapper": Wrapper,
"TYPE" : TYPE,
"BYTES" : BYTES,
"Array" : Array,
@@ -55,12 +56,13 @@ def prop(Type, type, TYPE, BYTES, Array, Hash) {
}
tasks.named('stringTemplates').configure {
- var intProperties = prop("Int", "int", "INT", "Integer.BYTES", "IntArray", "LongHash")
- var floatProperties = prop("Float", "float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
- var longProperties = prop("Long", "long", "LONG", "Long.BYTES", "LongArray", "LongHash")
- var doubleProperties = prop("Double", "double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash")
- var bytesRefProperties = prop("BytesRef", "BytesRef", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash")
- var booleanProperties = prop("Boolean", "boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "")
+ var intProperties = prop("Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash")
+ var floatProperties = prop("Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
+ var longProperties = prop("Long", "long", "Long", "LONG", "Long.BYTES", "LongArray", "LongHash")
+ var doubleProperties = prop("Double", "double", "Double", "DOUBLE", "Double.BYTES", "DoubleArray", "LongHash")
+ var bytesRefProperties = prop("BytesRef", "BytesRef", "", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", "", "BytesRefHash")
+ var booleanProperties = prop("Boolean", "boolean", "Boolean", "BOOLEAN", "Byte.BYTES", "BitArray", "")
+
// primitive vectors
File vectorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st")
template {
@@ -500,6 +502,24 @@ tasks.named('stringTemplates').configure {
it.outputFile = "org/elasticsearch/compute/aggregation/RateDoubleAggregator.java"
}
+
+ File topListAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st")
+ template {
+ it.properties = intProperties
+ it.inputFile = topListAggregatorInputFile
+ it.outputFile = "org/elasticsearch/compute/aggregation/TopListIntAggregator.java"
+ }
+ template {
+ it.properties = longProperties
+ it.inputFile = topListAggregatorInputFile
+ it.outputFile = "org/elasticsearch/compute/aggregation/TopListLongAggregator.java"
+ }
+ template {
+ it.properties = doubleProperties
+ it.inputFile = topListAggregatorInputFile
+ it.outputFile = "org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java"
+ }
+
File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st")
template {
it.properties = intProperties
@@ -635,4 +655,21 @@ tasks.named('stringTemplates').configure {
it.inputFile = resultBuilderInputFile
it.outputFile = "org/elasticsearch/compute/operator/topn/ResultBuilderForFloat.java"
}
+
+ File bucketedSortInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st")
+ template {
+ it.properties = intProperties
+ it.inputFile = bucketedSortInputFile
+ it.outputFile = "org/elasticsearch/compute/data/sort/IntBucketedSort.java"
+ }
+ template {
+ it.properties = longProperties
+ it.inputFile = bucketedSortInputFile
+ it.outputFile = "org/elasticsearch/compute/data/sort/LongBucketedSort.java"
+ }
+ template {
+ it.properties = doubleProperties
+ it.inputFile = bucketedSortInputFile
+ it.outputFile = "org/elasticsearch/compute/data/sort/DoubleBucketedSort.java"
+ }
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java
new file mode 100644
index 0000000000000..941722b4424d3
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListDoubleAggregator.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.ann.Aggregator;
+import org.elasticsearch.compute.ann.GroupingAggregator;
+import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.sort.DoubleBucketedSort;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+/**
+ * Aggregates the top N field values for double.
+ */
+@Aggregator({ @IntermediateState(name = "topList", type = "DOUBLE_BLOCK") })
+@GroupingAggregator
+class TopListDoubleAggregator {
+ public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
+ return new SingleState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(SingleState state, double v) {
+ state.add(v);
+ }
+
+ public static void combineIntermediate(SingleState state, DoubleBlock values) {
+ int start = values.getFirstValueIndex(0);
+ int end = start + values.getValueCount(0);
+ for (int i = start; i < end; i++) {
+ combine(state, values.getDouble(i));
+ }
+ }
+
+ public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory());
+ }
+
+ public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
+ return new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(GroupingState state, int groupId, double v) {
+ state.add(groupId, v);
+ }
+
+ public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
+ int start = values.getFirstValueIndex(valuesPosition);
+ int end = start + values.getValueCount(valuesPosition);
+ for (int i = start; i < end; i++) {
+ combine(state, groupId, values.getDouble(i));
+ }
+ }
+
+ public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
+ current.merge(groupId, state, statePosition);
+ }
+
+ public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory(), selected);
+ }
+
+ public static class GroupingState implements Releasable {
+ private final DoubleBucketedSort sort;
+
+ private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.sort = new DoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
+ }
+
+ public void add(int groupId, double value) {
+ sort.collect(value, groupId);
+ }
+
+ public void merge(int groupId, GroupingState other, int otherGroupId) {
+ sort.merge(groupId, other.sort, otherGroupId);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory(), selected);
+ }
+
+ Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ void enableGroupIdTracking(SeenGroupIds seen) {
+ // we figure out seen values from nulls on the values block
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(sort);
+ }
+ }
+
+ public static class SingleState implements Releasable {
+ private final GroupingState internalState;
+
+ private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.internalState = new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public void add(double value) {
+ internalState.add(0, value);
+ }
+
+ public void merge(GroupingState other) {
+ internalState.merge(0, other, 0);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory());
+ }
+
+ Block toBlock(BlockFactory blockFactory) {
+ try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
+ return internalState.toBlock(blockFactory, intValues);
+ }
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(internalState);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java
new file mode 100644
index 0000000000000..dafbf1c2a3051
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListIntAggregator.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.ann.Aggregator;
+import org.elasticsearch.compute.ann.GroupingAggregator;
+import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.sort.IntBucketedSort;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+/**
+ * Aggregates the top N field values for int.
+ */
+@Aggregator({ @IntermediateState(name = "topList", type = "INT_BLOCK") })
+@GroupingAggregator
+class TopListIntAggregator {
+ public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
+ return new SingleState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(SingleState state, int v) {
+ state.add(v);
+ }
+
+ public static void combineIntermediate(SingleState state, IntBlock values) {
+ int start = values.getFirstValueIndex(0);
+ int end = start + values.getValueCount(0);
+ for (int i = start; i < end; i++) {
+ combine(state, values.getInt(i));
+ }
+ }
+
+ public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory());
+ }
+
+ public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
+ return new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(GroupingState state, int groupId, int v) {
+ state.add(groupId, v);
+ }
+
+ public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
+ int start = values.getFirstValueIndex(valuesPosition);
+ int end = start + values.getValueCount(valuesPosition);
+ for (int i = start; i < end; i++) {
+ combine(state, groupId, values.getInt(i));
+ }
+ }
+
+ public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
+ current.merge(groupId, state, statePosition);
+ }
+
+ public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory(), selected);
+ }
+
+ public static class GroupingState implements Releasable {
+ private final IntBucketedSort sort;
+
+ private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.sort = new IntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
+ }
+
+ public void add(int groupId, int value) {
+ sort.collect(value, groupId);
+ }
+
+ public void merge(int groupId, GroupingState other, int otherGroupId) {
+ sort.merge(groupId, other.sort, otherGroupId);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory(), selected);
+ }
+
+ Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ void enableGroupIdTracking(SeenGroupIds seen) {
+ // we figure out seen values from nulls on the values block
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(sort);
+ }
+ }
+
+ public static class SingleState implements Releasable {
+ private final GroupingState internalState;
+
+ private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.internalState = new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public void add(int value) {
+ internalState.add(0, value);
+ }
+
+ public void merge(GroupingState other) {
+ internalState.merge(0, other, 0);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory());
+ }
+
+ Block toBlock(BlockFactory blockFactory) {
+ try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
+ return internalState.toBlock(blockFactory, intValues);
+ }
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(internalState);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java
new file mode 100644
index 0000000000000..c0e7122a4be0b
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopListLongAggregator.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.ann.Aggregator;
+import org.elasticsearch.compute.ann.GroupingAggregator;
+import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.sort.LongBucketedSort;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+/**
+ * Aggregates the top N field values for long.
+ */
+@Aggregator({ @IntermediateState(name = "topList", type = "LONG_BLOCK") })
+@GroupingAggregator
+class TopListLongAggregator {
+ public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
+ return new SingleState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(SingleState state, long v) {
+ state.add(v);
+ }
+
+ public static void combineIntermediate(SingleState state, LongBlock values) {
+ int start = values.getFirstValueIndex(0);
+ int end = start + values.getValueCount(0);
+ for (int i = start; i < end; i++) {
+ combine(state, values.getLong(i));
+ }
+ }
+
+ public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory());
+ }
+
+ public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
+ return new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(GroupingState state, int groupId, long v) {
+ state.add(groupId, v);
+ }
+
+ public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
+ int start = values.getFirstValueIndex(valuesPosition);
+ int end = start + values.getValueCount(valuesPosition);
+ for (int i = start; i < end; i++) {
+ combine(state, groupId, values.getLong(i));
+ }
+ }
+
+ public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
+ current.merge(groupId, state, statePosition);
+ }
+
+ public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory(), selected);
+ }
+
+ public static class GroupingState implements Releasable {
+ private final LongBucketedSort sort;
+
+ private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.sort = new LongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
+ }
+
+ public void add(int groupId, long value) {
+ sort.collect(value, groupId);
+ }
+
+ public void merge(int groupId, GroupingState other, int otherGroupId) {
+ sort.merge(groupId, other.sort, otherGroupId);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory(), selected);
+ }
+
+ Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ void enableGroupIdTracking(SeenGroupIds seen) {
+ // we figure out seen values from nulls on the values block
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(sort);
+ }
+ }
+
+ public static class SingleState implements Releasable {
+ private final GroupingState internalState;
+
+ private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.internalState = new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public void add(long value) {
+ internalState.add(0, value);
+ }
+
+ public void merge(GroupingState other) {
+ internalState.merge(0, other, 0);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory());
+ }
+
+ Block toBlock(BlockFactory blockFactory) {
+ try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
+ return internalState.toBlock(blockFactory, intValues);
+ }
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(internalState);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java
new file mode 100644
index 0000000000000..63318a2189908
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java
@@ -0,0 +1,346 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.DoubleArray;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.search.sort.BucketedSort;
+import org.elasticsearch.search.sort.SortOrder;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+/**
+ * Aggregates the top N double values per bucket.
+ * See {@link BucketedSort} for more information.
+ * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
+ */
+public class DoubleBucketedSort implements Releasable {
+
+ private final BigArrays bigArrays;
+ private final SortOrder order;
+ private final int bucketSize;
+ /**
+ * {@code true} if the bucket is in heap mode, {@code false} if
+ * it is still gathering.
+ */
+ private final BitArray heapMode;
+ /**
+ * An array containing all the values on all buckets. The structure is as follows:
+ *
+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
+ * Then, for each bucket, it can be in 2 states:
+ *
+ *
+ * -
+ * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
+ * In gather mode, the elements are stored in the array from the highest index to the lowest index.
+ * The lowest index contains the offset to the next slot to be filled.
+ *
+ * This allows us to insert elements in O(1) time.
+ *
+ *
+ * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
+ *
+ *
+ * -
+ * Heap mode: The bucket slots are organized as a min heap structure.
+ *
+ * The root of the heap is the minimum value in the bucket,
+ * which allows us to quickly discard new values that are not in the top N.
+ *
+ *
+ *
+ */
+ private DoubleArray values;
+
+ public DoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
+ this.bigArrays = bigArrays;
+ this.order = order;
+ this.bucketSize = bucketSize;
+ heapMode = new BitArray(0, bigArrays);
+
+ boolean success = false;
+ try {
+ values = bigArrays.newDoubleArray(0, false);
+ success = true;
+ } finally {
+ if (success == false) {
+ close();
+ }
+ }
+ }
+
+ /**
+ * Collects a {@code value} into a {@code bucket}.
+ *
+ * It may or may not be inserted in the heap, depending on if it is better than the current root.
+ *
+ */
+ public void collect(double value, int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (inHeapMode(bucket)) {
+ if (betterThan(value, values.get(rootIndex))) {
+ values.set(rootIndex, value);
+ downHeap(rootIndex, 0);
+ }
+ return;
+ }
+ // Gathering mode
+ long requiredSize = rootIndex + bucketSize;
+ if (values.size() < requiredSize) {
+ grow(requiredSize);
+ }
+ int next = getNextGatherOffset(rootIndex);
+ assert 0 <= next && next < bucketSize
+ : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
+ long index = next + rootIndex;
+ values.set(index, value);
+ if (next == 0) {
+ heapMode.set(bucket);
+ heapify(rootIndex);
+ } else {
+ setNextGatherOffset(rootIndex, next - 1);
+ }
+ }
+
+ /**
+ * The order of the sort.
+ */
+ public SortOrder getOrder() {
+ return order;
+ }
+
+ /**
+ * The number of values to store per bucket.
+ */
+ public int getBucketSize() {
+ return bucketSize;
+ }
+
+ /**
+ * Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
+ * Returns [0, 0] if the bucket has never been collected.
+ */
+ private Tuple getBucketValuesIndexes(int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (rootIndex >= values.size()) {
+ // We've never seen this bucket.
+ return Tuple.tuple(0L, 0L);
+ }
+ long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
+ long end = rootIndex + bucketSize;
+ return Tuple.tuple(start, end);
+ }
+
+ /**
+ * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
+ */
+ public void merge(int groupId, DoubleBucketedSort other, int otherGroupId) {
+ var otherBounds = other.getBucketValuesIndexes(otherGroupId);
+
+ // TODO: This can be improved for heapified buckets by making use of the heap structures
+ for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
+ collect(other.values.get(i), groupId);
+ }
+ }
+
+ /**
+ * Creates a block with the values from the {@code selected} groups.
+ */
+ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ // Check if the selected groups are all empty, to avoid allocating extra memory
+ if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
+ var bounds = this.getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ return size > 0;
+ })) {
+ return blockFactory.newConstantNullBlock(selected.getPositionCount());
+ }
+
+ // Used to sort the values in the bucket.
+ var bucketValues = new double[bucketSize];
+
+ try (var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) {
+ for (int s = 0; s < selected.getPositionCount(); s++) {
+ int bucket = selected.getInt(s);
+
+ var bounds = getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ if (size == 0) {
+ builder.appendNull();
+ continue;
+ }
+
+ if (size == 1) {
+ builder.appendDouble(values.get(bounds.v1()));
+ continue;
+ }
+
+ for (int i = 0; i < size; i++) {
+ bucketValues[i] = values.get(bounds.v1() + i);
+ }
+
+ // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
+ Arrays.sort(bucketValues, 0, (int) size);
+
+ builder.beginPositionEntry();
+ if (order == SortOrder.ASC) {
+ for (int i = 0; i < size; i++) {
+ builder.appendDouble(bucketValues[i]);
+ }
+ } else {
+ for (int i = (int) size - 1; i >= 0; i--) {
+ builder.appendDouble(bucketValues[i]);
+ }
+ }
+ builder.endPositionEntry();
+ }
+ return builder.build();
+ }
+ }
+
+ /**
+ * Is this bucket a min heap {@code true} or in gathering mode {@code false}?
+ */
+ private boolean inHeapMode(int bucket) {
+ return heapMode.get(bucket);
+ }
+
+ /**
+ * Get the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private int getNextGatherOffset(long rootIndex) {
+ return (int) values.get(rootIndex);
+ }
+
+ /**
+ * Set the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private void setNextGatherOffset(long rootIndex, int offset) {
+ values.set(rootIndex, offset);
+ }
+
+ /**
+ * {@code true} if the entry at index {@code lhs} is "better" than
+ * the entry at {@code rhs}. "Better" in this means "lower" for
+ * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
+ */
+ private boolean betterThan(double lhs, double rhs) {
+ return getOrder().reverseMul() * Double.compare(lhs, rhs) < 0;
+ }
+
+ /**
+ * Swap the data at two indices.
+ */
+ private void swap(long lhs, long rhs) {
+ var tmp = values.get(lhs);
+ values.set(lhs, values.get(rhs));
+ values.set(rhs, tmp);
+ }
+
+ /**
+ * Allocate storage for more buckets and store the "next gather offset"
+ * for those new buckets.
+ */
+ private void grow(long minSize) {
+ long oldMax = values.size();
+ values = bigArrays.grow(values, minSize);
+ // Set the next gather offsets for all newly allocated buckets.
+ setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
+ }
+
+ /**
+ * Maintain the "next gather offsets" for newly allocated buckets.
+ */
+ private void setNextGatherOffsets(long startingAt) {
+ int nextOffset = getBucketSize() - 1;
+ for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
+ setNextGatherOffset(bucketRoot, nextOffset);
+ }
+ }
+
+ /**
+ * Heapify a bucket whose entries are in random order.
+ *
+ * This works by validating the heap property on each node, iterating
+ * "upwards", pushing any out of order parents "down". Check out the
+ * wikipedia
+ * entry on binary heaps for more about this.
+ *
+ *
+ * While this *looks* like it could easily be {@code O(n * log n)}, it is
+ * a fairly well studied algorithm attributed to Floyd. There's
+ * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
+ * case.
+ *
+ *
+ * @param rootIndex the index the start of the bucket
+ */
+ private void heapify(long rootIndex) {
+ int maxParent = bucketSize / 2 - 1;
+ for (int parent = maxParent; parent >= 0; parent--) {
+ downHeap(rootIndex, parent);
+ }
+ }
+
+ /**
+ * Correct the heap invariant of a parent and its children. This
+ * runs in {@code O(log n)} time.
+ * @param rootIndex index of the start of the bucket
+ * @param parent Index within the bucket of the parent to check.
+ * For example, 0 is the "root".
+ */
+ private void downHeap(long rootIndex, int parent) {
+ while (true) {
+ long parentIndex = rootIndex + parent;
+ int worst = parent;
+ long worstIndex = parentIndex;
+ int leftChild = parent * 2 + 1;
+ long leftIndex = rootIndex + leftChild;
+ if (leftChild < bucketSize) {
+ if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
+ worst = leftChild;
+ worstIndex = leftIndex;
+ }
+ int rightChild = leftChild + 1;
+ long rightIndex = rootIndex + rightChild;
+ if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
+ worst = rightChild;
+ worstIndex = rightIndex;
+ }
+ }
+ if (worst == parent) {
+ break;
+ }
+ swap(worstIndex, parentIndex);
+ parent = worst;
+ }
+ }
+
+ @Override
+ public final void close() {
+ Releasables.close(values, heapMode);
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java
new file mode 100644
index 0000000000000..04a635d75fe52
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java
@@ -0,0 +1,346 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.IntArray;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.search.sort.BucketedSort;
+import org.elasticsearch.search.sort.SortOrder;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+/**
+ * Aggregates the top N int values per bucket.
+ * See {@link BucketedSort} for more information.
+ * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
+ */
+public class IntBucketedSort implements Releasable {
+
+ private final BigArrays bigArrays;
+ private final SortOrder order;
+ private final int bucketSize;
+ /**
+ * {@code true} if the bucket is in heap mode, {@code false} if
+ * it is still gathering.
+ */
+ private final BitArray heapMode;
+ /**
+ * An array containing all the values on all buckets. The structure is as follows:
+ *
+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
+ * Then, for each bucket, it can be in 2 states:
+ *
+ *
+ * -
+ * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
+ * In gather mode, the elements are stored in the array from the highest index to the lowest index.
+ * The lowest index contains the offset to the next slot to be filled.
+ *
+ * This allows us to insert elements in O(1) time.
+ *
+ *
+ * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
+ *
+ *
+ * -
+ * Heap mode: The bucket slots are organized as a min heap structure.
+ *
+ * The root of the heap is the minimum value in the bucket,
+ * which allows us to quickly discard new values that are not in the top N.
+ *
+ *
+ *
+ */
+ private IntArray values;
+
+ public IntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
+ this.bigArrays = bigArrays;
+ this.order = order;
+ this.bucketSize = bucketSize;
+ heapMode = new BitArray(0, bigArrays);
+
+ boolean success = false;
+ try {
+ values = bigArrays.newIntArray(0, false);
+ success = true;
+ } finally {
+ if (success == false) {
+ close();
+ }
+ }
+ }
+
+ /**
+ * Collects a {@code value} into a {@code bucket}.
+ *
+ * It may or may not be inserted in the heap, depending on if it is better than the current root.
+ *
+ */
+ public void collect(int value, int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (inHeapMode(bucket)) {
+ if (betterThan(value, values.get(rootIndex))) {
+ values.set(rootIndex, value);
+ downHeap(rootIndex, 0);
+ }
+ return;
+ }
+ // Gathering mode
+ long requiredSize = rootIndex + bucketSize;
+ if (values.size() < requiredSize) {
+ grow(requiredSize);
+ }
+ int next = getNextGatherOffset(rootIndex);
+ assert 0 <= next && next < bucketSize
+ : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
+ long index = next + rootIndex;
+ values.set(index, value);
+ if (next == 0) {
+ heapMode.set(bucket);
+ heapify(rootIndex);
+ } else {
+ setNextGatherOffset(rootIndex, next - 1);
+ }
+ }
+
+ /**
+ * The order of the sort.
+ */
+ public SortOrder getOrder() {
+ return order;
+ }
+
+ /**
+ * The number of values to store per bucket.
+ */
+ public int getBucketSize() {
+ return bucketSize;
+ }
+
+ /**
+ * Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
+ * Returns [0, 0] if the bucket has never been collected.
+ */
+ private Tuple getBucketValuesIndexes(int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (rootIndex >= values.size()) {
+ // We've never seen this bucket.
+ return Tuple.tuple(0L, 0L);
+ }
+ long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
+ long end = rootIndex + bucketSize;
+ return Tuple.tuple(start, end);
+ }
+
+ /**
+ * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
+ */
+ public void merge(int groupId, IntBucketedSort other, int otherGroupId) {
+ var otherBounds = other.getBucketValuesIndexes(otherGroupId);
+
+ // TODO: This can be improved for heapified buckets by making use of the heap structures
+ for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
+ collect(other.values.get(i), groupId);
+ }
+ }
+
+ /**
+ * Creates a block with the values from the {@code selected} groups.
+ */
+ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ // Check if the selected groups are all empty, to avoid allocating extra memory
+ if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
+ var bounds = this.getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ return size > 0;
+ })) {
+ return blockFactory.newConstantNullBlock(selected.getPositionCount());
+ }
+
+ // Used to sort the values in the bucket.
+ var bucketValues = new int[bucketSize];
+
+ try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
+ for (int s = 0; s < selected.getPositionCount(); s++) {
+ int bucket = selected.getInt(s);
+
+ var bounds = getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ if (size == 0) {
+ builder.appendNull();
+ continue;
+ }
+
+ if (size == 1) {
+ builder.appendInt(values.get(bounds.v1()));
+ continue;
+ }
+
+ for (int i = 0; i < size; i++) {
+ bucketValues[i] = values.get(bounds.v1() + i);
+ }
+
+ // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
+ Arrays.sort(bucketValues, 0, (int) size);
+
+ builder.beginPositionEntry();
+ if (order == SortOrder.ASC) {
+ for (int i = 0; i < size; i++) {
+ builder.appendInt(bucketValues[i]);
+ }
+ } else {
+ for (int i = (int) size - 1; i >= 0; i--) {
+ builder.appendInt(bucketValues[i]);
+ }
+ }
+ builder.endPositionEntry();
+ }
+ return builder.build();
+ }
+ }
+
+ /**
+ * Is this bucket a min heap {@code true} or in gathering mode {@code false}?
+ */
+ private boolean inHeapMode(int bucket) {
+ return heapMode.get(bucket);
+ }
+
+ /**
+ * Get the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private int getNextGatherOffset(long rootIndex) {
+ return values.get(rootIndex);
+ }
+
+ /**
+ * Set the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private void setNextGatherOffset(long rootIndex, int offset) {
+ values.set(rootIndex, offset);
+ }
+
+ /**
+ * {@code true} if the entry at index {@code lhs} is "better" than
+ * the entry at {@code rhs}. "Better" in this means "lower" for
+ * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
+ */
+ private boolean betterThan(int lhs, int rhs) {
+ return getOrder().reverseMul() * Integer.compare(lhs, rhs) < 0;
+ }
+
+ /**
+ * Swap the data at two indices.
+ */
+ private void swap(long lhs, long rhs) {
+ var tmp = values.get(lhs);
+ values.set(lhs, values.get(rhs));
+ values.set(rhs, tmp);
+ }
+
+ /**
+ * Allocate storage for more buckets and store the "next gather offset"
+ * for those new buckets.
+ */
+ private void grow(long minSize) {
+ long oldMax = values.size();
+ values = bigArrays.grow(values, minSize);
+ // Set the next gather offsets for all newly allocated buckets.
+ setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
+ }
+
+ /**
+ * Maintain the "next gather offsets" for newly allocated buckets.
+ */
+ private void setNextGatherOffsets(long startingAt) {
+ int nextOffset = getBucketSize() - 1;
+ for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
+ setNextGatherOffset(bucketRoot, nextOffset);
+ }
+ }
+
+ /**
+ * Heapify a bucket whose entries are in random order.
+ *
+ * This works by validating the heap property on each node, iterating
+ * "upwards", pushing any out of order parents "down". Check out the
+ * wikipedia
+ * entry on binary heaps for more about this.
+ *
+ *
+ * While this *looks* like it could easily be {@code O(n * log n)}, it is
+ * a fairly well studied algorithm attributed to Floyd. There's
+ * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
+ * case.
+ *
+ *
+ * @param rootIndex the index the start of the bucket
+ */
+ private void heapify(long rootIndex) {
+ int maxParent = bucketSize / 2 - 1;
+ for (int parent = maxParent; parent >= 0; parent--) {
+ downHeap(rootIndex, parent);
+ }
+ }
+
+ /**
+ * Correct the heap invariant of a parent and its children. This
+ * runs in {@code O(log n)} time.
+ * @param rootIndex index of the start of the bucket
+ * @param parent Index within the bucket of the parent to check.
+ * For example, 0 is the "root".
+ */
+ private void downHeap(long rootIndex, int parent) {
+ while (true) {
+ long parentIndex = rootIndex + parent;
+ int worst = parent;
+ long worstIndex = parentIndex;
+ int leftChild = parent * 2 + 1;
+ long leftIndex = rootIndex + leftChild;
+ if (leftChild < bucketSize) {
+ if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
+ worst = leftChild;
+ worstIndex = leftIndex;
+ }
+ int rightChild = leftChild + 1;
+ long rightIndex = rootIndex + rightChild;
+ if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
+ worst = rightChild;
+ worstIndex = rightIndex;
+ }
+ }
+ if (worst == parent) {
+ break;
+ }
+ swap(worstIndex, parentIndex);
+ parent = worst;
+ }
+ }
+
+ @Override
+ public final void close() {
+ Releasables.close(values, heapMode);
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java
new file mode 100644
index 0000000000000..e08c25256944b
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java
@@ -0,0 +1,346 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.LongArray;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.search.sort.BucketedSort;
+import org.elasticsearch.search.sort.SortOrder;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+/**
+ * Aggregates the top N long values per bucket.
+ * See {@link BucketedSort} for more information.
+ * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
+ */
+public class LongBucketedSort implements Releasable {
+
+ private final BigArrays bigArrays;
+ private final SortOrder order;
+ private final int bucketSize;
+ /**
+ * {@code true} if the bucket is in heap mode, {@code false} if
+ * it is still gathering.
+ */
+ private final BitArray heapMode;
+ /**
+ * An array containing all the values on all buckets. The structure is as follows:
+ *
+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
+ * Then, for each bucket, it can be in 2 states:
+ *
+ *
+ * -
+ * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
+ * In gather mode, the elements are stored in the array from the highest index to the lowest index.
+ * The lowest index contains the offset to the next slot to be filled.
+ *
+ * This allows us to insert elements in O(1) time.
+ *
+ *
+ * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
+ *
+ *
+ * -
+ * Heap mode: The bucket slots are organized as a min heap structure.
+ *
+ * The root of the heap is the minimum value in the bucket,
+ * which allows us to quickly discard new values that are not in the top N.
+ *
+ *
+ *
+ */
+ private LongArray values;
+
+ public LongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
+ this.bigArrays = bigArrays;
+ this.order = order;
+ this.bucketSize = bucketSize;
+ heapMode = new BitArray(0, bigArrays);
+
+ boolean success = false;
+ try {
+ values = bigArrays.newLongArray(0, false);
+ success = true;
+ } finally {
+ if (success == false) {
+ close();
+ }
+ }
+ }
+
+ /**
+ * Collects a {@code value} into a {@code bucket}.
+ *
+ * It may or may not be inserted in the heap, depending on if it is better than the current root.
+ *
+ */
+ public void collect(long value, int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (inHeapMode(bucket)) {
+ if (betterThan(value, values.get(rootIndex))) {
+ values.set(rootIndex, value);
+ downHeap(rootIndex, 0);
+ }
+ return;
+ }
+ // Gathering mode
+ long requiredSize = rootIndex + bucketSize;
+ if (values.size() < requiredSize) {
+ grow(requiredSize);
+ }
+ int next = getNextGatherOffset(rootIndex);
+ assert 0 <= next && next < bucketSize
+ : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
+ long index = next + rootIndex;
+ values.set(index, value);
+ if (next == 0) {
+ heapMode.set(bucket);
+ heapify(rootIndex);
+ } else {
+ setNextGatherOffset(rootIndex, next - 1);
+ }
+ }
+
+ /**
+ * The order of the sort.
+ */
+ public SortOrder getOrder() {
+ return order;
+ }
+
+ /**
+ * The number of values to store per bucket.
+ */
+ public int getBucketSize() {
+ return bucketSize;
+ }
+
+ /**
+ * Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
+ * Returns [0, 0] if the bucket has never been collected.
+ */
+ private Tuple getBucketValuesIndexes(int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (rootIndex >= values.size()) {
+ // We've never seen this bucket.
+ return Tuple.tuple(0L, 0L);
+ }
+ long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
+ long end = rootIndex + bucketSize;
+ return Tuple.tuple(start, end);
+ }
+
+ /**
+ * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
+ */
+ public void merge(int groupId, LongBucketedSort other, int otherGroupId) {
+ var otherBounds = other.getBucketValuesIndexes(otherGroupId);
+
+ // TODO: This can be improved for heapified buckets by making use of the heap structures
+ for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
+ collect(other.values.get(i), groupId);
+ }
+ }
+
+ /**
+ * Creates a block with the values from the {@code selected} groups.
+ */
+ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ // Check if the selected groups are all empty, to avoid allocating extra memory
+ if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
+ var bounds = this.getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ return size > 0;
+ })) {
+ return blockFactory.newConstantNullBlock(selected.getPositionCount());
+ }
+
+ // Used to sort the values in the bucket.
+ var bucketValues = new long[bucketSize];
+
+ try (var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) {
+ for (int s = 0; s < selected.getPositionCount(); s++) {
+ int bucket = selected.getInt(s);
+
+ var bounds = getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ if (size == 0) {
+ builder.appendNull();
+ continue;
+ }
+
+ if (size == 1) {
+ builder.appendLong(values.get(bounds.v1()));
+ continue;
+ }
+
+ for (int i = 0; i < size; i++) {
+ bucketValues[i] = values.get(bounds.v1() + i);
+ }
+
+ // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
+ Arrays.sort(bucketValues, 0, (int) size);
+
+ builder.beginPositionEntry();
+ if (order == SortOrder.ASC) {
+ for (int i = 0; i < size; i++) {
+ builder.appendLong(bucketValues[i]);
+ }
+ } else {
+ for (int i = (int) size - 1; i >= 0; i--) {
+ builder.appendLong(bucketValues[i]);
+ }
+ }
+ builder.endPositionEntry();
+ }
+ return builder.build();
+ }
+ }
+
+ /**
+ * Is this bucket a min heap {@code true} or in gathering mode {@code false}?
+ */
+ private boolean inHeapMode(int bucket) {
+ return heapMode.get(bucket);
+ }
+
+ /**
+ * Get the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private int getNextGatherOffset(long rootIndex) {
+ return (int) values.get(rootIndex);
+ }
+
+ /**
+ * Set the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private void setNextGatherOffset(long rootIndex, int offset) {
+ values.set(rootIndex, offset);
+ }
+
+ /**
+ * {@code true} if the entry at index {@code lhs} is "better" than
+ * the entry at {@code rhs}. "Better" in this means "lower" for
+ * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
+ */
+ private boolean betterThan(long lhs, long rhs) {
+ return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0;
+ }
+
+ /**
+ * Swap the data at two indices.
+ */
+ private void swap(long lhs, long rhs) {
+ var tmp = values.get(lhs);
+ values.set(lhs, values.get(rhs));
+ values.set(rhs, tmp);
+ }
+
+ /**
+ * Allocate storage for more buckets and store the "next gather offset"
+ * for those new buckets.
+ */
+ private void grow(long minSize) {
+ long oldMax = values.size();
+ values = bigArrays.grow(values, minSize);
+ // Set the next gather offsets for all newly allocated buckets.
+ setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
+ }
+
+ /**
+ * Maintain the "next gather offsets" for newly allocated buckets.
+ */
+ private void setNextGatherOffsets(long startingAt) {
+ int nextOffset = getBucketSize() - 1;
+ for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
+ setNextGatherOffset(bucketRoot, nextOffset);
+ }
+ }
+
+ /**
+ * Heapify a bucket whose entries are in random order.
+ *
+ * This works by validating the heap property on each node, iterating
+ * "upwards", pushing any out of order parents "down". Check out the
+ * wikipedia
+ * entry on binary heaps for more about this.
+ *
+ *
+ * While this *looks* like it could easily be {@code O(n * log n)}, it is
+ * a fairly well studied algorithm attributed to Floyd. There's
+ * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
+ * case.
+ *
+ *
+ * @param rootIndex the index the start of the bucket
+ */
+ private void heapify(long rootIndex) {
+ int maxParent = bucketSize / 2 - 1;
+ for (int parent = maxParent; parent >= 0; parent--) {
+ downHeap(rootIndex, parent);
+ }
+ }
+
+ /**
+ * Correct the heap invariant of a parent and its children. This
+ * runs in {@code O(log n)} time.
+ * @param rootIndex index of the start of the bucket
+ * @param parent Index within the bucket of the parent to check.
+ * For example, 0 is the "root".
+ */
+ private void downHeap(long rootIndex, int parent) {
+ while (true) {
+ long parentIndex = rootIndex + parent;
+ int worst = parent;
+ long worstIndex = parentIndex;
+ int leftChild = parent * 2 + 1;
+ long leftIndex = rootIndex + leftChild;
+ if (leftChild < bucketSize) {
+ if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
+ worst = leftChild;
+ worstIndex = leftIndex;
+ }
+ int rightChild = leftChild + 1;
+ long rightIndex = rootIndex + rightChild;
+ if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
+ worst = rightChild;
+ worstIndex = rightIndex;
+ }
+ }
+ if (worst == parent) {
+ break;
+ }
+ swap(worstIndex, parentIndex);
+ parent = worst;
+ }
+ }
+
+ @Override
+ public final void close() {
+ Releasables.close(values, heapMode);
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java
new file mode 100644
index 0000000000000..d52d25941780c
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunction.java
@@ -0,0 +1,126 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunction} implementation for {@link TopListDoubleAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListDoubleAggregatorFunction implements AggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.DOUBLE) );
+
+ private final DriverContext driverContext;
+
+ private final TopListDoubleAggregator.SingleState state;
+
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListDoubleAggregatorFunction(DriverContext driverContext, List channels,
+ TopListDoubleAggregator.SingleState state, int limit, boolean ascending) {
+ this.driverContext = driverContext;
+ this.channels = channels;
+ this.state = state;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListDoubleAggregatorFunction create(DriverContext driverContext,
+ List channels, int limit, boolean ascending) {
+ return new TopListDoubleAggregatorFunction(driverContext, channels, TopListDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public void addRawInput(Page page) {
+ DoubleBlock block = page.getBlock(channels.get(0));
+ DoubleVector vector = block.asVector();
+ if (vector != null) {
+ addRawVector(vector);
+ } else {
+ addRawBlock(block);
+ }
+ }
+
+ private void addRawVector(DoubleVector vector) {
+ for (int i = 0; i < vector.getPositionCount(); i++) {
+ TopListDoubleAggregator.combine(state, vector.getDouble(i));
+ }
+ }
+
+ private void addRawBlock(DoubleBlock block) {
+ for (int p = 0; p < block.getPositionCount(); p++) {
+ if (block.isNull(p)) {
+ continue;
+ }
+ int start = block.getFirstValueIndex(p);
+ int end = start + block.getValueCount(p);
+ for (int i = start; i < end; i++) {
+ TopListDoubleAggregator.combine(state, block.getDouble(i));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(Page page) {
+ assert channels.size() == intermediateBlockCount();
+ assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ DoubleBlock topList = (DoubleBlock) topListUncast;
+ assert topList.getPositionCount() == 1;
+ TopListDoubleAggregator.combineIntermediate(state, topList);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ state.toIntermediate(blocks, offset, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java
new file mode 100644
index 0000000000000..48df091d339b6
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionSupplier.java
@@ -0,0 +1,45 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.util.List;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunctionSupplier} implementation for {@link TopListDoubleAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListDoubleAggregatorFunctionSupplier(List channels, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ @Override
+ public TopListDoubleAggregatorFunction aggregator(DriverContext driverContext) {
+ return TopListDoubleAggregatorFunction.create(driverContext, channels, limit, ascending);
+ }
+
+ @Override
+ public TopListDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
+ return TopListDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
+ }
+
+ @Override
+ public String describe() {
+ return "top_list of doubles";
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java
new file mode 100644
index 0000000000000..0e3b98bb0f7e5
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListDoubleGroupingAggregatorFunction.java
@@ -0,0 +1,202 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link GroupingAggregatorFunction} implementation for {@link TopListDoubleAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.DOUBLE) );
+
+ private final TopListDoubleAggregator.GroupingState state;
+
+ private final List channels;
+
+ private final DriverContext driverContext;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListDoubleGroupingAggregatorFunction(List channels,
+ TopListDoubleAggregator.GroupingState state, DriverContext driverContext, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.state = state;
+ this.driverContext = driverContext;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListDoubleGroupingAggregatorFunction create(List channels,
+ DriverContext driverContext, int limit, boolean ascending) {
+ return new TopListDoubleGroupingAggregatorFunction(channels, TopListDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+ Page page) {
+ DoubleBlock valuesBlock = page.getBlock(channels.get(0));
+ DoubleVector valuesVector = valuesBlock.asVector();
+ if (valuesVector == null) {
+ if (valuesBlock.mayHaveNulls()) {
+ state.enableGroupIdTracking(seenGroupIds);
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+ };
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+ };
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListDoubleAggregator.combine(state, groupId, values.getDouble(v));
+ }
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ TopListDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ assert channels.size() == intermediateBlockCount();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ DoubleBlock topList = (DoubleBlock) topListUncast;
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListDoubleAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
+ }
+ }
+
+ @Override
+ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
+ if (input.getClass() != getClass()) {
+ throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
+ }
+ TopListDoubleAggregator.GroupingState inState = ((TopListDoubleGroupingAggregatorFunction) input).state;
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ TopListDoubleAggregator.combineStates(state, groupId, inState, position);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+ state.toIntermediate(blocks, offset, selected, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
+ DriverContext driverContext) {
+ blocks[offset] = TopListDoubleAggregator.evaluateFinal(state, selected, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java
new file mode 100644
index 0000000000000..e885b285c4a51
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunction.java
@@ -0,0 +1,126 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunction} implementation for {@link TopListIntAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListIntAggregatorFunction implements AggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.INT) );
+
+ private final DriverContext driverContext;
+
+ private final TopListIntAggregator.SingleState state;
+
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListIntAggregatorFunction(DriverContext driverContext, List channels,
+ TopListIntAggregator.SingleState state, int limit, boolean ascending) {
+ this.driverContext = driverContext;
+ this.channels = channels;
+ this.state = state;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListIntAggregatorFunction create(DriverContext driverContext,
+ List channels, int limit, boolean ascending) {
+ return new TopListIntAggregatorFunction(driverContext, channels, TopListIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public void addRawInput(Page page) {
+ IntBlock block = page.getBlock(channels.get(0));
+ IntVector vector = block.asVector();
+ if (vector != null) {
+ addRawVector(vector);
+ } else {
+ addRawBlock(block);
+ }
+ }
+
+ private void addRawVector(IntVector vector) {
+ for (int i = 0; i < vector.getPositionCount(); i++) {
+ TopListIntAggregator.combine(state, vector.getInt(i));
+ }
+ }
+
+ private void addRawBlock(IntBlock block) {
+ for (int p = 0; p < block.getPositionCount(); p++) {
+ if (block.isNull(p)) {
+ continue;
+ }
+ int start = block.getFirstValueIndex(p);
+ int end = start + block.getValueCount(p);
+ for (int i = start; i < end; i++) {
+ TopListIntAggregator.combine(state, block.getInt(i));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(Page page) {
+ assert channels.size() == intermediateBlockCount();
+ assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ IntBlock topList = (IntBlock) topListUncast;
+ assert topList.getPositionCount() == 1;
+ TopListIntAggregator.combineIntermediate(state, topList);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ state.toIntermediate(blocks, offset, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = TopListIntAggregator.evaluateFinal(state, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java
new file mode 100644
index 0000000000000..d8bf91ba85541
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionSupplier.java
@@ -0,0 +1,45 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.util.List;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunctionSupplier} implementation for {@link TopListIntAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListIntAggregatorFunctionSupplier(List channels, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ @Override
+ public TopListIntAggregatorFunction aggregator(DriverContext driverContext) {
+ return TopListIntAggregatorFunction.create(driverContext, channels, limit, ascending);
+ }
+
+ @Override
+ public TopListIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
+ return TopListIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
+ }
+
+ @Override
+ public String describe() {
+ return "top_list of ints";
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java
new file mode 100644
index 0000000000000..820ebb95e530c
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListIntGroupingAggregatorFunction.java
@@ -0,0 +1,200 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link GroupingAggregatorFunction} implementation for {@link TopListIntAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListIntGroupingAggregatorFunction implements GroupingAggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.INT) );
+
+ private final TopListIntAggregator.GroupingState state;
+
+ private final List channels;
+
+ private final DriverContext driverContext;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListIntGroupingAggregatorFunction(List channels,
+ TopListIntAggregator.GroupingState state, DriverContext driverContext, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.state = state;
+ this.driverContext = driverContext;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListIntGroupingAggregatorFunction create(List channels,
+ DriverContext driverContext, int limit, boolean ascending) {
+ return new TopListIntGroupingAggregatorFunction(channels, TopListIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+ Page page) {
+ IntBlock valuesBlock = page.getBlock(channels.get(0));
+ IntVector valuesVector = valuesBlock.asVector();
+ if (valuesVector == null) {
+ if (valuesBlock.mayHaveNulls()) {
+ state.enableGroupIdTracking(seenGroupIds);
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+ };
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+ };
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, IntBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListIntAggregator.combine(state, groupId, values.getInt(v));
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, IntVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListIntAggregator.combine(state, groupId, values.getInt(v));
+ }
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, IntVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ TopListIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ assert channels.size() == intermediateBlockCount();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ IntBlock topList = (IntBlock) topListUncast;
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListIntAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
+ }
+ }
+
+ @Override
+ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
+ if (input.getClass() != getClass()) {
+ throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
+ }
+ TopListIntAggregator.GroupingState inState = ((TopListIntGroupingAggregatorFunction) input).state;
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ TopListIntAggregator.combineStates(state, groupId, inState, position);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+ state.toIntermediate(blocks, offset, selected, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
+ DriverContext driverContext) {
+ blocks[offset] = TopListIntAggregator.evaluateFinal(state, selected, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java
new file mode 100644
index 0000000000000..1a09a1a860e2f
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunction.java
@@ -0,0 +1,126 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.LongVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunction} implementation for {@link TopListLongAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListLongAggregatorFunction implements AggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.LONG) );
+
+ private final DriverContext driverContext;
+
+ private final TopListLongAggregator.SingleState state;
+
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListLongAggregatorFunction(DriverContext driverContext, List channels,
+ TopListLongAggregator.SingleState state, int limit, boolean ascending) {
+ this.driverContext = driverContext;
+ this.channels = channels;
+ this.state = state;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListLongAggregatorFunction create(DriverContext driverContext,
+ List channels, int limit, boolean ascending) {
+ return new TopListLongAggregatorFunction(driverContext, channels, TopListLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public void addRawInput(Page page) {
+ LongBlock block = page.getBlock(channels.get(0));
+ LongVector vector = block.asVector();
+ if (vector != null) {
+ addRawVector(vector);
+ } else {
+ addRawBlock(block);
+ }
+ }
+
+ private void addRawVector(LongVector vector) {
+ for (int i = 0; i < vector.getPositionCount(); i++) {
+ TopListLongAggregator.combine(state, vector.getLong(i));
+ }
+ }
+
+ private void addRawBlock(LongBlock block) {
+ for (int p = 0; p < block.getPositionCount(); p++) {
+ if (block.isNull(p)) {
+ continue;
+ }
+ int start = block.getFirstValueIndex(p);
+ int end = start + block.getValueCount(p);
+ for (int i = start; i < end; i++) {
+ TopListLongAggregator.combine(state, block.getLong(i));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(Page page) {
+ assert channels.size() == intermediateBlockCount();
+ assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ LongBlock topList = (LongBlock) topListUncast;
+ assert topList.getPositionCount() == 1;
+ TopListLongAggregator.combineIntermediate(state, topList);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ state.toIntermediate(blocks, offset, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = TopListLongAggregator.evaluateFinal(state, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java
new file mode 100644
index 0000000000000..617895fbff1a3
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionSupplier.java
@@ -0,0 +1,45 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.util.List;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunctionSupplier} implementation for {@link TopListLongAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
+ private final List channels;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListLongAggregatorFunctionSupplier(List channels, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ @Override
+ public TopListLongAggregatorFunction aggregator(DriverContext driverContext) {
+ return TopListLongAggregatorFunction.create(driverContext, channels, limit, ascending);
+ }
+
+ @Override
+ public TopListLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
+ return TopListLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
+ }
+
+ @Override
+ public String describe() {
+ return "top_list of longs";
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java
new file mode 100644
index 0000000000000..cadb48b7d29d4
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopListLongGroupingAggregatorFunction.java
@@ -0,0 +1,202 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.LongVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link GroupingAggregatorFunction} implementation for {@link TopListLongAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopListLongGroupingAggregatorFunction implements GroupingAggregatorFunction {
+ private static final List INTERMEDIATE_STATE_DESC = List.of(
+ new IntermediateStateDesc("topList", ElementType.LONG) );
+
+ private final TopListLongAggregator.GroupingState state;
+
+ private final List channels;
+
+ private final DriverContext driverContext;
+
+ private final int limit;
+
+ private final boolean ascending;
+
+ public TopListLongGroupingAggregatorFunction(List channels,
+ TopListLongAggregator.GroupingState state, DriverContext driverContext, int limit,
+ boolean ascending) {
+ this.channels = channels;
+ this.state = state;
+ this.driverContext = driverContext;
+ this.limit = limit;
+ this.ascending = ascending;
+ }
+
+ public static TopListLongGroupingAggregatorFunction create(List channels,
+ DriverContext driverContext, int limit, boolean ascending) {
+ return new TopListLongGroupingAggregatorFunction(channels, TopListLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
+ }
+
+ public static List intermediateStateDesc() {
+ return INTERMEDIATE_STATE_DESC;
+ }
+
+ @Override
+ public int intermediateBlockCount() {
+ return INTERMEDIATE_STATE_DESC.size();
+ }
+
+ @Override
+ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+ Page page) {
+ LongBlock valuesBlock = page.getBlock(channels.get(0));
+ LongVector valuesVector = valuesBlock.asVector();
+ if (valuesVector == null) {
+ if (valuesBlock.mayHaveNulls()) {
+ state.enableGroupIdTracking(seenGroupIds);
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesBlock);
+ }
+ };
+ }
+ return new GroupingAggregatorFunction.AddInput() {
+ @Override
+ public void add(int positionOffset, IntBlock groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+
+ @Override
+ public void add(int positionOffset, IntVector groupIds) {
+ addRawInput(positionOffset, groupIds, valuesVector);
+ }
+ };
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, LongBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListLongAggregator.combine(state, groupId, values.getLong(v));
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntVector groups, LongVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ if (values.isNull(groupPosition + positionOffset)) {
+ continue;
+ }
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+ for (int v = valuesStart; v < valuesEnd; v++) {
+ TopListLongAggregator.combine(state, groupId, values.getLong(v));
+ }
+ }
+ }
+ }
+
+ private void addRawInput(int positionOffset, IntBlock groups, LongVector values) {
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ if (groups.isNull(groupPosition)) {
+ continue;
+ }
+ int groupStart = groups.getFirstValueIndex(groupPosition);
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
+ for (int g = groupStart; g < groupEnd; g++) {
+ int groupId = Math.toIntExact(groups.getInt(g));
+ TopListLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset));
+ }
+ }
+ }
+
+ @Override
+ public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ assert channels.size() == intermediateBlockCount();
+ Block topListUncast = page.getBlock(channels.get(0));
+ if (topListUncast.areAllValuesNull()) {
+ return;
+ }
+ LongBlock topList = (LongBlock) topListUncast;
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+ int groupId = Math.toIntExact(groups.getInt(groupPosition));
+ TopListLongAggregator.combineIntermediate(state, groupId, topList, groupPosition + positionOffset);
+ }
+ }
+
+ @Override
+ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
+ if (input.getClass() != getClass()) {
+ throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
+ }
+ TopListLongAggregator.GroupingState inState = ((TopListLongGroupingAggregatorFunction) input).state;
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
+ TopListLongAggregator.combineStates(state, groupId, inState, position);
+ }
+
+ @Override
+ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+ state.toIntermediate(blocks, offset, selected, driverContext);
+ }
+
+ @Override
+ public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
+ DriverContext driverContext) {
+ blocks[offset] = TopListLongAggregator.evaluateFinal(state, selected, driverContext);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName()).append("[");
+ sb.append("channels=").append(channels);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public void close() {
+ state.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/module-info.java b/x-pack/plugin/esql/compute/src/main/java/module-info.java
index 3772d6c83f5aa..dc8cda0fbe3c8 100644
--- a/x-pack/plugin/esql/compute/src/main/java/module-info.java
+++ b/x-pack/plugin/esql/compute/src/main/java/module-info.java
@@ -30,4 +30,5 @@
exports org.elasticsearch.compute.operator.topn;
exports org.elasticsearch.compute.operator.mvdedupe;
exports org.elasticsearch.compute.aggregation.table;
+ exports org.elasticsearch.compute.data.sort;
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st
new file mode 100644
index 0000000000000..810311154503e
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st
@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.ann.Aggregator;
+import org.elasticsearch.compute.ann.GroupingAggregator;
+import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+$if(!long)$
+import org.elasticsearch.compute.data.$Type$Block;
+$endif$
+import org.elasticsearch.compute.data.IntVector;
+$if(long)$
+import org.elasticsearch.compute.data.$Type$Block;
+$endif$
+import org.elasticsearch.compute.data.sort.$Type$BucketedSort;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+/**
+ * Aggregates the top N field values for $type$.
+ */
+@Aggregator({ @IntermediateState(name = "topList", type = "$TYPE$_BLOCK") })
+@GroupingAggregator
+class TopList$Type$Aggregator {
+ public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
+ return new SingleState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(SingleState state, $type$ v) {
+ state.add(v);
+ }
+
+ public static void combineIntermediate(SingleState state, $Type$Block values) {
+ int start = values.getFirstValueIndex(0);
+ int end = start + values.getValueCount(0);
+ for (int i = start; i < end; i++) {
+ combine(state, values.get$Type$(i));
+ }
+ }
+
+ public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory());
+ }
+
+ public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
+ return new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public static void combine(GroupingState state, int groupId, $type$ v) {
+ state.add(groupId, v);
+ }
+
+ public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
+ int start = values.getFirstValueIndex(valuesPosition);
+ int end = start + values.getValueCount(valuesPosition);
+ for (int i = start; i < end; i++) {
+ combine(state, groupId, values.get$Type$(i));
+ }
+ }
+
+ public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
+ current.merge(groupId, state, statePosition);
+ }
+
+ public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
+ return state.toBlock(driverContext.blockFactory(), selected);
+ }
+
+ public static class GroupingState implements Releasable {
+ private final $Type$BucketedSort sort;
+
+ private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.sort = new $Type$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
+ }
+
+ public void add(int groupId, $type$ value) {
+ sort.collect(value, groupId);
+ }
+
+ public void merge(int groupId, GroupingState other, int otherGroupId) {
+ sort.merge(groupId, other.sort, otherGroupId);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory(), selected);
+ }
+
+ Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ void enableGroupIdTracking(SeenGroupIds seen) {
+ // we figure out seen values from nulls on the values block
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(sort);
+ }
+ }
+
+ public static class SingleState implements Releasable {
+ private final GroupingState internalState;
+
+ private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
+ this.internalState = new GroupingState(bigArrays, limit, ascending);
+ }
+
+ public void add($type$ value) {
+ internalState.add(0, value);
+ }
+
+ public void merge(GroupingState other) {
+ internalState.merge(0, other, 0);
+ }
+
+ void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ blocks[offset] = toBlock(driverContext.blockFactory());
+ }
+
+ Block toBlock(BlockFactory blockFactory) {
+ try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
+ return internalState.toBlock(blockFactory, intValues);
+ }
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(internalState);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st
new file mode 100644
index 0000000000000..6587743e34b6f
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st
@@ -0,0 +1,350 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.$Type$Array;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.search.sort.BucketedSort;
+import org.elasticsearch.search.sort.SortOrder;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+/**
+ * Aggregates the top N $type$ values per bucket.
+ * See {@link BucketedSort} for more information.
+ * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file.
+ */
+public class $Type$BucketedSort implements Releasable {
+
+ private final BigArrays bigArrays;
+ private final SortOrder order;
+ private final int bucketSize;
+ /**
+ * {@code true} if the bucket is in heap mode, {@code false} if
+ * it is still gathering.
+ */
+ private final BitArray heapMode;
+ /**
+ * An array containing all the values on all buckets. The structure is as follows:
+ *
+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...).
+ * Then, for each bucket, it can be in 2 states:
+ *
+ *
+ * -
+ * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements.
+ * In gather mode, the elements are stored in the array from the highest index to the lowest index.
+ * The lowest index contains the offset to the next slot to be filled.
+ *
+ * This allows us to insert elements in O(1) time.
+ *
+ *
+ * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents.
+ *
+ *
+ * -
+ * Heap mode: The bucket slots are organized as a min heap structure.
+ *
+ * The root of the heap is the minimum value in the bucket,
+ * which allows us to quickly discard new values that are not in the top N.
+ *
+ *
+ *
+ */
+ private $Type$Array values;
+
+ public $Type$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
+ this.bigArrays = bigArrays;
+ this.order = order;
+ this.bucketSize = bucketSize;
+ heapMode = new BitArray(0, bigArrays);
+
+ boolean success = false;
+ try {
+ values = bigArrays.new$Type$Array(0, false);
+ success = true;
+ } finally {
+ if (success == false) {
+ close();
+ }
+ }
+ }
+
+ /**
+ * Collects a {@code value} into a {@code bucket}.
+ *
+ * It may or may not be inserted in the heap, depending on if it is better than the current root.
+ *
+ */
+ public void collect($type$ value, int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (inHeapMode(bucket)) {
+ if (betterThan(value, values.get(rootIndex))) {
+ values.set(rootIndex, value);
+ downHeap(rootIndex, 0);
+ }
+ return;
+ }
+ // Gathering mode
+ long requiredSize = rootIndex + bucketSize;
+ if (values.size() < requiredSize) {
+ grow(requiredSize);
+ }
+ int next = getNextGatherOffset(rootIndex);
+ assert 0 <= next && next < bucketSize
+ : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]";
+ long index = next + rootIndex;
+ values.set(index, value);
+ if (next == 0) {
+ heapMode.set(bucket);
+ heapify(rootIndex);
+ } else {
+ setNextGatherOffset(rootIndex, next - 1);
+ }
+ }
+
+ /**
+ * The order of the sort.
+ */
+ public SortOrder getOrder() {
+ return order;
+ }
+
+ /**
+ * The number of values to store per bucket.
+ */
+ public int getBucketSize() {
+ return bucketSize;
+ }
+
+ /**
+ * Get the first and last indexes (inclusive, exclusive) of the values for a bucket.
+ * Returns [0, 0] if the bucket has never been collected.
+ */
+ private Tuple getBucketValuesIndexes(int bucket) {
+ long rootIndex = (long) bucket * bucketSize;
+ if (rootIndex >= values.size()) {
+ // We've never seen this bucket.
+ return Tuple.tuple(0L, 0L);
+ }
+ long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1);
+ long end = rootIndex + bucketSize;
+ return Tuple.tuple(start, end);
+ }
+
+ /**
+ * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
+ */
+ public void merge(int groupId, $Type$BucketedSort other, int otherGroupId) {
+ var otherBounds = other.getBucketValuesIndexes(otherGroupId);
+
+ // TODO: This can be improved for heapified buckets by making use of the heap structures
+ for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) {
+ collect(other.values.get(i), groupId);
+ }
+ }
+
+ /**
+ * Creates a block with the values from the {@code selected} groups.
+ */
+ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
+ // Check if the selected groups are all empty, to avoid allocating extra memory
+ if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
+ var bounds = this.getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ return size > 0;
+ })) {
+ return blockFactory.newConstantNullBlock(selected.getPositionCount());
+ }
+
+ // Used to sort the values in the bucket.
+ var bucketValues = new $type$[bucketSize];
+
+ try (var builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) {
+ for (int s = 0; s < selected.getPositionCount(); s++) {
+ int bucket = selected.getInt(s);
+
+ var bounds = getBucketValuesIndexes(bucket);
+ var size = bounds.v2() - bounds.v1();
+
+ if (size == 0) {
+ builder.appendNull();
+ continue;
+ }
+
+ if (size == 1) {
+ builder.append$Type$(values.get(bounds.v1()));
+ continue;
+ }
+
+ for (int i = 0; i < size; i++) {
+ bucketValues[i] = values.get(bounds.v1() + i);
+ }
+
+ // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
+ Arrays.sort(bucketValues, 0, (int) size);
+
+ builder.beginPositionEntry();
+ if (order == SortOrder.ASC) {
+ for (int i = 0; i < size; i++) {
+ builder.append$Type$(bucketValues[i]);
+ }
+ } else {
+ for (int i = (int) size - 1; i >= 0; i--) {
+ builder.append$Type$(bucketValues[i]);
+ }
+ }
+ builder.endPositionEntry();
+ }
+ return builder.build();
+ }
+ }
+
+ /**
+ * Is this bucket a min heap {@code true} or in gathering mode {@code false}?
+ */
+ private boolean inHeapMode(int bucket) {
+ return heapMode.get(bucket);
+ }
+
+ /**
+ * Get the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private int getNextGatherOffset(long rootIndex) {
+$if(int)$
+ return values.get(rootIndex);
+$else$
+ return (int) values.get(rootIndex);
+$endif$
+ }
+
+ /**
+ * Set the next index that should be "gathered" for a bucket rooted
+ * at {@code rootIndex}.
+ */
+ private void setNextGatherOffset(long rootIndex, int offset) {
+ values.set(rootIndex, offset);
+ }
+
+ /**
+ * {@code true} if the entry at index {@code lhs} is "better" than
+ * the entry at {@code rhs}. "Better" in this means "lower" for
+ * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}.
+ */
+ private boolean betterThan($type$ lhs, $type$ rhs) {
+ return getOrder().reverseMul() * $Wrapper$.compare(lhs, rhs) < 0;
+ }
+
+ /**
+ * Swap the data at two indices.
+ */
+ private void swap(long lhs, long rhs) {
+ var tmp = values.get(lhs);
+ values.set(lhs, values.get(rhs));
+ values.set(rhs, tmp);
+ }
+
+ /**
+ * Allocate storage for more buckets and store the "next gather offset"
+ * for those new buckets.
+ */
+ private void grow(long minSize) {
+ long oldMax = values.size();
+ values = bigArrays.grow(values, minSize);
+ // Set the next gather offsets for all newly allocated buckets.
+ setNextGatherOffsets(oldMax - (oldMax % getBucketSize()));
+ }
+
+ /**
+ * Maintain the "next gather offsets" for newly allocated buckets.
+ */
+ private void setNextGatherOffsets(long startingAt) {
+ int nextOffset = getBucketSize() - 1;
+ for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) {
+ setNextGatherOffset(bucketRoot, nextOffset);
+ }
+ }
+
+ /**
+ * Heapify a bucket whose entries are in random order.
+ *
+ * This works by validating the heap property on each node, iterating
+ * "upwards", pushing any out of order parents "down". Check out the
+ * wikipedia
+ * entry on binary heaps for more about this.
+ *
+ *
+ * While this *looks* like it could easily be {@code O(n * log n)}, it is
+ * a fairly well studied algorithm attributed to Floyd. There's
+ * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst
+ * case.
+ *
+ *
+ * @param rootIndex the index the start of the bucket
+ */
+ private void heapify(long rootIndex) {
+ int maxParent = bucketSize / 2 - 1;
+ for (int parent = maxParent; parent >= 0; parent--) {
+ downHeap(rootIndex, parent);
+ }
+ }
+
+ /**
+ * Correct the heap invariant of a parent and its children. This
+ * runs in {@code O(log n)} time.
+ * @param rootIndex index of the start of the bucket
+ * @param parent Index within the bucket of the parent to check.
+ * For example, 0 is the "root".
+ */
+ private void downHeap(long rootIndex, int parent) {
+ while (true) {
+ long parentIndex = rootIndex + parent;
+ int worst = parent;
+ long worstIndex = parentIndex;
+ int leftChild = parent * 2 + 1;
+ long leftIndex = rootIndex + leftChild;
+ if (leftChild < bucketSize) {
+ if (betterThan(values.get(worstIndex), values.get(leftIndex))) {
+ worst = leftChild;
+ worstIndex = leftIndex;
+ }
+ int rightChild = leftChild + 1;
+ long rightIndex = rootIndex + rightChild;
+ if (rightChild < bucketSize && betterThan(values.get(worstIndex), values.get(rightIndex))) {
+ worst = rightChild;
+ worstIndex = rightIndex;
+ }
+ }
+ if (worst == parent) {
+ break;
+ }
+ swap(worstIndex, parentIndex);
+ parent = worst;
+ }
+ }
+
+ @Override
+ public final void close() {
+ Releasables.close(values, heapMode);
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java
new file mode 100644
index 0000000000000..f708038776032
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListDoubleAggregatorFunctionTests.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator;
+import org.elasticsearch.compute.operator.SourceOperator;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.contains;
+
+public class TopListDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase {
+ private static final int LIMIT = 100;
+
+ @Override
+ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+ return new SequenceDoubleBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToDouble(l -> randomDouble()));
+ }
+
+ @Override
+ protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) {
+ return new TopListDoubleAggregatorFunctionSupplier(inputChannels, LIMIT, true);
+ }
+
+ @Override
+ protected String expectedDescriptionOfAggregator() {
+ return "top_list of doubles";
+ }
+
+ @Override
+ public void assertSimpleOutput(List input, Block result) {
+ Object[] values = input.stream().flatMapToDouble(b -> allDoubles(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
+ assertThat((List>) BlockUtils.toJavaObject(result, 0), contains(values));
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java
new file mode 100644
index 0000000000000..443604efd5c15
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListIntAggregatorFunctionTests.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator;
+import org.elasticsearch.compute.operator.SourceOperator;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.contains;
+
+public class TopListIntAggregatorFunctionTests extends AggregatorFunctionTestCase {
+ private static final int LIMIT = 100;
+
+ @Override
+ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+ return new SequenceIntBlockSourceOperator(blockFactory, IntStream.range(0, size).map(l -> randomInt()));
+ }
+
+ @Override
+ protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) {
+ return new TopListIntAggregatorFunctionSupplier(inputChannels, LIMIT, true);
+ }
+
+ @Override
+ protected String expectedDescriptionOfAggregator() {
+ return "top_list of ints";
+ }
+
+ @Override
+ public void assertSimpleOutput(List input, Block result) {
+ Object[] values = input.stream().flatMapToInt(b -> allInts(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
+ assertThat((List>) BlockUtils.toJavaObject(result, 0), contains(values));
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java
new file mode 100644
index 0000000000000..4a6f101e573b8
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopListLongAggregatorFunctionTests.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator;
+import org.elasticsearch.compute.operator.SourceOperator;
+
+import java.util.List;
+import java.util.stream.LongStream;
+
+import static org.hamcrest.Matchers.contains;
+
+public class TopListLongAggregatorFunctionTests extends AggregatorFunctionTestCase {
+ private static final int LIMIT = 100;
+
+ @Override
+ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+ return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong()));
+ }
+
+ @Override
+ protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) {
+ return new TopListLongAggregatorFunctionSupplier(inputChannels, LIMIT, true);
+ }
+
+ @Override
+ protected String expectedDescriptionOfAggregator() {
+ return "top_list of longs";
+ }
+
+ @Override
+ public void assertSimpleOutput(List input, Block result) {
+ Object[] values = input.stream().flatMapToLong(b -> allLongs(b)).sorted().limit(LIMIT).boxed().toArray(Object[]::new);
+ assertThat((List>) BlockUtils.toJavaObject(result, 0), contains(values));
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java
new file mode 100644
index 0000000000000..9e1bc145ad4ca
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BucketedSortTestCase.java
@@ -0,0 +1,368 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.MockPageCacheRecycler;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.TestBlockFactory;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
+import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public abstract class BucketedSortTestCase extends ESTestCase {
+ /**
+ * Build a {@link T} to test. Sorts built by this method shouldn't need scores.
+ */
+ protected abstract T build(SortOrder sortOrder, int bucketSize);
+
+ /**
+ * Build the expected correctly typed value for a value.
+ */
+ protected abstract Object expectedValue(double v);
+
+ /**
+ * A random value for testing, with the appropriate precision for the type we're testing.
+ */
+ protected abstract double randomValue();
+
+ /**
+ * Collect a value into the sort.
+ * @param value value to collect, always sent as double just to have
+ * a number to test. Subclasses should cast to their favorite types
+ */
+ protected abstract void collect(T sort, double value, int bucket);
+
+ protected abstract void merge(T sort, int groupId, T other, int otherGroupId);
+
+ protected abstract Block toBlock(T sort, BlockFactory blockFactory, IntVector selected);
+
+ protected abstract void assertBlockTypeAndValues(Block block, Object... values);
+
+ public final void testNeverCalled() {
+ SortOrder order = randomFrom(SortOrder.values());
+ try (T sort = build(order, 1)) {
+ assertBlock(sort, randomNonNegativeInt());
+ }
+ }
+
+ public final void testSingleDoc() {
+ try (T sort = build(randomFrom(SortOrder.values()), 1)) {
+ collect(sort, 1, 0);
+
+ assertBlock(sort, 0, expectedValue(1));
+ }
+ }
+
+ public final void testNonCompetitive() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, 2, 0);
+ collect(sort, 1, 0);
+
+ assertBlock(sort, 0, expectedValue(2));
+ }
+ }
+
+ public final void testCompetitive() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+
+ assertBlock(sort, 0, expectedValue(2));
+ }
+ }
+
+ public final void testNegativeValue() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, -1, 0);
+ assertBlock(sort, 0, expectedValue(-1));
+ }
+ }
+
+ public final void testSomeBuckets() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, 2, 0);
+ collect(sort, 2, 1);
+ collect(sort, 2, 2);
+ collect(sort, 3, 0);
+
+ assertBlock(sort, 0, expectedValue(3));
+ assertBlock(sort, 1, expectedValue(2));
+ assertBlock(sort, 2, expectedValue(2));
+ assertBlock(sort, 3);
+ }
+ }
+
+ public final void testBucketGaps() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, 2, 0);
+ collect(sort, 2, 2);
+
+ assertBlock(sort, 0, expectedValue(2));
+ assertBlock(sort, 1);
+ assertBlock(sort, 2, expectedValue(2));
+ assertBlock(sort, 3);
+ }
+ }
+
+ public final void testBucketsOutOfOrder() {
+ try (T sort = build(SortOrder.DESC, 1)) {
+ collect(sort, 2, 1);
+ collect(sort, 2, 0);
+
+ assertBlock(sort, 0, expectedValue(2.0));
+ assertBlock(sort, 1, expectedValue(2.0));
+ assertBlock(sort, 2);
+ }
+ }
+
+ public final void testManyBuckets() {
+ // Collect the buckets in random order
+ int[] buckets = new int[10000];
+ for (int b = 0; b < buckets.length; b++) {
+ buckets[b] = b;
+ }
+ Collections.shuffle(Arrays.asList(buckets), random());
+
+ double[] maxes = new double[buckets.length];
+
+ try (T sort = build(SortOrder.DESC, 1)) {
+ for (int b : buckets) {
+ maxes[b] = 2;
+ collect(sort, 2, b);
+ if (randomBoolean()) {
+ maxes[b] = 3;
+ collect(sort, 3, b);
+ }
+ if (randomBoolean()) {
+ collect(sort, -1, b);
+ }
+ }
+ for (int b = 0; b < buckets.length; b++) {
+ assertBlock(sort, b, expectedValue(maxes[b]));
+ }
+ assertBlock(sort, buckets.length);
+ }
+ }
+
+ public final void testTwoHitsDesc() {
+ try (T sort = build(SortOrder.DESC, 2)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+ collect(sort, 3, 0);
+
+ assertBlock(sort, 0, expectedValue(3), expectedValue(2));
+ }
+ }
+
+ public final void testTwoHitsAsc() {
+ try (T sort = build(SortOrder.ASC, 2)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+ collect(sort, 3, 0);
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(2));
+ }
+ }
+
+ public final void testTwoHitsTwoBucket() {
+ try (T sort = build(SortOrder.DESC, 2)) {
+ collect(sort, 1, 0);
+ collect(sort, 1, 1);
+ collect(sort, 2, 0);
+ collect(sort, 2, 1);
+ collect(sort, 3, 0);
+ collect(sort, 3, 1);
+ collect(sort, 4, 1);
+
+ assertBlock(sort, 0, expectedValue(3), expectedValue(2));
+ assertBlock(sort, 1, expectedValue(4), expectedValue(3));
+ }
+ }
+
+ public final void testManyBucketsManyHits() {
+ // Set the values in random order
+ double[] values = new double[10000];
+ for (int v = 0; v < values.length; v++) {
+ values[v] = randomValue();
+ }
+ Collections.shuffle(Arrays.asList(values), random());
+
+ int buckets = between(2, 100);
+ int bucketSize = between(2, 100);
+ try (T sort = build(SortOrder.DESC, bucketSize)) {
+ BitArray[] bucketUsed = new BitArray[buckets];
+ Arrays.setAll(bucketUsed, i -> new BitArray(values.length, bigArrays()));
+ for (int doc = 0; doc < values.length; doc++) {
+ for (int bucket = 0; bucket < buckets; bucket++) {
+ if (randomBoolean()) {
+ bucketUsed[bucket].set(doc);
+ collect(sort, values[doc], bucket);
+ }
+ }
+ }
+ for (int bucket = 0; bucket < buckets; bucket++) {
+ List bucketValues = new ArrayList<>(values.length);
+ for (int doc = 0; doc < values.length; doc++) {
+ if (bucketUsed[bucket].get(doc)) {
+ bucketValues.add(values[doc]);
+ }
+ }
+ bucketUsed[bucket].close();
+ assertBlock(
+ sort,
+ bucket,
+ bucketValues.stream().sorted((lhs, rhs) -> rhs.compareTo(lhs)).limit(bucketSize).map(this::expectedValue).toArray()
+ );
+ }
+ assertBlock(sort, buckets);
+ }
+ }
+
+ public final void testMergeHeapToHeap() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+ collect(sort, 3, 0);
+
+ try (T other = build(SortOrder.ASC, 3)) {
+ collect(other, 1, 0);
+ collect(other, 2, 0);
+ collect(other, 3, 0);
+
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
+ }
+ }
+
+ public final void testMergeNoHeapToNoHeap() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+
+ try (T other = build(SortOrder.ASC, 3)) {
+ collect(other, 1, 0);
+ collect(other, 2, 0);
+
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
+ }
+ }
+
+ public final void testMergeHeapToNoHeap() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+
+ try (T other = build(SortOrder.ASC, 3)) {
+ collect(other, 1, 0);
+ collect(other, 2, 0);
+ collect(other, 3, 0);
+
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
+ }
+ }
+
+ public final void testMergeNoHeapToHeap() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+ collect(sort, 3, 0);
+
+ try (T other = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(1), expectedValue(2));
+ }
+ }
+
+ public final void testMergeHeapToEmpty() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ try (T other = build(SortOrder.ASC, 3)) {
+ collect(other, 1, 0);
+ collect(other, 2, 0);
+ collect(other, 3, 0);
+
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
+ }
+ }
+
+ public final void testMergeEmptyToHeap() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ collect(sort, 1, 0);
+ collect(sort, 2, 0);
+ collect(sort, 3, 0);
+
+ try (T other = build(SortOrder.ASC, 3)) {
+ merge(sort, 0, other, 0);
+ }
+
+ assertBlock(sort, 0, expectedValue(1), expectedValue(2), expectedValue(3));
+ }
+ }
+
+ public final void testMergeEmptyToEmpty() {
+ try (T sort = build(SortOrder.ASC, 3)) {
+ try (T other = build(SortOrder.ASC, 3)) {
+ merge(sort, 0, other, randomNonNegativeInt());
+ }
+
+ assertBlock(sort, 0);
+ }
+ }
+
+ private void assertBlock(T sort, int groupId, Object... values) {
+ var blockFactory = TestBlockFactory.getNonBreakingInstance();
+
+ try (var intVector = blockFactory.newConstantIntVector(groupId, 1)) {
+ var block = toBlock(sort, blockFactory, intVector);
+
+ assertThat(block.getPositionCount(), equalTo(1));
+ assertThat(block.getTotalValueCount(), equalTo(values.length));
+
+ if (values.length == 0) {
+ assertThat(block.elementType(), equalTo(ElementType.NULL));
+ assertThat(block.isNull(0), equalTo(true));
+ } else {
+ assertBlockTypeAndValues(block, values);
+ }
+ }
+ }
+
+ protected final BigArrays bigArrays() {
+ return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java
new file mode 100644
index 0000000000000..43b5caa092b9a
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/DoubleBucketedSortTests.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.search.sort.SortOrder;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class DoubleBucketedSortTests extends BucketedSortTestCase {
+ @Override
+ protected DoubleBucketedSort build(SortOrder sortOrder, int bucketSize) {
+ return new DoubleBucketedSort(bigArrays(), sortOrder, bucketSize);
+ }
+
+ @Override
+ protected Object expectedValue(double v) {
+ return v;
+ }
+
+ @Override
+ protected double randomValue() {
+ return randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
+ }
+
+ @Override
+ protected void collect(DoubleBucketedSort sort, double value, int bucket) {
+ sort.collect(value, bucket);
+ }
+
+ @Override
+ protected void merge(DoubleBucketedSort sort, int groupId, DoubleBucketedSort other, int otherGroupId) {
+ sort.merge(groupId, other, otherGroupId);
+ }
+
+ @Override
+ protected Block toBlock(DoubleBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ @Override
+ protected void assertBlockTypeAndValues(Block block, Object... values) {
+ assertThat(block.elementType(), equalTo(ElementType.DOUBLE));
+ var typedBlock = (DoubleBlock) block;
+ for (int i = 0; i < values.length; i++) {
+ assertThat(typedBlock.getDouble(i), equalTo(values[i]));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java
new file mode 100644
index 0000000000000..70d0a79ea7473
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/IntBucketedSortTests.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.search.sort.SortOrder;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class IntBucketedSortTests extends BucketedSortTestCase {
+ @Override
+ protected IntBucketedSort build(SortOrder sortOrder, int bucketSize) {
+ return new IntBucketedSort(bigArrays(), sortOrder, bucketSize);
+ }
+
+ @Override
+ protected Object expectedValue(double v) {
+ return (int) v;
+ }
+
+ @Override
+ protected double randomValue() {
+ return randomInt();
+ }
+
+ @Override
+ protected void collect(IntBucketedSort sort, double value, int bucket) {
+ sort.collect((int) value, bucket);
+ }
+
+ @Override
+ protected void merge(IntBucketedSort sort, int groupId, IntBucketedSort other, int otherGroupId) {
+ sort.merge(groupId, other, otherGroupId);
+ }
+
+ @Override
+ protected Block toBlock(IntBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ @Override
+ protected void assertBlockTypeAndValues(Block block, Object... values) {
+ assertThat(block.elementType(), equalTo(ElementType.INT));
+ var typedBlock = (IntBlock) block;
+ for (int i = 0; i < values.length; i++) {
+ assertThat(typedBlock.getInt(i), equalTo(values[i]));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java
new file mode 100644
index 0000000000000..bceed3b1d95b5
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/LongBucketedSortTests.java
@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.search.sort.SortOrder;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class LongBucketedSortTests extends BucketedSortTestCase {
+ @Override
+ protected LongBucketedSort build(SortOrder sortOrder, int bucketSize) {
+ return new LongBucketedSort(bigArrays(), sortOrder, bucketSize);
+ }
+
+ @Override
+ protected Object expectedValue(double v) {
+ return (long) v;
+ }
+
+ @Override
+ protected double randomValue() {
+ // 2L^50 fits in the mantisa of a double which the test sort of needs.
+ return randomLongBetween(-2L ^ 50, 2L ^ 50);
+ }
+
+ @Override
+ protected void collect(LongBucketedSort sort, double value, int bucket) {
+ sort.collect((long) value, bucket);
+ }
+
+ @Override
+ protected void merge(LongBucketedSort sort, int groupId, LongBucketedSort other, int otherGroupId) {
+ sort.merge(groupId, other, otherGroupId);
+ }
+
+ @Override
+ protected Block toBlock(LongBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
+ return sort.toBlock(blockFactory, selected);
+ }
+
+ @Override
+ protected void assertBlockTypeAndValues(Block block, Object... values) {
+ assertThat(block.elementType(), equalTo(ElementType.LONG));
+ var typedBlock = (LongBlock) block;
+ for (int i = 0; i < values.length; i++) {
+ assertThat(typedBlock.getLong(i), equalTo(values[i]));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec
index 2cdd5c1dfd931..0fb35b4253d6d 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec
@@ -38,10 +38,10 @@ double e()
"double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)"
"double log10(number:double|integer|long|unsigned_long)"
"keyword|text ltrim(string:keyword|text)"
-"double|integer|long max(number:double|integer|long)"
+"double|integer|long|date max(number:double|integer|long|date)"
"double|integer|long median(number:double|integer|long)"
"double|integer|long median_absolute_deviation(number:double|integer|long)"
-"double|integer|long min(number:double|integer|long)"
+"double|integer|long|date min(number:double|integer|long|date)"
"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)"
"double mv_avg(number:double|integer|long|unsigned_long)"
"keyword mv_concat(string:text|keyword, delim:text|keyword)"
@@ -109,6 +109,7 @@ double tau()
"keyword|text to_upper(str:keyword|text)"
"version to_ver(field:keyword|text|version)"
"version to_version(field:keyword|text|version)"
+"double|integer|long|date top_list(field:double|integer|long|date, limit:integer, order:keyword)"
"keyword|text trim(string:keyword|text)"
"boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)"
;
@@ -155,10 +156,10 @@ locate |[string, substring, start] |["keyword|text", "keyword|te
log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."]
log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`.
ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
-max |number |"double|integer|long" |[""]
+max |number |"double|integer|long|date" |[""]
median |number |"double|integer|long" |[""]
median_absolut|number |"double|integer|long" |[""]
-min |number |"double|integer|long" |[""]
+min |number |"double|integer|long|date" |[""]
mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""]
mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression.
mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.]
@@ -226,6 +227,7 @@ to_unsigned_lo|field |"boolean|date|keyword|text|d
to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`.
to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression.
+top_list |[field, limit, order] |["double|integer|long|date", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
trim |string |"keyword|text" |String expression. If `null`, the function returns `null`.
values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""]
;
@@ -344,6 +346,7 @@ to_unsigned_lo|Converts an input value to an unsigned long value. If the input p
to_upper |Returns a new string representing the input string converted to upper case.
to_ver |Converts an input string to a version value.
to_version |Converts an input string to a version value.
+top_list |Collects the top values for a field. Includes repeated values.
trim |Removes leading and trailing whitespaces from a string.
values |Collect values for a field.
;
@@ -392,10 +395,10 @@ locate |integer
log |double |[true, false] |false |false
log10 |double |false |false |false
ltrim |"keyword|text" |false |false |false
-max |"double|integer|long" |false |false |true
+max |"double|integer|long|date" |false |false |true
median |"double|integer|long" |false |false |true
median_absolut|"double|integer|long" |false |false |true
-min |"double|integer|long" |false |false |true
+min |"double|integer|long|date" |false |false |true
mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false
mv_avg |double |false |false |false
mv_concat |keyword |[false, false] |false |false
@@ -463,6 +466,7 @@ to_unsigned_lo|unsigned_long
to_upper |"keyword|text" |false |false |false
to_ver |version |false |false |false
to_version |version |false |false |false
+top_list |"double|integer|long|date" |[false, false, false] |false |true
trim |"keyword|text" |false |false |false
values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true
;
@@ -483,5 +487,5 @@ countFunctions#[skip:-8.14.99, reason:BIN added]
meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c;
a:long | b:long | c:long
-109 | 109 | 109
+110 | 110 | 110
;
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec
new file mode 100644
index 0000000000000..c24f6a7e70954
--- /dev/null
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_list.csv-spec
@@ -0,0 +1,156 @@
+topList
+required_capability: agg_top_list
+// tag::top-list[]
+FROM employees
+| STATS top_salaries = TOP_LIST(salary, 3, "desc"), top_salary = MAX(salary)
+// end::top-list[]
+;
+
+// tag::top-list-result[]
+top_salaries:integer | top_salary:integer
+[74999, 74970, 74572] | 74999
+// end::top-list-result[]
+;
+
+topListAllTypesAsc
+required_capability: agg_top_list
+FROM employees
+| STATS
+ date = TOP_LIST(hire_date, 2, "asc"),
+ double = TOP_LIST(salary_change, 2, "asc"),
+ integer = TOP_LIST(salary, 2, "asc"),
+ long = TOP_LIST(salary_change.long, 2, "asc")
+;
+
+date:date | double:double | integer:integer | long:long
+[1985-02-18T00:00:00.000Z,1985-02-24T00:00:00.000Z] | [-9.81,-9.28] | [25324,25945] | [-9,-9]
+;
+
+topListAllTypesDesc
+required_capability: agg_top_list
+FROM employees
+| STATS
+ date = TOP_LIST(hire_date, 2, "desc"),
+ double = TOP_LIST(salary_change, 2, "desc"),
+ integer = TOP_LIST(salary, 2, "desc"),
+ long = TOP_LIST(salary_change.long, 2, "desc")
+;
+
+date:date | double:double | integer:integer | long:long
+[1999-04-30T00:00:00.000Z,1997-05-19T00:00:00.000Z] | [14.74,14.68] | [74999,74970] | [14,14]
+;
+
+topListAllTypesRow
+required_capability: agg_top_list
+ROW
+ constant_date=TO_DATETIME("1985-02-18T00:00:00.000Z"),
+ constant_double=-9.81,
+ constant_integer=25324,
+ constant_long=TO_LONG(-9)
+| STATS
+ date = TOP_LIST(constant_date, 2, "asc"),
+ double = TOP_LIST(constant_double, 2, "asc"),
+ integer = TOP_LIST(constant_integer, 2, "asc"),
+ long = TOP_LIST(constant_long, 2, "asc")
+| keep date, double, integer, long
+;
+
+date:date | double:double | integer:integer | long:long
+1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
+;
+
+topListSomeBuckets
+required_capability: agg_top_list
+FROM employees
+| STATS top_salary = TOP_LIST(salary, 2, "desc") by still_hired
+| sort still_hired asc
+;
+
+top_salary:integer | still_hired:boolean
+[74999,74970] | false
+[74572,73578] | true
+;
+
+topListManyBuckets
+required_capability: agg_top_list
+FROM employees
+| STATS top_salary = TOP_LIST(salary, 2, "desc") by x=emp_no, y=emp_no+1
+| sort x asc
+| limit 3
+;
+
+top_salary:integer | x:integer | y:integer
+57305 | 10001 | 10002
+56371 | 10002 | 10003
+61805 | 10003 | 10004
+;
+
+topListMultipleStats
+required_capability: agg_top_list
+FROM employees
+| STATS top_salary = TOP_LIST(salary, 1, "desc") by emp_no
+| STATS top_salary = TOP_LIST(top_salary, 3, "asc")
+;
+
+top_salary:integer
+[25324,25945,25976]
+;
+
+topListAllTypesMin
+required_capability: agg_top_list
+FROM employees
+| STATS
+ date = TOP_LIST(hire_date, 1, "asc"),
+ double = TOP_LIST(salary_change, 1, "asc"),
+ integer = TOP_LIST(salary, 1, "asc"),
+ long = TOP_LIST(salary_change.long, 1, "asc")
+;
+
+date:date | double:double | integer:integer | long:long
+1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
+;
+
+topListAllTypesMax
+required_capability: agg_top_list
+FROM employees
+| STATS
+ date = TOP_LIST(hire_date, 1, "desc"),
+ double = TOP_LIST(salary_change, 1, "desc"),
+ integer = TOP_LIST(salary, 1, "desc"),
+ long = TOP_LIST(salary_change.long, 1, "desc")
+;
+
+date:date | double:double | integer:integer | long:long
+1999-04-30T00:00:00.000Z | 14.74 | 74999 | 14
+;
+
+topListAscDesc
+required_capability: agg_top_list
+FROM employees
+| STATS top_asc = TOP_LIST(salary, 3, "asc"), top_desc = TOP_LIST(salary, 3, "desc")
+;
+
+top_asc:integer | top_desc:integer
+[25324, 25945, 25976] | [74999, 74970, 74572]
+;
+
+topListEmpty
+required_capability: agg_top_list
+FROM employees
+| WHERE salary < 0
+| STATS top = TOP_LIST(salary, 3, "asc")
+;
+
+top:integer
+null
+;
+
+topListDuplicates
+required_capability: agg_top_list
+FROM employees
+| STATS integer = TOP_LIST(languages, 2, "desc")
+;
+
+integer:integer
+[5, 5]
+;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 3eef9f7356b39..e65f574422dd5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -42,6 +42,11 @@ public class EsqlCapabilities {
*/
private static final String FN_SUBSTRING_EMPTY_NULL = "fn_substring_empty_null";
+ /**
+ * Support for aggregation function {@code TOP_LIST}.
+ */
+ private static final String AGG_TOP_LIST = "agg_top_list";
+
/**
* Optimization for ST_CENTROID changed some results in cartesian data. #108713
*/
@@ -84,6 +89,7 @@ private static Set capabilities() {
caps.add(FN_CBRT);
caps.add(FN_IP_PREFIX);
caps.add(FN_SUBSTRING_EMPTY_NULL);
+ caps.add(AGG_TOP_LIST);
caps.add(ST_CENTROID_AGG_OPTIMIZED);
caps.add(METADATA_IGNORED_FIELD);
caps.add(FN_MV_APPEND);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index 8fd6ebe8d7d69..7034f23be1662 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -22,6 +22,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
@@ -192,6 +193,7 @@ private FunctionDefinition[][] functions() {
def(Min.class, Min::new, "min"),
def(Percentile.class, Percentile::new, "percentile"),
def(Sum.class, Sum::new, "sum"),
+ def(TopList.class, TopList::new, "top_list"),
def(Values.class, Values::new, "values") },
// math
new FunctionDefinition[] {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
index 3f6632f66bcee..1c1139c197ac0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
@@ -24,8 +24,12 @@
public class Max extends NumericAggregate implements SurrogateExpression {
- @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true)
- public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
+ @FunctionInfo(
+ returnType = { "double", "integer", "long", "date" },
+ description = "The maximum value of a numeric field.",
+ isAggregation = true
+ )
+ public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
super(source, field);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
index 16821752bc7b8..ecfc2200a3643 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
@@ -24,8 +24,12 @@
public class Min extends NumericAggregate implements SurrogateExpression {
- @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true)
- public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
+ @FunctionInfo(
+ returnType = { "double", "integer", "long", "date" },
+ description = "The minimum value of a numeric field.",
+ isAggregation = true
+ )
+ public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) {
super(source, field);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java
index b003b981c0709..390cd0d68018e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java
@@ -19,6 +19,28 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
+/**
+ * Aggregate function that receives a numeric, signed field, and returns a single double value.
+ *
+ * Implement the supplier methods to return the correct {@link AggregatorFunctionSupplier}.
+ *
+ *
+ * Some methods can be optionally overridden to support different variations:
+ *
+ *
+ * -
+ * {@link #supportsDates}: override to also support dates. Defaults to false.
+ *
+ * -
+ * {@link #resolveType}: override to support different parameters.
+ * Call {@code super.resolveType()} to add extra checks.
+ *
+ * -
+ * {@link #dataType}: override to return a different datatype.
+ * You can return {@code field().dataType()} to propagate the parameter type.
+ *
+ *
+ */
public abstract class NumericAggregate extends AggregateFunction implements ToAggregator {
NumericAggregate(Source source, Expression field, List parameters) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java
new file mode 100644
index 0000000000000..79893b1c7de07
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java
@@ -0,0 +1,181 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.aggregate;
+
+import org.elasticsearch.common.lucene.BytesRefs;
+import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
+import org.elasticsearch.compute.aggregation.TopListDoubleAggregatorFunctionSupplier;
+import org.elasticsearch.compute.aggregation.TopListIntAggregatorFunctionSupplier;
+import org.elasticsearch.compute.aggregation.TopListLongAggregatorFunctionSupplier;
+import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
+import org.elasticsearch.xpack.esql.expression.function.Example;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
+import org.elasticsearch.xpack.esql.planner.ToAggregator;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
+
+public class TopList extends AggregateFunction implements ToAggregator, SurrogateExpression {
+ private static final String ORDER_ASC = "ASC";
+ private static final String ORDER_DESC = "DESC";
+
+ @FunctionInfo(
+ returnType = { "double", "integer", "long", "date" },
+ description = "Collects the top values for a field. Includes repeated values.",
+ isAggregation = true,
+ examples = @Example(file = "stats_top_list", tag = "top-list")
+ )
+ public TopList(
+ Source source,
+ @Param(
+ name = "field",
+ type = { "double", "integer", "long", "date" },
+ description = "The field to collect the top values for."
+ ) Expression field,
+ @Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit,
+ @Param(
+ name = "order",
+ type = { "keyword" },
+ description = "The order to calculate the top values. Either `asc` or `desc`."
+ ) Expression order
+ ) {
+ super(source, field, Arrays.asList(limit, order));
+ }
+
+ public static TopList readFrom(PlanStreamInput in) throws IOException {
+ return new TopList(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression());
+ }
+
+ public void writeTo(PlanStreamOutput out) throws IOException {
+ source().writeTo(out);
+ List fields = children();
+ assert fields.size() == 3;
+ out.writeExpression(fields.get(0));
+ out.writeExpression(fields.get(1));
+ out.writeExpression(fields.get(2));
+ }
+
+ private Expression limitField() {
+ return parameters().get(0);
+ }
+
+ private Expression orderField() {
+ return parameters().get(1);
+ }
+
+ private int limitValue() {
+ return (int) limitField().fold();
+ }
+
+ private String orderRawValue() {
+ return BytesRefs.toString(orderField().fold());
+ }
+
+ private boolean orderValue() {
+ return orderRawValue().equalsIgnoreCase(ORDER_ASC);
+ }
+
+ @Override
+ protected TypeResolution resolveType() {
+ if (childrenResolved() == false) {
+ return new TypeResolution("Unresolved children");
+ }
+
+ var typeResolution = isType(
+ field(),
+ dt -> dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
+ sourceText(),
+ FIRST,
+ "numeric except unsigned_long or counter types"
+ ).and(isFoldable(limitField(), sourceText(), SECOND))
+ .and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer"))
+ .and(isFoldable(orderField(), sourceText(), THIRD))
+ .and(isString(orderField(), sourceText(), THIRD));
+
+ if (typeResolution.unresolved()) {
+ return typeResolution;
+ }
+
+ var limit = limitValue();
+ var order = orderRawValue();
+
+ if (limit <= 0) {
+ return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit));
+ }
+
+ if (order.equalsIgnoreCase(ORDER_ASC) == false && order.equalsIgnoreCase(ORDER_DESC) == false) {
+ return new TypeResolution(
+ format(null, "Invalid order value in [{}], expected [{}, {}] but got [{}]", sourceText(), ORDER_ASC, ORDER_DESC, order)
+ );
+ }
+
+ return TypeResolution.TYPE_RESOLVED;
+ }
+
+ @Override
+ public DataType dataType() {
+ return field().dataType();
+ }
+
+ @Override
+ protected NodeInfo info() {
+ return NodeInfo.create(this, TopList::new, children().get(0), children().get(1), children().get(2));
+ }
+
+ @Override
+ public TopList replaceChildren(List newChildren) {
+ return new TopList(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
+ }
+
+ @Override
+ public AggregatorFunctionSupplier supplier(List inputChannels) {
+ DataType type = field().dataType();
+ if (type == DataType.LONG || type == DataType.DATETIME) {
+ return new TopListLongAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
+ }
+ if (type == DataType.INTEGER) {
+ return new TopListIntAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
+ }
+ if (type == DataType.DOUBLE) {
+ return new TopListDoubleAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
+ }
+ throw EsqlIllegalArgumentException.illegalDataType(type);
+ }
+
+ @Override
+ public Expression surrogate() {
+ var s = source();
+
+ if (limitValue() == 1) {
+ if (orderValue()) {
+ return new Min(s, field());
+ } else {
+ return new Max(s, field());
+ }
+ }
+
+ return null;
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java
new file mode 100644
index 0000000000000..a99c7a8b7ac8d
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+/**
+ * Functions that aggregate values, with or without grouping within buckets.
+ * Used in `STATS` and similar commands.
+ *
+ * Guide to adding new aggregate function
+ *
+ * -
+ * Aggregation functions are more complex than scalar functions, so it's a good idea to discuss
+ * the new function with the ESQL team before starting to implement it.
+ *
+ * You may also discuss its implementation, as aggregations may require special performance considerations.
+ *
+ *
+ * -
+ * To learn the basics about making functions, check {@link org.elasticsearch.xpack.esql.expression.function.scalar}.
+ *
+ * It has the guide to making a simple function, which should be a good base to start doing aggregations.
+ *
+ *
+ * -
+ * Pick one of the csv-spec files in {@code x-pack/plugin/esql/qa/testFixtures/src/main/resources/}
+ * and add a test for the function you want to write. These files are roughly themed but there
+ * isn't a strong guiding principle in the organization.
+ *
+ * -
+ * Rerun the {@code CsvTests} and watch your new test fail.
+ *
+ * -
+ * Find an aggregate function in this package similar to the one you are working on and copy it to build
+ * yours.
+ * Your function might extend from the available abstract classes. Check the javadoc of each before using them:
+ *
+ * -
+ * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction}: The base class for aggregates
+ *
+ * -
+ * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate}: Aggregation for numeric values
+ *
+ * -
+ * {@link org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction}:
+ * Aggregation for spatial values
+ *
+ *
+ *
+ * -
+ * Fill the required methods in your new function. Check their JavaDoc for more information.
+ * Here are some of the important ones:
+ *
+ * -
+ * Constructor: Review the constructor annotations, and make sure to add the correct types and descriptions.
+ *
+ * - {@link org.elasticsearch.xpack.esql.expression.function.FunctionInfo}, for the constructor itself
+ * - {@link org.elasticsearch.xpack.esql.expression.function.Param}, for the function parameters
+ *
+ *
+ * -
+ * {@code resolveType}: Check the metadata of your function parameters.
+ * This may include types, whether they are foldable or not, or their possible values.
+ *
+ * -
+ * {@code dataType}: This will return the datatype of your function.
+ * May be based on its current parameters.
+ *
+ *
+ *
+ * Finally, you may want to implement some interfaces.
+ * Check their JavaDocs to see if they are suitable for your function:
+ *
+ * -
+ * {@link org.elasticsearch.xpack.esql.planner.ToAggregator}: (More information about aggregators below)
+ *
+ * -
+ * {@link org.elasticsearch.xpack.esql.expression.SurrogateExpression}
+ *
+ *
+ *
+ * -
+ * To introduce your aggregation to the engine:
+ *
+ *
+ *
+ *
+ * Creating aggregators for your function
+ *
+ * Aggregators contain the core logic of your aggregation. That is, how to combine values, what to store, how to process data, etc.
+ *
+ *
+ * -
+ * Copy an existing aggregator to use as a base. You'll usually make one per type. Check other classes to see the naming pattern.
+ * You can find them in {@link org.elasticsearch.compute.aggregation}.
+ *
+ * Note that some aggregators are autogenerated, so they live in different directories.
+ * The base is {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/}
+ *
+ *
+ * -
+ * Make a test for your aggregator.
+ * You can copy an existing one from {@code x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/}.
+ *
+ * Tests extending from {@code org.elasticsearch.compute.aggregation.AggregatorFunctionTestCase}
+ * will already include most required cases. You should only need to fill the required abstract methods.
+ *
+ *
+ * -
+ * Check the Javadoc of the {@link org.elasticsearch.compute.ann.Aggregator}
+ * and {@link org.elasticsearch.compute.ann.GroupingAggregator} annotations.
+ * Add/Modify them on your aggregator.
+ *
+ * -
+ * The {@link org.elasticsearch.compute.ann.Aggregator} JavaDoc explains the static methods you should add.
+ *
+ * -
+ * After implementing the required methods (Even if they have a dummy implementation),
+ * run the CsvTests to generate some extra required classes.
+ *
+ * One of them will be the {@code AggregatorFunctionSupplier} for your aggregator.
+ * Find it by its name ({@code AggregatorFunctionSupplier}),
+ * and return it in the {@code toSupplier} method in your function, under the correct type condition.
+ *
+ *
+ * -
+ * Now, complete the implementation of the aggregator, until the tests pass!
+ *
+ *
+ *
+ * StringTemplates
+ *
+ * Making an aggregator per type may be repetitive. To avoid code duplication, we use StringTemplates:
+ *
+ *
+ * -
+ * Create a new StringTemplate file.
+ * Use another as a reference, like
+ * {@code x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopListAggregator.java.st}.
+ *
+ * -
+ * Add the template scripts to {@code x-pack/plugin/esql/compute/build.gradle}.
+ *
+ * You can also see there which variables you can use, and which types are currently supported.
+ *
+ *
+ * -
+ * After completing your template, run the generation with {@code ./gradlew :x-pack:plugin:esql:compute:compileJava}.
+ *
+ * You may need to tweak some import orders per type so they don't raise warnings.
+ *
+ *
+ *
+ */
+package org.elasticsearch.xpack.esql.expression.function.aggregate;
+
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
index be5e105c3398e..831d105a89076 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
@@ -58,6 +58,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
@@ -298,6 +299,7 @@ public static List namedTypeEntries() {
of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile),
of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
+ of(AggregateFunction.class, TopList.class, (out, prefix) -> prefix.writeTo(out), TopList::readFrom),
of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction)
);
List entries = new ArrayList<>(declared);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java
index 863476ba55686..0d45ce10b1966 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java
@@ -29,6 +29,7 @@
* functions, designed to run over a {@link org.elasticsearch.compute.data.Block}
* {@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query
* {@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions
+ * {@link org.elasticsearch.xpack.esql.expression.function.aggregate} - Guide to writing aggregation functions
* {@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing
* {@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations
* {@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
index 68e6ea4d6cadb..83fdd5dc0c5d2 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
@@ -32,6 +32,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import java.lang.invoke.MethodHandle;
@@ -61,7 +62,8 @@ final class AggregateMapper {
Percentile.class,
SpatialCentroid.class,
Sum.class,
- Values.class
+ Values.class,
+ TopList.class
);
/** Record of agg Class, type, and grouping (or non-grouping). */
@@ -143,6 +145,8 @@ private static Stream, Tuple>> typeAndNames(Class
} else if (Values.class.isAssignableFrom(clazz)) {
// TODO can't we figure this out from the function itself?
types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
+ } else if (TopList.class.isAssignableFrom(clazz)) {
+ types = List.of("Int", "Long", "Double");
} else {
assert clazz == CountDistinct.class : "Expected CountDistinct, got: " + clazz;
types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();