diff --git a/be/src/exprs/column_ref.cpp b/be/src/exprs/column_ref.cpp index 8799f91c1647c..b5c9874f22ee5 100644 --- a/be/src/exprs/column_ref.cpp +++ b/be/src/exprs/column_ref.cpp @@ -30,6 +30,9 @@ int ColumnRef::get_slot_ids(std::vector* slot_ids) const { slot_ids->push_back(_column_id); return 1; } +void ColumnRef::for_each_slot_id(const std::function& cb) const { + cb(_column_id); +} bool ColumnRef::is_bound(const std::vector& tuple_ids) const { for (int tuple_id : tuple_ids) { diff --git a/be/src/exprs/column_ref.h b/be/src/exprs/column_ref.h index 041e341b9e54f..2d6e768833e0c 100644 --- a/be/src/exprs/column_ref.h +++ b/be/src/exprs/column_ref.h @@ -45,6 +45,7 @@ class ColumnRef final : public Expr { bool is_constant() const override { return false; } int get_slot_ids(std::vector* slot_ids) const override; + void for_each_slot_id(const std::function& cb) const override; std::string debug_string() const override; diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index f1c915387960b..c5cd19af9638e 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -655,6 +655,12 @@ int Expr::get_slot_ids(std::vector* slot_ids) const { return n; } +void Expr::for_each_slot_id(const std::function& cb) const { + for (auto child : _children) { + child->for_each_slot_id(cb); + } +} + int Expr::get_subfields(std::vector>* subfields) const { int n = 0; diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index 27ea64c804c5f..a82d29681df52 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -149,6 +149,8 @@ class Expr { virtual int get_subfields(std::vector>* subfields) const; + virtual void for_each_slot_id(const std::function& cb) const; + /// Create expression tree from the list of nodes contained in texpr within 'pool'. /// Returns the root of expression tree in 'expr' and the corresponding ExprContext in /// 'ctx'. diff --git a/be/src/exprs/lambda_function.cpp b/be/src/exprs/lambda_function.cpp index d2a186e0a521d..b7ea43890dac3 100644 --- a/be/src/exprs/lambda_function.cpp +++ b/be/src/exprs/lambda_function.cpp @@ -110,12 +110,9 @@ Status LambdaFunction::collect_lambda_argument_ids() { } SlotId LambdaFunction::max_used_slot_id() const { - std::vector ids; - for (auto child : _children) { - child->get_slot_ids(&ids); - } - DCHECK(!ids.empty()); - return *std::max_element(ids.begin(), ids.end()); + SlotId max_slot_id = 0; + for_each_slot_id([&max_slot_id](SlotId slot_id) { max_slot_id = std::max(max_slot_id, slot_id); }); + return max_slot_id; } Status LambdaFunction::collect_common_sub_exprs() { @@ -216,12 +213,13 @@ StatusOr LambdaFunction::evaluate_checked(ExprContext* context, Chunk } int LambdaFunction::get_slot_ids(std::vector* slot_ids) const { + // get_slot_ids only return capture slot ids, + // if expr is already prepared, we can get result from _captured_slot_ids, otherwise, get result from lambda expr if (_is_prepared) { slot_ids->insert(slot_ids->end(), _captured_slot_ids.begin(), _captured_slot_ids.end()); - slot_ids->insert(slot_ids->end(), _arguments_ids.begin(), _arguments_ids.end()); - return _captured_slot_ids.size() + _arguments_ids.size(); + return _captured_slot_ids.size(); } else { - return Expr::get_slot_ids(slot_ids); + return get_child(0)->get_slot_ids(slot_ids); } } diff --git a/be/src/exprs/placeholder_ref.h b/be/src/exprs/placeholder_ref.h index 6f2cbf97d11ed..ab7a4695872f2 100644 --- a/be/src/exprs/placeholder_ref.h +++ b/be/src/exprs/placeholder_ref.h @@ -32,6 +32,7 @@ class PlaceHolderRef final : public Expr { slot_ids->emplace_back(_column_id); return 1; } + void for_each_slot_id(const std::function& cb) const override { cb(_column_id); } private: SlotId _column_id; diff --git a/test/sql/test_array_fn/R/test_array_map_2 b/test/sql/test_array_fn/R/test_array_map_2 index db23d0e0695b7..41e7e11c89768 100644 --- a/test/sql/test_array_fn/R/test_array_map_2 +++ b/test/sql/test_array_fn/R/test_array_map_2 @@ -154,4 +154,16 @@ None None None None +-- !result +select array_map(x->array_map(x->x+100,x), [[1,2,3]]); +-- result: +[[101,102,103]] +-- !result +select array_map(x->array_map(x->x+100,x), [[1,2,3], [null]]); +-- result: +[[101,102,103],[null]] +-- !result +select array_map(x->array_map(x->array_map(x->x+100,x),x), [[[1,2,3]]]); +-- result: +[[[101,102,103]]] -- !result \ No newline at end of file diff --git a/test/sql/test_array_fn/T/test_array_map_2 b/test/sql/test_array_fn/T/test_array_map_2 index 3743301ed94c8..2cbbdc304e264 100644 --- a/test/sql/test_array_fn/T/test_array_map_2 +++ b/test/sql/test_array_fn/T/test_array_map_2 @@ -60,3 +60,6 @@ select array_map((x,y,z)->x+y+z, [1,2],[2,3],[3,4]) from t; select array_map((x,y,z)->x+y+z, [1,2],[2,null],[3,4]) from t; select array_map((x,y,z)->x+y+z, [1,2],[2,null],null) from t; +select array_map(x->array_map(x->x+100,x), [[1,2,3]]); +select array_map(x->array_map(x->x+100,x), [[1,2,3], [null]]); +select array_map(x->array_map(x->array_map(x->x+100,x),x), [[[1,2,3]]]);