From 81713731fbe645f3c65adccf3221210a4e7d63cd Mon Sep 17 00:00:00 2001 From: kchristin Date: Wed, 20 Nov 2024 22:03:35 +0200 Subject: [PATCH] Change name of cloned r value ref args and not originals --- .../ReverseModeForwPassVisitor.cpp | 7 ++++++- lib/Differentiator/ReverseModeVisitor.cpp | 11 ++++++----- test/Gradient/UserDefinedTypes.C | 16 ++++++++-------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index d05884095..8c1a4df68 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -168,13 +168,18 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { if (newPVD->getIdentifier()) m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), /*AddToContext=*/false); + else { + IdentifierInfo* newName = CreateUniqueIdentifier("arg"); + newPVD->setDeclName(newName); + m_DeclReplacements[PVD] = newPVD; + } auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { *it = newPVD; QualType dType = derivativeFnType->getParamType(dParamTypesIdx); IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + CreateUniqueIdentifier("_d_" + newPVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, PVD->getStorageClass()); paramDerivatives.push_back(dPVD); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 12afd756c..b019034ad 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4779,10 +4779,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } for (auto* PVD : m_DiffReq->parameters()) { - if (PVD->getNameAsString().empty()) { - IdentifierInfo* newName = CreateUniqueIdentifier("_r"); - const_cast(PVD)->setDeclName(newName); - } auto* newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); @@ -4791,6 +4787,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (newPVD->getIdentifier()) m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), /*AddToContext=*/false); + else { + IdentifierInfo* newName = CreateUniqueIdentifier("arg"); + newPVD->setDeclName(newName); + m_DeclReplacements[PVD] = newPVD; + } auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { @@ -4799,7 +4800,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq.Mode == DiffMode::experimental_pullback) { QualType dType = derivativeFnType->getParamType(dParamTypesIdx); IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + CreateUniqueIdentifier("_d_" + newPVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, PVD->getStorageClass()); paramDerivatives.push_back(dPVD); diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 846b9fee5..6bc61395a 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -577,29 +577,29 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&_r0, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d__r0) noexcept { +// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&arg, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d_arg) noexcept { // CHECK-NEXT: double _t0 = this->a; -// CHECK-NEXT: this->a = _r0.a; +// CHECK-NEXT: this->a = arg.a; // CHECK-NEXT: double _t1 = this->b; -// CHECK-NEXT: this->b = _r0.b; +// CHECK-NEXT: this->b = arg.b; // CHECK-NEXT: { // CHECK-NEXT: this->b = _t1; // CHECK-NEXT: double _r_d1 = (*_d_this).b; // CHECK-NEXT: (*_d_this).b = 0.; -// CHECK-NEXT: (*_d__r0).b += _r_d1; +// CHECK-NEXT: (*_d_arg).b += _r_d1; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: this->a = _t0; // CHECK-NEXT: double _r_d0 = (*_d_this).a; // CHECK-NEXT: (*_d_this).a = 0.; -// CHECK-NEXT: (*_d__r0).a += _r_d0; +// CHECK-NEXT: (*_d_arg).a += _r_d0; // CHECK-NEXT: } // CHECK-NEXT:} -// CHECK: inline constexpr clad::ValueAndAdjoint operator_equal_forw(MyStruct &&_r0, MyStruct *_d_this, MyStruct &&_d__r0) noexcept { +// CHECK: inline constexpr clad::ValueAndAdjoint operator_equal_forw(MyStruct &&arg, MyStruct *_d_this, MyStruct &&_d_arg) noexcept { // CHECK-NEXT: double _t0 = this->a; -// CHECK-NEXT: this->a = _r0.a; +// CHECK-NEXT: this->a = arg.a; // CHECK-NEXT: double _t1 = this->b; -// CHECK-NEXT: this->b = _r0.b; +// CHECK-NEXT: this->b = arg.b; // CHECK-NEXT: return {*this, (*_d_this)}; // CHECK-NEXT:} \ No newline at end of file