Skip to content

Commit

Permalink
Added ast tree to simplify expression lifetime management (#17156)
Browse files Browse the repository at this point in the history
This merge request follows up on #10744.
It attempts to simplify managing expressions by adding a class called an ast tree. The ast tree manages and holds related expressions together. When the tree is destroyed, all the expressions are also destroyed. Ideally we would use a bump allocator for allocating the expressions instead of `std::vector<std::unique_ptr<expression>>`.

We'd also ideally use a `cuda::std::inplace_vector` for storing the operands of the `operation` class, but that's in a newer version of CCCL.

Authors:
  - Basit Ayantunde (https://github.com/lamarrr)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Lawrence Mitchell (https://github.com/wence-)
  - Bradley Dice (https://github.com/bdice)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #17156
  • Loading branch information
lamarrr authored Nov 7, 2024
1 parent e29e0ab commit 4cbc15a
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 20 deletions.
6 changes: 5 additions & 1 deletion cpp/include/cudf/ast/detail/expression_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <cudf/ast/expressions.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>

#include <thrust/scan.h>

#include <functional>
#include <numeric>
Expand Down Expand Up @@ -296,7 +300,7 @@ class expression_parser {
* @return The indices of the operands stored in the data references.
*/
std::vector<cudf::size_type> visit_operands(
std::vector<std::reference_wrapper<expression const>> operands);
cudf::host_span<std::reference_wrapper<cudf::ast::expression const> const> operands);

/**
* @brief Add a data reference to the internal list.
Expand Down
100 changes: 97 additions & 3 deletions cpp/include/cudf/ast/expressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <cudf/utilities/error.hpp>

#include <cstdint>
#include <memory>
#include <vector>

namespace CUDF_EXPORT cudf {
namespace ast {
Expand Down Expand Up @@ -478,7 +480,7 @@ class operation : public expression {
*
* @return Vector of operands
*/
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> get_operands() const
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> const& get_operands() const
{
return operands;
}
Expand Down Expand Up @@ -506,8 +508,8 @@ class operation : public expression {
};

private:
ast_operator const op;
std::vector<std::reference_wrapper<expression const>> const operands;
ast_operator op;
std::vector<std::reference_wrapper<expression const>> operands;
};

/**
Expand Down Expand Up @@ -552,6 +554,98 @@ class column_name_reference : public expression {
std::string column_name;
};

/**
* @brief An AST expression tree. It owns and contains multiple dependent expressions. All the
* expressions are destroyed when the tree is destructed.
*/
class tree {
public:
/**
* @brief construct an empty ast tree
*/
tree() = default;

/**
* @brief Moves the ast tree
*/
tree(tree&&) = default;

/**
* @brief move-assigns the AST tree
* @returns a reference to the move-assigned tree
*/
tree& operator=(tree&&) = default;

~tree() = default;

// the tree is not copyable
tree(tree const&) = delete;
tree& operator=(tree const&) = delete;

/**
* @brief Add an expression to the AST tree
* @param args Arguments to use to construct the ast expression
* @returns a reference to the added expression
*/
template <typename Expr, typename... Args>
Expr const& emplace(Args&&... args)
{
static_assert(std::is_base_of_v<expression, Expr>);
auto expr = std::make_shared<Expr>(std::forward<Args>(args)...);
Expr const& expr_ref = *expr;
expressions.emplace_back(std::static_pointer_cast<expression>(std::move(expr)));
return expr_ref;
}

/**
* @brief Add an expression to the AST tree
* @param expr AST expression to be added
* @returns a reference to the added expression
*/
template <typename Expr>
Expr const& push(Expr expr)
{
return emplace<Expr>(std::move(expr));
}

/**
* @brief get the first expression in the tree
* @returns the first inserted expression into the tree
*/
expression const& front() const { return *expressions.front(); }

/**
* @brief get the last expression in the tree
* @returns the last inserted expression into the tree
*/
expression const& back() const { return *expressions.back(); }

/**
* @brief get the number of expressions added to the tree
* @returns the number of expressions added to the tree
*/
size_t size() const { return expressions.size(); }

/**
* @brief get the expression at an index in the tree. Index is checked.
* @param index index of expression in the ast tree
* @returns the expression at the specified index
*/
expression const& at(size_t index) { return *expressions.at(index); }

/**
* @brief get the expression at an index in the tree. Index is unchecked.
* @param index index of expression in the ast tree
* @returns the expression at the specified index
*/
expression const& operator[](size_t index) const { return *expressions[index]; }

private:
// TODO: use better ownership semantics, the shared_ptr here is redundant. Consider using a bump
// allocator with type-erased deleters.
std::vector<std::shared_ptr<expression>> expressions;
};

/** @} */ // end of group
} // namespace ast

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/ast/expression_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ cudf::data_type expression_parser::output_type() const
}

std::vector<cudf::size_type> expression_parser::visit_operands(
std::vector<std::reference_wrapper<expression const>> operands)
cudf::host_span<std::reference_wrapper<expression const> const> operands)
{
auto operand_data_reference_indices = std::vector<cudf::size_type>();
for (auto const& operand : operands) {
Expand Down
24 changes: 16 additions & 8 deletions cpp/src/ast/expressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,41 @@
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>

#include <stdexcept>

namespace cudf {
namespace ast {

operation::operation(ast_operator op, expression const& input) : op(op), operands({input})
operation::operation(ast_operator op, expression const& input) : op{op}, operands{input}
{
if (cudf::ast::detail::ast_operator_arity(op) != 1) {
CUDF_FAIL("The provided operator is not a unary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1,
"The provided operator is not a unary operator.",
std::invalid_argument);
}

operation::operation(ast_operator op, expression const& left, expression const& right)
: op(op), operands({left, right})
: op{op}, operands{left, right}
{
if (cudf::ast::detail::ast_operator_arity(op) != 2) {
CUDF_FAIL("The provided operator is not a binary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2,
"The provided operator is not a binary operator.",
std::invalid_argument);
}

cudf::size_type literal::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type operation::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_name_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
Expand All @@ -60,16 +65,19 @@ auto literal::accept(detail::expression_transformer& visitor) const
{
return visitor.visit(*this);
}

auto column_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto operation::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto column_name_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

Expand Down Expand Up @@ -373,7 +374,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -553,7 +554,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi

std::vector<std::reference_wrapper<ast::expression const>>
named_to_reference_converter::visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -623,7 +624,7 @@ class names_from_expression : public ast::detail::expression_transformer {
}

private:
void visit_operands(std::vector<std::reference_wrapper<ast::expression const>> operands)
void visit_operands(cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
for (auto const& operand : operands) {
operand.get().accept(*this);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands);
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands);

std::unordered_map<std::string, size_type> column_name_to_index;
std::optional<std::reference_wrapper<ast::expression const>> _stats_expr;
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ ConfigureTest(ENCODE_TEST encode/encode_tests.cpp)

# ##################################################################################################
# * ast tests -------------------------------------------------------------------------------------
ConfigureTest(AST_TEST ast/transform_tests.cpp)
ConfigureTest(AST_TEST ast/transform_tests.cpp ast/ast_tree_tests.cpp)

# ##################################################################################################
# * lists tests ----------------------------------------------------------------------------------
Expand Down
79 changes: 79 additions & 0 deletions cpp/tests/ast/ast_tree_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/testing_main.hpp>

#include <cudf/ast/expressions.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/transform.hpp>
#include <cudf/types.hpp>

template <typename T>
using column_wrapper = cudf::test::fixed_width_column_wrapper<T>;

TEST(AstTreeTest, ExpressionTree)
{
namespace ast = cudf::ast;
using op = ast::ast_operator;
using operation = ast::operation;

// computes (y = mx + c)... and linearly interpolates them using interpolator t
auto m0_col = column_wrapper<float>{10, 20, 50, 100};
auto x0_col = column_wrapper<float>{10, 5, 2, 1};
auto c0_col = column_wrapper<float>{100, 100, 100, 100};

auto m1_col = column_wrapper<float>{10, 20, 50, 100};
auto x1_col = column_wrapper<float>{20, 10, 4, 2};
auto c1_col = column_wrapper<float>{200, 200, 200, 200};

auto one_scalar = cudf::numeric_scalar<float>{1};
auto t_scalar = cudf::numeric_scalar<float>{0.5F};

auto table = cudf::table_view{{m0_col, x0_col, c0_col, m1_col, x1_col, c1_col}};

ast::tree tree{};

auto const& one = tree.push(ast::literal{one_scalar});
auto const& t = tree.push(ast::literal{t_scalar});
auto const& m0 = tree.push(ast::column_reference(0));
auto const& x0 = tree.push(ast::column_reference(1));
auto const& c0 = tree.push(ast::column_reference(2));
auto const& m1 = tree.push(ast::column_reference(3));
auto const& x1 = tree.push(ast::column_reference(4));
auto const& c1 = tree.push(ast::column_reference(5));

// compute: y0 = m0 x0 + c0
auto const& y0 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m0, x0}), c0});

// compute: y1 = m1 x1 + c1
auto const& y1 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m1, x1}), c1});

// compute weighted: (1 - t) * y0
auto const& y0_w = tree.push(operation{op::MUL, tree.push(operation{op::SUB, one, t}), y0});

// compute weighted: y = t * y1
auto const& y1_w = tree.push(operation{op::MUL, t, y1});

// add weighted: result = lerp(y0, y1, t) = (1 - t) * y0 + t * y1
auto result = cudf::compute_column(table, tree.push(operation{op::ADD, y0_w, y1_w}));

auto expected = column_wrapper<float>{300, 300, 300, 300};

CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view());
}
5 changes: 3 additions & 2 deletions cpp/tests/ast/transform_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,10 @@ TEST_F(TransformTest, UnaryTrigonometry)
TEST_F(TransformTest, ArityCheckFailure)
{
auto col_ref_0 = cudf::ast::column_reference(0);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0),
std::invalid_argument);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0),
cudf::logic_error);
std::invalid_argument);
}

TEST_F(TransformTest, StringComparison)
Expand Down

0 comments on commit 4cbc15a

Please sign in to comment.