Skip to content

Commit

Permalink
Change name of cloned r value ref args and not originals
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 21, 2024
1 parent 62f93fd commit e9f3cb7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
7 changes: 6 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,18 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
if (newPVD->getIdentifier())
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
else {
IdentifierInfo* newName = CreateUniqueIdentifier("arg");
newPVD->setDeclName(newName);
m_DeclReplacements[PVD] = newPVD;
}

auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
*it = newPVD;
QualType dType = derivativeFnType->getParamType(dParamTypesIdx);
IdentifierInfo* dII =
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
CreateUniqueIdentifier("_d_" + newPVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
Expand Down
11 changes: 6 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4493,10 +4493,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

for (auto* PVD : m_DiffReq->parameters()) {
if (PVD->getNameAsString().empty()) {
IdentifierInfo* newName = CreateUniqueIdentifier("_r");
const_cast<ParmVarDecl*>(PVD)->setDeclName(newName);
}
auto* newPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo());
Expand All @@ -4505,6 +4501,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (newPVD->getIdentifier())
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
else {
IdentifierInfo* newName = CreateUniqueIdentifier("arg");
newPVD->setDeclName(newName);
m_DeclReplacements[PVD] = newPVD;
}

auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
Expand All @@ -4513,7 +4514,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq.Mode == DiffMode::experimental_pullback) {
QualType dType = derivativeFnType->getParamType(dParamTypesIdx);
IdentifierInfo* dII =
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
CreateUniqueIdentifier("_d_" + newPVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
Expand Down
16 changes: 8 additions & 8 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -577,29 +577,29 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&_r0, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d__r0) noexcept {
// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&arg, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d_arg) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = _r0.a;
// CHECK-NEXT: this->a = arg.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = _r0.b;
// CHECK-NEXT: this->b = arg.b;
// CHECK-NEXT: {
// CHECK-NEXT: this->b = _t1;
// CHECK-NEXT: double _r_d1 = (*_d_this).b;
// CHECK-NEXT: (*_d_this).b = 0.;
// CHECK-NEXT: (*_d__r0).b += _r_d1;
// CHECK-NEXT: (*_d_arg).b += _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: this->a = _t0;
// CHECK-NEXT: double _r_d0 = (*_d_this).a;
// CHECK-NEXT: (*_d_this).a = 0.;
// CHECK-NEXT: (*_d__r0).a += _r_d0;
// CHECK-NEXT: (*_d_arg).a += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: inline constexpr clad::ValueAndAdjoint<MyStruct &, MyStruct &> operator_equal_forw(MyStruct &&_r0, MyStruct *_d_this, MyStruct &&_d__r0) noexcept {
// CHECK: inline constexpr clad::ValueAndAdjoint<MyStruct &, MyStruct &> operator_equal_forw(MyStruct &&arg, MyStruct *_d_this, MyStruct &&_d_arg) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = _r0.a;
// CHECK-NEXT: this->a = arg.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = _r0.b;
// CHECK-NEXT: this->b = arg.b;
// CHECK-NEXT: return {*this, (*_d_this)};
// CHECK-NEXT:}

0 comments on commit e9f3cb7

Please sign in to comment.