-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marged 2D and 3D kernels in Warp #75
Marged 2D and 3D kernels in Warp #75
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really nice but I think
-
we should also remove kernels and
warp_implementation
in all operators (except stepper of course) so that the operators are all consistent (the same logic for removing bc kernels applies to other operators). -
"
_construct_warp
" was always needed because of the kernels. Now without kernels we can also get rid of that closure and just register the warp functional as the__call__
method so that in the stepper we just call the operator names just like the JAX implementation. -
Optional: instead of removing all kernels all together, we may just write them where they are actually being used right now (only inside pytests functions).
Both points are not possible in Warp (inputs of functions must be explicit). You can try with a dummy class and if you can create a simple repro here I'll implement them. Otherwise I suggest merging this for the time being. |
There are also a lot other issues than that. You cannot remove construct warp as aux data is needed for the functional (and placing them in init is wrong for several reasons as you don't remove the memory allocations). |
I understand why 2 might be tricky or may be not possible (?) but I don't understand why request 1 above is not possible?! None of the operator kernels are ever used inside the code. |
I suggest sending another pull request if you believe (2) it is doable. That requires rewriting operator and doesn't apply to the scope of this PR. The same logic does not apply to all kernels as you may differentiate a kernel (tape can only differentiate kernel wrt to scalar, and if you only want to differentiate a part of the full kernel you should launch it separately). Some kerneles such as equilibrium are used outside as well. |
Registration of different backends should also be rewritten in this case. As I said, I'll bring back BC kernels later. A single kernel for LBM is only doable for the simplest LBM problems (and when registery pressure is not large). We will have to bring back all of them (for the same "consistency") the moment we implement anything more than the most basic simulations. |
For example, for moving bc, even though combining kernels is technically feasible, registry pressure may force us splitting them into multiple kernels. |
The only kernels we don't use outside of the stepper are collision and streaming (and I can certainly see streaming being used later for other operations outside stepper later). |
why not bring back BC kernels (perhaps in a moe abstracted way) now in this PR which is about this idea? |
Another week(s) of delays but fine |
Ok this is fixed now and BC kernels are back in an efficient manner. Note that warp_implementation cannot be moved to the parent class as it has to be registered per subclass. |
Perfect 💯 |
129 additions and 1,265 deletions