-
Notifications
You must be signed in to change notification settings - Fork 252
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te ``` Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
4 changed files
with
132 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
# path hack, TODO remove | ||
import sys | ||
# sys.path.insert(0, '/home/vasiliy/local/torchtitan/torchtitan') | ||
import torchtitan.te_utils as te_utils | ||
|
||
import transformer_engine.pytorch as te | ||
from transformer_engine.common.recipe import Format, DelayedScaling | ||
|
||
fp8_format = Format.HYBRID | ||
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) | ||
|
||
def test(): | ||
# for now, single GPU smoke test of TE fp8 | ||
|
||
x = torch.randn(32, 32, device='cuda') | ||
|
||
m = nn.Sequential(nn.Linear(32, 32)).cuda() | ||
te_utils.swap_linear_to_te_linear(m) | ||
print(m) | ||
|
||
with maybe_te_float8_ctx: | ||
y = m(x) | ||
y.sum().backward() | ||
|
||
print('done') | ||
|
||
if __name__ == '__main__': | ||
test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Utilities for testing TransformerEngine | ||
Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the | ||
link below, and gave up for now as it seemed to be a lot of remaining work. | ||
We can power through that if needed later. | ||
* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f | ||
Note: I looked into using te.LayerNormLinear, and that would require changing | ||
how Attention and FFN are defined in torchtitan to use a single gemm for | ||
attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed. | ||
Note: PyTorch's checkpointing does not work with TE float8, fails with | ||
* https://gist.github.com/vkuzo/54c76c16d6a38610a1d78f4de07a71e7 | ||
TE does have a `transformer_engine.pytorch.checkpoint` function, but | ||
unclear where the code for that lives. For now, we have to use | ||
`--activation_checkpoint.mode none`. | ||
Note: using `--activation_checkpoint.mode none` leads to poor TE performance as | ||
the memory usage is close to my GPU limits, the | ||
`WARNING - 164 CUDA memory allocation retries` from the logs seems relevant. | ||
Full logs: https://gist.github.com/vkuzo/0d6ebac2df3f7c90464da1e16d75d24c | ||
Need to decrease memory usage (either by using a smaller model or decreasing | ||
sequence_length) to train with TE without issues. | ||
""" | ||
|
||
import contextlib | ||
import os | ||
|
||
# required for current build to work with fp8 on devgpu003.cco3 | ||
# context: https://github.com/NVIDIA/TransformerEngine/pull/575 | ||
# error stack trace if not enabled: https://gist.github.com/vkuzo/8e78282f4a986961753fba25249fdf77 | ||
# os.environ["NVTE_UNFUSED_FP8_UPDATE"] = "1" | ||
|
||
import torch | ||
|
||
# import transformer_engine as te | ||
import transformer_engine.pytorch as te | ||
|
||
from transformer_engine.common.recipe import Format, DelayedScaling | ||
te_fp8_format = Format.HYBRID | ||
te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
|
||
def swap_linear_to_te_linear(model, fqn=''): | ||
for name, child in model.named_children(): | ||
new_fqn = f"{fqn}.{name}" | ||
if isinstance(child, torch.nn.Linear): | ||
te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None) | ||
te_linear.weight = child.weight | ||
te_linear.bias = child.bias | ||
setattr(model, name, te_linear) | ||
else: | ||
swap_linear_to_te_linear(child, new_fqn) | ||
|
||
def get_maybe_fp8_autocast(job_config): | ||
# not for land - set up TransformerEngine fp8 autocast | ||
# Note: te.fp8_autocast has to be created at every training iteration. | ||
# If we try to create it once and reuse, we get this error: | ||
# https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b | ||
maybe_te_float8_ctx = contextlib.nullcontext() | ||
if job_config.training.use_te and job_config.training.use_te_float8: | ||
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=te_fp8_recipe) | ||
return maybe_te_float8_ctx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters