Skip to content

Commit

Permalink
Remove excessive array_expression operators by generalizing the templ…
Browse files Browse the repository at this point in the history
…ates
  • Loading branch information
PetroZarytskyi committed Nov 13, 2024
1 parent cfa04be commit d7c3174
Showing 1 changed file with 9 additions and 41 deletions.
50 changes: 9 additions & 41 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,51 +121,19 @@ class array_expression {
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryDiv, RE>(
*this, r);
}
// Operator overload for addition.
template <typename L1, typename BinOp1, typename R1, typename L2,
typename BinOp2, typename R2>
array_expression<const array_expression<L1, BinOp1, R1>&, BinaryAdd,
const array_expression<L2, BinOp2, R2>&>
operator+(const array_expression<L2, BinOp2, R2>& r) const {
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryAdd,
const array_expression<L2, BinOp2, R2>&>(*this, r);
}

// Operator overload for multiplication.
template <typename L1, typename BinOp1, typename R1, typename L2,
typename BinOp2, typename R2>
array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub,
const array_expression<L2, BinOp2, R2>&>
operator*(const array_expression<L2, BinOp2, R2>& r) const {
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryMul,
const array_expression<L2, BinOp2, R2>&>(*this, r);
}
};

// Operator overload for subtraction.
template <typename L1, typename BinOp1, typename R1, typename L2,
typename BinOp2, typename R2>
array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub,
const array_expression<L2, BinOp2, R2>&>
operator-(const array_expression<L2, BinOp2, R2>& r) const {
return array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub,
const array_expression<L2, BinOp2, R2>&>(*this, r);
}
// A class to determine whether a given type is array_expression.
template <typename T> struct is_array_expr : std::false_type {};

// Operator overload for division.
template <typename L1, typename BinOp1, typename R1, typename L2,
typename BinOp2, typename R2>
array_expression<const array_expression<L1, BinOp1, R1>&, BinaryDiv,
const array_expression<L2, BinOp2, R2>&>
operator/(const array_expression<L2, BinOp2, R2>& r) const {
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryDiv,
const array_expression<L2, BinOp2, R2>&>(*this, r);
}
};
template <typename LeftExp, typename BinaryOp, typename RightExp>
struct is_array_expr<array_expression<LeftExp, BinaryOp, RightExp>>
: std::true_type {};

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator+(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
Expand All @@ -177,7 +145,7 @@ operator+(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator*(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
Expand All @@ -189,7 +157,7 @@ operator*(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator-(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
Expand Down

0 comments on commit d7c3174

Please sign in to comment.