From 1a3e8a53bd7a9bb7172be3c3cc81cda146d16106 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Thu, 18 Jul 2024 13:29:45 +0200 Subject: [PATCH] Prevent Clad from trying to create a void zero literal Previously, clad used to try to synthesise a void zero literal when differentiating a call to a void function with literal arguments in the forward mode. This caused it to crash. Fixes: #988 --- lib/Differentiator/BaseForwardModeVisitor.cpp | 3 +-- lib/Differentiator/VisitorBase.cpp | 2 ++ test/FirstDerivative/CallArguments.C | 2 +- test/FirstDerivative/FunctionCalls.C | 16 ++++++++++++++++ test/FirstDerivative/FunctionCallsWithResults.C | 4 ++-- 5 files changed, 22 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index b27140ff9..b6eebb8b2 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1216,8 +1216,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { validLoc, llvm::MutableArrayRef(CallArgs), validLoc) .get(); - auto* zero = ConstantFolder::synthesizeLiteral(CE->getType(), m_Context, - /*val=*/0); + auto* zero = getZeroInit(CE->getType()); return StmtDiff(call, zero); } } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 2d36e34ed..7c4b0abb1 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -376,6 +376,8 @@ namespace clad { Expr* VisitorBase::getZeroInit(QualType T) { // FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1. + if (T->isVoidType()) + return nullptr; if (T->isScalarType()) { ExprResult Zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index 454a8a816..efe05909e 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -148,7 +148,7 @@ float f_literal_args_func(float x, float y, float *z) { // CHECK-NEXT: float _d_y = 0; // CHECK-NEXT: printf("hello world "); // CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr); -// CHECK-NEXT: return _d_x * _t0 + x * 0.F; +// CHECK-NEXT: return _d_x * _t0 + x * 0; // CHECK-NEXT: } inline unsigned int getBin(double low, double high, double val, unsigned int numBins) { diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 33e15c2c6..2d2f0f651 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -183,6 +183,21 @@ double test_9(double x) { // CHECK-NEXT: return _t0.pushforward; // CHECK-NEXT: } +void some_important_void_func(double y) { + assert(y < 1); +} + +double test_10(double x) { + some_important_void_func(1); + return x; +} + +// CHECK: double test_10_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: some_important_void_func(1); +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } + int main () { clad::differentiate(test_1, 0); clad::differentiate(test_2, 0); @@ -196,6 +211,7 @@ int main () { clad::differentiate(test_8); // expected-error {{Both enable and disable TBR options are specified.}} clad::differentiate(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}} clad::differentiate(test_9); + clad::differentiate(test_10); return 0; // CHECK: void increment_pushforward(int &i, int &_d_i) { diff --git a/test/FirstDerivative/FunctionCallsWithResults.C b/test/FirstDerivative/FunctionCallsWithResults.C index 2f26288f6..733698a62 100644 --- a/test/FirstDerivative/FunctionCallsWithResults.C +++ b/test/FirstDerivative/FunctionCallsWithResults.C @@ -165,7 +165,7 @@ double fn4(double i, double j) { // CHECK: double fn4_darg0(double i, double j) { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_j = 0; -// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: double res = nonRealParamFn(0, 0); // CHECK-NEXT: _d_res += _d_i; // CHECK-NEXT: res += i; @@ -266,7 +266,7 @@ double fn8(double i, double j) { // CHECK-NEXT: clad::ValueAndPushforward _t1 = check_and_return_pushforward(_t0.value, 'a', _t0.pushforward, 0); // CHECK-NEXT: double &_t2 = _t1.value; // CHECK-NEXT: double _t3 = std::tanh(1.); -// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0.; +// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0; // CHECK-NEXT: } double g (double x) { return x; }