Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize adjoints of aggregate types with init lists #1163

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,6 @@ template <typename T, ::std::size_t N, typename U>
void size_pullback(::std::array<T, N>* /*a*/, U /*d_y*/,
::std::array<T, N>* /*d_a*/) noexcept {}
template <typename T, ::std::size_t N>
::clad::ValueAndAdjoint<::std::array<T, N>, ::std::array<T, N>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array<T, N>>,
const ::std::array<T, N>& arr,
const ::std::array<T, N>& d_arr) {
::std::array<T, N> a = arr;
::std::array<T, N> d_a = d_arr;
return {a, d_a};
}
template <typename T, ::std::size_t N>
void constructor_pullback(::std::array<T, N>* a, const ::std::array<T, N>& arr,
::std::array<T, N>* d_a, ::std::array<T, N>* d_arr) {
for (size_t i = 0; i < N; ++i)
Expand Down
50 changes: 38 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
#include <clang/AST/OperationKinds.h>
#include <clang/Sema/Ownership.h>

#include "llvm/ADT/SmallString.h"
#include "llvm/Support/SaveAndRestore.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: included header SmallString.h is not used directly [misc-include-cleaner]

Suggested change
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SaveAndRestore.h"

#include <llvm/Support/raw_ostream.h>

#include <algorithm>
#include <numeric>
Expand Down Expand Up @@ -1309,7 +1311,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr();
}
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get();
return StmtDiff(clonedILE);

const CXXRecordDecl* RD = ILEType->getAsCXXRecordDecl();
Expr* adjointInit = nullptr;
if (RD && RD->isAggregate()) {
llvm::SmallVector<Expr*, 4> adjParams;
for (const FieldDecl* FD : RD->fields())
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
adjParams.push_back(getZeroInit(FD->getType()));
adjointInit = m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get();
}
return StmtDiff(clonedILE, nullptr, adjointInit);
}

// FIXME: This is a makeshift arrangement to differentiate an InitListExpr
Expand Down Expand Up @@ -2753,6 +2764,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

ConstructorPullbackCallInfo constructorPullbackInfo;

bool isConstructInit =
VD->getInit() && isa<CXXConstructExpr>(VD->getInit()->IgnoreImplicit());

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (promoteToFnScope)
Expand Down Expand Up @@ -2798,7 +2812,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
if (isConstructInit) {
m_TrackConstructorPullbackInfo = true;
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
Expand Down Expand Up @@ -2870,13 +2884,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}

if (VD->getInit()) {
if (VDType->isStructureOrClassType()) {
if (!initDiff.getExpr())
initDiff = Visit(VD->getInit());
} else
initDiff = Visit(VD->getInit(), derivedE);
}
if (VD->getInit() && !isConstructInit)
initDiff = Visit(VD->getInit(), derivedE);

// If we are differentiating `VarDecl` corresponding to a local variable
// inside a loop, then we need to reset it to 0 at each iteration.
Expand Down Expand Up @@ -4155,7 +4164,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff
ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {

llvm::SmallVector<Expr*, 4> primalArgs;
llvm::SmallVector<Expr*, 4> adjointArgs;
llvm::SmallVector<Expr*, 4> reverseForwAdjointArgs;
Expand Down Expand Up @@ -4214,8 +4222,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Try to create a pullback constructor call
llvm::SmallVector<Expr*, 4> pullbackArgs;
QualType recordType =
m_Context.getRecordType(CE->getConstructor()->getParent());
const CXXRecordDecl* RD = CE->getConstructor()->getParent();
QualType recordType = m_Context.getRecordType(RD);
QualType recordPointerType = m_Context.getPointerType(recordType);
// thisE = object being created by this constructor call.
// dThisE = adjoint of the object being created by this constructor call.
Expand Down Expand Up @@ -4274,6 +4282,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn(
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (RD->isAggregate()) {
SmallString<128> Name_class;
PetroZarytskyi marked this conversation as resolved.
Show resolved Hide resolved
llvm::raw_svector_ostream OS_class(Name_class);
PetroZarytskyi marked this conversation as resolved.
Show resolved Hide resolved
RD->getNameForDiagnostic(OS_class, m_Context.getPrintingPolicy(),
/*qualified=*/true);
diag(DiagnosticsEngine::Warning, CE->getBeginLoc(),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is here the place where we found the custom forward reverse function and we diagnose we do not need it? If not we should move the check there, and point to the declaration itself with a DiagnosticsEngine::Note(CE->getDecl()->getBeginLoc()). And of course we should add a test for it...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the diagnostics in the same place but now the location comes from the declaration. Is it better now?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need both a warning on the call site and a note on our he definition. We also need a test.

"'%0' is an aggregate type and its constructor does not require a "
"user-defined forward sweep function",
{OS_class.str()});
const FunctionDecl* constr_forw =
cast<CallExpr>(customReverseForwFnCall)->getDirectCallee();
SmallString<128> Name_forw;
llvm::raw_svector_ostream OS_forw(Name_forw);
constr_forw->getNameForDiagnostic(
OS_forw, m_Context.getPrintingPolicy(), /*qualified=*/true);
diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(),
"'%0' is defined here", {OS_forw.str()});
}
Expr* callRes = StoreAndRef(customReverseForwFnCall);
Expr* val =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
Expand Down
10 changes: 9 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Lex/Preprocessor.h"
Expand All @@ -26,6 +26,7 @@
#include "clang/Sema/Template.h"

#include <algorithm>
#include <llvm/ADT/SmallVector.h>
#include <numeric>

#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -418,6 +419,13 @@ namespace clad {
Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
if (const auto* RD = T->getAsCXXRecordDecl())
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
if (RD->hasDefinition() && !RD->isUnion() && RD->isAggregate()) {
llvm::SmallVector<Expr*, 4> adjParams;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "llvm::SmallVector" is directly included [misc-include-cleaner]

lib/Differentiator/VisitorBase.cpp:28:

- #include <numeric>
+ #include <llvm/ADT/SmallVector.h>
+ #include <numeric>

for (const FieldDecl* FD : RD->fields())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "clang::FieldDecl" is directly included [misc-include-cleaner]

        for (const FieldDecl* FD : RD->fields())
                   ^

adjParams.push_back(getZeroInit(FD->getType()));
return m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get();
}
return m_Sema.ActOnInitList(noLoc, {}, noLoc).get();
}

Expand Down
33 changes: 16 additions & 17 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ int main() {
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t2 = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t4 = {};
// CHECK-NEXT: std::array<double, 3> _d_a({});
// CHECK-NEXT: std::array<double, 3> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 3> a;
// CHECK-NEXT: std::array<double, 3> _t0 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, x, &_d_a, *_d_x);
Expand Down Expand Up @@ -544,7 +544,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn16_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r0);
Expand All @@ -554,7 +554,7 @@ int main() {
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
// CHECK-NEXT: double _t5 = _t4.value;
// CHECK-NEXT: _t4.value = y;
// CHECK-NEXT: std::array<double, 3> _d__b({});
// CHECK-NEXT: std::array<double, 3> _d__b({{.*}});
// CHECK-NEXT: std::array<double, 3> _b0;
// CHECK-NEXT: std::array<double, 3> _t6 = _b0;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t7 = {{.*}}operator_subscript_reverse_forw(&_b0, 0, &_d__b, _r2);
Expand All @@ -568,23 +568,22 @@ int main() {
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t13 = {{.*}}operator_subscript_reverse_forw(&_b0, 2, &_d__b, _r4);
// CHECK-NEXT: double _t14 = _t13.value;
// CHECK-NEXT: _t13.value = x * x;
// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array<double, {{3U|3UL}}>, ::std::array<double, {{3U|3UL}}> > _t15 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<array<double, 3> >(), _b0, _d__b);
// CHECK-NEXT: std::array<double, 3> _d_b = _t15.adjoint;
// CHECK-NEXT: const std::array<double, 3> b = _t15.value;
// CHECK-NEXT: std::array<double, 2> _t18 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t19 = {{.*}}back_reverse_forw(&a, &_d_a);
// CHECK-NEXT: std::array<double, 3> _d_b = {{.*}};
// CHECK-NEXT: const std::array<double, 3> b = _b0;
// CHECK-NEXT: std::array<double, 2> _t17 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t18 = {{.*}}back_reverse_forw(&a, &_d_a);
// CHECK-NEXT: std::array<double, 3> _t19 = b;
// CHECK-NEXT: {{.*}}value_type _t16 = b.front();
// CHECK-NEXT: std::array<double, 3> _t20 = b;
// CHECK-NEXT: {{.*}}value_type _t17 = b.front();
// CHECK-NEXT: {{.*}}value_type _t15 = b.at(2);
// CHECK-NEXT: std::array<double, 3> _t21 = b;
// CHECK-NEXT: {{.*}}value_type _t16 = b.at(2);
// CHECK-NEXT: std::array<double, 3> _t22 = b;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}back_pullback(&_t18, 1 * _t16 * _t17, &_d_a);
// CHECK-NEXT: {{.*}}front_pullback(&_t20, _t19.value * 1 * _t16, &_d_b);
// CHECK-NEXT: {{.*}}back_pullback(&_t17, 1 * _t15 * _t16, &_d_a);
// CHECK-NEXT: {{.*}}front_pullback(&_t19, _t18.value * 1 * _t15, &_d_b);
// CHECK-NEXT: {{.*}}size_type _r5 = {{0U|0UL}};
// CHECK-NEXT: {{.*}}at_pullback(&_t21, 2, _t19.value * _t17 * 1, &_d_b, &_r5);
// CHECK-NEXT: {{.*}}at_pullback(&_t20, 2, _t18.value * _t16 * 1, &_d_b, &_r5);
// CHECK-NEXT: {{.*}}size_type _r6 = {{0U|0UL}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t22, 1, 1, &_d_b, &_r6);
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t21, 1, 1, &_d_b, &_r6);
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}}constructor_pullback(&b, _b0, &_d_b, &_d__b);
// CHECK-NEXT: {
Expand Down Expand Up @@ -629,7 +628,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn17_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 50> _d_a({});
// CHECK-NEXT: std::array<double, 50> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 50> a;
// CHECK-NEXT: std::array<double, 50> _t0 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, y + x + x, &_d_a, _r0);
Expand All @@ -653,7 +652,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0);
Expand Down
60 changes: 56 additions & 4 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out -Xclang -verify 2>&1 | %filecheck %s
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
Expand Down Expand Up @@ -395,10 +395,10 @@ MyStruct fn12(MyStruct s) {

// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) {
// CHECK-NEXT: MyStruct _t0 = s;
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {});
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {0., 0.});
// CHECK-NEXT: {
// CHECK-NEXT: MyStruct _r0 = {};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0);
// CHECK-NEXT: MyStruct _r0 = {0., 0.};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {0., 0.}, &(*_d_s), &_r0);
// CHECK-NEXT: (*_d_s).a += 2 * _r0.a;
// CHECK-NEXT: (*_d_s).b += 2 * _r0.b;
// CHECK-NEXT: }
Expand Down Expand Up @@ -467,6 +467,53 @@ void fn13(double *x, double *y, int size)
// CHECK-NEXT: }
// CHECK-NEXT:}

double fn14(double x, double y) {
MyStruct s = {2 * y, 3 * x + 2};
return s.a * s.b;
}

// CHECK: void fn14_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: MyStruct _d_s = {0., 0.};
// CHECK-NEXT: MyStruct s = {2 * y, 3 * x + 2};
// CHECK-NEXT: {
// CHECK-NEXT: _d_s.a += 1 * s.b;
// CHECK-NEXT: _d_s.b += s.a * 1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: *_d_y += 2 * _d_s.a;
// CHECK-NEXT: *_d_x += 3 * _d_s.b;
// CHECK-NEXT: }
// CHECK-NEXT:}

template <typename T, std::size_t N>
struct SimpleArray {
T elements[N]; // Aggregate initialization
};

namespace clad {
namespace custom_derivatives {
namespace class_functions {
template<::std::size_t N>
::clad::ValueAndAdjoint<SimpleArray<double, N>, SimpleArray<double, N>> // expected-note {{'clad::custom_derivatives::class_functions::constructor_reverse_forw<2}}{{' is defined here}}
constructor_reverse_forw(::clad::ConstructorReverseForwTag<SimpleArray<double, N>>) {
SimpleArray<double, N> a;
SimpleArray<double, N> d_a;
return {a, d_a};
}
}}}

double fn15(double x, double y) {
SimpleArray<double, 2> arr; // expected-warning {{'SimpleArray<double, 2>' is an aggregate type and its constructor does not require a user-defined forward sweep function}}
return arr.elements[0];
}

// CHECK:void fn15_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: ::clad::ValueAndAdjoint<SimpleArray<double, {{2U|2UL|2ULL}}>, SimpleArray<double, {{2U|2UL|2ULL}}> > _t0 = clad::custom_derivatives::class_functions::constructor_reverse_forw(clad::ConstructorReverseForwTag<SimpleArray<double, 2> >());
// CHECK-NEXT: SimpleArray<double, 2> _d_arr(_t0.adjoint);
// CHECK-NEXT: SimpleArray<double, 2> arr(_t0.value);
// CHECK-NEXT: _d_arr.elements[0] += 1;
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand Down Expand Up @@ -535,6 +582,11 @@ int main() {
fn13_test.execute(x, y, 3, d_x, d_y);
printArray(d_x, size); // CHECK-EXEC: {2.00, 2.00, 2.00}
printArray(d_y, size); // CHECK-EXEC: {0.00, 0.00, 0.00}

INIT_GRADIENT(fn14);
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00}

INIT_GRADIENT(fn15);
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down
Loading
Loading