-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures #2384
Closures #2384
Conversation
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
A quick example: vector[N] v = rep_vector(1.0, N);
real myfunc(row_vector s) {
return s * v;
} The C++ type of template<bool Ref>
class myfunc_type {
capture_type_t<VectorXd, Ref> v;
public:
auto operator()(const auto& s) {
return s * v;
}
// ...
// + methods for accessing autodiff stack of the private member v
} When the closure is first created the template type class myfunc_type {
const VectorXd& v; // capture_type_t<VectorXd, true> That's efficient but means the closure object cannot outlive the scope in which class myfunc_type {
VectorXd v; // capture_type_t<VectorXd, false> |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few questions/comments
inline void elementwise_check(const F& is_good, const char* function, | ||
const char* name, const T& x, const char* must_be, | ||
const Indexings&... indexings) { | ||
// XXX skip closures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We'll need to implement these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, thinking more, I don't think there's a big advantage to implementing this.
We could expose the variables captured by a closure to checks, but the Math checks wouldn't know in what order its getting them, and then depending on which function was accepting closures it would need to decide which checks to do on which inputs.
I think instead in the ODE solves we check only the arguments passed in explicitly (which this is effectively doing) or we get rid of the infinity checks on the inputs to the ODE solves. I'll make an issue and see if getting rid of the checks altogether is an option. (Edit: Issue #2406)
stan/math/prim/fun/value_of.hpp
Outdated
template <typename F, require_stan_closure_t<F>* = nullptr, | ||
require_not_st_arithmetic<F>* = nullptr> | ||
inline auto value_of(const F& f) { | ||
return f.value_of__(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow I vaguely remember this. Is this used anywhere yet?
} | ||
|
||
template <bool Propto, typename F, bool Ref> | ||
struct lpdf_wrapper { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is lpdf_wrapper
used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you know how one can do
real func_lpdf(real[] slice, ...) {
...
}
target += reduce_sum(func_lpdf, ...);
target += reduce_sum(func_lupdf, ...);
and both reduce_sum
calls take the same closure object as the first argument so you should be asking yourself, how does the closure remember if it's lpdf
or lupdf
?
The answer is that it's wrapped in lpdf_wrapper
right before calling reduce_sum
. Propto=true
means lupdf
and Propto=false
means lpdf
.
template <typename F> | ||
auto lp_from_lambda(const F& f) { | ||
return empty_closure_lp<F>(f); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for each kind of function we might have in Stan, we can also have a closure version of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each function kind follows a different calling convention so each kind needs its own adapter closure. These aren't used in math library but stanc3 allows userdefined higher order functions that might need them.
* @ingroup type_trait | ||
*/ | ||
template <typename F, typename... Args> | ||
using fn_return_type_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason this isn't part of return_type_t
? One less template function to remember would be nice, but if there's a reason to keep it separate that's good too.
using type | ||
= scalar_lub_t<scalar_type_t<T>, typename return_type<Ts...>::type>; | ||
using type = scalar_lub_t<scalar_type_t<T>, | ||
typename return_type<scalar_type_t<Ts>...>::type>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where'd the extra scalar_type_t
come from here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good question.
If I understand it correctly return_type_t
already applies scalar_type_t
to each argument due to the recursive definition and scalar_type_t
is idempotent so the extra one does nothing. You can revert this and compile reduce_sum_closure_test
to see what happens. (Spoiler alert: I do not understand it correctly.)
* | ||
* @tparam F A closure type | ||
* @param f A closure of vars | ||
* @return A new std::vector of vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's returning a new closure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad copy-paste.
res[1][1].grad(); | ||
EXPECT_FLOAT_EQ(t0v.adj(), -0.38494826636037426937); | ||
stan::math::set_zero_all_adjoints(); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to check the adjoint of a
in here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these tests are quite incomplete. The function doesn't even do anything with a
so it only tests that autodiffable closures compile. How do people come up with good test cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably it will be a depressing amount of copy-paste. I think we can do something like:
y' = -recursive_sum(captured_args...) * y;
where recursive_sum
just keeps summing everything until out pops a scalar. And then we can test that this basic example works for capturing at least a scalar, a matrix, and array, and then a couple combinations of the above.
And similarly we'll add tests for reduce_sum
and whatnot. Let's only do the boring tests once everything else is in place though (integrate + algebra_solver tests). I will help with those. Only need enough tests now to convince ourselves things can work (so just a couple are fine).
I think we already have templates that do this for existing types in the math library. The general recursive template pattern is also used for the vectorized form of scalar real functions. |
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, Args... args) const { | ||
return f_(s_, args..., msgs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having functions where the variadic arguments aren't last makes writing the C++ painful (cause no type deduction). Could it go captured_args..., msgs, args...
?
The PR looks good to me but I'm not very familiar with If you're looking into the algebra solver then I'm pinging @charlesm93 because the last post in the ancient Discourse thread is him saying he'd do it. Not sure if he's still up to it but anyway. And I think |
Yeah, this looks right. I'd have to dig in more to the MPI to know. As the caching is implemented now can lead to bugs. The current map_rect takes a function:
So if we took a closure, So
And so what you have looks right to me. I'd have to think more closely to know for sure though. |
I think the routine would be define another Regular Regardless, we would need to build a special |
Also there are some weird macros floating around that I'm not sure about: https://github.com/stan-dev/math/blob/develop/stan/math/prim/functor/map_rect_mpi.hpp#L48 |
Is this still being worked on? Happy to take on the review if so! |
I haven't resolved the merge conflicts here because I figured everyone is too busy with the adjoint ODE stuff anyway. I do have a branch that is up to date with it feature/adjoint-ode...nhuurre:closures-adjoint-ode (huge diff because the adjoint ODE PR hasn't merged in develop for a while, I think) IIRC I found some benchmark model in the adjoint ode discussion and tried running it with no changes to Stan code, so basically an empty closure, it was like 10% slower which was pretty depressing because I have no idea where to start debugging that. |
Don‘t wait for adjoint ODE...I am not sure how busy this is keeping others than myself. However, I will merge develop into the adjoint branch rather soon, since now the ODE testing branch is in develop. |
FYI: the adjoint ODE branch is now up to date with develop |
Your code seems to test a closure on |
Ok, I see. So I get:
so it looks like the closure stuff is taking off a little bit. |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay! Got a few things, like names etc but overall looks good! I made some changes in the branch below that handle references and things a bit nicer and removes fn_return_type
if you want to merge it in. I only ran
./runTests.py -j28 ./test/unit/math/ -f closure
though those seem to pass
nhuurre/math@feature/closures-v2...stan-dev:review1/closures
/** | ||
* A closure that wraps a C++ lambda and captures values. | ||
*/ | ||
template <bool Ref, typename F, typename... Ts> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all needs docs for template parameters etc.
return apply([this, msgs, &args...]( | ||
const auto&... s) { return f_(s..., args..., msgs); }, | ||
captures_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional] When capturing this
I prefer to use this->
before the member functions so they are a bit easier to see
using captured_scalar_t__ = return_type_t<Ts...>; | ||
using ValueOf__ | ||
= base_closure<false, F, decltype(eval(value_of(std::declval<Ts>())))...>; | ||
using CopyOf__ = base_closure<false, F, Ts...>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have docs for what these are as their definitions also change across the different types of closures
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, const Args&... args) const { | ||
return apply([this, msgs, &args...]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional] I like having things captured by reference at the front and then things copied after
template <bool propto> | ||
auto with_propto() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use camelcase for template parameters. How is this propto different from Propto
in the template parameters of the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this for like lupdf or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's an lupdf
hack. This normalization nonsense is rather confusing so here's an example:
parameters {
real y[100];
}
model {
function
real higher_lpdf(real[] x, real(real[], int, int) f_lpdf) {
real lp = 0;
lp += reduce_sum(f_lpdf, x, 1); // <-- A
lp += reduce_sum(f_lupdf, x, 1); // <-- B
return lp;
}
function
real partial_lpdf(real[] x, int s, int e) {
return std_normal_lupdf(x|);
}
target += higher_lpdf( y| partial_lpdf); // <-- 1
target += higher_lupdf(y| partial_lupdf); // <-- 2
}
Using lpdf
instead of lupdf
anywhere makes the return value normalized.
In the example reduce_sum
is called four times, at (1A), (1B), (2A), (2B), and only (2B) is unnormalized.
The above compiles to C++ that looks something like
auto higher_lpdf = from_lambda([&](auto f_lpdf) {
var lp = 0;
lp += reduce_sum(f_lpdf.with_propto<false>(), x, 1); // <-- A
lp += reduce_sum(f_lpdf.with_propto<true>(), x, 1); // <-- B
return lp;
});
auto partial_lpdf = from_lambda([]<bool propto>(auto x, int s, int e) {
return std_normal_lpdf<propto>(x);
});
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<false>()); // <-- 1
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<true>()); // <-- 2
Every time the closure object is passed to a higher-order function Propto
records if it's in its lpdf
or lupdf
form.
= base_closure<false, F, decltype(eval(value_of(std::declval<Ts>())))...>; | ||
using CopyOf__ = base_closure<false, F, Ts...>; | ||
F f_; | ||
std::tuple<capture_type_t<Ts, Ref>...> captures_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats the higher level logic for ref? Aka why can't these always just be references?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value_of()
and deep_copy_vars()
cannot return references.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value_of()
can return references. But yeah deep_copy_vars()
cannot. How about something like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also could you add some tests and docs for all these? It would help me understand what your design goal is and how things should work
struct closure_lp { | ||
using captured_scalar_t__ = return_type_t<Ts...>; | ||
using ValueOf__ = closure_lp<Propto, true, F, Ts...>; | ||
using CopyOf__ = closure_lp<Propto, true, F, Ts...>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The C++ standard reserves double underscore for the compiler implementations (yes we do this at the stanc3 level but it's not good and we should not do it here)
template <typename F, require_stan_closure_t<F>* = nullptr, | ||
require_arithmetic_t<return_type_t<F>>* = nullptr> | ||
inline double integrate_1d(const F& f, double a, double b, | ||
const std::vector<double>& theta, | ||
const std::vector<double>& x_r, | ||
const std::vector<int>& x_i, std::ostream* msgs, | ||
const double relative_tolerance | ||
= std::sqrt(EPSILON)) { | ||
return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, | ||
relative_tolerance, msgs, f, theta, x_r, x_i); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Q] Think I'm just missing some context, why is f
being passed as an argument here? instead of being pushed to integrate_1d_adapter
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
integrate_1d_impl
already knows how to handle the autodiff variables passed as arguments but would need new logic for extracting them from the functor.
|
||
std::vector<std::vector<return_type_t<T_y0, T_param, T_t0, T_ts>>> | ||
std::vector<std::vector<fn_return_type_t<F, T_y0, T_param, T_t0, T_ts>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just have the logic in return_type_t
for handling closures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I did have some logic for handling closures in return_type_t
but for backwards compatibility some places need logic for also handling arbitrary C++ lambdas. I created fn_return_type_t
because I wasn't sure if adding a catchall case for "anything that isn't a known type" would have some undesirable side effects. Your branch seems to have resolved that issue.
template <typename T> | ||
struct capture_type<T, false, | ||
require_stan_closure_t<std::remove_reference_t<T>>> { | ||
using type = typename std::remove_reference_t<T>::CopyOf__; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd make this stan_lambda_capture_type
or something of the sort
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
@nhuurre sorry for not coming back to this, is this ready to be looked at again? |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
@nhuurre and @SteveBronder, I'm marking it as a draft; there hasn't been a response in 28 days. |
Closing this PR for now. @nhuurre, if this is still active, please reopen. |
Summary
Issue #2197
Extends the following functions to support closure-like objects:
That should be enough for any variadic higher-order function to support closures almost automatically.
The stanc3 PR exposing this in the language is stan-dev/stanc3#742
Tests
I added a couple of tests for
integrate_ode_rk45
,ode_rk45
, andreduce_sum
.Side Effects
I don't think so.
Release notes
Implement closures.
Checklist
Math issue Implement closures #2197
Copyright holder: Niko Huurre
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested