diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..b43fad5bb29c5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link ChangePointLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class ChangePointLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public ChangePointLongAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public AggregatorFunction aggregator(DriverContext driverContext) { + throw new UnsupportedOperationException("non-grouping aggregator is not supported"); + } + + @Override + public ChangePointLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return ChangePointLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "change_point of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..eca924f8327b0 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ChangePointLongGroupingAggregatorFunction.java @@ -0,0 +1,226 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link ChangePointLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class ChangePointLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("timestamps", ElementType.LONG), + new IntermediateStateDesc("values", ElementType.LONG) ); + + private final ChangePointLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public ChangePointLongGroupingAggregatorFunction(List channels, + ChangePointLongAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static ChangePointLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new ChangePointLongGroupingAggregatorFunction(channels, ChangePointLongAggregator.initGrouping(driverContext), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + LongBlock timestampsBlock = page.getBlock(channels.get(1)); + LongVector timestampsVector = timestampsBlock.asVector(); + if (timestampsVector == null) { + throw new IllegalStateException("expected @timestamp vector; but got a block"); + } + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ChangePointLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + ChangePointLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ChangePointLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + var valuePosition = groupPosition + positionOffset; + ChangePointLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + ChangePointLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + ChangePointLongAggregator.GroupingState inState = ((ChangePointLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + ChangePointLongAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = ChangePointLongAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ChangePointLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ChangePointLongAggregator.java new file mode 100644 index 0000000000000..932cc8dcec457 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ChangePointLongAggregator.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; +import org.elasticsearch.xpack.ml.aggs.MlAggsHelper; +import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointDetector; +import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * Aggregates field values for long. + * TODO: make .java.st from this to support other types + * TODO: add "includeTimestamp" to @Aggregator + */ +// TODO: add normal @Aggregator +// @Aggregator({ +// includeTimestamps = true, +// @IntermediateState(name = "timestamps", type = "LONG_BLOCK"), +// @IntermediateState(name = "values", type = "LONG_BLOCK") }) +@GroupingAggregator( + includeTimestamps = true, + value = { @IntermediateState(name = "timestamps", type = "LONG_BLOCK"), @IntermediateState(name = "values", type = "LONG_BLOCK") } +) +class ChangePointLongAggregator { + + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext.bigArrays()); + } + + public static void combine(GroupingState current, int groupId, long timestamp, long value) { + current.add(groupId, timestamp, value); + } + + public static void combineIntermediate(GroupingState current, int groupId, LongBlock timestamps, LongBlock values, int otherPosition) { + current.combine(groupId, timestamps, values, otherPosition); + } + + public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { + current.combineState(currentGroupId, otherState, otherGroupId); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.evaluateFinal(selected, driverContext.blockFactory()); + } + + public static class SingleState implements Releasable { + private final BigArrays bigArrays; + private int count; + private LongArray timestamps; + private LongArray values; + + private SingleState(BigArrays bigArrays) { + this.bigArrays = bigArrays; + count = 0; + timestamps = bigArrays.newLongArray(0); + values = bigArrays.newLongArray(0); + } + + void add(long timestamp, long value) { + count++; + timestamps = bigArrays.grow(timestamps, count); + timestamps.set(count - 1, timestamp); + values = bigArrays.grow(values, count); + values.set(count - 1, value); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(timestamps, driverContext.blockFactory()); + blocks[offset + 1] = toBlock(values, driverContext.blockFactory()); + } + + Block toBlock(LongArray arr, BlockFactory blockFactory) { + if (arr.size() == 0) { + return blockFactory.newConstantNullBlock(1); + } + if (values.size() == 1) { + return blockFactory.newConstantLongBlockWith(arr.get(0), 1); + } + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder((int) arr.size())) { + builder.beginPositionEntry(); + for (int id = 0; id < arr.size(); id++) { + builder.appendLong(arr.get(id)); + } + builder.endPositionEntry(); + return builder.build(); + } + } + + record TimeAndValue(long timestamp, long value) implements Comparable { + @Override + public int compareTo(TimeAndValue other) { + return Long.compare(timestamp, other.timestamp); + } + } + + void sort() { + // TODO: this is very inefficient and doesn't account for memory! + List list = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + list.add(new TimeAndValue(timestamps.get(i), values.get(i))); + } + Collections.sort(list); + for (int i = 0; i < count; i++) { + timestamps.set(i, list.get(i).timestamp); + values.set(i, list.get(i).value); + } + } + + @Override + public void close() { + timestamps.close(); + values.close(); + } + } + + public static class GroupingState implements Releasable { + private final BigArrays bigArrays; + private final Map states; + + private GroupingState(BigArrays bigArrays) { + this.bigArrays = bigArrays; + states = new HashMap<>(); + } + + void add(int groupId, long timestamp, long value) { + SingleState state = states.computeIfAbsent(groupId, key -> new SingleState(bigArrays)); + state.add(timestamp, value); + } + + void combine(int groupId, LongBlock timestamps, LongBlock values, int otherPosition) { + final int valueCount = timestamps.getValueCount(otherPosition); + if (valueCount == 0) { + return; + } + final int firstIndex = timestamps.getFirstValueIndex(otherPosition); + SingleState state = states.computeIfAbsent(groupId, key -> new SingleState(bigArrays)); + for (int i = 0; i < valueCount; i++) { + state.add(timestamps.getLong(firstIndex + i), values.getLong(firstIndex + i)); + } + } + + void combineState(int groupId, GroupingState otherState, int otherGroupId) { + SingleState other = otherState.states.get(otherGroupId); + if (other == null) { + return; + } + var state = states.computeIfAbsent(groupId, key -> new SingleState(bigArrays)); + for (int i = 0; i < other.timestamps.size(); i++) { + state.add(state.timestamps.get(i), state.values.get(i)); + } + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(s -> s.timestamps, driverContext.blockFactory(), selected); + blocks[offset + 1] = toBlock(s -> s.values, driverContext.blockFactory(), selected); + } + + public Block evaluateFinal(IntVector selected, BlockFactory blockFactory) { + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + SingleState state = states.get(selected.getInt(s)); + state.sort(); + double[] values = new double[state.count]; + for (int i = 0; i < state.count; i++) { + values[i] = state.values.get(i); + } + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues(null, values); + ChangeType changeType = ChangePointDetector.getChangeType(bucketValues); + try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) { + xContentBuilder.startObject(); + NamedXContentObjectHelper.writeNamedObject(xContentBuilder, ToXContent.EMPTY_PARAMS, "type", changeType); + xContentBuilder.endObject(); + String xContent = Strings.toString(xContentBuilder); + builder.appendBytesRef(new BytesRef(xContent)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return builder.build(); + } + } + + Block toBlock(Function getArray, BlockFactory blockFactory, IntVector selected) { + if (states.isEmpty()) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int selectedGroup = selected.getInt(s); + SingleState state = states.get(selectedGroup); + LongArray values = getArray.apply(state); + int count = 0; + long first = 0; + for (int i = 0; i < state.count; i++) { + long value = values.get(i); + switch (count) { + case 0 -> first = value; + case 1 -> { + builder.beginPositionEntry(); + builder.appendLong(first); + builder.appendLong(value); + } + default -> builder.appendLong(value); + } + count++; + } + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendLong(first); + default -> builder.endPositionEntry(); + } + } + return builder.build(); + } + } + + void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` + } + + @Override + public void close() { + for (SingleState state : states.values()) { + state.close(); + } + } + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/change_point.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/change_point.csv-spec new file mode 100644 index 0000000000000..f6bd269de5de3 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/change_point.csv-spec @@ -0,0 +1,14 @@ +change point +# required_capability: change_point + +FROM k8s + | STATS count=COUNT() BY timestamp=BUCKET(@timestamp, 1 minute), cluster + | STATS cp=CHANGE_POINT(count, timestamp) BY cluster + | SORT cluster +; + +cp:keyword | cluster:keyword +"{""type"":{""indeterminable"":{""reason"":""not enough buckets to calculate change_point. Requires at least [22]; found [21]""}}}" | prod +"{""type"":{""stationary"":{}}}" | qa +"{""type"":{""indeterminable"":{""reason"":""not enough buckets to calculate change_point. Requires at least [22]; found [18]""}}}" | staging +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 50d0d2438d8a1..086ee014bfe9d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Check; import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; +import org.elasticsearch.xpack.esql.expression.function.aggregate.ChangePoint; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; @@ -425,6 +426,7 @@ private static FunctionDefinition[][] snapshotFunctions() { new FunctionDefinition[] { // The delay() function is for debug/snapshot environments only and should never be enabled in a non-snapshot build. // This is an experimental function and can be removed without notice. + def(ChangePoint.class, bi(ChangePoint::new), "change_point"), def(Delay.class, Delay::new, "delay"), def(Kql.class, Kql::new, "kql"), def(Rate.class, Rate::withUnresolvedTimestamp, "rate"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index db1d2a9e6f254..d94298eb55ec7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -16,6 +16,7 @@ public class AggregateWritables { public static List getNamedWriteables() { return List.of( Avg.ENTRY, + ChangePoint.ENTRY, Count.ENTRY, CountDistinct.ENTRY, Max.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ChangePoint.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ChangePoint.java new file mode 100644 index 0000000000000..d056892c065d4 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ChangePoint.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.ChangePointLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +public class ChangePoint extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "ChangePoint", + ChangePoint::new + ); + + @FunctionInfo(returnType = { "string" }, description = "...", isAggregation = true) + public ChangePoint( + Source source, + @Param(name = "field", type = { "double", "integer", "long" }, description = "field") Expression field, + @Param(name = "timestamp", type = { "date_nanos", "datetime", "double", "integer", "long" }) Expression timestamp + ) { + this(source, field, Literal.TRUE, timestamp); + } + + public ChangePoint(Source source, Expression field, Expression filter, Expression timestamp) { + super(source, field, filter, List.of(timestamp)); + } + + private ChangePoint(StreamInput in) throws IOException { + super( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteableCollectionAsList(Expression.class) + ); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + Expression timestamp() { + return parameters().get(0); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, ChangePoint::new, field(), timestamp()); + } + + @Override + public ChangePoint replaceChildren(List newChildren) { + return new ChangePoint(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + } + + @Override + public DataType dataType() { + return DataType.KEYWORD; + } + + @Override + protected TypeResolution resolveType() { + return isType(field(), dt -> dt.isNumeric(), sourceText(), FIRST, "numeric").and( + isType(timestamp(), dt -> dt.isDate() || dt.isNumeric(), sourceText(), SECOND, "date_nanos or datetime or numeric") + ); + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new ChangePoint(source(), field(), filter, timestamp()); + } + + @Override + public AggregatorFunctionSupplier supplier(List inputChannels) { + // if (inputChannels.size() != 2 && inputChannels.size() != 3) { + // throw new IllegalArgumentException("change point requires two for raw input or three channels for partial input; got " + + // inputChannels); + // } + final DataType type = field().dataType(); + return switch (type) { + case LONG -> new ChangePointLongAggregatorFunctionSupplier(inputChannels); + default -> throw EsqlIllegalArgumentException.illegalDataType(type); + }; + } + + @Override + public String toString() { + return "change_point{field=" + field() + ",timestamp=" + timestamp() + "}"; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 1918e3036e2b0..d7ae4259e5d2b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.ChangePoint; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial; @@ -71,6 +72,7 @@ final class AggregateMapper { /** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */ private static final List> AGG_FUNCTIONS = List.of( + ChangePoint.class, Count.class, CountDistinct.class, Max.class, @@ -195,6 +197,8 @@ private static Stream, Tuple>> typeAndNames(Class types = List.of(""); // no type } else if (CountDistinct.class.isAssignableFrom(clazz)) { types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList(); + } else if (ChangePoint.class.isAssignableFrom(clazz)) { + types = List.of("Long"); // TODO: add Int, Double } else { assert false : "unknown aggregate type " + clazz; throw new IllegalArgumentException("unknown aggregate type " + clazz); @@ -211,7 +215,8 @@ private static Stream> combinations(List types, Li } private static Stream groupingAndNonGrouping(Tuple, Tuple> tuple) { - if (tuple.v1().isAssignableFrom(Rate.class)) { + // TODO: also non-grouping change point + if (tuple.v1().isAssignableFrom(Rate.class) || tuple.v1().isAssignableFrom(ChangePoint.class)) { // rate doesn't support non-grouping aggregations return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true)); } else { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index c11ef8615eb72..7455add834122 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -171,7 +171,7 @@ public class CsvTests extends ESTestCase { @ParametersFactory(argumentFormatting = "%2$s.%3$s") public static List readScriptSpec() throws Exception { - List urls = classpathResources("/*.csv-spec"); + List urls = classpathResources("/change_point.csv-spec"); assertThat("Not enough specs found " + urls, urls, hasSize(greaterThan(0))); return SpecReader.readScriptSpec(urls, specParser()); } diff --git a/x-pack/plugin/ml/src/main/java/module-info.java b/x-pack/plugin/ml/src/main/java/module-info.java index 4984fa8912e28..8752cdc149484 100644 --- a/x-pack/plugin/ml/src/main/java/module-info.java +++ b/x-pack/plugin/ml/src/main/java/module-info.java @@ -37,7 +37,9 @@ exports org.elasticsearch.xpack.ml; exports org.elasticsearch.xpack.ml.action; + exports org.elasticsearch.xpack.ml.aggs; exports org.elasticsearch.xpack.ml.aggs.categorization; + exports org.elasticsearch.xpack.ml.aggs.changepoint; exports org.elasticsearch.xpack.ml.autoscaling; exports org.elasticsearch.xpack.ml.job.categorization; exports org.elasticsearch.xpack.ml.notifications;