-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize operator-like functions #1628
Conversation
stan/math/fwd/fun/multiply.hpp
Outdated
= require_t<bool_constant<T1::ColsAtCompileTime == T2::RowsAtCompileTime>>, | ||
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, | ||
int = 0> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This int template parameter is required to make compiler treat these as different overloads (if they have same arguments and template parameters it treats them as same overload and gives a lot of errors about redefining default template parameter values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! O that's how you do that. I was fighting with that compile error like a week ago. Thank you. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me this solution seems a bit hacky, but I don't know any better one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this approach. It's feels like it fights against the language. I think it would be good to post a stackoverflow Q about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a nice solution here: https://stackoverflow.com/questions/38305222/default-template-argument-when-using-stdenable-if-as-templ-param-why-ok-wit
The trick is to use requires as type of template parameters instead of defaults to typename parameters - require_t<...>* = nullptr
instead of typename = require_t<...>
stan/math/prim/fun/dot_product.hpp
Outdated
check_vector("dot_product", "v1", v1); | ||
check_vector("dot_product", "v2", v2); | ||
template <typename T1, typename T2, | ||
typename = require_all_eigen_vector_t<T1, T2>> | ||
inline auto dot_product(const T1 &v1, const T2 &v2) { | ||
check_matching_sizes("dot_product", "v1", v1, "v2", v2); | ||
return v1.dot(v2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These check_vector
checks were redundant, as Eigen's .dot()
is only defined for compile time vectors, so there is no way this could compile and check fail.
@@ -68,10 +60,11 @@ inline return_type_t<TD, TA, TB> trace_gen_quad_form( | |||
* @throw std::domain_error if A cannot be multiplied by B or B cannot | |||
* be multiplied by D. | |||
*/ | |||
template <int RD, int CD, int RA, int CA, typename TB, int RB, int CB> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typename TB
is never used and therefore can not be deduced. Due to the bug this overload was never chosen.
tols.hessian_hessian_ = 1e-1; | ||
tols.hessian_fvar_hessian_ = 1e-1; | ||
tols.hessian_hessian_ = 0.15; | ||
tols.hessian_fvar_hessian_ = 0.15; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed a bug that prevented all-double overload of trace_gen_quad_form()
to be ever chosen. While faster, that overload seems to be a bit less accurate, so I had to increase these tolerances.
…xp1~20180509124008.99 (branches/release_50)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just looked over the metaprogramming. I like it! I'll sit down and look at more once the tests pass
stan/math/fwd/fun/multiply.hpp
Outdated
const Eigen::Matrix<fvar<T>, R2, C2>& m2) { | ||
template <typename T1, typename T2, | ||
typename = require_all_eigen_vt<is_fvar, T1, T2>, | ||
typename = require_same_t<typename value_type_t<T1>::Scalar, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is require_same_st and require_same_vt
https://github.com/stan-dev/math/blob/develop/stan/math/prim/meta/require_generics.hpp#L283
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem here is that scalar_t<T1>
is fvar<X>
and I want to compare those X
es.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, I was wrong. Of course require_same_vt
is fine.
stan/math/fwd/fun/multiply.hpp
Outdated
typename = require_same_t<typename value_type_t<T1>::Scalar, | ||
typename value_type_t<T2>::Scalar>, | ||
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional]
We could have something like `require_eigen_dot_product (I'm not sure if there is some other math word for general ops with a row and column vector.
Also I think this could use require_any_not_t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also names like RowVec and ColVec may be good here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we have require_NOT, so these are not row and col vec.
stan/math/fwd/fun/multiply.hpp
Outdated
const Eigen::Matrix<double, R2, C2>& m2) { | ||
template <typename T1, typename T2, typename = require_all_eigen_t<T1, T2>, | ||
typename = require_fvar_t<value_type_t<T1>>, | ||
typename = require_same_vt<double, T2>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could shorten this up with
template <typename Mat1, typename Mat2, typename = require_eigen_vt<is_fvar, Mat1>,
typename = require_eigen_vt<std::is_floating_point, Mat2>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stan/math/fwd/fun/multiply.hpp
Outdated
typename = require_not_t< | ||
conjunction<is_eigen_row_vector<T1>, is_eigen_col_vector<T2>>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just a meta function for is_dot_product
for these instead of a requires specialization
typename = require_all_not_var_t<value_type_t<TD>, value_type_t<TA>, | ||
value_type_t<TB>>, | ||
typename = require_any_not_same_t<double, value_type_t<TD>, | ||
value_type_t<TA>, value_type_t<TB>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a weird way to express it but one way to look at this is
require_not_same_t<double, return_type_t<TD, TA, TB>>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this is shorter, I find it less clear.
template <typename TD, typename TA, typename TB, | ||
typename = require_all_same_t<double, value_type_t<TD>, | ||
value_type_t<TA>, value_type_t<TB>>, | ||
typename = require_all_eigen_t<TD, TA, TB>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
require_all_eigen_vt<std::is_floating_point, TD, TA, TB>
If we ever worry about mixing types of floats we could also have
require_same_same_vt<TD, TA, TB>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though the bottom may be better as a static assert to the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static assert is not what we want here - if that require is not satisfied, other overload must be chosen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway I would not complicate here, since this overload can be removed once transpose is generalized too.
stan/math/rev/fun/divide.hpp
Outdated
= new internal::matrix_scalar_divide_vd_vari<R, C>(m, c); | ||
Eigen::Matrix<var, R, C> result(m.rows(), m.cols()); | ||
template <typename T, typename = require_eigen_vt<is_var, T>> | ||
inline Eigen::Matrix<var, T::RowsAtCompileTime, T::ColsAtCompileTime> divide( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does auto not work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does. I missed that one.
…stable/2017-11-14)
…stable/2017-11-14)
…stable/2017-11-14)
For the errors up in the stan repo I started working on generalizing the https://github.com/stan-dev/stan/blob/refactor/rvalue-pass/src/stan/model/indexing/rvalue.hpp#L114 If that branch is helpful and you want to make any changes to that branch feel free to push straight to it |
I think the lockstep here is updating the stan repo first then merged this |
I agree. Looking over the branch you linked I see many changes, none of which seem like they could fix test failures in this PR. Did you link the right branch? |
Crap actually sorry this branch doesn't cover that. I thought I had a branch with these but looking on my local I must not have saved it If we use the generic template with require should we have inside each function that does coeff access like if (is_eigen_expression<EigMat>::value) {
the_mat.eval();
} Or is there some better pattern? |
What you suggested has multiple issues. The correct pattern would be:
where
where T is the type of |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
# Conflicts: # test/unit/math/mix/fun/trace_gen_quad_form_test.cpp
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
@SteveBronder This is still waiting for a review. |
Sorry for the delay it's been a v busy week. I should have time tomorrow night to look at this |
# Conflicts: # stan/math/fwd/fun/multiply.hpp # stan/math/prim/fun/multiply.hpp # stan/math/rev/fun/multiply.hpp
stan/math/fwd/fun/multiply.hpp
Outdated
template <typename Mat1, typename Mat2, | ||
typename = require_eigen_vt<is_fvar, Mat1>, | ||
typename = require_eigen_vt<std::is_floating_point, Mat2>, | ||
typename = require_not_eigen_row_and_col_t<Mat1, Mat2>, int = 0> | ||
inline auto multiply(const Mat1& m1, const Mat2& m2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to do these multiplies differently. This and the ones in prim are pretty confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
template <typename Mat, typename Scal, typename = require_eigen_t<Mat>, | ||
typename = require_stan_scalar_t<Scal>, | ||
typename = require_all_not_var_t<scalar_type_t<Mat>, Scal>> | ||
inline auto divide(const Mat& m, Scal c) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this so that divide works for fvar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one works for fvar. I just generalized it the same way as other functions.
stan/math/prim/fun/multiply.hpp
Outdated
typename = require_any_not_same_t<double, value_type_t<Mat1>, | ||
value_type_t<Mat2>>, | ||
typename = require_not_eigen_row_and_col_t<Mat1, Mat2>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[sidenote]
We should probably have a require_*_vt
for checking the value types of any arbitrary container
const Eigen::Matrix<double, RB, CB> &B) { | ||
template <typename TD, typename TA, typename TB, | ||
typename = require_all_eigen_t<TD, TA, TB>, | ||
typename = require_all_same_vt<double, TD, TA, TB>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not do this for all arithmetics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Done.
I looked this over and like a lot of it! But the multiply needs cleaned up. The templates there are v complicated, the requires stuff is pretty complicated to read for those. Do you have any ideas on how else to do those? I can look at this again on Thursday and see if I can hack up something else |
If @syclik or @bob-carpenter sign off on the multiply templates then I'm fine with it but I'm not sure about them |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
@SteveBronder Now everything you mentioned is fixed, including multiply templates, and tests passed so this is ready for next review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Summary
Generalizes function signatures of functions that act kind of like operators (
add
,subtract
,multiply
,divide
,elt_multiply
,elt_divide
). That is they can now accept and return general Eigen expressions.EDIT: I also had to generalize functions that were called directely by return value of one of generalized functions. Those were:
trace
,trace_gen_quad_form
,transpose
andmatrix_exp
.In many cases I replaced function specializations with overloads. This results in more
require
s, but it also allowed me to replace multiple specializations with one general overload, reducing total amount of code.There are also some changes to autodiff testing framework that allow it to work with Eigen expressions. I did some of those in earlier PR #1471, but I missed some places.
Tests
This is mostly a refactor so no new tests.
Side Effects
None.
Checklist
Math issue Generalize matrix function signatures #1470
Copyright holder: Tadej Ciglarič
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested