-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize operator-like functions #1628
Changes from 4 commits
84b5ab6
7d7779e
4b50381
3c23987
7cd74ef
8263fdb
fb2de8a
1fedd01
9029b14
d73d0c8
45319be
f989b11
01a87f4
d34d6fe
c68504a
aaa8c17
b98a5a0
9445e07
9180fc6
ed9b73a
f2b4e13
d989040
ca27b44
e7d4d53
87d205c
a9490b3
c7162fd
b160720
f88d89a
d9284e8
cf2b72d
fb619c0
5290f6b
9529ccc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#ifndef STAN_MATH_FWD_FUN_MULTIPLY_HPP | ||
#define STAN_MATH_FWD_FUN_MULTIPLY_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
|
@@ -10,120 +11,64 @@ | |
namespace stan { | ||
namespace math { | ||
|
||
template <typename T, int R1, int C1> | ||
inline Eigen::Matrix<fvar<T>, R1, C1> multiply( | ||
const Eigen::Matrix<fvar<T>, R1, C1>& m, const fvar<T>& c) { | ||
Eigen::Matrix<fvar<T>, R1, C1> res(m.rows(), m.cols()); | ||
for (int i = 0; i < m.size(); i++) | ||
res(i) = c * m(i); | ||
return res; | ||
} | ||
|
||
template <typename T, int R2, int C2> | ||
inline Eigen::Matrix<fvar<T>, R2, C2> multiply( | ||
const Eigen::Matrix<fvar<T>, R2, C2>& m, double c) { | ||
Eigen::Matrix<fvar<T>, R2, C2> res(m.rows(), m.cols()); | ||
for (int i = 0; i < m.size(); i++) | ||
res(i) = c * m(i); | ||
return res; | ||
} | ||
|
||
template <typename T, int R1, int C1> | ||
inline Eigen::Matrix<fvar<T>, R1, C1> multiply( | ||
const Eigen::Matrix<double, R1, C1>& m, const fvar<T>& c) { | ||
Eigen::Matrix<fvar<T>, R1, C1> res(m.rows(), m.cols()); | ||
for (int i = 0; i < m.size(); i++) | ||
res(i) = c * m(i); | ||
return res; | ||
} | ||
|
||
template <typename T, int R1, int C1> | ||
inline Eigen::Matrix<fvar<T>, R1, C1> multiply( | ||
const fvar<T>& c, const Eigen::Matrix<fvar<T>, R1, C1>& m) { | ||
return multiply(m, c); | ||
} | ||
|
||
template <typename T, int R1, int C1> | ||
inline Eigen::Matrix<fvar<T>, R1, C1> multiply( | ||
double c, const Eigen::Matrix<fvar<T>, R1, C1>& m) { | ||
return multiply(m, c); | ||
} | ||
|
||
template <typename T, int R1, int C1> | ||
inline Eigen::Matrix<fvar<T>, R1, C1> multiply( | ||
const fvar<T>& c, const Eigen::Matrix<double, R1, C1>& m) { | ||
return multiply(m, c); | ||
} | ||
|
||
template <typename T, int R1, int C1, int R2, int C2> | ||
inline Eigen::Matrix<fvar<T>, R1, C2> multiply( | ||
const Eigen::Matrix<fvar<T>, R1, C1>& m1, | ||
const Eigen::Matrix<fvar<T>, R2, C2>& m2) { | ||
template < | ||
typename T1, typename T2, typename = require_all_eigen_vt<is_fvar, T1, T2>, | ||
typename = require_same_t<typename value_type_t<T1>::Scalar, | ||
typename value_type_t<T2>::Scalar>, | ||
typename | ||
= require_t<bool_constant<T1::ColsAtCompileTime == T2::RowsAtCompileTime>>, | ||
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, | ||
int = 0> | ||
inline auto multiply(const T1& m1, const T2& m2) { | ||
check_multiplicable("multiply", "m1", m1, "m2", m2); | ||
Eigen::Matrix<fvar<T>, R1, C2> result(m1.rows(), m2.cols()); | ||
for (size_type i = 0; i < m1.rows(); i++) { | ||
Eigen::Matrix<fvar<T>, 1, C1> crow = m1.row(i); | ||
for (size_type j = 0; j < m2.cols(); j++) { | ||
Eigen::Matrix<fvar<T>, R2, 1> ccol = m2.col(j); | ||
result(i, j) = dot_product(crow, ccol); | ||
} | ||
} | ||
return result; | ||
return m1 * m2; | ||
SteveBronder marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename T, int R1, int C1, int R2, int C2> | ||
inline Eigen::Matrix<fvar<T>, R1, C2> multiply( | ||
const Eigen::Matrix<fvar<T>, R1, C1>& m1, | ||
const Eigen::Matrix<double, R2, C2>& m2) { | ||
template <typename T1, typename T2, typename = require_all_eigen_t<T1, T2>, | ||
typename = require_fvar_t<value_type_t<T1>>, | ||
typename = require_same_vt<double, T2>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could shorten this up with
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
typename = require_t< | ||
bool_constant<T1::ColsAtCompileTime == T2::RowsAtCompileTime>>, | ||
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just a meta function for |
||
int = 0> | ||
inline auto multiply(const T1& m1, const T2& m2) { | ||
check_multiplicable("multiply", "m1", m1, "m2", m2); | ||
Eigen::Matrix<fvar<T>, R1, C2> result(m1.rows(), m2.cols()); | ||
Eigen::Matrix<value_type_t<T1>, T1::RowsAtCompileTime, T2::ColsAtCompileTime> | ||
result(m1.rows(), m2.cols()); | ||
for (size_type i = 0; i < m1.rows(); i++) { | ||
Eigen::Matrix<fvar<T>, 1, C1> crow = m1.row(i); | ||
Eigen::Matrix<value_type_t<T1>, 1, T1::ColsAtCompileTime> crow = m1.row(i); | ||
SteveBronder marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (size_type j = 0; j < m2.cols(); j++) { | ||
Eigen::Matrix<double, R2, 1> ccol = m2.col(j); | ||
auto ccol = m2.col(j); | ||
result(i, j) = dot_product(crow, ccol); | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
template <typename T, int R1, int C1, int R2, int C2> | ||
inline Eigen::Matrix<fvar<T>, R1, C2> multiply( | ||
const Eigen::Matrix<double, R1, C1>& m1, | ||
const Eigen::Matrix<fvar<T>, R2, C2>& m2) { | ||
template <typename T1, typename T2, typename = require_all_eigen_t<T1, T2>, | ||
typename = require_same_vt<double, T1>, | ||
typename = require_fvar_t<value_type_t<T2>>, | ||
typename = require_t< | ||
bool_constant<T1::ColsAtCompileTime == T2::RowsAtCompileTime>>, | ||
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, | ||
char = 0> | ||
inline auto multiply(const T1& m1, const T2& m2) { | ||
check_multiplicable("multiply", "m1", m1, "m2", m2); | ||
Eigen::Matrix<fvar<T>, R1, C2> result(m1.rows(), m2.cols()); | ||
Eigen::Matrix<value_type_t<T2>, T1::RowsAtCompileTime, T2::ColsAtCompileTime> | ||
result(m1.rows(), m2.cols()); | ||
for (size_type i = 0; i < m1.rows(); i++) { | ||
Eigen::Matrix<double, 1, C1> crow = m1.row(i); | ||
Eigen::Matrix<double, 1, T1::ColsAtCompileTime> crow = m1.row(i); | ||
for (size_type j = 0; j < m2.cols(); j++) { | ||
Eigen::Matrix<fvar<T>, R2, 1> ccol = m2.col(j); | ||
auto ccol = m2.col(j); | ||
result(i, j) = dot_product(crow, ccol); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to not put this as for (size_type i = 0; i < m1.rows(); i++) {
for (size_type j = 0; j < m2.cols(); j++) {
result(i, j) = dot_product(m1.row(i), m2.col(j));
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a great assembly guy, but my usual heuristic of 'less is better' says the above looks a bit nicer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a case where less is not better. Doing Since we go over them in every iteration of inner loop it is better to inefficiently access them just once and make remaining accesses efficient. If you want I can make line 51 na 52 into a single line, but it won't affect performance in any way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Word that's sensible to me.
idt it will effect performance though I think zipping 51 into 52 would be nice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return result; | ||
} | ||
|
||
template <typename T, int C1, int R2> | ||
inline fvar<T> multiply(const Eigen::Matrix<fvar<T>, 1, C1>& rv, | ||
const Eigen::Matrix<fvar<T>, R2, 1>& v) { | ||
check_multiplicable("multiply", "rv", rv, "v", v); | ||
return dot_product(rv, v); | ||
} | ||
|
||
template <typename T, int C1, int R2> | ||
inline fvar<T> multiply(const Eigen::Matrix<fvar<T>, 1, C1>& rv, | ||
const Eigen::Matrix<double, R2, 1>& v) { | ||
check_multiplicable("multiply", "rv", rv, "v", v); | ||
return dot_product(rv, v); | ||
} | ||
|
||
template <typename T, int C1, int R2> | ||
inline fvar<T> multiply(const Eigen::Matrix<double, 1, C1>& rv, | ||
const Eigen::Matrix<fvar<T>, R2, 1>& v) { | ||
check_multiplicable("multiply", "rv", rv, "v", v); | ||
return dot_product(rv, v); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This int template parameter is required to make compiler treat these as different overloads (if they have same arguments and template parameters it treats them as same overload and gives a lot of errors about redefining default template parameter values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! O that's how you do that. I was fighting with that compile error like a week ago. Thank you. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me this solution seems a bit hacky, but I don't know any better one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this approach. It's feels like it fights against the language. I think it would be good to post a stackoverflow Q about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a nice solution here: https://stackoverflow.com/questions/38305222/default-template-argument-when-using-stdenable-if-as-templ-param-why-ok-wit
The trick is to use requires as type of template parameters instead of defaults to typename parameters -
require_t<...>* = nullptr
instead oftypename = require_t<...>