From f00afbcdae21b26a6f2d2483670abb8cd4f61b2e Mon Sep 17 00:00:00 2001 From: kchristin Date: Wed, 20 Nov 2024 15:51:54 +0200 Subject: [PATCH] Move setting of param name in pullback creation --- lib/Differentiator/ReverseModeVisitor.cpp | 8 ++++---- test/Gradient/STLCustomDerivatives.C | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 8fdd41809..fb52ca874 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1893,10 +1893,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* arg = CE->getArg(i); const auto* PVD = FD->getParamDecl( i - static_cast(isMethodOperatorCall)); - if (PVD->getType()->isRValueReferenceType()) { - IdentifierInfo* RValueName = CreateUniqueIdentifier("_r"); - const_cast(PVD)->setDeclName(RValueName); - } StmtDiff argDiff{}; // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly @@ -4781,6 +4777,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } for (auto* PVD : m_DiffReq->parameters()) { + if (PVD->getNameAsString().empty()) { + IdentifierInfo* newName = CreateUniqueIdentifier("_r"); + const_cast(PVD)->setDeclName(newName); + } auto* newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index a2612c60a..488e9c619 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -817,16 +817,16 @@ int main() { // CHECK-NEXT: std::vector _d_a({}); // CHECK-NEXT: std::vector a; // CHECK-NEXT: std::vector _t0 = a; -// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r1); +// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0); // CHECK-NEXT: std::vector _t1 = a; -// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2); +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1); // CHECK-NEXT: double _t3 = _t2.value; // CHECK-NEXT: _t2.value = x * x; // CHECK-NEXT: std::vector _t4 = a; -// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r3); +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2); // CHECK-NEXT: { -// CHECK-NEXT: {{.*}}size_type _r3 = 0{{.*}}; -// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r3); +// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2); // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: _t2.value = _t3; @@ -834,11 +834,11 @@ int main() { // CHECK-NEXT: _t2.adjoint = 0{{.*}}; // CHECK-NEXT: *_d_x += _r_d0 * x; // CHECK-NEXT: *_d_x += x * _r_d0; -// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; -// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r2); +// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1); // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: {{.*}}value_type _r1 = 0.; -// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r1); +// CHECK-NEXT: {{.*}}value_type _r0 = 0.; +// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0); // CHECK-NEXT: } // CHECK-NEXT: } \ No newline at end of file