From 644a6f999ad0e9f9deaa48fc54fdfaad30a785ee Mon Sep 17 00:00:00 2001 From: Song Jiacheng Date: Tue, 3 Sep 2024 10:30:52 +0800 Subject: [PATCH] [Enhancement] Support aggregate function map_agg. Signed-off-by: Song Jiacheng --- .../exprs/agg/factory/aggregate_factory.hpp | 1 + .../agg/factory/aggregate_resolver_avg.cpp | 25 +++ be/src/exprs/agg/map_agg.h | 144 +++++++++++++ .../com/starrocks/catalog/FunctionSet.java | 28 +++ test/sql/test_agg_function/R/test_map_agg | 202 ++++++++++++++++++ test/sql/test_agg_function/T/test_map_agg | 76 +++++++ 6 files changed, 476 insertions(+) create mode 100644 be/src/exprs/agg/map_agg.h create mode 100644 test/sql/test_agg_function/R/test_map_agg create mode 100644 test/sql/test_agg_function/T/test_map_agg diff --git a/be/src/exprs/agg/factory/aggregate_factory.hpp b/be/src/exprs/agg/factory/aggregate_factory.hpp index 996900eb295aff..43555b01a8295e 100644 --- a/be/src/exprs/agg/factory/aggregate_factory.hpp +++ b/be/src/exprs/agg/factory/aggregate_factory.hpp @@ -40,6 +40,7 @@ #include "exprs/agg/hll_union_count.h" #include "exprs/agg/intersect_count.h" #include "exprs/agg/mann_whitney.h" +#include "exprs/agg/map_agg.h" #include "exprs/agg/maxmin.h" #include "exprs/agg/maxmin_by.h" #include "exprs/agg/nullable_aggregate.h" diff --git a/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp b/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp index 5bd5ce14db8d9f..cd499f8b46e6f7 100644 --- a/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp +++ b/be/src/exprs/agg/factory/aggregate_resolver_avg.cpp @@ -117,6 +117,30 @@ struct ArrayAggDistinctDispatcher { } }; +struct MapAggDispatcher { + template + void operator()(AggregateFuncResolver* resolver) { + if constexpr (lt_is_aggregate) { + using KeyCppType = RunTimeCppType; + if constexpr (lt_is_largeint) { + using MyHashMap = phmap::flat_hash_map>; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else if constexpr (lt_is_fixedlength) { + using MyHashMap = phmap::flat_hash_map>; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else if constexpr (lt_is_string) { + using MyHashMap = phmap::flat_hash_map; + auto func = std::make_shared>(); + resolver->add_aggregate_mapping_notnull("map_agg", false, func); + } else { + throw std::runtime_error("map_agg does not support key type " + type_to_string(kt)); + } + } + } +}; + void AggregateFuncResolver::register_avg() { for (auto type : aggregate_types()) { type_dispatch_all(type, AvgDispatcher(), this); @@ -124,6 +148,7 @@ void AggregateFuncResolver::register_avg() { type_dispatch_all(type, ArrayAggDistinctDispatcher(), this); type_dispatch_all(type, ArrayUnionAggDispatcher(), this); type_dispatch_all(type, ArrayUniqueAggDispatcher(), this); + type_dispatch_all(type, MapAggDispatcher(), this); } type_dispatch_all(TYPE_JSON, ArrayAggDispatcher(), this); add_decimal_mapping("decimal_avg"); diff --git a/be/src/exprs/agg/map_agg.h b/be/src/exprs/agg/map_agg.h new file mode 100644 index 00000000000000..1cd4d154538441 --- /dev/null +++ b/be/src/exprs/agg/map_agg.h @@ -0,0 +1,144 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "column/binary_column.h" +#include "column/column.h" +#include "column/column_helper.h" +#include "column/fixed_length_column.h" +#include "column/hash_set.h" +#include "column/map_column.h" +#include "column/type_traits.h" +#include "exprs/agg/aggregate.h" +#include "exprs/function_context.h" +#include "gutil/casts.h" +#include "util/phmap/phmap.h" +#include "util/time.h" + +namespace starrocks { + +template > +struct MapAggAggregateFunctionState : public AggregateFunctionEmptyState { + using KeyColumnType = RunTimeColumnType; + using KeyType = typename SliceHashSet::key_type; + + MyHashMap hash_map; + void update(MemPool* mem_pool, const KeyColumnType& arg_key_column, const Column& arg_value_column, size_t offset, + size_t count) { + for (int i = offset; i < offset + count; i++) { + if constexpr (!lt_is_string) { + auto key = arg_key_column.get_data()[i]; + auto value = arg_value_column.get(i); + hash_map.emplace(key, value); + } else { + auto raw_key = arg_key_column.get_slice(i); + KeyType key(raw_key); + if (!hash_map.contains(key)) { + uint8_t* pos = mem_pool->allocate(key.size); + memcpy(pos, key.data, key.size); + hash_map.emplace(Slice(pos, key.size), arg_value_column.get(i)); + } + } + } + } +}; + +template > +class MapAggAggregateFunction final : public AggregateFunctionBatchHelper, + MapAggAggregateFunction> { +public: + using KeyColumnType = RunTimeColumnType; + using CppType = RunTimeCppType; + + void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, + size_t row_num) const override { + // Key could not be null. + if ((columns[0]->is_nullable() && columns[0]->is_null(row_num)) || columns[0]->only_null()) { + return; + } + const auto& key_column = down_cast(*ColumnHelper::get_data_column(columns[0])); + this->data(state).update(ctx->mem_pool(), key_column, *columns[1], row_num, 1); + } + + void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override { + auto map_column = down_cast(ColumnHelper::get_data_column(column)); + auto& offsets = map_column->offsets().get_data(); + if (offsets[row_num + 1] > offsets[row_num]) { + this->data(state).update( + ctx->mem_pool(), + *down_cast(ColumnHelper::get_data_column(map_column->keys_column().get())), + map_column->values(), offsets[row_num], offsets[row_num + 1] - offsets[row_num]); + } + } + + void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + auto& state_impl = this->data(state); + auto* map_column = down_cast(ColumnHelper::get_data_column(to)); + + auto elem_size = state_impl.hash_map.size(); + + auto* key_column = down_cast(ColumnHelper::get_data_column(map_column->keys_column().get())); + if constexpr (lt_is_string) { + for (const auto& entry : state_impl.hash_map) { + key_column->append(Slice(entry.first.data, entry.first.size)); + map_column->values_column()->append_datum(entry.second); + } + } else { + for (const auto& entry : state_impl.hash_map) { + key_column->append(entry.first); + map_column->values_column()->append_datum(entry.second); + } + } + + if (to->is_nullable()) { + down_cast(to)->null_column_data().emplace_back(0); + } + if (map_column->keys_column()->is_nullable()) { + // Key could not be NULL. + auto* nullable_column = down_cast(map_column->keys_column().get()); + nullable_column->null_column_data().resize(nullable_column->null_column_data().size() + elem_size); + } + + auto& offsets = map_column->offsets_column()->get_data(); + offsets.push_back(offsets.back() + elem_size); + } + + void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + serialize_to_column(ctx, state, to); + } + + void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size, + ColumnPtr* dst) const override { + auto* column = down_cast(ColumnHelper::get_data_column(dst->get())); + auto key_column = column->keys_column(); + auto value_column = column->values_column(); + auto& offsets = column->offsets_column()->get_data(); + for (size_t i = 0; i < chunk_size; i++) { + if ((src[0]->is_nullable() && src[0]->is_null(i)) || src[0]->only_null()) { + offsets.push_back(offsets.back()); + continue; + } + key_column->append(*src[0], i, 1); + value_column->append(*src[1], i, 1); + offsets.push_back(offsets.back() + 1); + } + } + + std::string get_name() const override { return "map_agg"; } +}; + +} // namespace starrocks diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index f46efbdcad4e29..af53a3513e5b55 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -461,6 +461,9 @@ public class FunctionSet { public static final String MAP_FROM_ARRAYS = "map_from_arrays"; public static final String MAP_KEYS = "map_keys"; public static final String MAP_SIZE = "map_size"; + + public static final String MAP_AGG = "map_agg"; + public static final String TRANSFORM_VALUES = "transform_values"; public static final String TRANSFORM_KEYS = "transform_keys"; @@ -1093,6 +1096,9 @@ private void initAggregateBuiltins() { // Percentile registerBuiltinPercentileAggFunction(); + // map_agg + registerBuiltinMapAggFunction(); + // HLL_UNION_AGG addBuiltin(AggregateFunction.createBuiltin(HLL_UNION_AGG, Lists.newArrayList(Type.HLL), Type.BIGINT, Type.HLL, @@ -1323,6 +1329,28 @@ private void registerBuiltinArrayAggDistinctFunction() { false, false, false)); } + private void registerBuiltinMapAggFunction() { + for (ScalarType keyType : Type.getNumericTypes()) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + for (ScalarType keyType : Type.STRING_TYPES) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + + for (ScalarType keyType : Type.DATE_TYPES) { + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(keyType, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.MAP_AGG, + Lists.newArrayList(Type.TIME, Type.ANY_ELEMENT), Type.ANY_MAP, null, + false, false, false)); + } + private void registerBuiltinArrayUniqueAggFunction() { // array_unique_agg mapping array_agg_distinct while array as input. for (ScalarType type : Type.getNumericTypes()) { diff --git a/test/sql/test_agg_function/R/test_map_agg b/test/sql/test_agg_function/R/test_map_agg new file mode 100644 index 00000000000000..68f0c7a75924eb --- /dev/null +++ b/test/sql/test_agg_function/R/test_map_agg @@ -0,0 +1,202 @@ +-- name: test_map_agg +CREATE TABLE t1 ( + c1 int, + c2 boolean, + c3 tinyint, + c4 int, + c5 bigint, + c6 largeint, + c7 string, + c8 double, + c9 date, + c10 datetime, + c11 array, + c12 map, + c13 struct + ) +DUPLICATE KEY(c1) +DISTRIBUTED BY HASH(c1) BUCKETS 3 +PROPERTIES ("replication_num" = "1"); +-- result: +-- !result +INSERT INTO t1 values + (1, true, 11, 111, 1111, 11111, "111111", 1.1, "2024-09-01", "2024-09-01 18:00:00", [1, 2, 3], map('key', 5.5), row(100, "abc")), + (2, false, 22, 222, 2222, 22222, "222222", 2.2, "2024-09-02", "2024-09-02 11:00:00", [3, 4, 5], map('key', 511.2), row(200, "bcd")), + (3, true, 33, 333, 3333, 33333, "333333", 3.3, "2024-09-03", "2024-09-03 00:00:00", [4, 1, 2], map('key', 666.6), row(300, "cccecd")), + (4, false, 11, 444, 4444, 44444, "444444", 4.4, "2024-09-04", "2024-09-04 12:00:00", [7, 7, 5], map('key', 444.4), row(400, "efdg")), + (5, null, null, null, null, null, null, null, null, null, null, null, null); +-- result: +-- !result +set streaming_preaggregation_mode=force_preaggregation; +-- result: +-- !result +select map_size(map_agg(c1, c3)) from t1; +-- result: +5 +-- !result +select map_agg(c1, c3)[1] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[2] from t1; +-- result: +22 +-- !result +select map_agg(c1, c3)[3] from t1; +-- result: +33 +-- !result +select map_agg(c1, c3)[4] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[5] from t1; +-- result: +None +-- !result +select map_size(map_agg(c5, c6)) from t1; +-- result: +4 +-- !result +select map_agg(c5, c6)[1111] from t1; +-- result: +11111 +-- !result +select map_agg(c5, c6)[null] from t1; +-- result: +E: (1064, 'Key not present in map: 67') +-- !result +select map_size(map_agg(c6, c10)) from t1; +-- result: +4 +-- !result +select map_agg(c6, c10)[11111] from t1; +-- result: +2024-09-01 18:00:00 +-- !result +select map_agg(c6, c10)[22222] from t1; +-- result: +2024-09-02 11:00:00 +-- !result +select map_size(map_agg(c8, c5)) from t1; +-- result: +4 +-- !result +select map_agg(c8, c5)[1.1] from t1; +-- result: +1111 +-- !result +select map_agg(c8, c5)[4.4] from t1; +-- result: +4444 +-- !result +select c11, map_agg(c10, c11) res from t1 group by c10 order by c11[1]; +-- result: +None {} +[1,2,3] {"2024-09-01 18:00:00":[1,2,3]} +[3,4,5] {"2024-09-02 11:00:00":[3,4,5]} +[4,1,2] {"2024-09-03 00:00:00":[4,1,2]} +[7,7,5] {"2024-09-04 12:00:00":[7,7,5]} +-- !result +select c12, map_agg(c9, c12) res from t1 group by c10 order by c12['key']; +-- result: +None {} +{"key":5.5} {"2024-09-01":{"Ope":5.5}} +{"key":444.4} {"2024-09-04":{"key":444.4}} +{"key":511.2} {"2024-09-02":{"key":511.2}} +{"key":666.6} {"2024-09-03":{"key":666.6}} +-- !result +select c13, map_agg(c9, c13) res from t1 group by c10 order by c13.a; +-- result: +None {} +{"a":100,"b":"abc"} {"2024-09-01":{"a":100,"b":"abc"}} +{"a":200,"b":"bcd"} {"2024-09-02":{"a":200,"b":"bcd"}} +{"a":300,"b":"cccecd"} {"2024-09-03":{"a":300,"b":"cccecd"}} +{"a":400,"b":"efdg"} {"2024-09-04":{"a":400,"b":"Oper"}} +-- !result +set streaming_preaggregation_mode=force_streaming; +-- result: +-- !result +select map_size(map_agg(c1, c3)) from t1; +-- result: +5 +-- !result +select map_agg(c1, c3)[1] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[2] from t1; +-- result: +22 +-- !result +select map_agg(c1, c3)[3] from t1; +-- result: +33 +-- !result +select map_agg(c1, c3)[4] from t1; +-- result: +11 +-- !result +select map_agg(c1, c3)[5] from t1; +-- result: +None +-- !result +select map_size(map_agg(c5, c6)) from t1; +-- result: +4 +-- !result +select map_agg(c5, c6)[1111] from t1; +-- result: +11111 +-- !result +select map_agg(c5, c6)[null] from t1; +-- result: +E: (1064, 'Key not present in map: 139852725092352') +-- !result +select map_size(map_agg(c6, c10)) from t1; +-- result: +4 +-- !result +select map_agg(c6, c10)[11111] from t1; +-- result: +2024-09-01 18:00:00 +-- !result +select map_agg(c6, c10)[22222] from t1; +-- result: +2024-09-02 11:00:00 +-- !result +select map_size(map_agg(c8, c5)) from t1; +-- result: +4 +-- !result +select map_agg(c8, c5)[1.1] from t1; +-- result: +1111 +-- !result +select map_agg(c8, c5)[4.4] from t1; +-- result: +4444 +-- !result +select c11, map_agg(c10, c11) res from t1 group by c10 order by c11[1]; +-- result: +None {} +[1,2,3] {"2024-09-01 18:00:00":[1,2,3]} +[3,4,5] {"2024-09-02 11:00:00":[3,4,5]} +[4,1,2] {"2024-09-03 00:00:00":[4,1,2]} +[7,7,5] {"2024-09-04 12:00:00":[7,7,5]} +-- !result +select c12, map_agg(c9, c12) res from t1 group by c10 order by c12['key']; +-- result: +None {} +{"key":5.5} {"2024-09-01":{"key":5.5}} +{"key":444.4} {"2024-09-04":{"Joi":444.4}} +{"key":511.2} {"2024-09-02":{"key":511.2}} +{"key":666.6} {"2024-09-03":{"key":666.6}} +-- !result +select c13, map_agg(c9, c13) res from t1 group by c10 order by c13.a; +-- result: +None {} +{"a":100,"b":"abc"} {"2024-09-01":{"a":100,"b":"abc"}} +{"a":200,"b":"bcd"} {"2024-09-02":{"a":200,"b":"bcd"}} +{"a":300,"b":"cccecd"} {"2024-09-03":{"a":300,"b":"cccecd"}} +{"a":400,"b":"efdg"} {"2024-09-04":{"a":400,"b":"Runt"}} \ No newline at end of file diff --git a/test/sql/test_agg_function/T/test_map_agg b/test/sql/test_agg_function/T/test_map_agg new file mode 100644 index 00000000000000..079f2a79ad5ff1 --- /dev/null +++ b/test/sql/test_agg_function/T/test_map_agg @@ -0,0 +1,76 @@ +-- name: test_map_agg +CREATE TABLE t1 ( + c1 int, + c2 boolean, + c3 tinyint, + c4 int, + c5 bigint, + c6 largeint, + c7 string, + c8 double, + c9 date, + c10 datetime, + c11 array, + c12 map, + c13 struct + ) +DUPLICATE KEY(c1) +DISTRIBUTED BY HASH(c1) BUCKETS 3 +PROPERTIES ("replication_num" = "1"); + +INSERT INTO t1 values + (1, true, 11, 111, 1111, 11111, "111111", 1.1, "2024-09-01", "2024-09-01 18:00:00", [1, 2, 3], map('key', 5.5), row(100, "abc")), + (2, false, 22, 222, 2222, 22222, "222222", 2.2, "2024-09-02", "2024-09-02 11:00:00", [3, 4, 5], map('key', 511.2), row(200, "bcd")), + (3, true, 33, 333, 3333, 33333, "333333", 3.3, "2024-09-03", "2024-09-03 00:00:00", [4, 1, 2], map('key', 666.6), row(300, "cccecd")), + (4, false, 11, 444, 4444, 44444, "444444", 4.4, "2024-09-04", "2024-09-04 12:00:00", [7, 7, 5], map('key', 444.4), row(400, "efdg")), + (5, null, null, null, null, null, null, null, null, null, null, null, null); + + +set streaming_preaggregation_mode=force_preaggregation; +select map_size(map_agg(c1, c3)) from t1; +select map_agg(c1, c3)[1] from t1; +select map_agg(c1, c3)[2] from t1; +select map_agg(c1, c3)[3] from t1; +select map_agg(c1, c3)[4] from t1; +select map_agg(c1, c3)[5] from t1; + +select map_size(map_agg(c5, c6)) from t1; +select map_agg(c5, c6)[1111] from t1; +select map_agg(c5, c6)[null] from t1; + +select map_size(map_agg(c6, c10)) from t1; +select map_agg(c6, c10)[11111] from t1; +select map_agg(c6, c10)[22222] from t1; + +select map_size(map_agg(c8, c5)) from t1; +select map_agg(c8, c5)[1.1] from t1; +select map_agg(c8, c5)[4.4] from t1; + +select c11, map_agg(c10, c11) res from t1 group by c10 order by c11[1]; +select c12, map_agg(c9, c12) res from t1 group by c10 order by c12['key']; +select c13, map_agg(c9, c13) res from t1 group by c10 order by c13.a; + +set streaming_preaggregation_mode=force_streaming; + +select map_size(map_agg(c1, c3)) from t1; +select map_agg(c1, c3)[1] from t1; +select map_agg(c1, c3)[2] from t1; +select map_agg(c1, c3)[3] from t1; +select map_agg(c1, c3)[4] from t1; +select map_agg(c1, c3)[5] from t1; + +select map_size(map_agg(c5, c6)) from t1; +select map_agg(c5, c6)[1111] from t1; +select map_agg(c5, c6)[null] from t1; + +select map_size(map_agg(c6, c10)) from t1; +select map_agg(c6, c10)[11111] from t1; +select map_agg(c6, c10)[22222] from t1; + +select map_size(map_agg(c8, c5)) from t1; +select map_agg(c8, c5)[1.1] from t1; +select map_agg(c8, c5)[4.4] from t1; + +select c11, map_agg(c10, c11) res from t1 group by c10 order by c11[1]; +select c12, map_agg(c9, c12) res from t1 group by c10 order by c12['key']; +select c13, map_agg(c9, c13) res from t1 group by c10 order by c13.a; \ No newline at end of file