Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. #41

Open
donglixp opened this issue Feb 3, 2024 · 15 comments
Assignees

Comments

@donglixp
Copy link

donglixp commented Feb 3, 2024

Problem Description

File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
loss.backward()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
_flash_attn_backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

Operating System

20.04.6 LTS (Focal Fossa)

CPU

AMD EPYC 7V12 64-Core Processor

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0, ROCm 5.7.1

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@donglixp
Copy link
Author

donglixp commented Feb 4, 2024

                q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
                k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
                v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
                print(q.shape)
                print(k.shape)
                print(v.shape)
                attn = flash_attn_func(q, k, v, causal=is_causal)
                attn = rearrange(attn, 'b l h d -> (b h) l d')

The error message:

torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
Traceback (most recent call last):
  File "/tmp/amlt-code-download/fairseq/train.py", line 14, in <module>
    cli_main()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 543, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 365, in call_main
    distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 339, in distributed_main
    main(cfg, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 191, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 307, in train
    log_output = trainer.train_step(samples)
  File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 850, in train_step
    raise e
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 818, in train_step
    loss, sample_size_i, logging_output = self.task.train_step(
  File "/tmp/amlt-code-download/fairseq/tasks/gpt.py", line 253, in train_step
    optimizer.backward(loss)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
    loss.backward()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
    _flash_attn_backward(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false.

@donglixp
Copy link
Author

donglixp commented Feb 4, 2024

The log shows the program runs well for steps and triggers the bug, rather than encountering the error at the first calling.

@donglixp
Copy link
Author

donglixp commented Feb 4, 2024

I tried both rocm-5.7/6.0 dockers.

@donglixp
Copy link
Author

donglixp commented Feb 4, 2024

The bug is related to the qkv shape:

[1, 2048, 48, 64]: works well

[2, 2048, 48, 64]: triggers the bug

[4, 2048, 48, 64]: triggers the bug

[1, 2048, 24, 128]: triggers the bug

[2, 2048, 24, 128]: triggers the bug

[2, 2048, 25, 128]: triggers the bug

[2, 2048, 24, 124]: works well

[2, 2048, 48, 62]: works well

@dejay-vu dejay-vu self-assigned this Feb 4, 2024
@dejay-vu
Copy link

dejay-vu commented Feb 4, 2024

@donglixp Can I have the script you are running?

@donglixp
Copy link
Author

donglixp commented Feb 5, 2024

@howiejayz
VM ROCM 5.6.0
Docker: rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1

Forward and backward with the following shapes:

[2, 2048, 48, 64]: triggers the bug

[4, 2048, 48, 64]: triggers the bug

[1, 2048, 24, 128]: triggers the bug

[2, 2048, 24, 128]: triggers the bug

[2, 2048, 25, 128]: triggers the bug

@donglixp
Copy link
Author

donglixp commented Feb 5, 2024

                q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
                k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
                v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
                attn = flash_attn_func(q, k, v, causal=is_causal)
                attn = rearrange(attn, 'b l h d -> (b h) l d')

@donglixp
Copy link
Author

donglixp commented Feb 5, 2024

Although using [2, 2048, 48, 62] didn't trigger the above error. I found that the job encountered loss divergence issues, while a similar recipe ran successfully before (when the VM ROCM is 5.4 and docker is 5.7).

@dejay-vu
Copy link

dejay-vu commented Feb 5, 2024

The error seems to be triggered when dout is not contiguous. May I ask how do you generate the dout when passing to the backward?

@donglixp
Copy link
Author

donglixp commented Feb 5, 2024

@howiejayz Yes, they were. The contiguous() was also handled at

dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]

@dejay-vu
Copy link

dejay-vu commented Feb 6, 2024

@donglixp. Thanks. Can I also have the information of the repo and branch you are testing? So I can reproduce your result and see what step goes wrong.

@rocking5566
Copy link
Collaborator

We change the backend of flash attention 2in the branch of ck_tile
I also submit an PR to support AMD / ROCm on FlashAttention 2
Dao-AILab#1010
This PR using composable_kernel as backend
I hope this may solve your issue

@zixian-wang-amd
Copy link

We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue

Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.

@rocking5566
Copy link
Collaborator

We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue

Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.

[email protected]

@carlushuang
Copy link
Collaborator

cc @danyao12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants