-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize adjoints of aggregate types with init lists #1163
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,9 @@ | |
#include <clang/AST/OperationKinds.h> | ||
#include <clang/Sema/Ownership.h> | ||
|
||
#include "llvm/ADT/SmallString.h" | ||
#include "llvm/Support/SaveAndRestore.h" | ||
#include <llvm/Support/raw_ostream.h> | ||
|
||
#include <algorithm> | ||
#include <numeric> | ||
|
@@ -1309,7 +1311,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr(); | ||
} | ||
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); | ||
return StmtDiff(clonedILE); | ||
|
||
const CXXRecordDecl* RD = ILEType->getAsCXXRecordDecl(); | ||
Expr* adjointInit = nullptr; | ||
if (RD && RD->isAggregate()) { | ||
llvm::SmallVector<Expr*, 4> adjParams; | ||
for (const FieldDecl* FD : RD->fields()) | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
adjParams.push_back(getZeroInit(FD->getType())); | ||
adjointInit = m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get(); | ||
} | ||
return StmtDiff(clonedILE, nullptr, adjointInit); | ||
} | ||
|
||
// FIXME: This is a makeshift arrangement to differentiate an InitListExpr | ||
|
@@ -2753,6 +2764,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
|
||
ConstructorPullbackCallInfo constructorPullbackInfo; | ||
|
||
bool isConstructInit = | ||
VD->getInit() && isa<CXXConstructExpr>(VD->getInit()->IgnoreImplicit()); | ||
|
||
// VDDerivedInit now serves two purposes -- as the initial derivative value | ||
// or the size of the derivative array -- depending on the primal type. | ||
if (promoteToFnScope) | ||
|
@@ -2798,7 +2812,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
VDDerivedInit = initDiff.getForwSweepExpr_dx(); | ||
} | ||
|
||
if (VDType->isStructureOrClassType()) { | ||
if (isConstructInit) { | ||
m_TrackConstructorPullbackInfo = true; | ||
initDiff = Visit(VD->getInit()); | ||
m_TrackConstructorPullbackInfo = false; | ||
|
@@ -2870,13 +2884,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); | ||
} | ||
|
||
if (VD->getInit()) { | ||
if (VDType->isStructureOrClassType()) { | ||
if (!initDiff.getExpr()) | ||
initDiff = Visit(VD->getInit()); | ||
} else | ||
initDiff = Visit(VD->getInit(), derivedE); | ||
} | ||
if (VD->getInit() && !isConstructInit) | ||
initDiff = Visit(VD->getInit(), derivedE); | ||
|
||
// If we are differentiating `VarDecl` corresponding to a local variable | ||
// inside a loop, then we need to reset it to 0 at each iteration. | ||
|
@@ -4155,7 +4164,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
|
||
StmtDiff | ||
ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) { | ||
|
||
llvm::SmallVector<Expr*, 4> primalArgs; | ||
llvm::SmallVector<Expr*, 4> adjointArgs; | ||
llvm::SmallVector<Expr*, 4> reverseForwAdjointArgs; | ||
|
@@ -4214,8 +4222,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
|
||
// Try to create a pullback constructor call | ||
llvm::SmallVector<Expr*, 4> pullbackArgs; | ||
QualType recordType = | ||
m_Context.getRecordType(CE->getConstructor()->getParent()); | ||
const CXXRecordDecl* RD = CE->getConstructor()->getParent(); | ||
QualType recordType = m_Context.getRecordType(RD); | ||
QualType recordPointerType = m_Context.getPointerType(recordType); | ||
// thisE = object being created by this constructor call. | ||
// dThisE = adjoint of the object being created by this constructor call. | ||
|
@@ -4274,6 +4282,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn( | ||
CE->getConstructor(), primalArgs, reverseForwAdjointArgs, | ||
/*baseExpr=*/nullptr)) { | ||
if (RD->isAggregate()) { | ||
SmallString<128> Name_class; | ||
PetroZarytskyi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
llvm::raw_svector_ostream OS_class(Name_class); | ||
PetroZarytskyi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RD->getNameForDiagnostic(OS_class, m_Context.getPrintingPolicy(), | ||
/*qualified=*/true); | ||
diag(DiagnosticsEngine::Warning, CE->getBeginLoc(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is here the place where we found the custom forward reverse function and we diagnose we do not need it? If not we should move the check there, and point to the declaration itself with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept the diagnostics in the same place but now the location comes from the declaration. Is it better now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need both a warning on the call site and a note on our he definition. We also need a test. |
||
"'%0' is an aggregate type and its constructor does not require a " | ||
"user-defined forward sweep function", | ||
{OS_class.str()}); | ||
const FunctionDecl* constr_forw = | ||
cast<CallExpr>(customReverseForwFnCall)->getDirectCallee(); | ||
SmallString<128> Name_forw; | ||
llvm::raw_svector_ostream OS_forw(Name_forw); | ||
constr_forw->getNameForDiagnostic( | ||
OS_forw, m_Context.getPrintingPolicy(), /*qualified=*/true); | ||
diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(), | ||
"'%0' is defined here", {OS_forw.str()}); | ||
} | ||
Expr* callRes = StoreAndRef(customReverseForwFnCall); | ||
Expr* val = | ||
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ | |
#include "clad/Differentiator/ErrorEstimator.h" | ||
#include "clad/Differentiator/Sins.h" | ||
#include "clad/Differentiator/StmtClone.h" | ||
|
||
#include "clang/AST/ASTContext.h" | ||
#include "clang/AST/Decl.h" | ||
#include "clang/AST/Expr.h" | ||
#include "clang/AST/TemplateBase.h" | ||
#include "clang/Lex/Preprocessor.h" | ||
|
@@ -26,6 +26,7 @@ | |
#include "clang/Sema/Template.h" | ||
|
||
#include <algorithm> | ||
#include <llvm/ADT/SmallVector.h> | ||
#include <numeric> | ||
|
||
#include "clad/Differentiator/Compatibility.h" | ||
|
@@ -418,6 +419,13 @@ namespace clad { | |
Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0); | ||
return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); | ||
} | ||
if (const auto* RD = T->getAsCXXRecordDecl()) | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (RD->hasDefinition() && !RD->isUnion() && RD->isAggregate()) { | ||
llvm::SmallVector<Expr*, 4> adjParams; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: no header providing "llvm::SmallVector" is directly included [misc-include-cleaner] lib/Differentiator/VisitorBase.cpp:28: - #include <numeric>
+ #include <llvm/ADT/SmallVector.h>
+ #include <numeric> |
||
for (const FieldDecl* FD : RD->fields()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: no header providing "clang::FieldDecl" is directly included [misc-include-cleaner] for (const FieldDecl* FD : RD->fields())
^ |
||
adjParams.push_back(getZeroInit(FD->getType())); | ||
return m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get(); | ||
} | ||
return m_Sema.ActOnInitList(noLoc, {}, noLoc).get(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: included header SmallString.h is not used directly [misc-include-cleaner]