Skip to content

Commit

Permalink
NeMo2.0 llama3 perf scripts (#11702)
Browse files Browse the repository at this point in the history
* perf scripts llama3 8b

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* copyright

Signed-off-by: Malay Nagda <[email protected]>

* llama3 70b

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* 405b recipe

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* doc strings

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* remove tb logging and formatting

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* disable default tb and profiling

Signed-off-by: Malay Nagda <[email protected]>

* num steps per epoch

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* correct filepaths

Signed-off-by: Malay Nagda <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* remove param

Signed-off-by: Malay Nagda <[email protected]>

* README

Signed-off-by: Malay Nagda <[email protected]>

* updated param

Signed-off-by: Malay Nagda <[email protected]>

---------

Signed-off-by: Malay Nagda <[email protected]>
Signed-off-by: malay-nagda <[email protected]>
Co-authored-by: malay-nagda <[email protected]>
  • Loading branch information
malay-nagda and malay-nagda authored Jan 1, 2025
1 parent c341fb3 commit 91471f0
Show file tree
Hide file tree
Showing 6 changed files with 759 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class PerfEnvPlugin(run.Plugin):
enable_layernorm_sm_margin: bool = True
layernorm_sm_margin: int = 16
enable_vboost: bool = False
nccl_pp_comm_chunksize: int = None
nccl_pp_comm_chunksize: Optional[int] = None

def get_vboost_srun_cmd(self, nodes, job_dir):
"Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command"
Expand Down
27 changes: 27 additions & 0 deletions scripts/llm/performance/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Performance Recipes

- Scripts defined in `scripts/llm/performance` are recipes optimized for performance. These scripts can launch pre-training experiments on Slurm based clusters.
- You will need a virtual environemnt with NeMo and Nemo-Run related dependencies installed as the experiment configuration is resolved before launching it inside NeMo container.

## Example

The following line shows an example of how you can launch a pre-training experiment-

`python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -partition <your_slurm_partition>`

## Configuration Options

- Slurm account and partition are mandatory arguments for launching the experiment.
- You can use the following optional arguments as needed-
- -l/--log_dir: Location to store your experiment artifacts and logs.
- Make sure the environemnt variable `NEMORUN_HOME=<log_dir>` is accessible and set correctly in your virtual environment.
- You can run `export NEMORUN_HOME=<log_dir>` in your terminal. You can add it your bashrc file (or equivalent for your OS/Linux distro) for setting it permanently.
- -t/--time_limit: Maximum time limit for your experiment. Your slurm job will be cancelled after this. Default is 30 minutes.
- -i/--container_image: The NeMo container you want to use. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'.
- -c/--compute_dtype: Specifies whether you want to use bf16 or fp8 precision for training. Defaults to 'bf16'. You can choose to use 'fp8'.
- -ep/--enable_profiling: Enable nsys profiling. It is disabled by default. When enabled, profiling will be enabled for 1 step from step 5 to step 6. You can change the step in the respective recipe script.
- -tb/--tensorboard: Enable tensorboard logging. It is disabled by default.
- CAUTION: Tensorboard logging may cause performance overhead.
- -d/--dryrun: Using this argument will not launch the experiment. It will simply print the sbatch script to stdout. This can be helpful to verify you have set your experiment correctly as needed.
- You don't need to set any value for `--enable_profiling`, `--tensorboard` and `--dryrun`. See the below example for reference-
`python3 scripts/llm/performance/llama3_8b.py --account <your_slurm_account> -p <your_slurm_partition> -ep --tensorboard -d`
179 changes: 179 additions & 0 deletions scripts/llm/performance/llama3_405b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import nemo_run as run
from nemo_run.config import NEMORUN_HOME
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor

from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
from nemo.utils import logging

NUM_NODES = 72
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 252
TP_SIZE = 8
PP_SIZE = 9
CP_SIZE = 2
VP_SIZE = 7
MAX_STEPS = 100


def llama3_405b_performance_recipe(
compute_dtype: str,
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
vp_size: Optional[int],
max_steps: int,
):
"""
llama3 405b pre-train recipe aimed at achieving best possible performance.
NOTE: Use fp8 precision training with caution. It might not give desirable results.
"""
recipe = pretrain_recipe(performance_mode=True)

# data module configs
recipe.data.micro_batch_size = mbs
recipe.data.global_batch_size = gbs
recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run
recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B")

recipe.trainer.max_steps = max_steps
recipe.trainer.num_nodes = num_nodes
recipe.trainer.devices = num_gpus_per_node

# parallelism configs
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
recipe.trainer.strategy.context_parallel_size = cp_size
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
if tp_size > 1:
recipe.trainer.strategy.sequence_parallel = True
else:
recipe.trainer.strategy.sequence_parallel = False

comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)

# compute dtype configs
if compute_dtype.lower() == "fp8":
recipe.trainer.plugins = bf16_with_fp8_mixed()
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True

recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype

# callback configs
garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=500,
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
]
)
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

# Misc. for overall faster experiment runtime
recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval = max_steps * gbs / dp_size
recipe.trainer.log_every_n_steps = 1

return recipe


if __name__ == "__main__":
args = parse_cli_args().parse_args()
if args.log_dir != NEMORUN_HOME:
import sys

logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
sys.exit(1)

exp_name = "_".join(
[
f"llama3_405b",
args.compute_dtype,
f"{NUM_NODES}nodes",
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
]
)

executor = slurm_executor(
args.account,
args.partition,
args.log_dir,
NUM_NODES,
NUM_GPUS_PER_NODE,
args.time_limit,
args.container_image,
custom_mounts=[],
custom_env_vars={},
retries=0,
)

recipe = llama3_405b_performance_recipe(
args.compute_dtype,
NUM_NODES,
NUM_GPUS_PER_NODE,
MICRO_BATCH_SIZE,
GLOBAL_BATCH_SIZE,
TP_SIZE,
PP_SIZE,
CP_SIZE,
VP_SIZE,
MAX_STEPS,
)

if not args.tensorboard: # tensorboard adds performance overhead.
recipe.log.tensorboard = None
recipe.trainer.logger = False
else:
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
recipe.log.log_dir = "/nemo_run/lightning_logs"

plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)]
if args.enable_profiling:
plugins.append(NsysPlugin(start_step=5, end_step=6))

with run.Experiment(exp_name) as exp:
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=plugins,
)

if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()
179 changes: 179 additions & 0 deletions scripts/llm/performance/llama3_70b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import nemo_run as run
from nemo_run.config import NEMORUN_HOME
from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor

from nemo.collections.llm.recipes.llama3_70b import pretrain_recipe
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
from nemo.utils import logging

NUM_NODES = 8
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 128
TP_SIZE = 4
PP_SIZE = 4
CP_SIZE = 2
VP_SIZE = 5
MAX_STEPS = 100


def llama3_70b_performance_recipe(
compute_dtype: str,
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
vp_size: Optional[int],
max_steps: int,
):
"""
llama3 70b pre-train recipe aimed at achieving best possible performance.
NOTE: Use fp8 precision training with caution. It might not give desirable results.
"""
recipe = pretrain_recipe(performance_mode=True)

# data module configs
recipe.data.micro_batch_size = mbs
recipe.data.global_batch_size = gbs
recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run
recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-70B")

recipe.trainer.max_steps = max_steps
recipe.trainer.num_nodes = num_nodes
recipe.trainer.devices = num_gpus_per_node

# parallelism configs
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
recipe.trainer.strategy.context_parallel_size = cp_size
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
if tp_size > 1:
recipe.trainer.strategy.sequence_parallel = True
else:
recipe.trainer.strategy.sequence_parallel = False

comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)

# compute dtype configs
if compute_dtype.lower() == "fp8":
recipe.trainer.plugins = bf16_with_fp8_mixed()
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True
recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True

recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype

# callback configs
garbage_collection_callback = run.Config(
GarbageCollectionCallback,
gc_interval_train=100,
gc_interval_val=500,
)
recipe.trainer.callbacks.extend(
[
garbage_collection_callback,
]
)
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

# Misc. for overall faster experiment runtime
recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval = max_steps * gbs / dp_size
recipe.trainer.log_every_n_steps = 1

return recipe


if __name__ == "__main__":
args = parse_cli_args().parse_args()
if args.log_dir != NEMORUN_HOME:
import sys

logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
sys.exit(1)

exp_name = "_".join(
[
f"llama3_70b",
args.compute_dtype,
f"{NUM_NODES}nodes",
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
]
)

executor = slurm_executor(
args.account,
args.partition,
args.log_dir,
NUM_NODES,
NUM_GPUS_PER_NODE,
args.time_limit,
args.container_image,
custom_mounts=[],
custom_env_vars={},
retries=0,
)

recipe = llama3_70b_performance_recipe(
args.compute_dtype,
NUM_NODES,
NUM_GPUS_PER_NODE,
MICRO_BATCH_SIZE,
GLOBAL_BATCH_SIZE,
TP_SIZE,
PP_SIZE,
CP_SIZE,
VP_SIZE,
MAX_STEPS,
)

if not args.tensorboard: # tensorboard adds performance overhead.
recipe.log.tensorboard = None
recipe.trainer.logger = False
else:
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
recipe.log.log_dir = "/nemo_run/lightning_logs"

plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)]
if args.enable_profiling:
plugins.append(NsysPlugin(start_step=5, end_step=6))

with run.Experiment(exp_name) as exp:
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=plugins,
)

if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()
Loading

0 comments on commit 91471f0

Please sign in to comment.