Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Aliasing issue in OpenCL #2943

Merged
merged 11 commits into from
Sep 28, 2023
2 changes: 1 addition & 1 deletion stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand Down
36 changes: 25 additions & 11 deletions stan/math/opencl/kernel_generator/as_operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand All @@ -19,11 +20,12 @@ namespace math {
/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for operations - a no-op
* @tparam AssignOp ignored
* @tparam T_operation type of the input operation
* @param a an operation
* @return operation
*/
template <typename T_operation,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
typename = std::enable_if_t<std::is_base_of<
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
inline T_operation&& as_operation_cl(T_operation&& a) {
Expand All @@ -33,11 +35,13 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for scalars (arithmetic types). It wraps them into \c scalar_.
* @tparam AssignOp ignored
* @tparam T_scalar type of the input scalar
* @param a scalar
* @return \c scalar_ wrapping the input
*/
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar,
typename = require_arithmetic_t<T_scalar>,
require_not_same_t<T_scalar, bool>* = nullptr>
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
return scalar_<T_scalar>(a);
Expand All @@ -47,23 +51,29 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
* Converts any valid kernel generator expression into an operation. This is an
* overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
* not be used as a type of a kernel argument.
* @tparam AssignOp ignored
* @param a scalar
* @return \c scalar_<char> wrapping the input
*/
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
template <assign_op_cl AssignOp = assign_op_cl::equals>
inline scalar_<char> as_operation_cl(const bool a) {
return scalar_<char>(a);
}

/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for \c matrix_cl. It wraps them into into \c load_.
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
* is assigned using standard or compound assign.
* @tparam T_matrix_cl \c matrix_cl
* @param a \c matrix_cl
* @return \c load_ wrapping the input
*/
template <typename T_matrix_cl,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
is_arena_matrix_cl<T_matrix_cl>>>
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
}

/**
Expand All @@ -73,12 +83,16 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
* as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
* rvalue reference, the reference is removed, so that a variable of this type
* actually stores the value.
* @tparam T a `matrix_cl` or `Scalar` type
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
* is assigned using standard or compound assign.
*/
template <typename T>
using as_operation_cl_t = std::conditional_t<
std::is_lvalue_reference<T>::value,
decltype(as_operation_cl(std::declval<T>())),
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
using as_operation_cl_t
= std::conditional_t<std::is_lvalue_reference<T>::value,
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
std::remove_reference_t<decltype(
as_operation_cl<AssignOp>(std::declval<T>()))>>;

/** @}*/
} // namespace math
Expand Down
74 changes: 74 additions & 0 deletions stan/math/opencl/kernel_generator/assignment_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#ifdef STAN_OPENCL
#include <stan/math/prim/meta/is_detected.hpp>

namespace stan {
namespace math {

/**
* Ops that decide the type of assignment for LHS operations
*/
enum class assign_op_cl {
equals,
plus_equals,
minus_equals,
divide_equals,
multiply_equals
};

namespace internal {
/**
* @param value A static constexpr const char* member for printing assignment
* ops
*/
template <assign_op_cl assign_op>
struct assignment_op_str_impl;

template <>
struct assignment_op_str_impl<assign_op_cl::equals> {
static constexpr const char* value = " = ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::plus_equals> {
static constexpr const char* value = " += ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::minus_equals> {
static constexpr const char* value = " -= ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::divide_equals> {
static constexpr const char* value = " /= ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::multiply_equals> {
static constexpr const char* value = " *= ";
};

template <typename, typename = void>
struct assignment_op_str : assignment_op_str_impl<assign_op_cl::equals> {};

template <typename T>
struct assignment_op_str<T, void_t<decltype(T::assignment_op)>>
: assignment_op_str_impl<T::assignment_op> {};

} // namespace internal

/**
* @tparam T A type that has an `assignment_op` static constexpr member type
* @return The types assignment op as a constexpr const char*
*/
template <typename T>
inline constexpr const char* assignment_op() noexcept {
return internal::assignment_op_str<std::decay_t<T>>::value;
}

} // namespace math
} // namespace stan
#endif
#endif
22 changes: 16 additions & 6 deletions stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>

#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
Expand All @@ -23,17 +25,20 @@ namespace math {
/**
* Represents an access to a \c matrix_cl in kernel generator expressions
* @tparam T \c matrix_cl
* @tparam AssignOp tells higher level operations whether the final operation
* should be an assignment or a type of compound assignment.
*/
template <typename T>
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
class load_
: public operation_cl_lhs<load_<T>,
: public operation_cl_lhs<load_<T, AssignOp>,
typename std::remove_reference_t<T>::type> {
protected:
T a_;

public:
static constexpr assign_op_cl assignment_op = AssignOp;
using Scalar = typename std::remove_reference_t<T>::type;
using base = operation_cl<load_<T>, Scalar>;
using base = operation_cl<load_<T, AssignOp>, Scalar>;
using base::var_name_;
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
"load_: argument a must be a matrix_cl<T>!");
Expand All @@ -51,9 +56,13 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
inline load_<const T&, AssignOp> deep_copy() const& {
return load_<const T&, AssignOp>(a_);
}
inline load_<T, AssignOp> deep_copy() && {
return load_<T, AssignOp>(std::forward<T>(a_));
}

/**
* Generates kernel code for this expression.
Expand Down Expand Up @@ -327,6 +336,7 @@ class load_
}
}
};

/** @}*/
} // namespace math
} // namespace stan
Expand Down
Loading