From 551a423761883f642caa4741c8e4a61a31fbc9ce Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Mon, 15 Jan 2024 15:40:49 +0100 Subject: [PATCH] Fix support for non-type template params --- .../Differentiator/BaseForwardModeVisitor.h | 2 ++ .../clad/Differentiator/ReverseModeVisitor.h | 2 ++ include/clad/Differentiator/VisitorBase.h | 4 +-- lib/Differentiator/BaseForwardModeVisitor.cpp | 7 ++++- lib/Differentiator/ReverseModeVisitor.cpp | 8 ++++-- test/FirstDerivative/BasicArithmeticMulDiv.C | 21 +++++++++++++++ test/Gradient/Gradients.C | 26 +++++++++++++++++++ 7 files changed, 65 insertions(+), 5 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 370097b42..7b09cc282 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -102,6 +102,8 @@ class BaseForwardModeVisitor StmtDiff VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE); StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE); + StmtDiff VisitSubstNonTypeTemplateParmExpr( + const clang::SubstNonTypeTemplateParmExpr* NTTP); virtual clang::QualType GetPushForwardDerivativeType(clang::QualType ParamType); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 31e19e51f..c0010bce3 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -373,6 +373,8 @@ namespace clad { VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + StmtDiff VisitSubstNonTypeTemplateParmExpr( + const clang::SubstNonTypeTemplateParmExpr* NTTP); /// A helper method to differentiate a single Stmt in the reverse mode. /// Internally, calls Visit(S, expr). Its result is wrapped into a diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 91870f6e8..f85085a86 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -320,14 +320,14 @@ namespace clad { /// \n Variable declaration cannot be added to code directly, instead we /// have to build a declaration staement. /// \param[in] D The declaration to build a declaration statement from. - /// \returns The declration statement expression corresponding to the input + /// \returns The declaration statement expression corresponding to the input /// variable declaration. clang::DeclStmt* BuildDeclStmt(clang::Decl* D); /// Wraps a set of declarations in a DeclStmt. /// \n This function is useful to wrap multiple variable declarations in one /// single declaration statement. /// \param[in] D The declarations to build a declaration statement from. - /// \returns The declration statemetn expression corresponding to the input + /// \returns The declaration statement expression corresponding to the input /// variable declaration. clang::DeclStmt* BuildDeclStmt(llvm::MutableArrayRef DS); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 1dc39bce1..2b67f940b 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -2081,4 +2081,9 @@ StmtDiff BaseForwardModeVisitor::VisitPseudoObjectExpr( return {Clone(POE), ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0)}; } -} // end namespace clad + +StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr( + const clang::SubstNonTypeTemplateParmExpr* NTTP) { + return Visit(NTTP->getReplacement()); +} +} // end namespace clad \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b67f23043..0d88ce8ae 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -739,8 +739,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( - DiagnosticsEngine::Warning, - S->getBeginLoc(), + DiagnosticsEngine::Warning, S->getBeginLoc(), "attempted to differentiate unsupported statement, no changes applied"); // Unknown stmt, just clone it. return StmtDiff(Clone(S)); @@ -3499,6 +3498,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return MTEDiff; } + StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr( + const clang::SubstNonTypeTemplateParmExpr* NTTP) { + return Visit(NTTP->getReplacement()); + } + QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType, QualType xType) { diff --git a/test/FirstDerivative/BasicArithmeticMulDiv.C b/test/FirstDerivative/BasicArithmeticMulDiv.C index ac20d70d5..cc13ca49c 100644 --- a/test/FirstDerivative/BasicArithmeticMulDiv.C +++ b/test/FirstDerivative/BasicArithmeticMulDiv.C @@ -104,6 +104,23 @@ double m_10(double x, bool flag) { // CHECK-NEXT: return flag ? (((_d_x = _d_x * 2 + x * 0) , (x *= 2)) , (_d_x * x + x * _d_x)) : (((_d_x += 0) , (x += 1)) , (_d_x * x + x * _d_x)); // CHECK-NEXT: } +template +double m_11(double x) { + const size_t maxN = 53; + const size_t m = maxN < N ? maxN : N; + return x*m; +} + +// CHECK: double m_11_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: const size_t _d_maxN = 0; +// CHECK-NEXT: const size_t maxN = 53; +// CHECK-NEXT: bool _t0 = maxN < 64UL; +// CHECK-NEXT: const size_t _d_m = _t0 ? _d_maxN : 0UL; +// CHECK-NEXT: const size_t m = _t0 ? maxN : 64UL; +// CHECK-NEXT: return _d_x * m + x * _d_m; +// CHECK-NEXT: } + int d_1(int x) { int y = 4; return y / y; // == 0 @@ -190,6 +207,7 @@ double m_7_darg0(double x); double m_8_darg0(double x); double m_9_darg0(double x); double m_10_darg0(double x, bool flag); +double m_11_darg0(double x); int d_1_darg0(int x); int d_2_darg0(int x); int d_3_darg0(int x); @@ -230,6 +248,9 @@ int main () { printf("Result is = %f\n", m_10_darg0(1, true)); // CHECK-EXEC: Result is = 8 printf("Result is = %f\n", m_10_darg0(1, false)); // CHECK-EXEC: Result is = 4 + clad::differentiate(m_11<64>, 0); + printf("Result is = %f\n", m_11_darg0(1)); // CHECK-EXEC: Result is = 53 + clad::differentiate(d_1, 0); printf("Result is = %d\n", d_1_darg0(1)); // CHECK-EXEC: Result is = 0 diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index d4523c39a..e88546989 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -761,6 +761,27 @@ void fn_increment_in_return_grad(double i, double j, clad::array_ref _d_ // CHECK-NEXT: * _d_i += _d_temp; // CHECK-NEXT: } +template +double fn_template_non_type(double x) { + const size_t maxN = 53; + const size_t m = maxN < N ? maxN : N; + return x*m; +} + +// CHECK: void fn_template_non_type_grad(double x, clad::array_ref _d_x) { +// CHECK-NEXT: size_t _d_maxN = 0; +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: size_t _d_m = 0; +// CHECK-NEXT: const size_t maxN = 53; +// CHECK-NEXT: _cond0 = maxN < 15UL; +// CHECK-NEXT: const size_t m = _cond0 ? maxN : 15UL; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_x += 1 * m; +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: _d_maxN += _d_m; +// CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -812,4 +833,9 @@ int main() { TEST_GRADIENT(fn_global_var_use, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {7.00, 0.00} TEST(fn_increment_in_return, 3, 2); // CHECK-EXEC: Result is = {7.00, 0.00} + + auto fn_template_non_type_dx = clad::gradient(fn_template_non_type<15>); + double x = 5, dx = 0; + fn_template_non_type_dx.execute(x, &dx); + printf("Result is = %.2f\n", dx); // CHECK-EXEC: Result is = 15.00 }