Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention on NVIDIA #678

Open
MikeMpapa opened this issue Jul 30, 2024 · 2 comments
Open

Fix flash attention on NVIDIA #678

MikeMpapa opened this issue Jul 30, 2024 · 2 comments

Comments

@MikeMpapa
Copy link

Hi again :) - I am getting the following JAX notifications while running a training job and I was wondering if you can provide some clarity.

  1. INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
    I do run the training job on an instance with 8 A100 GPUs. Does this notification indicate that each of the GPUs is training a different model independently instead of distributedly training a single model or it just means that I have only one node/instance so there is no use for jax.distributed? I did not find anything anything on Levanter documentation about the need to provide a "distributed config" but maybe I missed something??? On wandb-log I can see all 8 GPUs being utilized under the experiment.

  2. INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
    This is not clear to me if it causes any issues in the training process

  3. /levanter/src/levanter/models/attention.py:266: UserWarning: transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention.. Falling back to the reference implementation.
    I used ghcr.io/nvidia/jax:levanter-2024-07-24 as my base container and create a new version where I installed transformers_engine like that RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable as explained in NVIDIA's documentation. It doesn't seem to do the job though since I am still seeing the warning. Do you have any suggestions on this??

Thanks a lot in advance!!

@dlwh
Copy link
Member

dlwh commented Jul 30, 2024

The first two are totally normal and fine.

  1. You're not using distributed, just multigpu, which can be done in a single process on JAX. (You can also do it multiprocess which is supposedly slightly faster, but it's a bit more of a hassle if you're not using SLURM.
  2. We very early on set INFO level logging for JAX and for us. JAX logs each of the platforms it tries before landing on the one it uses.
  3. This one is a bit of a concern, but just means it's gonna be a bit slower than it could be. I'm on vacation now and not able to work on this. @DwarKapex Do you have any insight on (3)? Thanks!

@nouiz
Copy link

nouiz commented Aug 5, 2024

For 3:

import transformer_engine work when start that container like this:

docker run -ti --gpus all ghcr.io/nvidia/jax:levanter-2024-07-24.

Do you have a repro for 3? Do you have the issue with today containers?

@dlwh dlwh changed the title Jax on Levanter Fix flash attention on NVIDIA Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants