Skip to content

Commit

Permalink
Fix type casting of ValueAndPushForward types
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jul 18, 2024
1 parent 2a5ac8b commit b3308a1
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 34 deletions.
8 changes: 8 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ namespace clad {
template <typename T, typename U> struct ValueAndPushforward {
T value;
U pushforward;

// Define the cast operator from ValueAndPushforward<T, U> to
// ValueAndPushforward<V, w> where V is convertible to T and W is
// convertible to U.
template <typename V = T, typename W = U>
operator ValueAndPushforward<V, W>() const {
return {static_cast<V>(value), static_cast<W>(pushforward)};
}
};

/// It is used to identify constructor custom pushforwards. For
Expand Down
21 changes: 19 additions & 2 deletions lib/Differentiator/PushForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,25 @@ StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
return nullptr;

StmtDiff retValDiff = Visit(RS->getRetValue());
llvm::SmallVector<Expr*, 2> returnValues = {retValDiff.getExpr(),
retValDiff.getExpr_dx()};
Expr* retVal = retValDiff.getExpr();
Expr* retVal_dx = retValDiff.getExpr_dx();
if (!m_Context.hasSameUnqualifiedType(retVal->getType(),
m_DiffReq->getReturnType())) {
// Check if implficit cast would work.
// Add a cast to the return type.
TypeSourceInfo* TSI =
m_Context.getTrivialTypeSourceInfo(m_DiffReq->getReturnType());
retVal = m_Sema
.BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(),
BuildParens(retVal))
.get();
retVal_dx =
m_Sema
.BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(),
BuildParens(retVal_dx))
.get();
}
llvm::SmallVector<Expr*, 2> returnValues = {retVal, retVal_dx};
// This can instantiate as part of the move or copy initialization and
// needs a fake source location.
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
Expand Down
33 changes: 33 additions & 0 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,25 @@ float f_literal_args_func(float x, float y) {
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
double binWidth = (high - low) / numBins;
return val >= high ? numBins - 1 : std::abs((val - low) / binWidth);
}

float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {
const float t116 = *(xlArr + getBin(0., 1., params[0], 1));
return t116 * params[0];
}

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins);

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0;
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }

extern "C" int printf(const char* fmt, ...);
int main () { // expected-no-diagnostics
auto f = clad::differentiate(g, 0);
Expand Down Expand Up @@ -187,9 +206,23 @@ int main () { // expected-no-diagnostics
auto f9 = clad::differentiate(f_literal_args_func, 0);
printf("f9_darg0=%.2f\n", f9.execute(1.F,2.F));
//CHECK-EXEC: hello world f9_darg0=0.25
auto f10 = clad::differentiate(f_call_inline_fxn, "params[0]");
float params = 1.0, obs = 5.0, xlArr = 7.0;
printf("f10_darg0_0=%.2f\n", f10.execute(&params, &obs, &xlArr));
//CHECK-EXEC: f10_darg0_0=7.00

// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) {
// CHECK-NEXT: double _t0 = (high - low);
// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins);
// CHECK-NEXT: double binWidth = _t0 / numBins;
// CHECK-NEXT: double _t1 = (val - low);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t2 = clad::custom_derivatives::std::abs_pushforward(_t1 / binWidth, ((_d_val - _d_low) * binWidth - _t1 * _d_binWidth) / (binWidth * binWidth));
// CHECK-NEXT: bool _t3 = val >= high;
// CHECK-NEXT: return {(unsigned int)(_t3 ? numBins - 1 : _t2.value), (unsigned int)(_t3 ? _d_numBins - 0 : _t2.pushforward)};
// CHECK-NEXT: }

return 0;
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> fn6_pushforward(double i, double j, double k, double _d_i, double _d_j, double _d_k) {
// CHECK-NEXT: if (i < 0.5)
// CHECK-NEXT: return {0, 0};
// CHECK-NEXT: return {(double)0, (double)0};
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = fn6_pushforward(i - 1, j - 1, k - 1, _d_i - 0, _d_j - 0, _d_k - 0);
// CHECK-NEXT: return {i + j + k + _t0.value, _d_i + _d_j + _d_k + _t0.pushforward};
// CHECK-NEXT: }
Expand Down Expand Up @@ -420,7 +420,7 @@ int main () {
// CHECK: clad::ValueAndPushforward<double, double> check_and_return_pushforward(double x, char c, double _d_x, char _d_c) {
// CHECK-NEXT: if (c == 'a')
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(double)1, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> g_pushforward(double x, double _d_x) {
Expand Down
6 changes: 3 additions & 3 deletions test/FirstDerivative/FunctionsInNamespaces.C
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> someFn_pushforward(double &i, double &j, double &_d_i, double &_d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: return {3, 0};
// CHECK-NEXT: return {(double)3, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<int, int> func4_pushforward(int x, int y, int _d_x, int _d_y) {
Expand All @@ -118,13 +118,13 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double &_d_i, double _d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, j, _d_i, _d_j, _d_j);
// CHECK-NEXT: return {2, 0};
// CHECK-NEXT: return {(double)2, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double k, double &_d_i, double _d_j, double _d_k) {
// CHECK-NEXT: _d_i = _d_j;
// CHECK-NEXT: i = j;
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(double)1, (double)0};
// CHECK-NEXT: }

return 0;
Expand Down
52 changes: 26 additions & 26 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx) {
// CHECK-NEXT: return {this->data[idx], _d_this->data[idx]};
// CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > operator_plus_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &b, const Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b) {
Expand Down Expand Up @@ -1162,7 +1162,7 @@ int main() {
// CHECK-NEXT: this->data[i] = t.data[i];
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return {*this, *_d_this};
// CHECK-NEXT: return {(Tensor<double, 5> &)*this, (Tensor<double, 5> &)*_d_this};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > operator_caret_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &b, const Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b) {
Expand All @@ -1182,13 +1182,13 @@ int main() {
// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_plus_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > _t0 = operator_plus_pushforward(lhs, rhs, _d_lhs, _d_rhs);
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward);
// CHECK-NEXT: return {lhs, _d_lhs};
// CHECK-NEXT: return {(Tensor<double, 5U> &)lhs, (Tensor<double, 5U> &)_d_lhs};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_minus_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > _t0 = operator_minus_pushforward(lhs, rhs, _d_lhs, _d_rhs);
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward);
// CHECK-NEXT: return {lhs, _d_lhs};
// CHECK-NEXT: return {(Tensor<double, 5U> &)lhs, (Tensor<double, 5U> &)_d_lhs};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> operator_plus_plus_pushforward(Tensor<double, 5> *_d_this) {
Expand All @@ -1199,7 +1199,7 @@ int main() {
// CHECK-NEXT: this->data[i] += 1;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return {*this, *_d_this};
// CHECK-NEXT: return {(Tensor<double, 5> &)*this, (Tensor<double, 5> &)*_d_this};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> operator_minus_minus_pushforward(Tensor<double, 5> *_d_this) {
Expand All @@ -1210,7 +1210,7 @@ int main() {
// CHECK-NEXT: this->data[i] += 1;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return {*this, *_d_this};
// CHECK-NEXT: return {(Tensor<double, 5> &)*this, (Tensor<double, 5> &)*_d_this};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5>, Tensor<double, 5> > operator_plus_plus_pushforward(int param, Tensor<double, 5> *_d_this, int _d_param) {
Expand All @@ -1229,19 +1229,19 @@ int main() {
// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_star_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > _t0 = operator_star_pushforward(lhs, rhs, _d_lhs, _d_rhs);
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward);
// CHECK-NEXT: return {lhs, _d_lhs};
// CHECK-NEXT: return {(Tensor<double, 5U> &)lhs, (Tensor<double, 5U> &)_d_lhs};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_slash_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > _t0 = operator_slash_pushforward(lhs, rhs, _d_lhs, _d_rhs);
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward);
// CHECK-NEXT: return {lhs, _d_lhs};
// CHECK-NEXT: return {(Tensor<double, 5U> &)lhs, (Tensor<double, 5U> &)_d_lhs};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_caret_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > _t0 = operator_slash_pushforward(lhs, rhs, _d_lhs, _d_rhs);
// CHECK-NEXT: clad::ValueAndPushforward<Tensor<double, 5> &, Tensor<double, 5> &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward);
// CHECK-NEXT: return {lhs, _d_lhs};
// CHECK-NEXT: return {(Tensor<double, 5U> &)lhs, (Tensor<double, 5U> &)_d_lhs};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_less_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
Expand All @@ -1258,7 +1258,7 @@ int main() {
// CHECK-NEXT: rsum += lhs.data[i];
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_greater_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
Expand All @@ -1275,23 +1275,23 @@ int main() {
// CHECK-NEXT: rsum += lhs.data[i];
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_less_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_greater_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_equal_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_exclaim_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: void operator_comma_pushforward(const Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, const Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
Expand All @@ -1313,49 +1313,49 @@ int main() {
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U> &, Tensor<double, 5U> &> operator_percent_equal_pushforward(Tensor<double, 5U> &a, const Tensor<double, 5U> &b, Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b) {
// CHECK-NEXT: return {a, _d_a};
// CHECK-NEXT: return {(Tensor<double, 5U> &)a, (Tensor<double, 5U> &)_d_a};
// CHECK-NEXT: }

// CHECK: void operator_tilde_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &_d_a) {
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_less_less_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_greater_greater_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_AmpAmp_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_pipe_pipe_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_less_less_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_greater_greater_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_amp_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_pipe_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_pipe_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<bool, bool> operator_amp_equal_pushforward(Tensor<double, 5U> &lhs, const Tensor<double, 5U> &rhs, Tensor<double, 5U> &_d_lhs, const Tensor<double, 5U> &_d_rhs) {
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(bool)1, (bool)0};
// CHECK-NEXT: }
}
2 changes: 1 addition & 1 deletion test/NthDerivative/CustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ int main() {
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t1 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: float &_t2 = _t1.value;
// CHECK-NEXT: return {{[{][{]}}_t0.value, _t2 * d_x}, {_t0.pushforward, _t1.pushforward * d_x + _t2 * _d_d_x{{[}][}]}};
// CHECK-NEXT: return {{[{][(]ValueAndPushforward<float, float>[)][{]}}_t0.value, _t2 * d_x}, (ValueAndPushforward<float, float>){_t0.pushforward, _t1.pushforward * d_x + _t2 * _d_d_x{{[}][}]}};
// CHECK-NEXT:}

}

0 comments on commit b3308a1

Please sign in to comment.