-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable parthenon::par_reduce
for MD loops with Kokkos 1D Range
#1130
Enable parthenon::par_reduce
for MD loops with Kokkos 1D Range
#1130
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice fix. Thanks for implementing it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
src/kokkos_abstraction.hpp
Outdated
template <typename Function, typename R, typename T, typename Index, typename... FArgs> | ||
class FlatFunctor<Function, R (T::*)(Index, Index, Index, FArgs...) const> { | ||
int NjNi, Nj, Ni, kl, jl, il; | ||
Function function; | ||
|
||
public: | ||
FlatFunctor(const Function _function, const int _NjNi, const int _Nj, const int _Ni, | ||
const int _kl, const int _jl, const int _il) | ||
: function(_function), NjNi(_NjNi), Nj(_Nj), Ni(_Ni), kl(_kl), jl(_jl), il(_il) {} | ||
KOKKOS_INLINE_FUNCTION | ||
void operator()(const int &idx, FArgs &&...fargs) const { | ||
int k = idx / NjNi; | ||
int j = (idx - k * NjNi) / Ni; | ||
int i = idx - k * NjNi - j * Ni; | ||
k += kl; | ||
j += jl; | ||
i += il; | ||
function(k, j, i, std::forward<FArgs>(fargs)...); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not for this PR, but I wonder if we could make this work for an arbitrary dimensional index space using IndexRange
in utils/
with a little bit of template magic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it could work. This seems to work and at least gets part way there https://godbolt.org/z/fbYP5vW6r
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (https://godbolt.org/z/n48951jGK) seems to work, but I didn't explicitly pull out the function signature like you did (which I think matters for references).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I remember cuda builds had given me trouble with that. This one uses some more template helper structs to also handle IndexRange
s in the constructor. I imagine you could use similar things to handle all the par_dispatch
overloads as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saved for future reference here #1134
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Out of curiosity, was there a specific use and/or performance consideration that inspired these changes?
Head branch was pushed to by a user without write access
Not any performance consideration. I was wanting to use |
All tests pass. I'm force merging. |
PR Summary
Replaces
KOKKOS_LAMBDA
s with templated functors inpar_dispatch(LoopPatternFlatRange, ...)
that deduce the signature of the providedfunction
to allow for extra arguments used in reductions.Makes it possible to use
par_reduce(DEFAULT_LOOP_PATTERN, ..)
&par_reduce(name, ...)
PR Checklist