From a67db3e0d87519502786b60ba3d05445f5d575c9 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Sat, 12 Oct 2024 14:50:18 +0200 Subject: [PATCH] playing around with vector constr --- .../clad/Differentiator/ReverseModeVisitor.h | 2 + include/clad/Differentiator/STLBuiltins.h | 49 ++++++++++++------- lib/Differentiator/ReverseModeVisitor.cpp | 8 +++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index b044ee0ec..b45838f77 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -420,6 +420,8 @@ namespace clad { StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); + StmtDiff VisitCXXBindTemporaryExpr( + const clang::CXXBindTemporaryExpr* BTE); DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, bool keepLocal = false); StmtDiff VisitSubstNonTypeTemplateParmExpr( diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 40b562dc4..c26881547 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -451,27 +451,42 @@ void at_pullback(::std::vector* vec, (*d_vec)[idx] += d_y; } -template +template ::clad::ValueAndAdjoint<::std::vector, ::std::vector> -constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector>, - S count, U val, - typename ::std::vector::allocator_type alloc, - S d_count, U d_val, - typename ::std::vector::allocator_type d_alloc) { - ::std::vector v(count, val); - ::std::vector d_v(count, 0); +constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector>) { + ::std::vector v; + ::std::vector d_v; return {v, d_v}; } -template -void constructor_pullback(::std::vector* v, S count, U val, - typename ::std::vector::allocator_type alloc, - ::std::vector* d_v, S* d_count, U* d_val, - typename ::std::vector::allocator_type* d_alloc) { - for (unsigned i = 0; i < count; ++i) - *d_val += (*d_v)[i]; - d_v->clear(); -} +// template +// ::clad::ValueAndAdjoint<::std::vector, ::std::vector> +// constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector>, +// S count, U val, +// A alloc, +// dS d_count, dU d_val, +// dA d_alloc) { +// ::std::vector v(count, val); +// ::std::vector d_v(count, 0); +// return {v, d_v}; +// } + + +// template +// void constructor_pullback(::std::vector* v, S count, U val, Alloc alloc, +// ::std::vector* d_v, dS* d_count, dU* d_val, +// dAlloc* d_alloc) { +// for (unsigned i = 0; i < count; ++i) +// *d_val += (*d_v)[i]; +// d_v->clear(); +// } +// template +// void constructor_pullback(::std::vector* v, S count, U val, +// ::std::vector* d_v, dS* d_count, dU* d_val) { +// for (unsigned i = 0; i < count; ++i) +// *d_val += (*d_v)[i]; +// d_v->clear(); +// } template void assign_pullback(::std::vector* v, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 2a7a16aa4..01679a499 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3935,6 +3935,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {endBlock(direction::forward), endBlock(direction::reverse)}; } + StmtDiff ReverseModeVisitor::VisitCXXBindTemporaryExpr( + const clang::CXXBindTemporaryExpr* BTE) { + // `CXXBindTemporaryExpr` node will be created automatically, if it is + // required, by `ActOn`/`Build` Sema functions. + StmtDiff BTEDiff = Visit(BTE->getSubExpr(), dfdx()); + return BTEDiff; + } + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff,