Skip to content

Commit

Permalink
Merge branch 'main' into newptl_fix_validation_in_spellmapper
Browse files Browse the repository at this point in the history
  • Loading branch information
bene-ges authored Nov 22, 2023
2 parents 3c7981d + 9c7926d commit 388bb50
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
4 changes: 1 addition & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ WORKDIR /tmp/nemo
COPY requirements .
RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done

# install flash attention dependencies
# install flash attention
RUN pip install flash-attn
# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
RUN pip install triton==2.0.0.dev20221202
# install numba for latest containers
RUN pip install numba>=0.57.1

Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8.

Flash Attention
~~~~~~~~~~~~~~~~~~~~
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_.
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.

.. code-block:: bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

if self.megatron_amp_O2:

if not self.with_distributed_adam:
if not self.with_distributed_adam and not self.cfg.get("use_cpu_initialization", False):
# Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type
if isinstance(self.model, list):
for module in self.model:
Expand Down Expand Up @@ -1585,7 +1585,7 @@ def build_transformer_config(self) -> TransformerConfig:
'recompute_method': recompute_method,
'recompute_num_layers': recompute_num_layers,
'distribute_saved_activations': False, # not currently used in NeMo
'ub_tp_comm_overlap': ub_tp_comm_overlap,
'tp_comm_overlap': ub_tp_comm_overlap,
'fp8': fp8,
}

Expand Down
17 changes: 14 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,23 @@

HAVE_MEGATRON_CORE = False

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

except (ImportError, ModuleNotFoundError, AssertionError):

flash_attn_func_triton = None


try:
# Flash Attention 1.X
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

HAVE_FLASH_ATTENTION = True
flash_attn_func = None
Expand All @@ -85,8 +97,7 @@
except (ImportError, ModuleNotFoundError):

HAVE_FLASH_ATTENTION = False

flash_attn_unpadded_func, flash_attn_func_triton, flash_attn_func = None, None, None
flash_attn_unpadded_func, flash_attn_func = None, None
unpad_input, pad_input = None, None

try:
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ tensorboard
text-unidecode
torch
tqdm>=4.41.0
triton
wget
wrapt
6 changes: 5 additions & 1 deletion tests/collections/nlp/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@
HAVE_FA = False

try:
import pkg_resources
import triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

HAVE_TRITON = True
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError, AssertionError):
HAVE_TRITON = False

try:
Expand Down

0 comments on commit 388bb50

Please sign in to comment.