Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashInfer generating NANs on A100 GPU #574

Open
dbarbuzzi opened this issue Oct 30, 2024 · 3 comments
Open

FlashInfer generating NANs on A100 GPU #574

dbarbuzzi opened this issue Oct 30, 2024 · 3 comments
Assignees

Comments

@dbarbuzzi
Copy link

When running a particular vLLM test on an A100 GPU, flashinfer appears to be generating nans under a specific scenario. The test fails under a specific scenario on an A100 while passing all scenarios on both an H100 and an L4. We are using flashinfer-0.1.6+cu124torch2.4.

The test that fails is test_flashinfer_decode_with_paged_fp8_kv.

The failure scenario is when three of the parameters are three specific values at the same time:

  • block_size = 32
  • head_size = 256
  • num_heads = (32, 8)
    • 32 gets assigned to num_query_heads
    • 8 gets assigned to num_kv_heads

If any of these parameters is one of the other possible values, the test will pass on the A100.

The failure message seems to indicate that, under this scenario, nans are being generated:

AssertionError: Tensor-likes are not close!

Mismatched elements: 1024 / 24576 (4.2%)
Greatest absolute difference: nan at index (0, 0, 0) (up to 0.02 allowed)
Greatest relative difference: nan at index (0, 0, 0) (up to 0.01 allowed)

The general error message is the same between failures; the only variations are the total number of elements (either 24576 or 32768; the number of mismatched elements is always 1024) or the index (it is either (0, 0, 0) or (3, 0, 0)).

@yzh119 yzh119 self-assigned this Oct 30, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Oct 30, 2024

Hi @dbarbuzzi , thanks for reporting this issue, this issue appears for fp8, is that correct?

@dbarbuzzi
Copy link
Author

Hi @dbarbuzzi , thanks for reporting this issue, this issue appears for fp8, is that correct?

Yes, that is correct; I should have clarified that originally.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 7, 2024

Bug confirmed, working on a fix :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants