Skip to content

Commit

Permalink
Merge branch 'develop' into compound-funs
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Sep 30, 2023
2 parents 8cd9390 + 598dba7 commit cc8dc55
Show file tree
Hide file tree
Showing 18 changed files with 451 additions and 49 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ lib/tbb

# local make include
/make/local
/make/ucrt

# python byte code
*.pyc
Expand Down
15 changes: 15 additions & 0 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,21 @@ ifeq ($(OS),Windows_NT)
CXXFLAGS_OS ?= -m64
endif

make/ucrt:
pound := \#
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | findstr _UCRT)
ifneq (,$(UCRT_STRING))
IS_UCRT ?= true
else
IS_UCRT ?= false
endif
$(shell echo "IS_UCRT ?= $(IS_UCRT)" > $(MATH)make/ucrt)

include make/ucrt
ifeq ($(IS_UCRT),true)
CXXFLAGS_OS += -D_UCRT
endif

ifneq (gcc,$(CXX_TYPE))
LDLIBS_OS ?= -static-libgcc
else
Expand Down
6 changes: 6 additions & 0 deletions make/libraries
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ ifeq (Linux, $(OS))
SHELL = /usr/bin/env bash
endif

ifeq (Windows_NT, $(OS))
ifeq ($(IS_UCRT),true)
TBB_CXXFLAGS += -D_UCRT
endif
endif

# If brackets or spaces are found in MAKE on Windows
# we error, as those characters cause issues when building.
ifeq (Windows_NT, $(OS))
Expand Down
5 changes: 4 additions & 1 deletion make/tests
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,15 @@ HEADER_TESTS := $(addsuffix -test,$(call findfiles,stan,*.hpp))

ifeq ($(OS),Windows_NT)
DEV_NULL = nul
ifeq ($(IS_UCRT),true)
UCRT_NULL_FLAG = -S
endif
else
DEV_NULL = /dev/null
endif

%.hpp-test : %.hpp test/dummy.cpp
$(COMPILE.cpp) $(CXXFLAGS) -O0 -include $^ -o $(DEV_NULL) -Wunused-local-typedefs
$(COMPILE.cpp) $(CXXFLAGS) -O0 -include $^ $(UCRT_NULL_FLAG) -o $(DEV_NULL) -Wunused-local-typedefs

test/dummy.cpp:
@mkdir -p test
Expand Down
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ clean-deps:
@$(RM) $(call findfiles,test,*.d.*)
@$(RM) $(call findfiles,lib,*.d.*)
@$(RM) $(call findfiles,stan,*.dSYM)
@$(RM) $(call findfiles,make,ucrt)

clean-all: clean clean-doxygen clean-deps clean-libraries

Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand Down
36 changes: 25 additions & 11 deletions stan/math/opencl/kernel_generator/as_operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand All @@ -19,11 +20,12 @@ namespace math {
/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for operations - a no-op
* @tparam AssignOp ignored
* @tparam T_operation type of the input operation
* @param a an operation
* @return operation
*/
template <typename T_operation,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_operation,
typename = std::enable_if_t<std::is_base_of<
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
inline T_operation&& as_operation_cl(T_operation&& a) {
Expand All @@ -33,11 +35,13 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for scalars (arithmetic types). It wraps them into \c scalar_.
* @tparam AssignOp ignored
* @tparam T_scalar type of the input scalar
* @param a scalar
* @return \c scalar_ wrapping the input
*/
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_scalar,
typename = require_arithmetic_t<T_scalar>,
require_not_same_t<T_scalar, bool>* = nullptr>
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
return scalar_<T_scalar>(a);
Expand All @@ -47,23 +51,29 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
* Converts any valid kernel generator expression into an operation. This is an
* overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
* not be used as a type of a kernel argument.
* @tparam AssignOp ignored
* @param a scalar
* @return \c scalar_<char> wrapping the input
*/
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
template <assign_op_cl AssignOp = assign_op_cl::equals>
inline scalar_<char> as_operation_cl(const bool a) {
return scalar_<char>(a);
}

/**
* Converts any valid kernel generator expression into an operation. This is an
* overload for \c matrix_cl. It wraps them into into \c load_.
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
* is assigned using standard or compound assign.
* @tparam T_matrix_cl \c matrix_cl
* @param a \c matrix_cl
* @return \c load_ wrapping the input
*/
template <typename T_matrix_cl,
template <assign_op_cl AssignOp = assign_op_cl::equals, typename T_matrix_cl,
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
is_arena_matrix_cl<T_matrix_cl>>>
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
}

/**
Expand All @@ -73,12 +83,16 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
* as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
* rvalue reference, the reference is removed, so that a variable of this type
* actually stores the value.
* @tparam T a `matrix_cl` or `Scalar` type
* @tparam AssignOp an optional `assign_op_cl` that dictates whether the object
* is assigned using standard or compound assign.
*/
template <typename T>
using as_operation_cl_t = std::conditional_t<
std::is_lvalue_reference<T>::value,
decltype(as_operation_cl(std::declval<T>())),
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
using as_operation_cl_t
= std::conditional_t<std::is_lvalue_reference<T>::value,
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
std::remove_reference_t<decltype(
as_operation_cl<AssignOp>(std::declval<T>()))>>;

/** @}*/
} // namespace math
Expand Down
74 changes: 74 additions & 0 deletions stan/math/opencl/kernel_generator/assignment_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#ifdef STAN_OPENCL
#include <stan/math/prim/meta/is_detected.hpp>

namespace stan {
namespace math {

/**
* Ops that decide the type of assignment for LHS operations
*/
enum class assign_op_cl {
equals,
plus_equals,
minus_equals,
divide_equals,
multiply_equals
};

namespace internal {
/**
* @param value A static constexpr const char* member for printing assignment
* ops
*/
template <assign_op_cl assign_op>
struct assignment_op_str_impl;

template <>
struct assignment_op_str_impl<assign_op_cl::equals> {
static constexpr const char* value = " = ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::plus_equals> {
static constexpr const char* value = " += ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::minus_equals> {
static constexpr const char* value = " -= ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::divide_equals> {
static constexpr const char* value = " /= ";
};

template <>
struct assignment_op_str_impl<assign_op_cl::multiply_equals> {
static constexpr const char* value = " *= ";
};

template <typename, typename = void>
struct assignment_op_str : assignment_op_str_impl<assign_op_cl::equals> {};

template <typename T>
struct assignment_op_str<T, void_t<decltype(T::assignment_op)>>
: assignment_op_str_impl<T::assignment_op> {};

} // namespace internal

/**
* @tparam T A type that has an `assignment_op` static constexpr member type
* @return The types assignment op as a constexpr const char*
*/
template <typename T>
inline constexpr const char* assignment_op() noexcept {
return internal::assignment_op_str<std::decay_t<T>>::value;
}

} // namespace math
} // namespace stan
#endif
#endif
22 changes: 16 additions & 6 deletions stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>

#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
Expand All @@ -23,17 +25,20 @@ namespace math {
/**
* Represents an access to a \c matrix_cl in kernel generator expressions
* @tparam T \c matrix_cl
* @tparam AssignOp tells higher level operations whether the final operation
* should be an assignment or a type of compound assignment.
*/
template <typename T>
template <typename T, assign_op_cl AssignOp = assign_op_cl::equals>
class load_
: public operation_cl_lhs<load_<T>,
: public operation_cl_lhs<load_<T, AssignOp>,
typename std::remove_reference_t<T>::type> {
protected:
T a_;

public:
static constexpr assign_op_cl assignment_op = AssignOp;
using Scalar = typename std::remove_reference_t<T>::type;
using base = operation_cl<load_<T>, Scalar>;
using base = operation_cl<load_<T, AssignOp>, Scalar>;
using base::var_name_;
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
"load_: argument a must be a matrix_cl<T>!");
Expand All @@ -51,9 +56,13 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
inline load_<const T&, AssignOp> deep_copy() const& {
return load_<const T&, AssignOp>(a_);
}
inline load_<T, AssignOp> deep_copy() && {
return load_<T, AssignOp>(std::forward<T>(a_));
}

/**
* Generates kernel code for this expression.
Expand Down Expand Up @@ -327,6 +336,7 @@ class load_
}
}
};

/** @}*/
} // namespace math
} // namespace stan
Expand Down
Loading

0 comments on commit cc8dc55

Please sign in to comment.