diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index d05884095..8c1a4df68 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -168,13 +168,18 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { if (newPVD->getIdentifier()) m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), /*AddToContext=*/false); + else { + IdentifierInfo* newName = CreateUniqueIdentifier("arg"); + newPVD->setDeclName(newName); + m_DeclReplacements[PVD] = newPVD; + } auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { *it = newPVD; QualType dType = derivativeFnType->getParamType(dParamTypesIdx); IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + CreateUniqueIdentifier("_d_" + newPVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, PVD->getStorageClass()); paramDerivatives.push_back(dPVD); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b90dacb99..9f1eb6f6d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1216,7 +1216,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* value = RS->getRetValue(); QualType type = value->getType(); auto* dfdf = m_Pullback; - if (dfdf && (isa(dfdf) || isa(dfdf))) { + if (dfdf && (isa(dfdf) || isa(dfdf)) && + type->isScalarType()) { ExprResult tmp = dfdf; dfdf = m_Sema .ImpCastExprToType(tmp.get(), type, @@ -1277,6 +1278,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { + if (!dfdx()) + return StmtDiff(Clone(ILE)); QualType ILEType = ILE->getType(); llvm::SmallVector clonedExprs(ILE->getNumInits()); if (isArrayOrPointerType(ILEType)) { @@ -4499,6 +4502,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (newPVD->getIdentifier()) m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), /*AddToContext=*/false); + else { + IdentifierInfo* newName = CreateUniqueIdentifier("arg"); + newPVD->setDeclName(newName); + m_DeclReplacements[PVD] = newPVD; + } auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { @@ -4507,7 +4515,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq.Mode == DiffMode::experimental_pullback) { QualType dType = derivativeFnType->getParamType(dParamTypesIdx); IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + CreateUniqueIdentifier("_d_" + newPVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, PVD->getStorageClass()); paramDerivatives.push_back(dPVD); diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 81413ba95..e430d4f03 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -535,6 +535,8 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { auto it = m_DeclReplacements.find(VD); if (it != std::end(m_DeclReplacements)) { DRE->setDecl(it->second); + DRE->getDecl()->setReferenced(); + DRE->getDecl()->setIsUsed(); QualType NonRefQT = it->second->getType().getNonReferenceType(); if (NonRefQT != DRE->getType()) DRE->setType(NonRefQT); @@ -552,7 +554,7 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { // FIXME: Handle the case when there are overloads found. Update // it with the best match. // - // FIXME: This is the right way to go in principe, however there is no + // FIXME: This is the right way to go in principle, however there is no // properly built decl context. // m_Sema.MarkDeclRefReferenced(clonedDRE); if (!R.isSingleResult()) diff --git a/test/ForwardMode/STLCustomDerivatives.C b/test/ForwardMode/STLCustomDerivatives.C index 7ee45affa..35a321b2c 100644 --- a/test/ForwardMode/STLCustomDerivatives.C +++ b/test/ForwardMode/STLCustomDerivatives.C @@ -426,4 +426,4 @@ int main() { TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00} TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00} TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00} -} +} \ No newline at end of file diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 58ae5d64d..488e9c619 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -841,4 +841,4 @@ int main() { // CHECK-NEXT: {{.*}}value_type _r0 = 0.; // CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0); // CHECK-NEXT: } -// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index b9c880682..6bc61395a 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -383,6 +383,27 @@ double fn11(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +struct MyStruct{ + double a; + double b; +}; + +MyStruct fn12(MyStruct s) { + s = {2 * s.a, 2 * s.b + 2}; + return s; +} + +// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) { +// CHECK-NEXT: MyStruct _t0 = s; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {}); +// CHECK-NEXT: { +// CHECK-NEXT: MyStruct _r0 = {}; +// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0); +// CHECK-NEXT: (*_d_s).a += 2 * _r0.a; +// CHECK-NEXT: (*_d_s).b += 2 * _r0.b; +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -391,6 +412,10 @@ void print(const Tangent& t) { } } +void print(const MyStruct& s) { + printf("{%.2f, %.2f}\n", s.a, s.b); +} + int main() { pairdd p(3, 5), d_p; double i = 3, d_i, d_j; @@ -425,6 +450,10 @@ int main() { TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00} TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, -14, &d_i, &d_j); // CHECK-EXEC: {1.00, -1.00} + MyStruct s = {1.0, 2.0}, d_s = {1.0, 1.0}; + auto fn12_test = clad::gradient(fn12); + fn12_test.execute(s, &d_s); + print(d_s); // CHECK-EXEC: {2.00, 2.00} } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) { @@ -546,4 +575,31 @@ int main() { // CHECK-NEXT: *_d_x += _d_y; // CHECK-NEXT: (*_d_t).data[0] += _d_y; // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&arg, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d_arg) noexcept { +// CHECK-NEXT: double _t0 = this->a; +// CHECK-NEXT: this->a = arg.a; +// CHECK-NEXT: double _t1 = this->b; +// CHECK-NEXT: this->b = arg.b; +// CHECK-NEXT: { +// CHECK-NEXT: this->b = _t1; +// CHECK-NEXT: double _r_d1 = (*_d_this).b; +// CHECK-NEXT: (*_d_this).b = 0.; +// CHECK-NEXT: (*_d_arg).b += _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: this->a = _t0; +// CHECK-NEXT: double _r_d0 = (*_d_this).a; +// CHECK-NEXT: (*_d_this).a = 0.; +// CHECK-NEXT: (*_d_arg).a += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT:} + +// CHECK: inline constexpr clad::ValueAndAdjoint operator_equal_forw(MyStruct &&arg, MyStruct *_d_this, MyStruct &&_d_arg) noexcept { +// CHECK-NEXT: double _t0 = this->a; +// CHECK-NEXT: this->a = arg.a; +// CHECK-NEXT: double _t1 = this->b; +// CHECK-NEXT: this->b = arg.b; +// CHECK-NEXT: return {*this, (*_d_this)}; +// CHECK-NEXT:} \ No newline at end of file