diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index b12f04f3bf60..0b509df6415f 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -314,7 +314,7 @@ class PerfEnvPlugin(run.Plugin): enable_layernorm_sm_margin: bool = True layernorm_sm_margin: int = 16 enable_vboost: bool = False - nccl_pp_comm_chunksize: int = None + nccl_pp_comm_chunksize: Optional[int] = None def get_vboost_srun_cmd(self, nodes, job_dir): "Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command" diff --git a/scripts/llm/performance/README.md b/scripts/llm/performance/README.md new file mode 100644 index 000000000000..62bf58329633 --- /dev/null +++ b/scripts/llm/performance/README.md @@ -0,0 +1,27 @@ +# Performance Recipes + +- Scripts defined in `scripts/llm/performance` are recipes optimized for performance. These scripts can launch pre-training experiments on Slurm based clusters. +- You will need a virtual environemnt with NeMo and Nemo-Run related dependencies installed as the experiment configuration is resolved before launching it inside NeMo container. + +## Example + +The following line shows an example of how you can launch a pre-training experiment- + +`python3 scripts/llm/performance/llama3_8b.py --account -partition ` + +## Configuration Options + +- Slurm account and partition are mandatory arguments for launching the experiment. +- You can use the following optional arguments as needed- + - -l/--log_dir: Location to store your experiment artifacts and logs. + - Make sure the environemnt variable `NEMORUN_HOME=` is accessible and set correctly in your virtual environment. + - You can run `export NEMORUN_HOME=` in your terminal. You can add it your bashrc file (or equivalent for your OS/Linux distro) for setting it permanently. + - -t/--time_limit: Maximum time limit for your experiment. Your slurm job will be cancelled after this. Default is 30 minutes. + - -i/--container_image: The NeMo container you want to use. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'. + - -c/--compute_dtype: Specifies whether you want to use bf16 or fp8 precision for training. Defaults to 'bf16'. You can choose to use 'fp8'. + - -ep/--enable_profiling: Enable nsys profiling. It is disabled by default. When enabled, profiling will be enabled for 1 step from step 5 to step 6. You can change the step in the respective recipe script. + - -tb/--tensorboard: Enable tensorboard logging. It is disabled by default. + - CAUTION: Tensorboard logging may cause performance overhead. + - -d/--dryrun: Using this argument will not launch the experiment. It will simply print the sbatch script to stdout. This can be helpful to verify you have set your experiment correctly as needed. +- You don't need to set any value for `--enable_profiling`, `--tensorboard` and `--dryrun`. See the below example for reference- + `python3 scripts/llm/performance/llama3_8b.py --account -p -ep --tensorboard -d` diff --git a/scripts/llm/performance/llama3_405b.py b/scripts/llm/performance/llama3_405b.py new file mode 100644 index 000000000000..56980eed8c67 --- /dev/null +++ b/scripts/llm/performance/llama3_405b.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +from nemo_run.config import NEMORUN_HOME +from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor + +from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback +from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin +from nemo.utils import logging + +NUM_NODES = 72 +NUM_GPUS_PER_NODE = 8 +MICRO_BATCH_SIZE = 1 +GLOBAL_BATCH_SIZE = 252 +TP_SIZE = 8 +PP_SIZE = 9 +CP_SIZE = 2 +VP_SIZE = 7 +MAX_STEPS = 100 + + +def llama3_405b_performance_recipe( + compute_dtype: str, + num_nodes: int, + num_gpus_per_node: int, + mbs: int, + gbs: int, + tp_size: int, + pp_size: int, + cp_size: int, + vp_size: Optional[int], + max_steps: int, +): + """ + llama3 405b pre-train recipe aimed at achieving best possible performance. + + NOTE: Use fp8 precision training with caution. It might not give desirable results. + """ + recipe = pretrain_recipe(performance_mode=True) + + # data module configs + recipe.data.micro_batch_size = mbs + recipe.data.global_batch_size = gbs + recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run + recipe.data.tokenizer = hf_tokenizer("meta-llama/Llama-3.1-405B") + + recipe.trainer.max_steps = max_steps + recipe.trainer.num_nodes = num_nodes + recipe.trainer.devices = num_gpus_per_node + + # parallelism configs + recipe.trainer.strategy.tensor_model_parallel_size = tp_size + recipe.trainer.strategy.pipeline_model_parallel_size = pp_size + recipe.trainer.strategy.context_parallel_size = cp_size + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size + if tp_size > 1: + recipe.trainer.strategy.sequence_parallel = True + else: + recipe.trainer.strategy.sequence_parallel = False + + comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) + + # compute dtype configs + if compute_dtype.lower() == "fp8": + recipe.trainer.plugins = bf16_with_fp8_mixed() + recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True + recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True + + recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype + + # callback configs + garbage_collection_callback = run.Config( + GarbageCollectionCallback, + gc_interval_train=100, + gc_interval_val=500, + ) + recipe.trainer.callbacks.extend( + [ + garbage_collection_callback, + ] + ) + dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) + if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: + if comm_overlap_callback_idx >= 0: + recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True + + # Misc. for overall faster experiment runtime + recipe.log.ckpt = None + recipe.trainer.enable_checkpointing = False + recipe.trainer.val_check_interval = max_steps * gbs / dp_size + recipe.trainer.log_every_n_steps = 1 + + return recipe + + +if __name__ == "__main__": + args = parse_cli_args().parse_args() + if args.log_dir != NEMORUN_HOME: + import sys + + logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.") + sys.exit(1) + + exp_name = "_".join( + [ + f"llama3_405b", + args.compute_dtype, + f"{NUM_NODES}nodes", + f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", + f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", + ] + ) + + executor = slurm_executor( + args.account, + args.partition, + args.log_dir, + NUM_NODES, + NUM_GPUS_PER_NODE, + args.time_limit, + args.container_image, + custom_mounts=[], + custom_env_vars={}, + retries=0, + ) + + recipe = llama3_405b_performance_recipe( + args.compute_dtype, + NUM_NODES, + NUM_GPUS_PER_NODE, + MICRO_BATCH_SIZE, + GLOBAL_BATCH_SIZE, + TP_SIZE, + PP_SIZE, + CP_SIZE, + VP_SIZE, + MAX_STEPS, + ) + + if not args.tensorboard: # tensorboard adds performance overhead. + recipe.log.tensorboard = None + recipe.trainer.logger = False + else: + # default path is NOT intuitive- `/code/nemo_experiments/tb_logs/default/` + # following line ensures file is at- `/lightning_logs/tb_logs/default/` + recipe.log.log_dir = "/nemo_run/lightning_logs" + + plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)] + if args.enable_profiling: + plugins.append(NsysPlugin(start_step=5, end_step=6)) + + with run.Experiment(exp_name) as exp: + exp.add( + recipe, + executor=executor, + name=exp_name, + plugins=plugins, + ) + + if not args.dryrun: + exp.run(sequential=True, detach=True) + else: + exp.dryrun() diff --git a/scripts/llm/performance/llama3_70b.py b/scripts/llm/performance/llama3_70b.py new file mode 100644 index 000000000000..ab41c05f2e8d --- /dev/null +++ b/scripts/llm/performance/llama3_70b.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +from nemo_run.config import NEMORUN_HOME +from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor + +from nemo.collections.llm.recipes.llama3_70b import pretrain_recipe +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback +from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin +from nemo.utils import logging + +NUM_NODES = 8 +NUM_GPUS_PER_NODE = 8 +MICRO_BATCH_SIZE = 1 +GLOBAL_BATCH_SIZE = 128 +TP_SIZE = 4 +PP_SIZE = 4 +CP_SIZE = 2 +VP_SIZE = 5 +MAX_STEPS = 100 + + +def llama3_70b_performance_recipe( + compute_dtype: str, + num_nodes: int, + num_gpus_per_node: int, + mbs: int, + gbs: int, + tp_size: int, + pp_size: int, + cp_size: int, + vp_size: Optional[int], + max_steps: int, +): + """ + llama3 70b pre-train recipe aimed at achieving best possible performance. + + NOTE: Use fp8 precision training with caution. It might not give desirable results. + """ + recipe = pretrain_recipe(performance_mode=True) + + # data module configs + recipe.data.micro_batch_size = mbs + recipe.data.global_batch_size = gbs + recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run + recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-70B") + + recipe.trainer.max_steps = max_steps + recipe.trainer.num_nodes = num_nodes + recipe.trainer.devices = num_gpus_per_node + + # parallelism configs + recipe.trainer.strategy.tensor_model_parallel_size = tp_size + recipe.trainer.strategy.pipeline_model_parallel_size = pp_size + recipe.trainer.strategy.context_parallel_size = cp_size + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size + if tp_size > 1: + recipe.trainer.strategy.sequence_parallel = True + else: + recipe.trainer.strategy.sequence_parallel = False + + comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) + + # compute dtype configs + if compute_dtype.lower() == "fp8": + recipe.trainer.plugins = bf16_with_fp8_mixed() + recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.proj_fprop.fp8_buf = True + recipe.trainer.callbacks[comm_overlap_callback_idx].tp_comm_overlap_cfg.fc2_fprop.fp8_buf = True + + recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype + + # callback configs + garbage_collection_callback = run.Config( + GarbageCollectionCallback, + gc_interval_train=100, + gc_interval_val=500, + ) + recipe.trainer.callbacks.extend( + [ + garbage_collection_callback, + ] + ) + dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) + if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: + if comm_overlap_callback_idx >= 0: + recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True + + # Misc. for overall faster experiment runtime + recipe.log.ckpt = None + recipe.trainer.enable_checkpointing = False + recipe.trainer.val_check_interval = max_steps * gbs / dp_size + recipe.trainer.log_every_n_steps = 1 + + return recipe + + +if __name__ == "__main__": + args = parse_cli_args().parse_args() + if args.log_dir != NEMORUN_HOME: + import sys + + logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.") + sys.exit(1) + + exp_name = "_".join( + [ + f"llama3_70b", + args.compute_dtype, + f"{NUM_NODES}nodes", + f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", + f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", + ] + ) + + executor = slurm_executor( + args.account, + args.partition, + args.log_dir, + NUM_NODES, + NUM_GPUS_PER_NODE, + args.time_limit, + args.container_image, + custom_mounts=[], + custom_env_vars={}, + retries=0, + ) + + recipe = llama3_70b_performance_recipe( + args.compute_dtype, + NUM_NODES, + NUM_GPUS_PER_NODE, + MICRO_BATCH_SIZE, + GLOBAL_BATCH_SIZE, + TP_SIZE, + PP_SIZE, + CP_SIZE, + VP_SIZE, + MAX_STEPS, + ) + + if not args.tensorboard: # tensorboard adds performance overhead. + recipe.log.tensorboard = None + recipe.trainer.logger = False + else: + # default path is NOT intuitive- `/code/nemo_experiments/tb_logs/default/` + # following line ensures file is at- `/lightning_logs/tb_logs/default/` + recipe.log.log_dir = "/nemo_run/lightning_logs" + + plugins = [PerfEnvPlugin(enable_vboost=True, nccl_pp_comm_chunksize=2097152)] + if args.enable_profiling: + plugins.append(NsysPlugin(start_step=5, end_step=6)) + + with run.Experiment(exp_name) as exp: + exp.add( + recipe, + executor=executor, + name=exp_name, + plugins=plugins, + ) + + if not args.dryrun: + exp.run(sequential=True, detach=True) + else: + exp.dryrun() diff --git a/scripts/llm/performance/llama3_8b.py b/scripts/llm/performance/llama3_8b.py new file mode 100644 index 000000000000..38fa9e5bd621 --- /dev/null +++ b/scripts/llm/performance/llama3_8b.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +from nemo_run.config import NEMORUN_HOME +from utils import get_comm_overlap_callback_idx, hf_tokenizer, parse_cli_args, slurm_executor + +from nemo.collections.llm.recipes.llama3_8b import pretrain_recipe +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback +from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin +from nemo.utils import logging + +NUM_NODES = 1 +NUM_GPUS_PER_NODE = 8 +MICRO_BATCH_SIZE = 1 +GLOBAL_BATCH_SIZE = 128 +TP_SIZE = 1 +PP_SIZE = 1 +CP_SIZE = 2 +VP_SIZE = None +MAX_STEPS = 100 + + +def llama3_8b_performance_recipe( + compute_dtype: str, + num_nodes: int, + num_gpus_per_node: int, + mbs: int, + gbs: int, + tp_size: int, + pp_size: int, + cp_size: int, + vp_size: Optional[int], + max_steps: int, +): + """ + llama3 8b pre-train recipe aimed at achieving best possible performance. + + NOTE: Use fp8 precision training with caution. It might not give desirable results. + """ + recipe = pretrain_recipe(performance_mode=True) + + # data module configs + recipe.data.micro_batch_size = mbs + recipe.data.global_batch_size = gbs + recipe.data.num_train_samples = max_steps * gbs # ensure only 1 epoch for whole run + recipe.data.tokenizer = hf_tokenizer("meta-llama/Meta-Llama-3-8B") + + recipe.trainer.max_steps = max_steps + recipe.trainer.num_nodes = num_nodes + recipe.trainer.devices = num_gpus_per_node + + # parallelism configs + recipe.trainer.strategy.tensor_model_parallel_size = tp_size + recipe.trainer.strategy.pipeline_model_parallel_size = pp_size + recipe.trainer.strategy.context_parallel_size = cp_size + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size + if tp_size > 1: + recipe.trainer.strategy.sequence_parallel = True + else: + recipe.trainer.strategy.sequence_parallel = False + + comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) + + # compute dtype configs + if compute_dtype.lower() == "fp8": + recipe.trainer.plugins = bf16_with_fp8_mixed() + recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype + + # callback configs + garbage_collection_callback = run.Config( + GarbageCollectionCallback, + gc_interval_train=100, + gc_interval_val=500, + ) + recipe.trainer.callbacks.extend( + [ + garbage_collection_callback, + ] + ) + dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size) + if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1: + if comm_overlap_callback_idx >= 0: + recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True + + # Misc. for overall faster experiment runtime + recipe.log.ckpt = None + recipe.trainer.enable_checkpointing = False + recipe.trainer.val_check_interval = max_steps * gbs / dp_size + recipe.trainer.log_every_n_steps = 1 + + return recipe + + +if __name__ == "__main__": + args = parse_cli_args().parse_args() + if args.log_dir != NEMORUN_HOME: + import sys + + logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.") + sys.exit(1) + + exp_name = "_".join( + [ + f"llama3_8b", + args.compute_dtype, + f"{NUM_NODES}nodes", + f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}", + f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs", + ] + ) + + executor = slurm_executor( + args.account, + args.partition, + args.log_dir, + NUM_NODES, + NUM_GPUS_PER_NODE, + args.time_limit, + args.container_image, + custom_mounts=[], + custom_env_vars={}, + retries=0, + ) + + recipe = llama3_8b_performance_recipe( + args.compute_dtype, + NUM_NODES, + NUM_GPUS_PER_NODE, + MICRO_BATCH_SIZE, + GLOBAL_BATCH_SIZE, + TP_SIZE, + PP_SIZE, + CP_SIZE, + VP_SIZE, + MAX_STEPS, + ) + + if not args.tensorboard: # tensorboard adds performance overhead. + recipe.log.tensorboard = None + recipe.trainer.logger = False + else: + # default path is NOT intuitive- `/code/nemo_experiments/tb_logs/default/` + # following line ensures file is at- `/lightning_logs/tb_logs/default/` + recipe.log.log_dir = "/nemo_run/lightning_logs" + + plugins = [PerfEnvPlugin(enable_vboost=True)] + if args.enable_profiling: + plugins.append(NsysPlugin(start_step=5, end_step=6)) + + with run.Experiment(exp_name) as exp: + exp.add( + recipe, + executor=executor, + name=exp_name, + plugins=plugins, + ) + + if not args.dryrun: + exp.run(sequential=True, detach=True) + else: + exp.dryrun() diff --git a/scripts/llm/performance/utils.py b/scripts/llm/performance/utils.py new file mode 100644 index 000000000000..8574b4f30f2b --- /dev/null +++ b/scripts/llm/performance/utils.py @@ -0,0 +1,197 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from typing import Dict, List, Optional + +import nemo_run as run +from lightning.pytorch.callbacks.callback import Callback +from nemo_run.config import NEMORUN_HOME + +from nemo.collections.common.tokenizers.huggingface import AutoTokenizer +from nemo.collections.llm.recipes.llama3_8b import MegatronCommOverlapCallback + + +def slurm_executor( + account: str, + partition: str, + log_dir: str, + nodes: int, + num_gpus_per_node: int, + time_limit: str = "01:00:00", + container_image: str = "nvcr.io/nvidia/nemo:dev", + custom_mounts: Optional[List[str]] = None, + custom_env_vars: Optional[Dict[str, str]] = None, + custom_srun_args: Optional[List[str]] = None, + retries: int = 0, +) -> run.SlurmExecutor: + """ + Slurm cluster definition with appropriate cluster params and NeMo container params needed for pre-training + and fine-tuning experiments + """ + if not (log_dir and account and partition and nodes and num_gpus_per_node): + raise RuntimeError( + "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this ", + "function.", + ) + + mounts = [] + if custom_mounts: + mounts.extend(custom_mounts) + + env_vars = { + "TRANSFORMERS_OFFLINE": "1", + "TOKENIZERS_PARALLELISM": "False", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "NVTE_FUSED_ATTN": "1", + "NVTE_FLASH_ATTN": "0", + "NEMO_LOG_MEMORY_USAGE": "1", + "NEMORUN_HOME": log_dir, + } + if custom_env_vars: + env_vars |= custom_env_vars + + srun_args = ["--mpi=pmix"] + if custom_srun_args: + srun_args.extend(custom_srun_args) + + executor = run.SlurmExecutor( + account=account, + partition=partition, + tunnel=run.LocalTunnel( + job_dir=os.path.join(log_dir, "experiments"), + ), + nodes=nodes, + ntasks_per_node=num_gpus_per_node, + mem="0", + exclusive=True, + packager=run.GitArchivePackager(), + ) + + executor.container_image = container_image + executor.container_mounts = mounts + executor.env_vars = env_vars + executor.srun_args = srun_args + executor.retries = retries + executor.time = time_limit + + return executor + + +def hf_tokenizer(model_name: str) -> run.Config[AutoTokenizer]: + """ + AutoTokenizer first searches for tokenizer files locally in env var 'NEMO_HOME'. + If tokenizer files are not present locally, AutoTokenizer will try downloading from HuggingFace. + In the case tokenizer needs downloading, make sure env vars- 'TRANSFORMERS_OFFLINE=0' and + 'HF_TOKEN:' are defined in SlurmExecutor.env_vars + """ + return run.Config( + AutoTokenizer, + pretrained_model_name=model_name, + use_fast=True, + ) + + +def get_comm_overlap_callback_idx(callbacks: List[Callback]): + """ + nemo.lightning.Trainer has a list of callbacks defined. This method identifies index of MegatronCommOverlapCallback + from the list defined in recipes in nemo.collections.llm.recipes. The index is needed to override ddp communication + params + """ + if callbacks: # default is None in lightning + for idx, callback in enumerate(callbacks): + if isinstance(callback, MegatronCommOverlapCallback): + return idx + return -1 + + +def parse_cli_args(): + """ + Command line arguments correspong to Slurm cluster and NeMo2.0 for running pre-training and + fine-tuning experiments. + """ + parser = argparse.ArgumentParser(description="NeMo2.0 Performance Pretraining and Fine-Tuning") + + parser.add_argument( + "-a", + "--account", + type=str, + help="Slurm account to use for experiment", + required=True, + ) + parser.add_argument( + "-p", + "--partition", + type=str, + help="Slurm partition to use for experiment", + required=True, + ) + parser.add_argument( + "-l", + "--log_dir", + type=str, + help=f"Directory for logging experiment results. Defaults to {NEMORUN_HOME}", + required=False, + default=NEMORUN_HOME, + ) + parser.add_argument( + "-t", + "--time_limit", + type=str, + help="Maximum time limit to run experiment for. Defaults to 30 minutes (format- 'HH:MM:SS')", + required=False, + default="00:30:00", + ) + parser.add_argument( + "-i", + "--container_image", + type=str, + help="NeMo container to use for experiment. Defaults to latest dev container- 'nvcr.io/nvidia/nemo:dev'\ + Make sure your NGC credentials are accessible in your environment.", + required=False, + default="nvcr.io/nvidia/nemo:dev", + ) + parser.add_argument( + "-c", + "--compute_dtype", + type=str, + help="Compute precision. Options- bf16 or fp8. Defaults to bf16", + required=False, + default="bf16", + ) + parser.add_argument( + "-ep", + "--enable_profiling", + help="Enable Nsys profiling. Diabled by default", + action="store_true", + ) + parser.add_argument( + "-tb", + "--tensorboard", + help="Enable tensorboard logging. Disabled by default", + action="store_true", + ) + parser.add_argument( + "-d", + "--dryrun", + help="If true, prints sbatch script to terminal without launching experiment.", + required=False, + action="store_true", + ) + + return parser