Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for expression templates in array and array_ref class #628

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 96 additions & 54 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CLAD_ARRAY_H
#define CLAD_ARRAY_H

#include "clad/Differentiator/ArrayExpression.h"
#include "clad/Differentiator/CladConfig.h"

#include <assert.h>
Expand Down Expand Up @@ -36,23 +37,31 @@ template <typename T> class array {

template <typename U>
CUDA_HOST_DEVICE array(U* a, std::size_t size)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = static_cast<T>(a[i]);
}

CUDA_HOST_DEVICE array(const array<T>& arr) : array(arr.m_arr, arr.m_size) {}

CUDA_HOST_DEVICE array(std::size_t size, const clad::array<T>& arr)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = arr[i];
}

template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array(std::size_t size,
const array_expression<L, BinaryOp, R>& expression)
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = expression[i];
}

// initializing all entries using the same value
template <typename U>
CUDA_HOST_DEVICE array(std::size_t size, U val)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = static_cast<T>(val);
}
Expand Down Expand Up @@ -229,6 +238,15 @@ template <typename T> class array {
m_arr[i] *= static_cast<T>(arr[i]);
return *this;
}
/// Initializes the clad::array from the given clad::array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] = arr_exp[i];
return *this;
}
/// Performs element wise division
template <typename U>
CUDA_HOST_DEVICE array<T>& operator/=(const array<U>& arr) {
Expand All @@ -237,25 +255,55 @@ template <typename T> class array {
m_arr[i] /= static_cast<T>(arr[i]);
return *this;
}
/// Performs element wise addition with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator+=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] += arr_exp[i];
return *this;
}
/// Performs element wise subtraction with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator-=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] -= arr_exp[i];
return *this;
}
/// Performs element wise multiplication with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator*=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] *= arr_exp[i];
return *this;
}
/// Performs element wise division with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator/=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] /= arr_exp[i];
return *this;
}

/// Negate the array and return a new array.
CUDA_HOST_DEVICE array<T> operator-() const {
array<T> arr2(m_size);
for (std::size_t i = 0; i < m_size; i++)
arr2[i] = -m_arr[i];
return arr2;
CUDA_HOST_DEVICE array_expression<T, BinarySub, array<T>> operator-() const {
return array_expression<T, BinarySub, array<T>>(static_cast<T>(0), *this);
}

/// Subtracts the number from every element in the array and returns a new
/// array, when the number is on the left side.
template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
int>::type = 0>
CUDA_HOST_DEVICE friend array<T> operator-(U n, const array<T>& arr) {
size_t size = arr.size();
array<T> arr2(size);
for (std::size_t i = 0; i < size; i++)
arr2[i] = n - arr[i];
return arr2;
CUDA_HOST_DEVICE friend array_expression<U, BinarySub, array<T>>
operator-(U n, const array<T>& arr) {
return array_expression<U, BinarySub, array<T>>(n, arr);
}

/// Implicitly converts from clad::array to pointer to an array of type T
Expand All @@ -281,79 +329,73 @@ template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns a new
/// array.
/// Multiplies the number to every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 *= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryMul, U>
operator*(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryMul, U>(arr, n);
}

/// Multiplies the number to every element in the array and returns a new
/// array, when the number is on the left side.
/// Multiplies the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(U n, const array<T>& arr) {
return arr * n;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryMul, U>
operator*(U n, const array<T>& arr) {
return array_expression<array<T>, BinaryMul, U>(arr, n);
}

/// Divides the number from every element in the array and returns a new
/// array
/// Divides the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator/(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 /= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryDiv, U>
operator/(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryDiv, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 += n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, U>
operator+(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryAdd, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array,
/// when the number is on the left side.
/// Adds the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(U n, const array<T>& arr) {
return arr + n;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, U>
operator+(U n, const array<T>& arr) {
return array_expression<array<T>, BinaryAdd, U>(arr, n);
}

/// Subtracts the number from every element in the array and returns a new
/// array
/// Subtracts the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator-(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 -= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinarySub, U>
operator-(const array<T>& arr, U n) {
return array_expression<array<T>, BinarySub, U>(arr, n);
}

/// Function to define element wise adding of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator+(const array<T>& arr1,
const array<U>& arr2) {
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, array<U>>
operator+(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
array<T> arr(arr1);
arr += arr2;
return arr;
return array_expression<array<T>, BinaryAdd, array<U>>(arr1, arr2);
}

/// Function to define element wise subtraction of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator-(const array<T>& arr1,
const array<U>& arr2) {
CUDA_HOST_DEVICE array_expression<array<T>, BinarySub, array<U>>
operator-(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
array<T> arr(arr1);
arr -= arr2;
return arr;
return array_expression<array<T>, BinarySub, array<U>>(arr1, arr2);
}

} // namespace clad
Expand Down
150 changes: 150 additions & 0 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#ifndef CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H
#define CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H

#include <algorithm>
#include <type_traits>

// This is a helper class to implement expression templates for clad::array.

// NOLINTBEGIN(*-pointer-arithmetic)
namespace clad {

// Operator to add two elements.
struct BinaryAdd {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t + u) {
return t + u;
}
};

// Operator to add two elements.
struct BinaryMul {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t * u) {
return t * u;
}
};

// Operator to divide two elements.
struct BinaryDiv {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t / u) {
return t / u;
}
};

// Operator to subtract two elements.
struct BinarySub {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t - u) {
return t - u;
}
};

// Class to represent an array expression using templates.
template <typename LeftExp, typename BinaryOp, typename RightExp>
class array_expression {
LeftExp l;
RightExp r;

public:
array_expression(LeftExp const& l, RightExp const& r) : l(l), r(r) {}

// for scalars
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
int>::type = 0>
std::size_t get_size(T const& t) const {
return 1;
}
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
int>::type = 0>
T get(T const& t, std::size_t i) const {
return t;
}

// for vectors
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value,
int>::type = 0>
std::size_t get_size(T const& t) const {
return t.size();
}
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value,
int>::type = 0>
auto get(T const& t, std::size_t i) const -> decltype(t[i]) {
return t[i];
}

// We also need to handle the case when any of the operands is a scalar.
auto operator[](std::size_t i) const
-> decltype(BinaryOp::apply(get(l, i), get(r, i))) {
return BinaryOp::apply(get(l, i), get(r, i));
}

std::size_t size() const { return std::max(get_size(l), get_size(r)); }

// Operator overload for addition.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryAdd, RE>
operator+(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryAdd, RE>(*this, r);
}

// Operator overload for multiplication.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryMul, RE>
operator*(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryMul, RE>(*this, r);
}

// Operator overload for subtraction.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinarySub, RE>
operator-(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinarySub, RE>(*this, r);
}

// Operator overload for division.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryDiv, RE>
operator/(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryDiv, RE>(*this, r);
}
};

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryAdd, array_expression<LeftExp, BinaryOp, RightExp>>
operator+(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinaryAdd,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}

// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryMul, array_expression<LeftExp, BinaryOp, RightExp>>
operator*(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinaryMul,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}

// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinarySub, array_expression<LeftExp, BinaryOp, RightExp>>
operator-(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinarySub,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}
} // namespace clad
// NOLINTEND(*-pointer-arithmetic)

#endif // CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H
Loading
Loading