Skip to content

Commit

Permalink
Clone base decl when having an anonymous struct or union (#1152)
Browse files Browse the repository at this point in the history
Fixes #1151
  • Loading branch information
kchristin22 authored Nov 23, 2024
1 parent 2f268db commit 693b1bd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
13 changes: 9 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3314,17 +3314,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* field = ME->getMemberDecl();
assert(!isa<CXXMethodDecl>(field) &&
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
Expr* clonedME = baseDiff.getExpr();
llvm::StringRef fieldName = field->getName();
if (baseDiff.getExpr() && !fieldName.empty())
clonedME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
baseDiff.getExpr(), fieldName);
if (clad::utils::hasNonDifferentiableAttribute(ME)) {
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
return {clonedME, zero};
}
if (!baseDiff.getExpr_dx())
return {clonedME, nullptr};
MemberExpr* derivedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr_dx(), field->getName());
Expr* derivedME = baseDiff.getExpr_dx();
if (!fieldName.empty())
derivedME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
baseDiff.getExpr_dx(), fieldName);
if (dfdx()) {
Expr* addAssign =
BuildOp(BinaryOperatorKind::BO_AddAssign, derivedME, dfdx());
Expand Down
81 changes: 81 additions & 0 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,69 @@ MyStruct fn12(MyStruct s) {
// CHECK-NEXT: }
// CHECK-NEXT:}

typedef int Fint;
typedef union Findex
{
struct
{
Fint j, k, l;
};
Fint dim[3];
} Findex;

void fn13(double *x, double *y, int size)
{
Findex p;

for (p.j = 0; p.j < size; p.j += 1)
{
y[p.j] = 2.0 * x[p.j];
}
}

// CHECK: void fn13_grad_0_1(double *x, double *y, int size, double *_d_x, double *_d_y) {
// CHECK-NEXT: int _d_size = 0;
// CHECK-NEXT: Fint _t1;
// CHECK-NEXT: clad::tape<Fint> _t2 = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: Findex _d_p({});
// CHECK-NEXT: Findex p;
// CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
// CHECK-NEXT: _t1 = p.j;
// CHECK-NEXT: for (p.j = 0; ; clad::push(_t2, p.j) , (p.j += 1)) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!(p.j < size))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t3, y[p.j]);
// CHECK-NEXT: y[p.j] = 2. * x[p.j];
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: p.j = clad::pop(_t2);
// CHECK-NEXT: Fint _r_d1 = _d_p.j;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: y[p.j] = clad::pop(_t3);
// CHECK-NEXT: double _r_d2 = _d_y[p.j];
// CHECK-NEXT: _d_y[p.j] = 0.;
// CHECK-NEXT: _d_x[p.j] += 2. * _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: p.j = _t1;
// CHECK-NEXT: Fint _r_d0 = _d_p.j;
// CHECK-NEXT: _d_p.j = 0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand All @@ -416,6 +479,16 @@ void print(const MyStruct& s) {
printf("{%.2f, %.2f}\n", s.a, s.b);
}

void printArray(double* arr, int size) {
printf("{");
for (int i = 0; i < size; ++i) {
printf("%.2f", arr[i]);
if (i != size - 1)
printf(", ");
}
printf("}\n");
}

int main() {
pairdd p(3, 5), d_p;
double i = 3, d_i, d_j;
Expand Down Expand Up @@ -454,6 +527,14 @@ int main() {
auto fn12_test = clad::gradient(fn12);
fn12_test.execute(s, &d_s);
print(d_s); // CHECK-EXEC: {2.00, 2.00}

auto fn13_test = clad::gradient(fn13, "x, y");
double x[3] = {1.0, 2.0, 3.0}, y[3] = {0.0, 0.0, 0.0};
double d_x[3] = {0.0, 0.0, 0.0}, d_y[3] = {1.0, 1.0, 1.0};
int size = 3;
fn13_test.execute(x, y, 3, d_x, d_y);
printArray(d_x, size); // CHECK-EXEC: {2.00, 2.00, 2.00}
printArray(d_y, size); // CHECK-EXEC: {0.00, 0.00, 0.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down

0 comments on commit 693b1bd

Please sign in to comment.