Skip to content

Commit

Permalink
Fix struct init using initializer lists
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 19, 2024
1 parent 35ee97d commit 3d8feec
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 55 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class BaseForwardModeVisitor
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ namespace clad {
StmtDiff VisitForStmt(const clang::ForStmt* FS);
StmtDiff VisitIfStmt(const clang::IfStmt* If);
StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
StmtDiff
VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE);
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,11 @@ BaseForwardModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) {
return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx());
}

StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr(
const ImplicitValueInitExpr* E) {
return StmtDiff(Clone(E), Clone(E));
}

StmtDiff
BaseForwardModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) {
// FIXME: Shouldn't we simply clone the CXXDefaultArgExpr?
Expand Down
117 changes: 67 additions & 50 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get();
return StmtDiff(clonedILE);
}
// Check if type is a CXXRecordDecl and a struct.
if (!isCladValueAndPushforwardType(ILEType) && ILEType->isRecordType() &&
ILEType->getAsCXXRecordDecl()->isStruct()) {
for (unsigned i = 0, e = ILE->getNumInits(); i < e; i++) {
// fetch ith field of the struct.
auto field_iterator = ILEType->getAsCXXRecordDecl()->field_begin();
std::advance(field_iterator, i);
Expr* member_acess = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), dfdx(), (*field_iterator)->getName());
clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr();
}
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get();
return StmtDiff(clonedILE);
}

// FIXME: This is a makeshift arrangement to differentiate an InitListExpr
// that represents a ValueAndPushforward type. Ideally this must be
// differentiated at VisitCXXConstructExpr
Expand Down Expand Up @@ -2582,11 +2597,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit()) {
if (isa<CXXNewExpr>(VD->getInit())) {
isInitializedByNewExpr = true;
}
}
if (isPointerType && (VD->getInit() != nullptr) &&
isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
Expand Down Expand Up @@ -2655,7 +2668,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (VD->getType()->getPointeeType().isConstQualified()) {
Expand All @@ -2677,12 +2689,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && !(isPointerType && !isInitializedByNewExpr)) {
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = BuildDeclRef(VDDerived);
if (isInitializedByNewExpr) {
// derivedE should be dereferenced.
if (isInitializedByNewExpr)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}
if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
Expand All @@ -2709,6 +2719,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
getZeroInit(VDDerivedType));
addToCurrentBlock(assignToZero, direction::reverse);
}
} else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
}
VarDecl* VDClone = nullptr;
Expr* derivedVDE = BuildDeclRef(VDDerived);
Expand Down Expand Up @@ -2926,6 +2938,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Visit(ICE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitImplicitValueInitExpr(
const ImplicitValueInitExpr* IVIE) {
return {Clone(IVIE), Clone(IVIE)};
}

StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase());
auto* field = ME->getMemberDecl();
Expand Down Expand Up @@ -3722,47 +3739,47 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) {
StmtDiff initializerDiff;
if (CNE->hasInitializer())
initializerDiff = Visit(CNE->getInitializer(), dfdx());

Expr* clonedArraySizeE = nullptr;
Expr* derivedArraySizeE = nullptr;
if (CNE->getArraySize()) {
clonedArraySizeE =
Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr();
// Array size is a non-differentiable expression, thus the original value
// should be used in both the cloned and the derived statements.
derivedArraySizeE = Clone(clonedArraySizeE);
}
Expr* clonedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), clonedArraySizeE,
initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo());
Expr* derivedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), derivedArraySizeE,
initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo());
return {clonedNewE, derivedNewE};
}
StmtDiff initializerDiff;
if (CNE->hasInitializer())
initializerDiff = Visit(CNE->getInitializer(), dfdx());

Expr* clonedArraySizeE = nullptr;
Expr* derivedArraySizeE = nullptr;
if (CNE->getArraySize()) {
clonedArraySizeE =
Visit(clad_compat::ArraySize_GetValue(CNE->getArraySize())).getExpr();
// Array size is a non-differentiable expression, thus the original value
// should be used in both the cloned and the derived statements.
derivedArraySizeE = Clone(clonedArraySizeE);
}
Expr* clonedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), clonedArraySizeE,
initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo());
Expr* derivedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), derivedArraySizeE,
initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo());
return {clonedNewE, derivedNewE};
}

StmtDiff
ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) {
StmtDiff argDiff = Visit(CDE->getArgument());
Expr* clonedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr())
.get();
Expr* derivedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr_dx())
.get();
// create a compound statement containing both the cloned and the derived
// delete expressions.
CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE});
m_DeallocExprs.push_back(CS);
return {nullptr, nullptr};
}
StmtDiff
ReverseModeVisitor::VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE) {
StmtDiff argDiff = Visit(CDE->getArgument());
Expr* clonedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr())
.get();
Expr* derivedDeleteE =
m_Sema
.ActOnCXXDelete(noLoc, CDE->isGlobalDelete(), CDE->isArrayForm(),
argDiff.getExpr_dx())
.get();
// create a compound statement containing both the cloned and the derived
// delete expressions.
CompoundStmt* CS = MakeCompoundStmt({clonedDeleteE, derivedDeleteE});
m_DeallocExprs.push_back(CS);
return {nullptr, nullptr};
}

// FIXME: Add support for differentiating calls to constructors.
// We currently assume that constructor arguments are non-differentiable.
Expand Down
25 changes: 25 additions & 0 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,41 @@ double fn5(double i, double j) {
// CHECK-NEXT: return *(_d_arr + idx1) + *(_d_arr + idx2);
// CHECK-NEXT: }

struct T {
double i;
int j;
};

double fn6 (double i) {
T* t = new T{i};
double res = t->i;
delete t;
return res;
}

// CHECK: double fn6_darg0(double i) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: T *_d_t = new T({_d_i, /*implicit*/(int)0});
// CHECK-NEXT: T *t = new T({i, /*implicit*/(int)0});
// CHECK-NEXT: double _d_res = _d_t->i;
// CHECK-NEXT: double res = t->i;
// CHECK-NEXT: delete _d_t;
// CHECK-NEXT: delete t;
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

int main() {
INIT_DIFFERENTIATE(fn1, "i");
INIT_DIFFERENTIATE(fn2, "i");
INIT_DIFFERENTIATE(fn3, "i");
INIT_DIFFERENTIATE(fn4, "i");
INIT_DIFFERENTIATE(fn5, "i");
INIT_DIFFERENTIATE(fn6, "i");

TEST_DIFFERENTIATE(fn1, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fn2, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fn3, 3, 5); // CHECK-EXEC: {6.00}
TEST_DIFFERENTIATE(fn4, 3, 5); // CHECK-EXEC: {16.00}
TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00}
TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00}
}
40 changes: 35 additions & 5 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,40 @@ double newAndDeletePointer(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: * _d_j += *_d_q;
// CHECK-NEXT: * _d_i += *_d_p;
// CHECK-NEXT: delete [] r;
// CHECK-NEXT: delete [] _d_r;
// CHECK-NEXT: delete q;
// CHECK-NEXT: delete _d_q;
// CHECK-NEXT: delete p;
// CHECK-NEXT: delete _d_p;
// CHECK-NEXT: delete q;
// CHECK-NEXT: delete _d_q;
// CHECK-NEXT: delete [] r;
// CHECK-NEXT: delete [] _d_r;
// CHECK-NEXT: }

struct T {
double x;
int y;
};

double structPointer (double x) {
T* t = new T{x};
double res = t->x;
delete t;
return res;
}

// CHECK: void structPointer_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: T *_d_t = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: _d_t = new T;
// CHECK-NEXT: T *t = new T({x, /*implicit*/(int)0});
// CHECK-NEXT: double res = t->x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: _d_t->x += _d_res;
// CHECK-NEXT: * _d_x += *_d_t.x;
// CHECK-NEXT: delete t;
// CHECK-NEXT: delete _d_t;
// CHECK-NEXT: }


#define NON_MEM_FN_TEST(var)\
res[0]=0;\
Expand Down Expand Up @@ -503,4 +529,8 @@ int main() {
double d_i = 0, d_j = 0;
d_newAndDeletePointer.execute(5, 7, &d_i, &d_j);
printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 9.00 7.00

auto d_structPointer = clad::gradient(structPointer);
double d_x = 0;
d_structPointer.execute(5, &d_x);
}

0 comments on commit 3d8feec

Please sign in to comment.