Skip to content

Commit

Permalink
Add test case for param naming
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 20, 2024
1 parent f00afbc commit b5f381a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) {
if (!dfdx())
return StmtDiff(Clone(ILE));
QualType ILEType = ILE->getType();
llvm::SmallVector<Expr*, 16> clonedExprs(ILE->getNumInits());
if (isArrayOrPointerType(ILEType)) {
Expand Down
58 changes: 57 additions & 1 deletion test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,27 @@ double fn11(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

struct MyStruct{
double a;
double b;
};

MyStruct fn12(MyStruct s) {
s = {2 * s.a, 2 * s.b + 2};
return s;
}

// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) {
// CHECK-NEXT: MyStruct _t0 = s;
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {});
// CHECK-NEXT: {
// CHECK-NEXT: MyStruct _r0 = {};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0);
// CHECK-NEXT: (*_d_s).a += 2 * _r0.a;
// CHECK-NEXT: (*_d_s).b += 2 * _r0.b;
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand All @@ -391,6 +412,10 @@ void print(const Tangent& t) {
}
}

void print(const MyStruct& s) {
printf("{%.2f, %.2f}\n", s.a, s.b);
}

int main() {
pairdd p(3, 5), d_p;
double i = 3, d_i, d_j;
Expand Down Expand Up @@ -425,6 +450,10 @@ int main() {
TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00}
TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, -14, &d_i, &d_j); // CHECK-EXEC: {1.00, -1.00}
MyStruct s = {1.0, 2.0}, d_s = {1.0, 1.0};
auto fn12_test = clad::gradient(fn12);
fn12_test.execute(s, &d_s);
print(d_s); // CHECK-EXEC: {2.00, 2.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down Expand Up @@ -546,4 +575,31 @@ int main() {
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: (*_d_t).data[0] += _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&_r0, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d__r0) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = _r0.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = _r0.b;
// CHECK-NEXT: {
// CHECK-NEXT: this->b = _t1;
// CHECK-NEXT: double _r_d1 = (*_d_this).b;
// CHECK-NEXT: (*_d_this).b = 0.;
// CHECK-NEXT: (*_d__r0).b += _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: this->a = _t0;
// CHECK-NEXT: double _r_d0 = (*_d_this).a;
// CHECK-NEXT: (*_d_this).a = 0.;
// CHECK-NEXT: (*_d__r0).a += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: inline constexpr clad::ValueAndAdjoint<MyStruct &, MyStruct &> operator_equal_forw(MyStruct &&_r0, MyStruct *_d_this, MyStruct &&_d__r0) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = _r0.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = _r0.b;
// CHECK-NEXT: return {*this, (*_d_this)};
// CHECK-NEXT:}

0 comments on commit b5f381a

Please sign in to comment.