Skip to content

Commit

Permalink
Move setting of param name in pullback creation
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 21, 2024
1 parent 015b255 commit 11f024a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,10 +1673,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
if (PVD->getType()->isRValueReferenceType()) {
IdentifierInfo* RValueName = CreateUniqueIdentifier("_r");
const_cast<ParmVarDecl*>(PVD)->setDeclName(RValueName);
}
StmtDiff argDiff{};
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
Expand Down Expand Up @@ -4495,6 +4491,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

for (auto* PVD : m_DiffReq->parameters()) {
if (PVD->getNameAsString().empty()) {
IdentifierInfo* newName = CreateUniqueIdentifier("_r");
const_cast<ParmVarDecl*>(PVD)->setDeclName(newName);
}
auto* newPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo());
Expand Down
18 changes: 9 additions & 9 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -817,28 +817,28 @@ int main() {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: std::vector<double> _t0 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r1);
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0);
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1);
// CHECK-NEXT: double _t3 = _t2.value;
// CHECK-NEXT: _t2.value = x * x;
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r3);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r3 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r3);
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t2.value = _t3;
// CHECK-NEXT: double _r_d0 = _t2.adjoint;
// CHECK-NEXT: _t2.adjoint = 0{{.*}};
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r2);
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}value_type _r1 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 11f024a

Please sign in to comment.