Skip to content

Commit

Permalink
Add support for std::string variables and functions with implicit c…
Browse files Browse the repository at this point in the history
…onstructor calls in their arguments

Previously, Clad failed to differentiate functions with `std::string` variables in them or
function calls that call a constructor implicitly. For instance, a call
```
f(x, "text");
```
to a function `double f(double, const std::string&)` would have caused a segmentation fault
before this commit. Same applies to other user defined arguments, not only those of the
`std::string` type.

Fixes: #974

Co-authored-by: Vassil Vassilev <[email protected]>
  • Loading branch information
gojakuch and vgvassilev committed Jul 10, 2024
1 parent 036eb67 commit f309b61
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 25 deletions.
12 changes: 6 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2012,17 +2012,17 @@ BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
// forward mode derived constructor that would require same arguments as of
// a pushforward function, that is, `{a, 1, b, _d_a, 0., _d_b}`.
if (CE->getNumArgs() != 1) {
if (CE->isListInitialization()) {
clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get();
derivedArgsE = m_Sema.ActOnInitList(noLoc, derivedArgs, noLoc).get();
} else if (CE->getNumArgs() == 0) {
if (CE->getNumArgs() == 0 && !CE->isListInitialization()) {
// ParenList is empty -- default initialisation.
// Passing empty parenList here will silently cause 'most vexing
// parse' issue.
return StmtDiff();
} else {
clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get();
derivedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, derivedArgs).get();
// Rely on the initializer list expressions as they seem to be more
// flexible in terms of conversions and other similar scenarios where a
// constructor is called implicitly.
clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get();
derivedArgsE = m_Sema.ActOnInitList(noLoc, derivedArgs, noLoc).get();
}
} else {
clonedArgsE = clonedArgs[0];
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ double fn10(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: std::mt19937 _d_gen64;
// CHECK-NEXT: std::mt19937 gen64;
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> _d_distribution(0., 0.);
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> distribution(0., 1.);
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> _d_distribution({0., 0.});
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> distribution({0., 1.});
// CHECK-NEXT: clad::ValueAndPushforward<result_type, result_type> _t0 = distribution.operator_call_pushforward(gen64, &_d_distribution, _d_gen64);
// CHECK-NEXT: double _d_rand = _t0.pushforward;
// CHECK-NEXT: double rand0 = _t0.value;
Expand Down
38 changes: 37 additions & 1 deletion test/FirstDerivative/Variables.C
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "clad/Differentiator/Differentiator.h"
#include <cmath>
#include <string>
#include <iostream>

double f_x(double x) {
double t0 = x;
Expand Down Expand Up @@ -102,7 +103,7 @@ double f_string(double x) {
namespace clad {
namespace custom_derivatives {
clad::ValueAndPushforward<double, double> string_test_pushforward(double x, const char s[], double _d_x, const char *_d_s) {
return {0, 0};
return {1, 0};
}
}}
double string_test(double x, const char s[]) {
Expand All @@ -118,13 +119,48 @@ double f_string_call(double x) {
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

double f_stdstring(double x) {
std::string s = "string literal";
return x;
}

// CHECK: double f_stdstring_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: std::string _d_s = {{[{]"", std::allocator<char>\(\)[}]|""}};
// CHECK-NEXT: std::string s = {{[{]"string literal", std::allocator<char>\(\)[}]|"string literal"}};
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

namespace clad {
namespace custom_derivatives {
clad::ValueAndPushforward<double, double> stdstring_test_pushforward(double x, const ::std::string& s, double _d_x, const ::std::string& _d_s) {
return {x, 1};
}
}}
double stdstring_test(double x, const std::string& s) {
return x;
}
double f_stdstring_call(double x) {
return stdstring_test(x, "string literal");
}

// CHECK: double f_stdstring_call_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::stdstring_test_pushforward(x, {{[{]"string literal", std::allocator<char>\(\)[}]|"string literal"}}, _d_x, {{[{]"", std::allocator<char>\(\)[}]|""}});
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

int main() {
clad::differentiate(f_x, 0);
clad::differentiate(f_ops1, 0);
clad::differentiate(f_ops2, 0);
clad::differentiate(f_sin, 0);
clad::differentiate(f_string, 0);
clad::differentiate(f_string_call, 0);
auto df_stdstring = clad::differentiate(f_stdstring, 0);
std::cout << df_stdstring.execute(3.0) << '\n'; // CHECK-EXEC: 1
auto df_stdstring_call = clad::differentiate(f_stdstring_call, 0);
std::cout << df_stdstring_call.execute(3.0) << '\n'; // CHECK-EXEC: 1
}


16 changes: 8 additions & 8 deletions test/ForwardMode/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,17 @@ int main() {
// CHECK: double fn_s1_mem_fn_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: SimpleFunctions1 _d_obj(0, 0);
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj({0, 0});
// CHECK-NEXT: SimpleFunctions1 obj({2, 3});
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = obj.mem_fn_1_pushforward(i, j, &_d_obj, _d_i, _d_j);
// CHECK-NEXT: return _t0.pushforward + _d_i * j + i * _d_j;
// CHECK-NEXT: }

// CHECK: double fn_s1_field_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: SimpleFunctions1 _d_obj(0, 0);
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj({0, 0});
// CHECK-NEXT: SimpleFunctions1 obj({2, 3});
// CHECK-NEXT: double &_t0 = obj.x;
// CHECK-NEXT: double &_t1 = obj.y;
// CHECK-NEXT: return _d_obj.x * _t1 + _t0 * 0. + _d_i * j + i * _d_j;
Expand All @@ -175,10 +175,10 @@ int main() {
// CHECK: double fn_s1_operator_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: SimpleFunctions1 _d_obj1(0, 0);
// CHECK-NEXT: SimpleFunctions1 obj1(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj2(0, 0);
// CHECK-NEXT: SimpleFunctions1 obj2(3, 5);
// CHECK-NEXT: SimpleFunctions1 _d_obj1({0, 0});
// CHECK-NEXT: SimpleFunctions1 obj1({2, 3});
// CHECK-NEXT: SimpleFunctions1 _d_obj2({0, 0});
// CHECK-NEXT: SimpleFunctions1 obj2({3, 5});
// CHECK-NEXT: clad::ValueAndPushforward<SimpleFunctions1, SimpleFunctions1> _t0 = obj1.operator_plus_pushforward(obj2, &_d_obj1, _d_obj2);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = _t0.value.mem_fn_1_pushforward(i, j, &_t0.pushforward, _d_i, _d_j);
// CHECK-NEXT: return _t1.pushforward;
Expand Down
16 changes: 8 additions & 8 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ std::pair<double, double> fn1(double i, double j) {
// CHECK: std::pair<double, double> fn1_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: std::pair<double, double> _d_c(0, 0), _d_d({0., 0.});
// CHECK-NEXT: std::pair<double, double> c(3, 5), d({7., 9.});
// CHECK-NEXT: std::pair<double, double> _d_c({0, 0}), _d_d({0., 0.});
// CHECK-NEXT: std::pair<double, double> c({3, 5}), d({7., 9.});
// CHECK-NEXT: std::pair<double, double> _d_e = _d_d;
// CHECK-NEXT: std::pair<double, double> e = d;
// CHECK-NEXT: _d_c.first += _d_i;
Expand Down Expand Up @@ -498,8 +498,8 @@ complexD fn8(double i, TensorD5 t) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: TensorD5 _d_t;
// CHECK-NEXT: t.updateTo_pushforward(i * i, & _d_t, _d_i * i + i * _d_i);
// CHECK-NEXT: complexD _d_c(0., 0.);
// CHECK-NEXT: complexD c(0., 0.);
// CHECK-NEXT: complexD _d_c({0., 0.});
// CHECK-NEXT: complexD c({0., 0.});
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = t.sum_pushforward(& _d_t);
// CHECK-NEXT: double &_t1 = _t0.value;
// CHECK-NEXT: c.real_pushforward(7 * _t1, &_d_c, 0 * _t1 + 7 * _t0.pushforward);
Expand All @@ -525,8 +525,8 @@ complexD fn9(double i, complexD c) {
// CHECK: complexD fn9_darg0(double i, complexD c) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: complexD _d_c;
// CHECK-NEXT: complexD _d_r(0., 0.);
// CHECK-NEXT: complexD r(0., 0.);
// CHECK-NEXT: complexD _d_r({0., 0.});
// CHECK-NEXT: complexD r({0., 0.});
// CHECK-NEXT: c.real_pushforward(i * i, &_d_c, _d_i * i + i * _d_i);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = c.real_pushforward(&_d_c);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t1 = c.real_pushforward(&_d_c);
Expand Down Expand Up @@ -558,8 +558,8 @@ std::complex<double> fn10(double i, double j) {
// CHECK: std::complex<double> fn10_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: std::complex<double> _d_c1(0., 0.), _d_c2(0., 0.);
// CHECK-NEXT: std::complex<double> c1(0., 0.), c2(0., 0.);
// CHECK-NEXT: std::complex<double> _d_c1({0., 0.}), _d_c2({0., 0.});
// CHECK-NEXT: std::complex<double> c1({0., 0.}), c2({0., 0.});
// CHECK-NEXT: c1.real_pushforward(2 * i, &_d_c1, 0 * i + 2 * _d_i);
// CHECK-NEXT: c1.imag_pushforward(5 * i, &_d_c1, 0 * i + 5 * _d_i);
// CHECK-NEXT: c2.real_pushforward(5 * i, &_d_c2, 0 * i + 5 * _d_i);
Expand Down

0 comments on commit f309b61

Please sign in to comment.