From c532206190991cdb0437755c2411764832289ddc Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 9 Oct 2024 22:33:51 +0300 Subject: [PATCH] Improve support for operations between clad::array and clad::array_ref --- include/clad/Differentiator/ArrayRef.h | 64 ++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index 5fe173ee4..42ee89c00 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -52,6 +52,14 @@ template class array_ref { m_size = a.size(); return *this; } + template + CUDA_HOST_DEVICE array_ref& + operator=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] = arr_exp[i]; + return *this; + } /// Returns the size of the underlying array constexpr CUDA_HOST_DEVICE std::size_t size() const { return m_size; } constexpr CUDA_HOST_DEVICE PUREFUNC T* ptr() const { return m_arr; } @@ -71,7 +79,7 @@ template class array_ref { // Arithmetic overloads /// Divides the arrays element wise template - CUDA_HOST_DEVICE array_ref& operator/=(array_ref& Ar) { + CUDA_HOST_DEVICE array_ref& operator/=(const array_ref& Ar) { assert(m_size == Ar.size() && "Size of both the array_refs must be equal " "for carrying out addition assignment"); for (std::size_t i = 0; i < m_size; i++) @@ -80,7 +88,7 @@ template class array_ref { } /// Multiplies the arrays element wise template - CUDA_HOST_DEVICE array_ref& operator*=(array_ref& Ar) { + CUDA_HOST_DEVICE array_ref& operator*=(const array_ref& Ar) { assert(m_size == Ar.size() && "Size of both the array_refs must be equal " "for carrying out addition assignment"); for (std::size_t i = 0; i < m_size; i++) @@ -89,7 +97,7 @@ template class array_ref { } /// Adds the arrays element wise template - CUDA_HOST_DEVICE array_ref& operator+=(array_ref& Ar) { + CUDA_HOST_DEVICE array_ref& operator+=(const array_ref& Ar) { assert(m_size == Ar.size() && "Size of both the array_refs must be equal " "for carrying out addition assignment"); for (std::size_t i = 0; i < m_size; i++) @@ -98,7 +106,7 @@ template class array_ref { } /// Subtracts the arrays element wise template - CUDA_HOST_DEVICE array_ref& operator-=(array_ref& Ar) { + CUDA_HOST_DEVICE array_ref& operator-=(const array_ref& Ar) { assert(m_size == Ar.size() && "Size of both the array_refs must be equal " "for carrying out addition assignment"); for (std::size_t i = 0; i < m_size; i++) @@ -106,28 +114,68 @@ template class array_ref { return *this; } /// Divides the elements of the array_ref by elements of the array - template CUDA_HOST_DEVICE array_ref& operator/=(array& A) { + template + CUDA_HOST_DEVICE array_ref& operator/=(const array& A) { assert(m_size == A.size() && "Size of arrays must be equal"); for (std::size_t i = 0; i < m_size; i++) m_arr[i] /= A[i]; return *this; } /// Multiplies the elements of the array_ref by elements of the array - template CUDA_HOST_DEVICE array_ref& operator*=(array& A) { + template + CUDA_HOST_DEVICE array_ref& + operator*=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] *= arr_exp[i]; + return *this; + } + /// Adds the elements of the array_ref by elements of the array + template + CUDA_HOST_DEVICE array_ref& + operator+=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] += arr_exp[i]; + return *this; + } + /// Subtracts the elements of the array_ref by elements of the array + template + CUDA_HOST_DEVICE array_ref& + operator-=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] -= arr_exp[i]; + return *this; + } + /// Divides the elements of the array_ref by elements of the array + template + CUDA_HOST_DEVICE array_ref& + operator/=(const array_expression& arr_exp) { + assert(arr_exp.size() == m_size); + for (std::size_t i = 0; i < m_size; i++) + m_arr[i] /= arr_exp[i]; + return *this; + } + /// Multiplies the elements of the array_ref by elements of the array + template + CUDA_HOST_DEVICE array_ref& operator*=(const array& A) { assert(m_size == A.size() && "Size of arrays must be equal"); for (std::size_t i = 0; i < m_size; i++) m_arr[i] *= A[i]; return *this; } /// Adds the elements of the array_ref by elements of the array - template CUDA_HOST_DEVICE array_ref& operator+=(array& A) { + template + CUDA_HOST_DEVICE array_ref& operator+=(const array& A) { assert(m_size == A.size() && "Size of arrays must be equal"); for (std::size_t i = 0; i < m_size; i++) m_arr[i] += A[i]; return *this; } /// Subtracts the elements of the array_ref by elements of the array - template CUDA_HOST_DEVICE array_ref& operator-=(array& A) { + template + CUDA_HOST_DEVICE array_ref& operator-=(const array& A) { assert(m_size == A.size() && "Size of arrays must be equal"); for (std::size_t i = 0; i < m_size; i++) m_arr[i] -= A[i];