diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index 7bd7ab5e1..dad1a9014 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -1,6 +1,7 @@ #ifndef CLAD_ARRAY_H #define CLAD_ARRAY_H +#include "clad/Differentiator/ArrayExpression.h" #include "clad/Differentiator/CladConfig.h" #include @@ -49,6 +50,14 @@ template class array { m_arr[i] = arr[i]; } + template + CUDA_HOST_DEVICE array(std::size_t size, + const array_expression& expression) + : m_arr(new T[size]{static_cast(T())}), m_size(size) { + for (std::size_t i = 0; i < size; ++i) + m_arr[i] = expression[i]; + } + // initializing all entries using the same value template CUDA_HOST_DEVICE array(std::size_t size, U val) @@ -237,25 +246,55 @@ template class array { m_arr[i] /= static_cast(arr[i]); return *this; } + /// Performs element wise addition with array_expression + template + CUDA_HOST_DEVICE array& + operator+=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] += static_cast(arr_exp[i]); + return *this; + } + /// Performs element wise subtraction with array_expression + template + CUDA_HOST_DEVICE array& + operator-=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] -= static_cast(arr_exp[i]); + return *this; + } + /// Performs element wise multiplication with array_expression + template + CUDA_HOST_DEVICE array& + operator*=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] *= static_cast(arr_exp[i]); + return *this; + } + /// Performs element wise division with array_expression + template + CUDA_HOST_DEVICE array& + operator/=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] /= static_cast(arr_exp[i]); + return *this; + } /// Negate the array and return a new array. - CUDA_HOST_DEVICE array operator-() const { - array arr2(m_size); - for (std::size_t i = 0; i < m_size; i++) - arr2[i] = -m_arr[i]; - return arr2; + CUDA_HOST_DEVICE array_expression> operator-() const { + return array_expression>(static_cast(0), *this); } /// Subtracts the number from every element in the array and returns a new /// array, when the number is on the left side. template ::value, int>::type = 0> - CUDA_HOST_DEVICE friend array operator-(U n, const array& arr) { - size_t size = arr.size(); - array arr2(size); - for (std::size_t i = 0; i < size; i++) - arr2[i] = n - arr[i]; - return arr2; + CUDA_HOST_DEVICE friend array_expression> + operator-(U n, const array& arr) { + return array_expression>(n, arr); } /// Implicitly converts from clad::array to pointer to an array of type T @@ -281,79 +320,73 @@ template CUDA_HOST_DEVICE array zero_vector(std::size_t n) { /// Overloaded operators for clad::array which return a new array. -/// Multiplies the number to every element in the array and returns a new -/// array. +/// Multiplies the number to every element in the array and returns an array +/// expression. template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator*(const array& arr, U n) { - array arr2(arr); - arr2 *= n; - return arr2; +CUDA_HOST_DEVICE array_expression, BinaryMul, U> +operator*(const array& arr, U n) { + return array_expression, BinaryMul, U>(arr, n); } -/// Multiplies the number to every element in the array and returns a new -/// array, when the number is on the left side. +/// Multiplies the number to every element in the array and returns an array +/// expression, when the number is on the left side. template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator*(U n, const array& arr) { - return arr * n; +CUDA_HOST_DEVICE array_expression, BinaryMul, U> +operator*(U n, const array& arr) { + return array_expression, BinaryMul, U>(arr, n); } -/// Divides the number from every element in the array and returns a new -/// array +/// Divides the number from every element in the array and returns an array +/// expression. template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator/(const array& arr, U n) { - array arr2(arr); - arr2 /= n; - return arr2; +CUDA_HOST_DEVICE array_expression, BinaryDiv, U> +operator/(const array& arr, U n) { + return array_expression, BinaryDiv, U>(arr, n); } /// Adds the number to every element in the array and returns a new array template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator+(const array& arr, U n) { - array arr2(arr); - arr2 += n; - return arr2; +CUDA_HOST_DEVICE array_expression, BinaryAdd, U> +operator+(const array& arr, U n) { + return array_expression, BinaryAdd, U>(arr, n); } -/// Adds the number to every element in the array and returns a new array, -/// when the number is on the left side. +/// Adds the number to every element in the array and returns an array +/// expression, when the number is on the left side. template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator+(U n, const array& arr) { - return arr + n; +CUDA_HOST_DEVICE array_expression, BinaryAdd, U> +operator+(U n, const array& arr) { + return array_expression, BinaryAdd, U>(arr, n); } -/// Subtracts the number from every element in the array and returns a new -/// array +/// Subtracts the number from every element in the array and returns an array +/// expression. template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator-(const array& arr, U n) { - array arr2(arr); - arr2 -= n; - return arr2; +CUDA_HOST_DEVICE array_expression, BinarySub, U> +operator-(const array& arr, U n) { + return array_expression, BinarySub, U>(arr, n); } /// Function to define element wise adding of two arrays. template -CUDA_HOST_DEVICE array operator+(const array& arr1, - const array& arr2) { +CUDA_HOST_DEVICE array_expression, BinaryAdd, array> +operator+(const array& arr1, const array& arr2) { assert(arr1.size() == arr2.size()); - array arr(arr1); - arr += arr2; - return arr; + return array_expression, BinaryAdd, array>(arr1, arr2); } /// Function to define element wise subtraction of two arrays. template -CUDA_HOST_DEVICE array operator-(const array& arr1, - const array& arr2) { +CUDA_HOST_DEVICE array_expression, BinarySub, array> +operator-(const array& arr1, const array& arr2) { assert(arr1.size() == arr2.size()); - array arr(arr1); - arr -= arr2; - return arr; + return array_expression, BinarySub, array>(arr1, arr2); } } // namespace clad diff --git a/include/clad/Differentiator/ArrayExpression.h b/include/clad/Differentiator/ArrayExpression.h new file mode 100644 index 000000000..8ea75e5dc --- /dev/null +++ b/include/clad/Differentiator/ArrayExpression.h @@ -0,0 +1,125 @@ +#ifndef ARRAY_EXPRESSION_H +#define ARRAY_EXPRESSION_H + +#include +#include + +// This is a helper class to implement expression templates for clad::array. + +// NOLINTBEGIN(*-pointer-arithmetic) +namespace clad { + +// Operator to add two elements. +struct BinaryAdd { + template + static auto apply(T const& t, U const& u) -> decltype(t + u) { + return t + u; + } +}; + +// Operator to add two elements. +struct BinaryMul { + template + static auto apply(T const& t, U const& u) -> decltype(t * u) { + return t * u; + } +}; + +// Operator to divide two elements. +struct BinaryDiv { + template + static auto apply(T const& t, U const& u) -> decltype(t / u) { + return t / u; + } +}; + +// Operator to subtract two elements. +struct BinarySub { + template + static auto apply(T const& t, U const& u) -> decltype(t - u) { + return t - u; + } +}; + +// Class to represent an array expression using templates. +template +class array_expression { + LeftExp l; + RightExp r; + +public: + array_expression(LeftExp const& _l, RightExp const& _r) : l(_l), r(_r) {} + + // for scalars + template ::value, + int>::type = 0> + std::size_t get_size(T const& t) const { + return 1; + } + template ::value, + int>::type = 0> + T get(T const& t, std::size_t i) const { + return t; + } + + // for vectors + template ::value, + int>::type = 0> + std::size_t get_size(T const& t) const { + return t.size(); + } + template ::value, + int>::type = 0> + auto get(T const& t, std::size_t i) const -> decltype(t[i]) { + return t[i]; + } + + // We also need to handle the case when any of the operands is a scalar. + auto operator[](std::size_t i) const + -> decltype(BinaryOp::apply(get(l, i), get(r, i))) { + return BinaryOp::apply(get(l, i), get(r, i)); + } + + std::size_t size() const { return std::max(get_size(l), get_size(r)); } + + // Operator overload for addition. + template + array_expression, BinaryAdd, RE> + operator+(RE const& r) const { + return array_expression, + BinaryAdd, RE>(*this, r); + } + + // Operator overload for multiplication. + template + array_expression, BinaryMul, RE> + operator*(RE const& r) const { + return array_expression, + BinaryMul, RE>(*this, r); + } +}; + +// Operator overload for addition, when the right operand is an array_expression +// and the left operand is a scalar. +template ::value, int>::type = 0> +array_expression> +operator+(T const& l, array_expression const& r) { + return array_expression>(l, r); +} + +// Operator overload for multiplication, when the right operand is an +// array_expression and the left operand is a scalar. +template ::value, int>::type = 0> +array_expression> +operator*(T const& l, array_expression const& r) { + return array_expression>(l, r); +} + +} // namespace clad +// NOLINTEND(*-pointer-arithmetic) + +#endif // ARRAY_EXPRESSION_H \ No newline at end of file diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index 8a4aa4988..4d99391a9 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -158,114 +158,104 @@ template class array_ref { } }; -/// Overloaded operators for clad::array_ref which returns a new clad::array -/// object. +/// Overloaded operators for clad::array_ref which returns an array +/// expression. /// Multiplies the arrays element wise template -CUDA_HOST_DEVICE array operator*(const array_ref& Ar, - const array_ref& Br) { +CUDA_HOST_DEVICE array_expression, BinaryMul, array_ref> +operator*(const array_ref& Ar, const array_ref& Br) { assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out addition " - "assignment"); - array C(Ar); - C *= Br; - return C; + "Size of both the array_refs must be equal for carrying out " + "multiplication assignment"); + return array_expression, BinaryMul, array_ref>(Ar, Br); } /// Adds the arrays element wise template -CUDA_HOST_DEVICE array operator+(const array_ref& Ar, - const array_ref& Br) { +CUDA_HOST_DEVICE array_expression, BinaryAdd, array_ref> +operator+(const array_ref& Ar, const array_ref& Br) { assert(Ar.size() == Br.size() && "Size of both the array_refs must be equal for carrying out addition " "assignment"); - array C(Ar); - C += Br; - return C; + return array_expression, BinaryAdd, array_ref>(Ar, Br); } /// Subtracts the arrays element wise template -CUDA_HOST_DEVICE array operator-(const array_ref& Ar, - const array_ref& Br) { - assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out addition " - "assignment"); - array C(Ar); - C -= Br; - return C; +CUDA_HOST_DEVICE array_expression, BinarySub, array_ref> +operator-(const array_ref& Ar, const array_ref& Br) { + assert( + Ar.size() == Br.size() && + "Size of both the array_refs must be equal for carrying out subtraction " + "assignment"); + return array_expression, BinarySub, array_ref>(Ar, Br); } /// Divides the arrays element wise template -CUDA_HOST_DEVICE array operator/(const array_ref& Ar, - const array_ref& Br) { +CUDA_HOST_DEVICE array_expression, BinaryDiv, array_ref> +operator/(const array_ref& Ar, const array_ref& Br) { assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out addition " + "Size of both the array_refs must be equal for carrying out division " "assignment"); - array C(Ar); - C /= Br; - return C; + return array_expression, BinaryDiv, array_ref>(Ar, Br); } /// Multiplies array_ref by a scalar template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator*(const array_ref& Ar, U a) { - array C(Ar); - C *= a; - return C; +CUDA_HOST_DEVICE array_expression, BinaryMul, U> +operator*(const array_ref& Ar, U a) { + return array_expression, BinaryMul, U>(Ar, a); } /// Multiplies array_ref by a scalar (reverse order) template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator*(U a, const array_ref& Ar) { - return Ar * a; +CUDA_HOST_DEVICE array_expression, BinaryMul, U> +operator*(U a, const array_ref& Ar) { + return array_expression, BinaryMul, U>(Ar, a); } /// Divides array_ref by a scalar template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator/(const array_ref& Ar, U a) { - array C(Ar); - C /= a; - return C; +CUDA_HOST_DEVICE array_expression, BinaryDiv, U> +operator/(const array_ref& Ar, U a) { + return array_expression, BinaryDiv, U>(Ar, a); } /// Adds array_ref by a scalar template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator+(const array_ref& Ar, U a) { - array C(Ar); - C += a; - return C; +CUDA_HOST_DEVICE array_expression, BinaryAdd, U> +operator+(const array_ref& Ar, U a) { + return array_expression, BinaryAdd, U>(Ar, a); } /// Adds array_ref by a scalar (reverse order) template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator+(U a, const array_ref& Ar) { - return Ar + a; +CUDA_HOST_DEVICE array_expression, BinaryAdd, U> +operator+(U a, const array_ref& Ar) { + return array_expression, BinaryAdd, U>(Ar, a); } /// Subtracts array_ref by a scalar template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator-(const array_ref& Ar, U a) { - array C(Ar); - C -= a; - return C; +CUDA_HOST_DEVICE array_expression, BinarySub, U> +operator-(const array_ref& Ar, U a) { + return array_expression, BinarySub, U>(Ar, a); } /// Subtracts array_ref by a scalar (reverse order) template ::value, int>::type = 0> -CUDA_HOST_DEVICE array operator-(U a, const array_ref& Ar) { - array C(Ar.size(), a); - C -= Ar; - return C; +CUDA_HOST_DEVICE array_expression> +operator-(U a, const array_ref& Ar) { + return array_expression>(a, Ar); } /// `array_ref` specialisation is created to be used as a placeholder