-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizing Shared Memory Usage #4756
Comments
Any progress? I am quite interested in this |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
tl.load
matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can Itl.load
all required variables at the beginning and expect the same memory usage as if weretl.load
them right before the operation they were involved in?tl.store
andtl.load
in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?x1 = tl.load(ptr)
and then later load another variable into itx1 = tl.load(ptr2)
will this overwrite the memory in SRAM?Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.
The text was updated successfully, but these errors were encountered: