Skip to content

Commit

Permalink
playing around with vector constr
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 12, 2024
1 parent a14a3f6 commit a67db3e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ namespace clad {
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
StmtDiff VisitCXXBindTemporaryExpr(
const clang::CXXBindTemporaryExpr* BTE);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool keepLocal = false);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
Expand Down
49 changes: 32 additions & 17 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,27 +451,42 @@ void at_pullback(::std::vector<T>* vec,
(*d_vec)[idx] += d_y;
}

template <typename T, typename S, typename U>
template <typename T>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
S count, U val,
typename ::std::vector<T>::allocator_type alloc,
S d_count, U d_val,
typename ::std::vector<T>::allocator_type d_alloc) {
::std::vector<T> v(count, val);
::std::vector<T> d_v(count, 0);
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>) {
::std::vector<T> v;
::std::vector<T> d_v;
return {v, d_v};
}

template <typename T, typename S, typename U>
void constructor_pullback(::std::vector<T>* v, S count, U val,
typename ::std::vector<T>::allocator_type alloc,
::std::vector<T>* d_v, S* d_count, U* d_val,
typename ::std::vector<T>::allocator_type* d_alloc) {
for (unsigned i = 0; i < count; ++i)
*d_val += (*d_v)[i];
d_v->clear();
}
// template <typename T, typename S, typename U, typename dS, typename dU, typename A, typename dA>
// ::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
// constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
// S count, U val,
// A alloc,
// dS d_count, dU d_val,
// dA d_alloc) {
// ::std::vector<T> v(count, val);
// ::std::vector<T> d_v(count, 0);
// return {v, d_v};
// }


// template <typename T, typename S, typename U, typename dS, typename dU, typename Alloc, typename dAlloc>
// void constructor_pullback(::std::vector<T>* v, S count, U val, Alloc alloc,
// ::std::vector<T>* d_v, dS* d_count, dU* d_val,
// dAlloc* d_alloc) {
// for (unsigned i = 0; i < count; ++i)
// *d_val += (*d_v)[i];
// d_v->clear();
// }
// template <typename T, typename S, typename U, typename dS, typename dU>
// void constructor_pullback(::std::vector<T>* v, S count, U val,
// ::std::vector<T>* d_v, dS* d_count, dU* d_val) {
// for (unsigned i = 0; i < count; ++i)
// *d_val += (*d_v)[i];
// d_v->clear();
// }

template <typename T, typename U, typename dU>
void assign_pullback(::std::vector<T>* v,
Expand Down
8 changes: 8 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3935,6 +3935,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::VisitCXXBindTemporaryExpr(
const clang::CXXBindTemporaryExpr* BTE) {
// `CXXBindTemporaryExpr` node will be created automatically, if it is
// required, by `ActOn`/`Build` Sema functions.
StmtDiff BTEDiff = Visit(BTE->getSubExpr(), dfdx());
return BTEDiff;
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Expand Down

0 comments on commit a67db3e

Please sign in to comment.