-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cross Attention does not work on CPU and older GPUs #6
Comments
I created a monkey patch for CrossAttention, which replaces the FlexAttention part with standard SDPA. This seems to fix the problem, though it would be nice to integrate upstream, somehow. (note: this particular patch seems to have an unchecked memory leak; I have no idea why cross attention is causing that here; something about |
After more investigation, I found that we can break the unchecked memory growth by detaching def forward(self, x, kv, mask=None) -> torch.Tensor:
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x = x.detach()
x = self.cross_attn_norm_q(x) I don't know how much of a "solution" this really is, but it does indicate that the problems relate to gradient growth. Again, this only happens when using cross attention in the encoder; not the decoder. So, I'm thinking there must be a huge gradients matrix here, which comes from comparing inputs via cross attention, and somehow - FlexAttention makes that a lot more efficient. |
We ran our experiments using FlexAttention and didn't test on any other variants. I'm not sure if you will reproduce the results with SDPA, but worth a shot. |
The current CrossAttention code has a hardcoded dependency on FlexAttention. This is a problem for people like me, who need to use older, consumer GPUs.
I'm not sure of the feasibility, but it would be great if we could fall-back onto eager execution, if possible.
On CPU, FlexAttention doesn't work at all.
On older GPUs:
If this could be easily fixed, I would greatly appreciate a patch! But if not, I'll likely be looking to implement this myself.
The text was updated successfully, but these errors were encountered: