XLA loops without a kernel launch on each iteration #16186
carlosgmartin
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Copied from here, since the XLA repo might be a better place to discuss this.
My understanding is that, currently, each iteration of
jax.lax.scan
requires a kernel launch on GPU backends. This causes an appreciable performance penalty.For context, consider the following comments:
July 2020:
March 2021:
July 2021:
May 2023:
September 2023:
February 2024:
My question is this: Is this a fundamental limitation of XLA and/or GPU hardware? Can it be resolved? The first two comments above suggest it's possible. If so, is this currently being discussed or worked on somewhere?
Beta Was this translation helpful? Give feedback.
All reactions