Skip to content

Commit

Permalink
Prevent Clad from trying to create a void zero literal
Browse files Browse the repository at this point in the history
Previously, clad used to try to synthesise a void zero literal
when differentiating a call to a void function with
literal arguments in the forward mode. This caused it to crash.

Fixes: vgvassilev#988
  • Loading branch information
gojakuch committed Jul 19, 2024
1 parent e04d04e commit 4e1a142
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
3 changes: 1 addition & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,8 +1219,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
validLoc, llvm::MutableArrayRef<Expr*>(CallArgs),
validLoc)
.get();
auto* zero = ConstantFolder::synthesizeLiteral(CE->getType(), m_Context,
/*val=*/0);
auto* zero = getZeroInit(CE->getType());
return StmtDiff(call, zero);
}
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ namespace clad {

Expr* VisitorBase::getZeroInit(QualType T) {
// FIXME: Consolidate other uses of synthesizeLiteral for creation 0 or 1.
if (T->isVoidType()) {
return nullptr;
}
if (T->isScalarType()) {
ExprResult Zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down
16 changes: 16 additions & 0 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ double test_9(double x) {
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

void some_important_void_func(double y) {
assert(y < 1);
}

double test_10(double x) {
some_important_void_func(1);
return x;
}

// CHECK: double test_10_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: some_important_void_func(1);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

int main () {
clad::differentiate(test_1, 0);
clad::differentiate(test_2, 0);
Expand All @@ -196,6 +211,7 @@ int main () {
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
clad::differentiate<clad::opts::diagonal_only>(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}}
clad::differentiate(test_9);
clad::differentiate(test_10);
return 0;

// CHECK: void increment_pushforward(int &i, int &_d_i) {
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ double fn4(double i, double j) {
// CHECK: double fn4_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = nonRealParamFn(0, 0);
// CHECK-NEXT: _d_res += _d_i;
// CHECK-NEXT: res += i;
Expand Down Expand Up @@ -266,7 +266,7 @@ double fn8(double i, double j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = check_and_return_pushforward(_t0.value, 'a', _t0.pushforward, 0);
// CHECK-NEXT: double &_t2 = _t1.value;
// CHECK-NEXT: double _t3 = std::tanh(1.);
// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0.;
// CHECK-NEXT: return _t1.pushforward * _t3 + _t2 * 0;
// CHECK-NEXT: }

double g (double x) { return x; }
Expand Down

0 comments on commit 4e1a142

Please sign in to comment.