From d6cd77d66a4cf4d8e7a6ccfb956d5327355330fa Mon Sep 17 00:00:00 2001 From: Song Jiacheng Date: Tue, 3 Sep 2024 10:30:52 +0800 Subject: [PATCH] [Enhancement] Support aggregate function map_agg. Signed-off-by: Song Jiacheng --- .../exprs/agg/factory/aggregate_factory.hpp | 1 + .../agg/factory/aggregate_resolver_avg.cpp | 26 +++ be/src/exprs/agg/map_agg.h | 156 ++++++++++++++ .../com/starrocks/catalog/FunctionSet.java | 29 +++ .../analyzer/PolymorphicFunctionAnalyzer.java | 8 + test/sql/test_agg_function/R/test_map_agg | 195 ++++++++++++++++++ test/sql/test_agg_function/T/test_map_agg | 74 +++++++ 7 files changed, 489 insertions(+) create mode 100644 be/src/exprs/agg/map_agg.h create mode 100644 test/sql/test_agg_function/R/test_map_agg create mode 100644 test/sql/test_agg_function/T/test_map_agg diff --git a/be/src/exprs/agg/factory/aggregate_factory.hpp b/be/src/exprs/agg/factory/aggregate_factory.hpp index e96ed6a55bf593..5b0daf04f01437 100644 --- a/be/src/exprs/agg/factory/aggregate_factory.hpp +++ b/be/src/exprs/agg/factory/aggregate_factory.hpp @@ -40,6 +40,7 @@ #include "exprs/agg/hll_union_count.h" #include "exprs/agg/intersect_count.h" #include "exprs/agg/mann_whitney.h" +#include "exprs/agg/map_agg.h" #include "exprs/agg/maxmin.h" #include "exprs/agg/maxmin_by.h" #include "exprs/agg/nullable_aggregate.h" diff --git a/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp b/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp index 5bd5ce14db8d9f..09b2269f347a20 100644 --- a/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp +++ b/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp @@ -117,6 +117,31 @@ struct ArrayAggDistinctDispatcher { } }; +struct MapAggDispatcher { + template + void operator()(AggregateFuncResolver* resolver) { + if constexpr (lt_is_aggregate) { + using KeyCppType = RunTimeCppType; + if constexpr (lt_is_largeint) { + using MyHashMap = phmap::flat_hash_map>; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else if constexpr (lt_is_fixedlength) { + using MyHashMap = phmap::flat_hash_map>; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else if constexpr (lt_is_string) { + using MyHashMap = + phmap::flat_hash_map; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else { + throw std::runtime_error("map_agg does not support key type " + type_to_string(kt)); + } + } + } +}; + void AggregateFuncResolver::register_avg() { for (auto type : aggregate_types()) { type_dispatch_all(type, AvgDispatcher(), this); @@ -124,6 +149,7 @@ void AggregateFuncResolver::register_avg() { type_dispatch_all(type, ArrayAggDistinctDispatcher(), this); type_dispatch_all(type, ArrayUnionAggDispatcher(), this); type_dispatch_all(type, ArrayUniqueAggDispatcher(), this); + type_dispatch_all(type, MapAggDispatcher(), this); } type_dispatch_all(TYPE_JSON, ArrayAggDispatcher(), this); add_decimal_mapping("decimal_avg"); diff --git a/be/src/exprs/agg/map_agg.h b/be/src/exprs/agg/map_agg.h new file mode 100644 index 00000000000000..9c8eab53cffb3b --- /dev/null +++ b/be/src/exprs/agg/map_agg.h @@ -0,0 +1,156 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "column/binary_column.h" +#include "column/column.h" +#include "column/column_helper.h" +#include "column/fixed_length_column.h" +#include "column/hash_set.h" +#include "column/map_column.h" +#include "column/type_traits.h" +#include "exprs/agg/aggregate.h" +#include "exprs/function_context.h" +#include "gutil/casts.h" +#include "util/phmap/phmap.h" +#include "util/time.h" + +namespace starrocks { + +template > +struct MapAggAggregateFunctionState : public AggregateFunctionEmptyState { + using KeyColumnType = RunTimeColumnType; + using KeyType = typename SliceHashSet::key_type; + + MyHashMap hash_map; + ColumnPtr value_column; + void update(MemPool* mem_pool, const KeyColumnType& arg_key_column, const Column& arg_value_column, size_t offset, + size_t count) { + if constexpr (!lt_is_string) { + for (int i = offset; i < offset + count; i++) { + auto key = arg_key_column.get_data()[i]; + if (!hash_map.contains(key)) { + auto value = arg_value_column.get(i); + value_column->append_datum(value); + hash_map.emplace(key, value_column->size() - 1); + } + } + } else { + for (int i = offset; i < offset + count; i++) { + auto raw_key = arg_key_column.get_slice(i); + KeyType key(raw_key); + if (!hash_map.contains(key)) { + uint8_t* pos = mem_pool->allocate(key.size); + memcpy(pos, key.data, key.size); + auto value = arg_value_column.get(i); + value_column->append_datum(value); + hash_map.emplace(Slice(pos, key.size), value_column->size() - 1); + } + } + } + } +}; + +template > +class MapAggAggregateFunction final : public AggregateFunctionBatchHelper, + MapAggAggregateFunction> { +public: + using KeyColumnType = RunTimeColumnType; + + void create(FunctionContext* ctx, AggDataPtr __restrict ptr) const override { + auto* state = new (ptr) MapAggAggregateFunctionState; + state->value_column = ctx->create_column(*ctx->get_arg_type(1), true); + } + + void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, + size_t row_num) const override { + // Key could not be null. + if ((columns[0]->is_nullable() && columns[0]->is_null(row_num)) || columns[0]->only_null()) { + return; + } + const auto& key_column = down_cast(*ColumnHelper::get_data_column(columns[0])); + this->data(state).update(ctx->mem_pool(), key_column, *columns[1], row_num, 1); + } + + void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override { + auto map_column = down_cast(ColumnHelper::get_data_column(column)); + auto& offsets = map_column->offsets().get_data(); + if (offsets[row_num + 1] > offsets[row_num]) { + this->data(state).update( + ctx->mem_pool(), + *down_cast(ColumnHelper::get_data_column(map_column->keys_column().get())), + map_column->values(), offsets[row_num], offsets[row_num + 1] - offsets[row_num]); + } + } + + void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + auto& state_impl = this->data(state); + auto* map_column = down_cast(ColumnHelper::get_data_column(to)); + + auto elem_size = state_impl.hash_map.size(); + + auto* key_column = down_cast(ColumnHelper::get_data_column(map_column->keys_column().get())); + if constexpr (lt_is_string) { + for (const auto& entry : state_impl.hash_map) { + key_column->append(Slice(entry.first.data, entry.first.size)); + map_column->values_column()->append_datum(state_impl.value_column->get(entry.second)); + } + } else { + for (const auto& entry : state_impl.hash_map) { + key_column->append(entry.first); + map_column->values_column()->append_datum(state_impl.value_column->get(entry.second)); + } + } + + if (to->is_nullable()) { + down_cast(to)->null_column_data().emplace_back(0); + } + if (map_column->keys_column()->is_nullable()) { + // Key could not be NULL. + auto* nullable_column = down_cast(map_column->keys_column().get()); + nullable_column->null_column_data().resize(nullable_column->null_column_data().size() + elem_size); + } + + auto& offsets = map_column->offsets_column()->get_data(); + offsets.push_back(offsets.back() + elem_size); + } + + void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + serialize_to_column(ctx, state, to); + } + + void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size, + ColumnPtr* dst) const override { + auto* column = down_cast(ColumnHelper::get_data_column(dst->get())); + auto key_column = column->keys_column(); + auto value_column = column->values_column(); + auto& offsets = column->offsets_column()->get_data(); + for (size_t i = 0; i < chunk_size; i++) { + if ((src[0]->is_nullable() && src[0]->is_null(i)) || src[0]->only_null()) { + offsets.push_back(offsets.back()); + continue; + } + key_column->append(*src[0], i, 1); + value_column->append(*src[1], i, 1); + offsets.push_back(offsets.back() + 1); + } + } + + std::string get_name() const override { return "map_agg"; } +}; + +} // namespace starrocks diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index b321fe0ba41d4b..8d6f1c375e7ec8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -474,6 +474,9 @@ public class FunctionSet { public static final String MAP_FROM_ARRAYS = "map_from_arrays"; public static final String MAP_KEYS = "map_keys"; public static final String MAP_SIZE = "map_size"; + + public static final String MAP_AGG = "map_agg"; + public static final String TRANSFORM_VALUES = "transform_values"; public static final String TRANSFORM_KEYS = "transform_keys"; @@ -789,6 +792,7 @@ public class FunctionSet { .add(APPROX_TOP_K) .add(INTERSECT_COUNT) .add(LC_PERCENTILE_DISC) + .add(MAP_AGG) .build(); public FunctionSet() { @@ -1173,6 +1177,9 @@ private void initAggregateBuiltins() { // Percentile registerBuiltinPercentileAggFunction(); + // map_agg + registerBuiltinMapAggFunction(); + // HLL_UNION_AGG addBuiltin(AggregateFunction.createBuiltin(HLL_UNION_AGG, Lists.newArrayList(Type.HLL), Type.BIGINT, Type.HLL, @@ -1409,6 +1416,28 @@ private void registerBuiltinArrayAggDistinctFunction() { false, false, false)); } + private void registerBuiltinMapAggFunction() { + for (ScalarType keyType : Type.getNumericTypes()) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + for (ScalarType keyType : Type.STRING_TYPES) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + + for (ScalarType keyType : Type.DATE_TYPES) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(Type.TIME, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + private void registerBuiltinArrayUniqueAggFunction() { // array_unique_agg mapping array_agg_distinct while array as input. for (ScalarType type : Type.getNumericTypes()) { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java index 1bee662560c99f..857a09d40a720d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java @@ -191,6 +191,13 @@ public Type apply(Type[] types) { } } + private static class MapAggDeduce implements java.util.function.Function { + @Override + public Type apply(Type[] types) { + return new MapType(types[0], types[1]); + } + } + private static final ImmutableMap> DEDUCE_RETURN_TYPE_FUNCTIONS = ImmutableMap.>builder() .put(FunctionSet.MAP_KEYS, new MapKeysDeduce()) @@ -215,6 +222,7 @@ public Type apply(Type[] types) { .put(FunctionSet.getAggStateName(FunctionSet.ARRAY_AGG), new ArrayAggStateDeduce()) .put(FunctionSet.getAggStateUnionName(FunctionSet.ARRAY_AGG), types -> types[0]) .put(FunctionSet.getAggStateMergeName(FunctionSet.ARRAY_AGG), new ArrayAggMergeDeduce()) + .put(FunctionSet.MAP_AGG, new MapAggDeduce()) .build(); private static Function resolveByDeducingReturnType(Function fn, Type[] inputArgTypes) { diff --git a/test/sql/test_agg_function/R/test_map_agg b/test/sql/test_agg_function/R/test_map_agg new file mode 100644 index 00000000000000..78d5b4b8523d2b --- /dev/null +++ b/test/sql/test_agg_function/R/test_map_agg @@ -0,0 +1,195 @@ +-- name: test_map_agg +CREATE TABLE t1 ( + c1 int, + c2 boolean, + c3 tinyint, + c4 int, + c5 bigint, + c6 largeint, + c7 string, + c8 double, + c9 date, + c10 datetime, + c11 array, + c12 map, + c13 struct + ) +DUPLICATE KEY(c1) +DISTRIBUTED BY HASH(c1) BUCKETS 3 +PROPERTIES ("replication_num" = "1"); +-- result: +-- !result +INSERT INTO t1 values + (1, true, 11, 111, 1111, 11111, "111111", 1.1, "2024-09-01", "2024-09-01 18:00:00", [1, 2, 3], map('key', 5.5), row(100, "abc")), + (2, false, 22, 222, 2222, 22222, "222222", 2.2, "2024-09-02", "2024-09-02 11:00:00", [3, 4, 5], map('key', 511.2), row(200, "bcd")), + (3, true, 33, 333, 3333, 33333, "333333", 3.3, "2024-09-03", "2024-09-03 00:00:00", [4, 1, 2], map('key', 666.6), row(300, "cccecd")), + (4, false, 11, 444, 4444, 44444, "444444", 4.4, "2024-09-04", "2024-09-04 12:00:00", [7, 7, 5], map('key', 444.4), row(400, "efdg")), + (5, null, null, null, null, null, null, null, null, null, null, null, null); +-- result: +-- !result +set streaming_preaggregation_mode=force_preaggregation; +-- result: +-- !result +select map_size(map_agg(c1, c3)) from t1; +-- result: +5 +-- !result +select map_agg(c1, c3)[1] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[2] from t1; +-- result: +22 +-- !result +select map_agg(c1, c3)[3] from t1; +-- result: +33 +-- !result +select map_agg(c1, c3)[4] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[5] from t1; +-- result: +None +-- !result +select map_size(map_agg(c5, c6)) from t1; +-- result: +4 +-- !result +select map_agg(c5, c6)[1111] from t1; +-- result: +11111 +-- !result +select map_size(map_agg(c6, c10)) from t1; +-- result: +4 +-- !result +select map_agg(c6, c10)[11111] from t1; +-- result: +2024-09-01 18:00:00 +-- !result +select map_agg(c6, c10)[22222] from t1; +-- result: +2024-09-02 11:00:00 +-- !result +select map_size(map_agg(c8, c5)) from t1; +-- result: +4 +-- !result +select map_agg(c8, c5)[1.1] from t1; +-- result: +1111 +-- !result +select map_agg(c8, c5)[4.4] from t1; +-- result: +4444 +-- !result +select c11, map_agg(c10, c11) res from t1 group by c11 order by c11[1]; +-- result: +None {} +[1,2,3] {"2024-09-01 18:00:00":[1,2,3]} +[3,4,5] {"2024-09-02 11:00:00":[3,4,5]} +[4,1,2] {"2024-09-03 00:00:00":[4,1,2]} +[7,7,5] {"2024-09-04 12:00:00":[7,7,5]} +-- !result +select c12, map_agg(c9, c12) res from t1 group by c12 order by c12['key']; +-- result: +None {} +{"key":5.5} {"2024-09-01":{"key":5.5}} +{"key":444.4} {"2024-09-04":{"key":444.4}} +{"key":511.2} {"2024-09-02":{"key":511.2}} +{"key":666.6} {"2024-09-03":{"key":666.6}} +-- !result +select c13, map_agg(c9, c13) res from t1 group by c13 order by c13.a; +-- result: +None {} +{"a":100,"b":"abc"} {"2024-09-01":{"a":100,"b":"abc"}} +{"a":200,"b":"bcd"} {"2024-09-02":{"a":200,"b":"bcd"}} +{"a":300,"b":"cccecd"} {"2024-09-03":{"a":300,"b":"cccecd"}} +{"a":400,"b":"efdg"} {"2024-09-04":{"a":400,"b":"efdg"}} +-- !result +set streaming_preaggregation_mode=force_streaming; +-- result: +-- !result +select map_size(map_agg(c1, c3)) from t1; +-- result: +5 +-- !result +select map_agg(c1, c3)[1] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[2] from t1; +-- result: +22 +-- !result +select map_agg(c1, c3)[3] from t1; +-- result: +33 +-- !result +select map_agg(c1, c3)[4] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[5] from t1; +-- result: +None +-- !result +select map_size(map_agg(c5, c6)) from t1; +-- result: +4 +-- !result +select map_agg(c5, c6)[1111] from t1; +-- result: +11111 +-- !result +select map_size(map_agg(c6, c10)) from t1; +-- result: +4 +-- !result +select map_agg(c6, c10)[11111] from t1; +-- result: +2024-09-01 18:00:00 +-- !result +select map_agg(c6, c10)[22222] from t1; +-- result: +2024-09-02 11:00:00 +-- !result +select map_size(map_agg(c8, c5)) from t1; +-- result: +4 +-- !result +select map_agg(c8, c5)[1.1] from t1; +-- result: +1111 +-- !result +select map_agg(c8, c5)[4.4] from t1; +-- result: +4444 +-- !result +select c11, map_agg(c10, c11) res from t1 group by c11 order by c11[1]; +-- result: +None {} +[1,2,3] {"2024-09-01 18:00:00":[1,2,3]} +[3,4,5] {"2024-09-02 11:00:00":[3,4,5]} +[4,1,2] {"2024-09-03 00:00:00":[4,1,2]} +[7,7,5] {"2024-09-04 12:00:00":[7,7,5]} +-- !result +select c12, map_agg(c9, c12) res from t1 group by c12 order by c12['key']; +-- result: +None {} +{"key":5.5} {"2024-09-01":{"key":5.5}} +{"key":444.4} {"2024-09-04":{"key":444.4}} +{"key":511.2} {"2024-09-02":{"key":511.2}} +{"key":666.6} {"2024-09-03":{"key":666.6}} +-- !result +select c13, map_agg(c9, c13) res from t1 group by c13 order by c13.a; +-- result: +None {} +{"a":100,"b":"abc"} {"2024-09-01":{"a":100,"b":"abc"}} +{"a":200,"b":"bcd"} {"2024-09-02":{"a":200,"b":"bcd"}} +{"a":300,"b":"cccecd"} {"2024-09-03":{"a":300,"b":"cccecd"}} +{"a":400,"b":"efdg"} {"2024-09-04":{"a":400,"b":"efdg"}} +-- !result \ No newline at end of file diff --git a/test/sql/test_agg_function/T/test_map_agg b/test/sql/test_agg_function/T/test_map_agg new file mode 100644 index 00000000000000..da14a14c4dd543 --- /dev/null +++ b/test/sql/test_agg_function/T/test_map_agg @@ -0,0 +1,74 @@ +-- name: test_map_agg +CREATE TABLE t1 ( + c1 int, + c2 boolean, + c3 tinyint, + c4 int, + c5 bigint, + c6 largeint, + c7 string, + c8 double, + c9 date, + c10 datetime, + c11 array, + c12 map, + c13 struct + ) +DUPLICATE KEY(c1) +DISTRIBUTED BY HASH(c1) BUCKETS 3 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO t1 values + (1, true, 11, 111, 1111, 11111, "111111", 1.1, "2024-09-01", "2024-09-01 18:00:00", [1, 2, 3], map('key', 5.5), row(100, "abc")), + (2, false, 22, 222, 2222, 22222, "222222", 2.2, "2024-09-02", "2024-09-02 11:00:00", [3, 4, 5], map('key', 511.2), row(200, "bcd")), + (3, true, 33, 333, 3333, 33333, "333333", 3.3, "2024-09-03", "2024-09-03 00:00:00", [4, 1, 2], map('key', 666.6), row(300, "cccecd")), + (4, false, 11, 444, 4444, 44444, "444444", 4.4, "2024-09-04", "2024-09-04 12:00:00", [7, 7, 5], map('key', 444.4), row(400, "efdg")), + (5, null, null, null, null, null, null, null, null, null, null, null, null); + + +set streaming_preaggregation_mode=force_preaggregation; +select map_size(map_agg(c1, c3)) from t1; +select map_agg(c1, c3)[1] from t1; +select map_agg(c1, c3)[2] from t1; +select map_agg(c1, c3)[3] from t1; +select map_agg(c1, c3)[4] from t1; +select map_agg(c1, c3)[5] from t1; + +select map_size(map_agg(c5, c6)) from t1; +select map_agg(c5, c6)[1111] from t1; + +select map_size(map_agg(c6, c10)) from t1; +select map_agg(c6, c10)[11111] from t1; +select map_agg(c6, c10)[22222] from t1; + +select map_size(map_agg(c8, c5)) from t1; +select map_agg(c8, c5)[1.1] from t1; +select map_agg(c8, c5)[4.4] from t1; + +select c11, map_agg(c10, c11) res from t1 group by c11 order by c11[1]; +select c12, map_agg(c9, c12) res from t1 group by c12 order by c12['key']; +select c13, map_agg(c9, c13) res from t1 group by c13 order by c13.a; + +set streaming_preaggregation_mode=force_streaming; + +select map_size(map_agg(c1, c3)) from t1; +select map_agg(c1, c3)[1] from t1; +select map_agg(c1, c3)[2] from t1; +select map_agg(c1, c3)[3] from t1; +select map_agg(c1, c3)[4] from t1; +select map_agg(c1, c3)[5] from t1; + +select map_size(map_agg(c5, c6)) from t1; +select map_agg(c5, c6)[1111] from t1; + +select map_size(map_agg(c6, c10)) from t1; +select map_agg(c6, c10)[11111] from t1; +select map_agg(c6, c10)[22222] from t1; + +select map_size(map_agg(c8, c5)) from t1; +select map_agg(c8, c5)[1.1] from t1; +select map_agg(c8, c5)[4.4] from t1; + +select c11, map_agg(c10, c11) res from t1 group by c11 order by c11[1]; +select c12, map_agg(c9, c12) res from t1 group by c12 order by c12['key']; +select c13, map_agg(c9, c13) res from t1 group by c13 order by c13.a; \ No newline at end of file