From 3daeb517a75bdbeaaef8062be6b1f46f8ab11e36 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 18 Jul 2024 10:29:02 +0200 Subject: [PATCH] Fix type casting of ValueAndPushForward types #986 fixes #983 --- .../clad/Differentiator/BuiltinDerivatives.h | 8 +++ include/clad/Differentiator/CladUtils.h | 1 + lib/Differentiator/BaseForwardModeVisitor.cpp | 17 +++--- lib/Differentiator/CladUtils.cpp | 16 ++++++ lib/Differentiator/PushForwardModeVisitor.cpp | 21 ++++++- test/FirstDerivative/CallArguments.C | 56 +++++++++++++++++++ .../FunctionCallsWithResults.C | 4 +- test/FirstDerivative/FunctionsInNamespaces.C | 6 +- test/ForwardMode/UserDefinedTypes.C | 52 ++++++++--------- test/NthDerivative/CustomDerivatives.C | 2 +- 10 files changed, 139 insertions(+), 44 deletions(-) diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index 85937f72b..e706a3937 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -20,6 +20,14 @@ namespace clad { template struct ValueAndPushforward { T value; U pushforward; + + // Define the cast operator from ValueAndPushforward to + // ValueAndPushforward where V is convertible to T and W is + // convertible to U. + template + operator ValueAndPushforward() const { + return {static_cast(value), static_cast(pushforward)}; + } }; /// It is used to identify constructor custom pushforwards. For diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ae8b55813..05899cad7 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -317,6 +317,7 @@ namespace clad { void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); bool IsLiteral(const clang::Expr* E); + bool IsZeroOrNullValue(const clang::Expr* E); bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7853205a7..b27140ff9 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1071,7 +1071,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // Returning the function call and zero derivative return StmtDiff(Call, zero); } - // Find the built-in derivatives namespace. std::string s = std::to_string(m_DerivativeOrder); if (m_DerivativeOrder == 1) @@ -1200,19 +1199,17 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // FIXME: revert this when this is integrated in the activity analysis pass. if (!callDiff) { if (!isa(CE) && !isa(CE)) { - bool allArgsAreConstantLiterals = true; + bool allArgsHaveZeroDerivatives = true; for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) { - const Expr* arg = CE->getArg(i); - // if it's of type MaterializeTemporaryExpr, then check its - // subexpression. - if (const auto* MTE = dyn_cast(arg)) - arg = clad_compat::GetSubExpr(MTE); - if (!arg->isEvaluatable(m_Context)) { - allArgsAreConstantLiterals = false; + Expr* dArg = diffArgs[i]; + // If argDiff.expr_dx is nullptr or is a constant 0, then the derivative + // of the function call is 0. + if (!clad::utils::IsZeroOrNullValue(dArg)) { + allArgsHaveZeroDerivatives = false; break; } } - if (allArgsAreConstantLiterals) { + if (allArgsHaveZeroDerivatives) { Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 96b35d6aa..3b6a379e0 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -658,6 +658,22 @@ namespace clad { isa(E); } + bool IsZeroOrNullValue(const clang::Expr* E) { + if (!E) + return true; + if (const auto* ICE = dyn_cast(E)) + return IsZeroOrNullValue(ICE->getSubExpr()); + if (isa(E)) + return true; + if (const auto* FL = dyn_cast(E)) + return FL->getValue().isZero(); + if (const auto* IL = dyn_cast(E)) + return IL->getValue() == 0; + if (const auto* SL = dyn_cast(E)) + return SL->getLength() == 0; + return false; + } + bool IsMemoryFunction(const clang::FunctionDecl* FD) { #if CLANG_VERSION_MAJOR > 12 diff --git a/lib/Differentiator/PushForwardModeVisitor.cpp b/lib/Differentiator/PushForwardModeVisitor.cpp index 962eda0b9..0bb0c858d 100644 --- a/lib/Differentiator/PushForwardModeVisitor.cpp +++ b/lib/Differentiator/PushForwardModeVisitor.cpp @@ -26,8 +26,25 @@ StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { return nullptr; StmtDiff retValDiff = Visit(RS->getRetValue()); - llvm::SmallVector returnValues = {retValDiff.getExpr(), - retValDiff.getExpr_dx()}; + Expr* retVal = retValDiff.getExpr(); + Expr* retVal_dx = retValDiff.getExpr_dx(); + if (!m_Context.hasSameUnqualifiedType(retVal->getType(), + m_DiffReq->getReturnType())) { + // Check if implficit cast would work. + // Add a cast to the return type. + TypeSourceInfo* TSI = + m_Context.getTrivialTypeSourceInfo(m_DiffReq->getReturnType()); + retVal = m_Sema + .BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(), + BuildParens(retVal)) + .get(); + retVal_dx = + m_Sema + .BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(), + BuildParens(retVal_dx)) + .get(); + } + llvm::SmallVector returnValues = {retVal, retVal_dx}; // This can instantiate as part of the move or copy initialization and // needs a fake source location. SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index 9c543a655..454a8a816 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -132,6 +132,44 @@ float f_const_args_func_8(const float x, float y) { // CHECK-NEXT: return _t0.pushforward + _t1.pushforward - _d_y; // CHECK-NEXT: } +float f_literal_helper(float x, char ch, float* p, float* q) { + if (ch == 'a') + return x * x; + return -x * x; +} + +float f_literal_args_func(float x, float y, float *z) { + printf("hello world "); + return x * f_literal_helper(0.5, 'a', z, nullptr); +} + +// CHECK: float f_literal_args_func_darg0(float x, float y, float *z) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: float _d_y = 0; +// CHECK-NEXT: printf("hello world "); +// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr); +// CHECK-NEXT: return _d_x * _t0 + x * 0.F; +// CHECK-NEXT: } + +inline unsigned int getBin(double low, double high, double val, unsigned int numBins) { + double binWidth = (high - low) / numBins; + return val >= high ? numBins - 1 : std::abs((val - low) / binWidth); +} + +float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) { + const float t116 = *(xlArr + getBin(0., 1., params[0], 1)); + return t116 * params[0]; +} + +// CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins); + +// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) { +// CHECK-NEXT: clad::ValueAndPushforward _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0); +// CHECK-NEXT: const float _d_t116 = 0; +// CHECK-NEXT: const float t116 = *(xlArr + _t0.value); +// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F; +// CHECK-NEXT: } + extern "C" int printf(const char* fmt, ...); int main () { // expected-no-diagnostics auto f = clad::differentiate(g, 0); @@ -165,9 +203,27 @@ int main () { // expected-no-diagnostics const float f8x = 1.F; printf("f8_darg0=%f\n", f8.execute(f8x,2.F)); //CHECK-EXEC: f8_darg0=2.000000 + auto f9 = clad::differentiate(f_literal_args_func, 0); + float z = 3.0; + printf("f9_darg0=%.2f\n", f9.execute(1.F, 2.F, &z)); + //CHECK-EXEC: hello world f9_darg0=0.25 + auto f10 = clad::differentiate(f_call_inline_fxn, "params[0]"); + float params = 1.0, obs = 5.0, xlArr = 7.0; + printf("f10_darg0_0=%.2f\n", f10.execute(¶ms, &obs, &xlArr)); + //CHECK-EXEC: f10_darg0_0=7.00 // CHECK: clad::ValueAndPushforward f_const_helper_pushforward(const float x, const float _d_x) { // CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; +// CHECK-NEXT: } + +// CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) { +// CHECK-NEXT: double _t0 = (high - low); +// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins); +// CHECK-NEXT: double binWidth = _t0 / numBins; +// CHECK-NEXT: double _t1 = (val - low); +// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::abs_pushforward(_t1 / binWidth, ((_d_val - _d_low) * binWidth - _t1 * _d_binWidth) / (binWidth * binWidth)); +// CHECK-NEXT: bool _t3 = val >= high; +// CHECK-NEXT: return {(unsigned int)(_t3 ? numBins - 1 : _t2.value), (unsigned int)(_t3 ? _d_numBins - 0 : _t2.pushforward)}; // CHECK-NEXT: } return 0; diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index 72289ca78..2f26288f6 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -385,7 +385,7 @@ int main () { // CHECK: clad::ValueAndPushforward fn6_pushforward(double i, double j, double k, double _d_i, double _d_j, double _d_k) { // CHECK-NEXT: if (i < 0.5) -// CHECK-NEXT: return {0, 0}; +// CHECK-NEXT: return {(double)0, (double)0}; // CHECK-NEXT: clad::ValueAndPushforward _t0 = fn6_pushforward(i - 1, j - 1, k - 1, _d_i - 0, _d_j - 0, _d_k - 0); // CHECK-NEXT: return {i + j + k + _t0.value, _d_i + _d_j + _d_k + _t0.pushforward}; // CHECK-NEXT: } @@ -420,7 +420,7 @@ int main () { // CHECK: clad::ValueAndPushforward check_and_return_pushforward(double x, char c, double _d_x, char _d_c) { // CHECK-NEXT: if (c == 'a') // CHECK-NEXT: return {x, _d_x}; -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(double)1, (double)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward g_pushforward(double x, double _d_x) { diff --git a/test/FirstDerivative/FunctionsInNamespaces.C b/test/FirstDerivative/FunctionsInNamespaces.C index edb1f7b64..7c676be9b 100644 --- a/test/FirstDerivative/FunctionsInNamespaces.C +++ b/test/FirstDerivative/FunctionsInNamespaces.C @@ -107,7 +107,7 @@ int main () { // CHECK: clad::ValueAndPushforward someFn_pushforward(double &i, double &j, double &_d_i, double &_d_j) { // CHECK-NEXT: clad::ValueAndPushforward _t0 = someFn_1_pushforward(i, j, _d_i, _d_j); - // CHECK-NEXT: return {3, 0}; + // CHECK-NEXT: return {(double)3, (double)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward func4_pushforward(int x, int y, int _d_x, int _d_y) { @@ -118,13 +118,13 @@ int main () { // CHECK: clad::ValueAndPushforward someFn_1_pushforward(double &i, double j, double &_d_i, double _d_j) { // CHECK-NEXT: clad::ValueAndPushforward _t0 = someFn_1_pushforward(i, j, j, _d_i, _d_j, _d_j); - // CHECK-NEXT: return {2, 0}; + // CHECK-NEXT: return {(double)2, (double)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward someFn_1_pushforward(double &i, double j, double k, double &_d_i, double _d_j, double _d_k) { // CHECK-NEXT: _d_i = _d_j; // CHECK-NEXT: i = j; - // CHECK-NEXT: return {1, 0}; + // CHECK-NEXT: return {(double)1, (double)0}; // CHECK-NEXT: } return 0; diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index acdc66727..82ef5f54f 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -1082,7 +1082,7 @@ int main() { // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_subscript_pushforward(std::size_t idx, Tensor *_d_this, std::size_t _d_idx) { -// CHECK-NEXT: return {this->data[idx], _d_this->data[idx]}; +// CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward, Tensor > operator_plus_pushforward(const Tensor &a, const Tensor &b, const Tensor &_d_a, const Tensor &_d_b) { @@ -1162,7 +1162,7 @@ int main() { // CHECK-NEXT: this->data[i] = t.data[i]; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return {*this, *_d_this}; +// CHECK-NEXT: return {(Tensor &)*this, (Tensor &)*_d_this}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward, Tensor > operator_caret_pushforward(const Tensor &a, const Tensor &b, const Tensor &_d_a, const Tensor &_d_b) { @@ -1182,13 +1182,13 @@ int main() { // CHECK: clad::ValueAndPushforward &, Tensor &> operator_plus_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { // CHECK-NEXT: clad::ValueAndPushforward, Tensor > _t0 = operator_plus_pushforward(lhs, rhs, _d_lhs, _d_rhs); // CHECK-NEXT: clad::ValueAndPushforward &, Tensor &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward); -// CHECK-NEXT: return {lhs, _d_lhs}; +// CHECK-NEXT: return {(Tensor &)lhs, (Tensor &)_d_lhs}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_minus_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { // CHECK-NEXT: clad::ValueAndPushforward, Tensor > _t0 = operator_minus_pushforward(lhs, rhs, _d_lhs, _d_rhs); // CHECK-NEXT: clad::ValueAndPushforward &, Tensor &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward); -// CHECK-NEXT: return {lhs, _d_lhs}; +// CHECK-NEXT: return {(Tensor &)lhs, (Tensor &)_d_lhs}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_plus_plus_pushforward(Tensor *_d_this) { @@ -1199,7 +1199,7 @@ int main() { // CHECK-NEXT: this->data[i] += 1; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return {*this, *_d_this}; +// CHECK-NEXT: return {(Tensor &)*this, (Tensor &)*_d_this}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_minus_minus_pushforward(Tensor *_d_this) { @@ -1210,7 +1210,7 @@ int main() { // CHECK-NEXT: this->data[i] += 1; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return {*this, *_d_this}; +// CHECK-NEXT: return {(Tensor &)*this, (Tensor &)*_d_this}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward, Tensor > operator_plus_plus_pushforward(int param, Tensor *_d_this, int _d_param) { @@ -1229,19 +1229,19 @@ int main() { // CHECK: clad::ValueAndPushforward &, Tensor &> operator_star_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { // CHECK-NEXT: clad::ValueAndPushforward, Tensor > _t0 = operator_star_pushforward(lhs, rhs, _d_lhs, _d_rhs); // CHECK-NEXT: clad::ValueAndPushforward &, Tensor &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward); -// CHECK-NEXT: return {lhs, _d_lhs}; +// CHECK-NEXT: return {(Tensor &)lhs, (Tensor &)_d_lhs}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_slash_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { // CHECK-NEXT: clad::ValueAndPushforward, Tensor > _t0 = operator_slash_pushforward(lhs, rhs, _d_lhs, _d_rhs); // CHECK-NEXT: clad::ValueAndPushforward &, Tensor &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward); -// CHECK-NEXT: return {lhs, _d_lhs}; +// CHECK-NEXT: return {(Tensor &)lhs, (Tensor &)_d_lhs}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_caret_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { // CHECK-NEXT: clad::ValueAndPushforward, Tensor > _t0 = operator_slash_pushforward(lhs, rhs, _d_lhs, _d_rhs); // CHECK-NEXT: clad::ValueAndPushforward &, Tensor &> _t1 = lhs.operator_equal_pushforward(_t0.value, & _d_lhs, _t0.pushforward); -// CHECK-NEXT: return {lhs, _d_lhs}; +// CHECK-NEXT: return {(Tensor &)lhs, (Tensor &)_d_lhs}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_less_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { @@ -1258,7 +1258,7 @@ int main() { // CHECK-NEXT: rsum += lhs.data[i]; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_greater_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { @@ -1275,23 +1275,23 @@ int main() { // CHECK-NEXT: rsum += lhs.data[i]; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_less_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_greater_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_equal_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_exclaim_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: void operator_comma_pushforward(const Tensor &lhs, const Tensor &rhs, const Tensor &_d_lhs, const Tensor &_d_rhs) { @@ -1313,49 +1313,49 @@ int main() { // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward &, Tensor &> operator_percent_equal_pushforward(Tensor &a, const Tensor &b, Tensor &_d_a, const Tensor &_d_b) { -// CHECK-NEXT: return {a, _d_a}; +// CHECK-NEXT: return {(Tensor &)a, (Tensor &)_d_a}; // CHECK-NEXT: } // CHECK: void operator_tilde_pushforward(const Tensor &a, const Tensor &_d_a) { // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_less_less_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_greater_greater_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_AmpAmp_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_pipe_pipe_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_less_less_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_greater_greater_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_amp_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_pipe_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_pipe_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } // CHECK: clad::ValueAndPushforward operator_amp_equal_pushforward(Tensor &lhs, const Tensor &rhs, Tensor &_d_lhs, const Tensor &_d_rhs) { -// CHECK-NEXT: return {1, 0}; +// CHECK-NEXT: return {(bool)1, (bool)0}; // CHECK-NEXT: } } \ No newline at end of file diff --git a/test/NthDerivative/CustomDerivatives.C b/test/NthDerivative/CustomDerivatives.C index 5a72c2dc5..e5bcbecbb 100644 --- a/test/NthDerivative/CustomDerivatives.C +++ b/test/NthDerivative/CustomDerivatives.C @@ -163,7 +163,7 @@ int main() { // CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x); // CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x); // CHECK-NEXT: float &_t2 = _t1.value; -// CHECK-NEXT: return {{[{][{]}}_t0.value, _t2 * d_x}, {_t0.pushforward, _t1.pushforward * d_x + _t2 * _d_d_x{{[}][}]}}; +// CHECK-NEXT: return {{[{][(]ValueAndPushforward[)][{]}}_t0.value, _t2 * d_x}, (ValueAndPushforward){_t0.pushforward, _t1.pushforward * d_x + _t2 * _d_d_x{{[}][}]}}; // CHECK-NEXT:} } \ No newline at end of file