diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 115b16bb8..ffd0f7ce6 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -188,7 +188,9 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, return EData; } -TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) { +TBRAnalyzer::VarData::VarData(QualType QT, const ASTContext& C, + bool forceNonRefType) { + QT = QT.getDesugaredType(C); if (forceNonRefType && QT->isReferenceType()) QT = QT->getPointeeType(); @@ -205,7 +207,7 @@ TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) { elemType = QT->getArrayElementTypeNoTypeQual(); ProfileID nonConstIdxID; auto& idxData = (*m_Val.m_ArrData)[nonConstIdxID]; - idxData = VarData(QualType::getFromOpaquePtr(elemType)); + idxData = VarData(QualType::getFromOpaquePtr(elemType), C); } else if (QT->isBuiltinType()) { m_Type = VarData::FUND_TYPE; m_Val.m_FundData = false; @@ -216,7 +218,7 @@ TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) { newArrMap = std::unique_ptr(new ArrMap()); for (const auto* field : recordDecl->fields()) { const auto varType = field->getType(); - (*newArrMap)[getProfileID(field)] = VarData(varType); + (*newArrMap)[getProfileID(field)] = VarData(varType, C); } } } @@ -287,11 +289,11 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) { if (const auto* const pointerType = dyn_cast(varType)) { const auto* elemType = pointerType->getPointeeType().getTypePtrOrNull(); if (elemType && elemType->isRecordType()) { - curBranch[VD] = VarData(QualType::getFromOpaquePtr(elemType)); + curBranch[VD] = VarData(QualType::getFromOpaquePtr(elemType), m_Context); return; } } - curBranch[VD] = VarData(varType, forceNonRefType); + curBranch[VD] = VarData(varType, m_Context, forceNonRefType); } void TBRAnalyzer::markLocation(const clang::Expr* E) { @@ -331,7 +333,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { if (MD && !MD->isStatic()) { const Type* recordType = MD->getParent()->getTypeForDecl(); getCurBlockVarsData()[nullptr] = - VarData(QualType::getFromOpaquePtr(recordType)); + VarData(QualType::getFromOpaquePtr(recordType), m_Context); } auto paramsRef = FD->parameters(); for (std::size_t i = 0; i < FD->getNumParams(); ++i) diff --git a/lib/Differentiator/TBRAnalyzer.h b/lib/Differentiator/TBRAnalyzer.h index 0603e0621..cc666b51f 100644 --- a/lib/Differentiator/TBRAnalyzer.h +++ b/lib/Differentiator/TBRAnalyzer.h @@ -102,7 +102,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { /// reference type (it will store TBR information itself without referring /// to other VarData's). This is necessary for reference-type parameters, /// when the referenced expressions are out of the function's scope. - VarData(QualType QT, bool forceNonRefType = false); + VarData(QualType QT, const ASTContext& C, bool forceNonRefType = false); /// Erases all children VarData's of this VarData. ~VarData() { diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index f5f954808..510aa060a 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -326,6 +326,29 @@ double fn9(Tangent t, dcomplex c) { // CHECK-NEXT: } // CHECK-NEXT: } +template +struct A { + using PtrType = T*; +}; + +double fn10(double x, double y) { + A::PtrType ptr = &x; + ptr[0] += 6; + return *ptr; +} + +// CHECK: void fn10_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: A::PtrType _d_ptr = &*_d_x; +// CHECK-NEXT: A::PtrType ptr = &x; +// CHECK-NEXT: double _t0 = ptr[0]; +// CHECK-NEXT: ptr[0] += 6; +// CHECK-NEXT: *_d_ptr += 1; +// CHECK-NEXT: { +// CHECK-NEXT: ptr[0] = _t0; +// CHECK-NEXT: double _r_d0 = _d_ptr[0]; +// CHECK-NEXT: } +// CHECK-NEXT: } + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -351,6 +374,7 @@ int main() { INIT_GRADIENT(fn7); INIT_GRADIENT(fn8); INIT_GRADIENT(fn9); + INIT_GRADIENT(fn10); TEST_GRADIENT(fn1, /*numOfDerivativeArgs=*/2, p, i, &d_p, &d_i); // CHECK-EXEC: {1.00, 2.00, 3.00} TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, t, i, &d_t, &d_i); // CHECK-EXEC: {4.00, 2.00, 2.00, 2.00, 2.00, 1.00} @@ -364,6 +388,7 @@ int main() { TEST_GRADIENT(fn7, /*numOfDerivativeArgs=*/2, c1, c2, &d_c1, &d_c2);// CHECK-EXEC: {0.00, 3.00, 5.00, 1.00} TEST_GRADIENT(fn8, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {0.00, 0.00, 0.00, 0.00, 0.00, 5.00, 0.00} TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00} + TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00} } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {