-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add perfect forwarding and constexpr to reverse mode functions #3092
base: develop
Are you sure you want to change the base?
Changes from 17 commits
f164443
d70fb0f
f08c711
ddea3c0
d98f4f0
3012408
af37cbc
5152942
888e2cc
9887597
d4ebc3d
d7350f2
a54fb01
bf36144
5d989a2
fdb5d03
be242f0
aea8afa
d5bcc83
a3f3cd8
c3d8136
ab679fe
ad456a1
bc96704
d1fc936
5c8bd28
17ffd02
506fe50
c4c7605
82fc4f6
c826d98
d5de333
855546a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -62,5 +62,12 @@ template <typename T> | |||||
struct is_constant<T, require_eigen_t<T>> | ||||||
: bool_constant<is_constant<typename std::decay_t<T>::Scalar>::value> {}; | ||||||
|
||||||
template <typename... Types> | ||||||
inline constexpr bool is_constant_all_v = is_constant_all<Types...>::value; | ||||||
|
||||||
template <typename... Types> | ||||||
inline constexpr bool is_constant_v | ||||||
= std::conjunction<is_constant<Types>...>::value; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
For consistency There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also same comment about redundancy - is there a situation where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think we should start using the std version! We just used the stan math version because we didn't have c++17 available previously
Since I made |
||||||
|
||||||
} // namespace stan | ||||||
#endif |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,6 +28,10 @@ struct is_stan_scalar | |||||
is_fvar<std::decay_t<T>>, std::is_arithmetic<std::decay_t<T>>, | ||||||
is_complex<std::decay_t<T>>>::value> {}; | ||||||
|
||||||
template <typename... Types> | ||||||
inline constexpr bool is_stan_scalar_v | ||||||
= std::conjunction<is_stan_scalar<Types>...>::value; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
/*! \ingroup require_stan_scalar_real */ | ||||||
/*! \defgroup stan_scalar_types stan_scalar */ | ||||||
/*! \addtogroup stan_scalar_types */ | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,24 +35,24 @@ template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr> | |
inline auto append_col(const T1& A, const T2& B) { | ||
check_size_match("append_col", "columns of A", A.rows(), "columns of B", | ||
B.rows()); | ||
if (!is_constant<T1>::value && !is_constant<T2>::value) { | ||
arena_t<promote_scalar_t<var, T1>> arena_A = A; | ||
arena_t<promote_scalar_t<var, T2>> arena_B = B; | ||
if constexpr (is_autodiffable_v<T1, T2>) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Applies to the other changes as well) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I'm confused on definitions. If something is not constant then shouldn't it be autodiffable (i.e. var and fvar?) |
||
arena_t<T1> arena_A = A; | ||
arena_t<T2> arena_B = B; | ||
return make_callback_var( | ||
append_col(value_of(arena_A), value_of(arena_B)), | ||
[arena_A, arena_B](auto& vi) mutable { | ||
arena_A.adj() += vi.adj().leftCols(arena_A.cols()); | ||
arena_B.adj() += vi.adj().rightCols(arena_B.cols()); | ||
}); | ||
} else if (!is_constant<T1>::value) { | ||
arena_t<promote_scalar_t<var, T1>> arena_A = A; | ||
} else if constexpr (is_autodiffable_v<T1>) { | ||
arena_t<T1> arena_A = A; | ||
return make_callback_var(append_col(value_of(arena_A), value_of(B)), | ||
[arena_A](auto& vi) mutable { | ||
arena_A.adj() | ||
+= vi.adj().leftCols(arena_A.cols()); | ||
}); | ||
} else { | ||
arena_t<promote_scalar_t<var, T2>> arena_B = B; | ||
arena_t<T2> arena_B = B; | ||
return make_callback_var(append_col(value_of(A), value_of(arena_B)), | ||
[arena_B](auto& vi) mutable { | ||
arena_B.adj() | ||
|
@@ -79,21 +79,21 @@ template <typename Scal, typename RowVec, | |
require_stan_scalar_t<Scal>* = nullptr, | ||
require_t<is_eigen_row_vector<RowVec>>* = nullptr> | ||
inline auto append_col(const Scal& A, const var_value<RowVec>& B) { | ||
if (!is_constant<Scal>::value && !is_constant<RowVec>::value) { | ||
if constexpr (is_autodiffable_v<Scal, RowVec>) { | ||
var arena_A = A; | ||
arena_t<promote_scalar_t<var, RowVec>> arena_B = B; | ||
arena_t<RowVec> arena_B = B; | ||
return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)), | ||
[arena_A, arena_B](auto& vi) mutable { | ||
arena_A.adj() += vi.adj().coeff(0); | ||
arena_B.adj() += vi.adj().tail(arena_B.size()); | ||
}); | ||
} else if (!is_constant<Scal>::value) { | ||
} else if constexpr (is_autodiffable_v<Scal>) { | ||
var arena_A = A; | ||
return make_callback_var( | ||
append_col(value_of(arena_A), value_of(B)), | ||
[arena_A](auto& vi) mutable { arena_A.adj() += vi.adj().coeff(0); }); | ||
} else { | ||
arena_t<promote_scalar_t<var, RowVec>> arena_B = B; | ||
arena_t<RowVec> arena_B = B; | ||
return make_callback_var(append_col(value_of(A), value_of(arena_B)), | ||
[arena_B](auto& vi) mutable { | ||
arena_B.adj() += vi.adj().tail(arena_B.size()); | ||
|
@@ -119,17 +119,17 @@ template <typename RowVec, typename Scal, | |
require_t<is_eigen_row_vector<RowVec>>* = nullptr, | ||
require_stan_scalar_t<Scal>* = nullptr> | ||
inline auto append_col(const var_value<RowVec>& A, const Scal& B) { | ||
if (!is_constant<RowVec>::value && !is_constant<Scal>::value) { | ||
arena_t<promote_scalar_t<var, RowVec>> arena_A = A; | ||
if constexpr (is_autodiffable_v<RowVec, Scal>) { | ||
arena_t<RowVec> arena_A = A; | ||
var arena_B = B; | ||
return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)), | ||
[arena_A, arena_B](auto& vi) mutable { | ||
arena_A.adj() += vi.adj().head(arena_A.size()); | ||
arena_B.adj() | ||
+= vi.adj().coeff(vi.adj().size() - 1); | ||
}); | ||
} else if (!is_constant<RowVec>::value) { | ||
arena_t<promote_scalar_t<var, RowVec>> arena_A = A; | ||
} else if constexpr (is_autodiffable_v<RowVec>) { | ||
arena_t<RowVec> arena_A = A; | ||
return make_callback_var(append_col(value_of(arena_A), value_of(B)), | ||
[arena_A](auto& vi) mutable { | ||
arena_A.adj() += vi.adj().head(arena_A.size()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two typedefs seem to test the same thing. Also, shouldn't this be
is_autodiff_all_v
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need better names.
is_autodiff_v<>
looks to see is the type is a fvar or var and fails otherwise.is_autodiffable_v
looks into the scalar type of the type to see if it is autodiffable, so things like eigen matrices and vectors of var or fvar types would be true foris_autodiffable_v
.I left
is_autodiff
alone to not mess with the other functions that use it (mostly functions that use it in a requires). Maybe the currentis_autodiff
should be namedis_autodiff_scalar
?