diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index a9b3d4361f5be..87c2e82980d3e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -24,7 +24,15 @@ from typing_extensions import Annotated import nemo.lightning as nl -from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, + io, +) +from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging @@ -492,6 +500,7 @@ def _setup( tokenizer: Optional[TokenizerType], model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) _log = log or NeMoLogger() if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 7e70a970913e6..5c6b71c747970 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -15,6 +15,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from nemo.lightning.pytorch.plugins import MegatronDataSampler class HfDatasetDataModule(pl.LightningDataModule): @@ -24,6 +25,7 @@ def __init__( num_workers=2, pin_memory=True, persistent_workers=True, + seq_length=1024, micro_batch_size=2, global_batch_size=2, pad_token_id=0, @@ -37,6 +39,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.seq_length = seq_length self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.pad_token_id = pad_token_id @@ -58,6 +61,7 @@ def pad_within_micro(batch, pad_token_id): max_len = max(map(len, batch)) return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask'])) return { key: batchify( torch.LongTensor( @@ -67,16 +71,26 @@ def pad_within_micro(batch, pad_token_id): ) ) ) - for key in ['tokens', 'labels'] + for key in keys } + def setup(self, stage: str): + if not self.use_mcore_sampler: + return + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + ) + def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) - dataloader = DataLoader( + return DataLoader( self.dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, @@ -84,20 +98,3 @@ def train_dataloader(self, collate_fn=None): collate_fn=collate_fn, batch_size=self.micro_batch_size, ) - if not self.use_mcore_sampler: - return dataloader - - rank = 0 - world_size = 1 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - return add_megatron_sampler( - dataloader, - self.micro_batch_size, - self.global_batch_size, - dataloader_type=self.mcore_dataloader_type, - rank=rank, - world_size=world_size, - ) diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index b48f99e061c97..496ec1bd262ee 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -274,7 +274,12 @@ def make_vocab_size_divisible_by(vocab_size): base //= 2 return base - output = LlamaConfig( + if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': + # Apply Llama3.1 customize rope scaling + cls = Llama31Config + else: + cls = LlamaConfig + output = cls( num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 21994b75f60dc..54f5045ccef23 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -26,6 +26,8 @@ llama3_70b, llama3_70b_16k, llama3_70b_64k, + llama31_8b, + llama31_70b, llama31_405b, mamba2_1_3b, mamba2_2_7b, @@ -65,6 +67,8 @@ "llama3_70b", "llama3_70b_16k", "llama3_70b_64k", + "llama31_8b", + "llama31_70b", "llama31_405b", "mamba2_130m", "mamba2_370m", diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index 5a1ff58e86610..e38eb4e7ec649 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -16,6 +16,7 @@ import nemo_run as run import pytorch_lightning as pl +import torch import nemo.lightning as nl from nemo.collections import llm @@ -82,7 +83,7 @@ def default_finetune_recipe( def default_finetune_trainer( tensor_parallelism=1, pipeline_parallelism=1, - pipeline_parallelism_type=None, + pipeline_parallelism_type=torch.bfloat16, virtual_pipeline_parallelism=None, context_parallelism=1, sequence_parallelism=False, @@ -93,6 +94,19 @@ def default_finetune_trainer( limit_val_batches=None, val_check_interval=30, ): + """ + Create a default fine-tuning trainer for any model. + + This function sets up a template for strategy and trainer. + + Args: + See docstrings of MegatronStrategy and Trainer. + + Returns: + run.Config: Config for a finetuning trainer. + + See usages of this in recipes for further details. + """ strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, @@ -125,7 +139,8 @@ def default_finetune_trainer( def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for + https://huggingface.co/{model_id}. This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index cead1f2e5689d..0d637bb63c0a4 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -278,7 +278,7 @@ def finetune_recipe( model(), "google/gemma-2b", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) if peft_scheme is None or peft_scheme.lower() == 'none': - recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 2 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index 055e9a06fcbaa..85500b4157d72 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -31,6 +32,7 @@ from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -237,3 +239,162 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: ) return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 3, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 405B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_405b + $ nemo llm finetune --factory "llama31_405b(num_nodes=3, name='my_llama31_405b_finetune')" + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_405b_finetune", num_nodes=3) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 405B model + requires substantial computational resources. + """ + if packed_sequence is None: + packed_sequence = performance_mode + + if seq_length is None: + seq_length = 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 12 + elif peft_scheme.lower() == 'lora': + num_nodes = 3 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.data.global_batch_size = 6 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 6 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + recipe.data.global_batch_size = 6 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + # Note: limited support. This is not necessarily the most optimized setting + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 6 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py new file mode 100644 index 0000000000000..91e4e10c83e6b --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -0,0 +1,403 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_70b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 70B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 70B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_70b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config70B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 5, + context_parallelism: int = 2, + sequence_parallelism: bool = True, + num_nodes: int = 4, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 70B model. + + This function sets up the distributed training strategy optimized for the large 70B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_70b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=4, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 70B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_70b + $ nemo llm pretrain --factory "llama31_70b(num_nodes=4, name='my_70b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_70b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 70B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 70B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = None, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 70B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_70b + $ nemo llm finetune --factory "llama31_70b(num_nodes=4, name='my_70b_finetune')" + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_70b_finetune", num_nodes=4) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model + requires substantial computational resources. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 4 + elif peft_scheme.lower() == 'lora': + num_nodes = 1 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.1-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py new file mode 100644 index 0000000000000..a4f0082e85358 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -0,0 +1,385 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_8b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 8B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 8B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_8b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config8B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 8B model. + + This function sets up the distributed training strategy optimized for the large 8B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_8b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 8B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_8b + $ nemo llm pretrain --factory "llama31_8b(num_nodes=4, name='my_8b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_8b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 8B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 8B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_8b + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_8b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Meta-Llama-3.1-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index 6e9da5c5116d5..bfb7567005099 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -259,7 +259,11 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -291,4 +295,80 @@ def finetune_recipe( recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 394a7718b8bd7..b66130198e3c4 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -24,7 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -247,7 +247,11 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -276,4 +280,70 @@ def finetune_recipe( recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py new file mode 100644 index 0000000000000..cb0ba14df868f --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -0,0 +1,310 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig15B, StarcoderModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "starcoder_15b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Starcoder 15B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder 15B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=starcoder_15b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + + return run.Config(StarcoderModel, config=run.Config(StarcoderConfig15B)) + + +def starcoder_trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 2, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 2000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Starcoder 15B models. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ) + + precision_plugin = None + if precision == "16-mixed": + precision_plugin = fp16_mixed() + elif precision == "bf16-mixed": + precision_plugin = bf16_mixed() + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + callbacks=callbacks, + devices=num_gpus_per_node, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=precision_plugin, + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=val_check_interval, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + # General + dir: Optional[str] = None, + name: str = "default", + # Trainer + tensor_parallelism: int = 1, + pipeline_parallelism: int = 8, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 300000, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 1000, + # Data + global_batch_size=32, + micro_batch_size=2, + seq_length=4096, + # Optimizer + warmup_steps=500, + constant_steps=0, + min_lr=3e-5, + max_lr=3e-4, + # Training function + fn=pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Starcoder 15B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + gradient_clip_val (float): Value for gradient clipping. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + global_batch_size (int): Global batch size. + micro_batch_size (int): Micro batch size. + seq_length (int): Sequence length. + warmup_steps (int): Number of warmup steps. + constant_steps (int): Number of constant steps. + min_lr (float): Minimum learning rate. + max_lr (float): Maximum learning rate. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory starcoder_15b + $ nemo llm pretrain --factory "starcoder_15b(num_nodes=1, name='my_starcoder2_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="starcoder2_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses a mock dataset, look for the finetune examples to see how to change the dataset. + """ + return run.Partial( + fn, + model=model(), + trainer=starcoder_trainer( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_type=pipeline_parallelism_type, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + max_steps=max_steps, + precision=precision, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + val_check_interval=val_check_interval, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, + seq_length=seq_length, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + precision=precision, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr, + max_lr=max_lr, + clip_grad=gradient_clip_val, + ), + resume=default_resume(), + ) + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Starcoder 15B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory starcoder_15b + + Python API usage: + >>> recipe = finetune_recipe(name="starcoder_15b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + recipe = default_finetune_recipe(model(), "bigcode/starcoder", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.pipeline_model_parallel_size = 8 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index db1aec0f5a557..b0e134ab0c35d 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -161,7 +161,7 @@ def convert_model_to_trt_llm_ckpt( or nemo_model_config.get("layernorm_zero_centered_gamma", False), "tp_size": training_tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"] and (decoder_type == "gptnext" or is_mcore), "num_attention_heads": num_attention_heads, "num_kv_heads": num_kv_heads, @@ -336,7 +336,7 @@ def dist_model_to_trt_llm_ckpt( "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", "tp_size": tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"], "num_attention_heads": nemo_model_config["num_attention_heads"], "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), "convert_on_device": True, diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 2cc720e148d4a..91d3b3f936d0b 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -33,7 +33,7 @@ from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -from nemo.lightning.pytorch.trainer import Trainer +from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode(): "ModelCheckpoint", "OptimizerModule", "Trainer", + "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ] diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 5d1738e348b1a..aadff50d4968a 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -561,7 +561,9 @@ def _io_flatten_object(instance): def _io_unflatten_object(values, metadata): - assert hasattr(_thread_local, "output_dir") + if not hasattr(_thread_local, "output_dir"): + return fdl.Config.__unflatten__(values, metadata) + output_dir = _thread_local.output_dir if len(values) == 1: diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 8b10f9aca50a5..a901a3a8842ac 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -220,7 +220,7 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None): if callback.dirpath is None: callback.dirpath = Path(log_dir / "checkpoints") if callback.filename is None: - callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}" + callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{consumed_samples}}" ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + "-last" def _handle_task_config(self, task_config, log_dir): diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 164c07fe5b808..4aabc2b45293f 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy import fiddle as fdl import pytorch_lightning as pl +from pytorch_lightning.loops import _TrainingEpochLoop +from pytorch_lightning.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric @@ -23,8 +26,40 @@ from nemo.lightning.io.mixin import IOMixin, serialization, track_io -class Trainer(pl.Trainer, IOMixin): +class NoValOnRestartTrainingLoop(_TrainingEpochLoop): + """ + Extend the PTL Epoch loop to skip validation when restarting. + This happens when resuming a checkpoint that has already run validation, but loading restores + the training state before validation has run. + """ + + def _should_check_val_fx(self, data_fetcher) -> bool: + if self.skip_val_on_restart: + return False + return super()._should_check_val_fx(data_fetcher) + + def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: + super().load_state_dict(state_dict, prefix) + + self.skip_val_on_restart = True + + def advance(self, data_fetcher: _DataFetcher) -> None: + super().advance(data_fetcher) + + self.skip_val_on_restart = False + +def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: + if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) + trainer.fit_loop.epoch_loop = loop + + +class Trainer(pl.Trainer, IOMixin): def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" if isinstance(obj, (dict, list)): diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 16b6c574d2fa8..6a86dacbfefb0 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -14,7 +14,7 @@ matplotlib>=3.3.2 #megatron_core>0.6.0 # add back once mcore on pypi is compatible again nltk>=3.6.5 numpy<2 # tensorstore has an implicit compiled dependency on numpy<2 -opencc<1.1.7 +opencc pangu prettytable rapidfuzz diff --git a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py index 18ddb89359420..e393779ddc7cb 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py @@ -123,7 +123,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py index 59bc0a64bbe99..b11cb7b385341 100644 --- a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py @@ -121,7 +121,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups # 32 / 2 = 16 qkv_total_dim = head_num + 2 * num_query_groups # 32 + 2 * 2 = 36 diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index 8da15148dfd87..1c2d49cbddc4f 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -26,7 +26,7 @@ from nemo.utils import logging """ -Script to convert a llama2 checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +Script to convert a llama checkpoint in nemo (mcore path) into a HuggingFace checkpoint. This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. 1) Generate only HF weights from a nemo file: @@ -37,13 +37,21 @@ 2) Generate the full HF model folder + python convert_llama_nemo_to_hf.py \ + --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ + --output_path /path/to/pytorch_model.bin \ + --hf_input_path /path/to/input_hf_folder \ + --hf_output_path /path/to/output_hf_folder + +3) Generate the full HF model folder with a custom tokenizer + python convert_llama_nemo_to_hf.py \ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ --output_path /path/to/pytorch_model.bin \ --hf_input_path /path/to/input_hf_folder \ --hf_output_path /path/to/output_hf_folder \ - --input_tokenizer /path/to/tokenizer \ - --hf_output_tokenizer /path/to/output_tokenizer \ + --input_tokenizer /path/to/custom_nemo_tokenizer.model \ + --hf_output_tokenizer /path/to/output_tokenizer Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Llama2 70b). However this option makes the conversion script significantly slower. @@ -138,7 +146,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups @@ -232,17 +240,25 @@ def replace_hf_weights_and_tokenizer( nemo_exported = torch.load(weights_file) if tokenizer_path: - tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, local_files_only=True, legacy=False,) - tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) - fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) - tokenizer_length = len(fast_tokenizer) - model.resize_token_embeddings(tokenizer_length) + try: + tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_path, + local_files_only=True, + legacy=False, + ) + tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) + fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) + tokenizer_length = len(fast_tokenizer) + model.resize_token_embeddings(tokenizer_length) + except: + tokenizer = None + logging.warning("Could not load custom tokenizer, proceeding with default tokenizer") model.load_state_dict(nemo_exported) model.save_pretrained(output_hf_path) logging.info(f"Full HF model saved to {output_hf_path}") - if tokenizer_path: + if tokenizer_path and (tokenizer is not None): fast_tokenizer.save_pretrained(output_hf_tokenizer) tokenizer.save_pretrained(output_hf_tokenizer) logging.info(f"Tokenizer saved to {output_hf_tokenizer}") diff --git a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py index 796819c38ba44..3a06b3d069541 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py @@ -133,7 +133,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = model.cfg.get('kv_channels', hidden_size // head_num) + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py index 58311d0324c2a..1b61d489fc63a 100644 --- a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py @@ -136,7 +136,7 @@ def convert(in_file, precision=None) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py new file mode 100644 index 0000000000000..12e56e9f1793b --- /dev/null +++ b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. +Available model listed in MODEL_CONFIG_MAPPING +Example usage: + +a. Convert a .nemo checkpoint + python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \ + --input_path=Meta-Llama-3-8B.nemo \ + --output_path=your_output_dir \ + --model_id=meta-llama/Meta-Llama-3-8B + +b. Convert a model weight directory. + The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file. + Please also provide tokenizer_library and tokenizer_path when loading from weight directory. + python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \ + --input_path=nemotron3-8b-extracted/model_weights \ + --tokenizer_path=path_to_your_tokenizer_model.model \ + --tokenizer_library=sentencepiece \ + --output_path=your_output_dir \ + --model_id=nvidia/nemotron-3-8b-base-4k + +""" + +import os +import shutil +import tempfile +from argparse import ArgumentParser +from pathlib import Path + +import torch +from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace +from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedObject +from omegaconf import OmegaConf +from transformers import AutoTokenizer as HFAutoTokenizer + +from nemo.collections import llm +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir +from nemo.utils import logging + +MODEL_CONFIG_MAPPING = { + "meta-llama/Llama-2-7b-hf": (llm.LlamaModel, llm.Llama2Config7B), + "meta-llama/Llama-2-13b-hf": (llm.LlamaModel, llm.Llama2Config13B), + "meta-llama/Llama-2-70b-hf": (llm.LlamaModel, llm.Llama2Config70B), + "meta-llama/Meta-Llama-3-8B": (llm.LlamaModel, llm.Llama3Config8B), + "meta-llama/Meta-Llama-3-70B": (llm.LlamaModel, llm.Llama3Config70B), + "mistralai/Mixtral-8x7B-v0.1": (llm.MixtralModel, llm.MixtralConfig8x7B), + "mistralai/Mixtral-8x22B-v0.1": (llm.MixtralModel, llm.MixtralConfig8x22B), + "mistralai/Mistral-7B-v0.1": (llm.MistralModel, llm.MistralConfig7B), + "nvidia/nemotron-3-8b-base-4k": (llm.NemotronModel, llm.Nemotron3Config8B), + "nemotron4-22b": (llm.NemotronModel, llm.Nemotron3Config22B), + "nemotron4-15b": (llm.NemotronModel, llm.Nemotron4Config15B), + "nemotron4-340b": (llm.NemotronModel, llm.Nemotron4Config340B), +} + + +def get_args(): + """ + Parse the command line arguments. + """ + parser = ArgumentParser( + description="""Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. + This script may download from Hugging Face, make sure you have + access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)""" + ) + parser.add_argument( + "--input_path", + type=str, + default=None, + required=True, + help="""Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory a + fter untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass + in `model_weights` directory.""", + ) + parser.add_argument( + "--output_path", type=str, default=None, required=True, help="Path to output NeMo 2.0 directory." + ) + parser.add_argument( + "--model_id", type=str, default=None, required=True, help="Hugging Face or nemotron model id for the model" + ) + parser.add_argument( + "--tokenizer_path", + type=str, + default=None, + required=False, + help="""Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config + 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails""", + ) + parser.add_argument( + "--tokenizer_library", + type=str, + default=None, + required=False, + help="Tokenizer library, e.g. `sentencepiece`, `megatron`. Defaults to `sentencepiece`", + ) + args = parser.parse_args() + return args + + +def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel: + """ + Get NeMo 2.0 model class from model_id and tokenizer. Use bf16 for NeMo 1.0 ckpts. + + Returns: + llm.GPTModel: NeMo 2.0 model instance + """ + + if model_id not in MODEL_CONFIG_MAPPING: + valid_ids = "\n- ".join([""] + list(MODEL_CONFIG_MAPPING.keys())) + raise ValueError(f"Unsupported model_id: {model_id}. Please provide a valid model_id from {valid_ids}") + model_cls, config_cls = MODEL_CONFIG_MAPPING[model_id] + # nemo1 ckpts are bf16 + return model_cls(config_cls(bf16=True, params_dtype=torch.bfloat16), tokenizer=tokenizer) + + +def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: + """ + Get tokenizer from input .nemo file, or args.tokenizer_path, or Hugging Face. + Only SentencePiece and Hugging Face tokenizers are supported. + + Returns: + AutoTokenizer: tokenizer instance + """ + if not input_path.is_dir(): # if .nemo tar + with tempfile.TemporaryDirectory() as tmp_dir: # we want to clean up this tmp dir + NLPSaveRestoreConnector._unpack_nemo_file(input_path, tmp_dir) + cfg = OmegaConf.load(f"{tmp_dir}/model_config.yaml") + tokenizer_lib = cfg.tokenizer.library + tokenizer_model = cfg.tokenizer.get("model") and cfg.tokenizer.get("model").split("nemo:", 1)[-1] + if tokenizer_model: + shutil.copy(f"{tmp_dir}/{tokenizer_model}", f"{tokenizer_tmp_dir}/{tokenizer_model}") + elif cfg.tokenizer.library == "huggingface": + HFAutoTokenizer.from_pretrained(cfg.tokenizer.type).save_pretrained(tokenizer_tmp_dir) + tokenizer_model = f"{tokenizer_tmp_dir}/{tokenizer_model}" if tokenizer_model else None + else: + if args.tokenizer_path: # not .nemo file, only weight dir need to specify tokenizer lib and path + tokenizer_lib = args.tokenizer_library or "sentencepiece" + if args.tokenizer_library is None: + logging.warning( + "You specified tokenizer_path but did not provide tokenizer_library using default sentencepiece" + ) + tokenizer_model = args.tokenizer_path + else: # no .nemo config, no tokenizer path specified, grab from HF, reload + tokenizer_lib = "huggingface" + HFAutoTokenizer.from_pretrained(args.model_id).save_pretrained(tokenizer_tmp_dir) + + if tokenizer_lib == "huggingface": + return AutoTokenizer(tokenizer_tmp_dir) + else: # not directly use huggingface tokenizer in get_nmt_tokenizer since it will pull from HF and no reload + return get_nmt_tokenizer(library=tokenizer_lib, tokenizer_model=tokenizer_model) + + +def main() -> None: + """ + Main function to convert NeMo 1.0 checkpoint to NeMo 2.0 format. + """ + tokenizer_tmp_dir = Path("/tmp/nemo_tokenizer") + tokenizer_tmp_dir.mkdir(parents=True, exist_ok=True) + tokenizer = get_tokenizer(Path(args.input_path), tokenizer_tmp_dir) + model = get_nemo2_model(args.model_id, tokenizer=tokenizer) + model.optim = None + + trainer = Trainer( + devices=1, + accelerator="cpu", + strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed()), + ) + + trainer.strategy.connect(model) + trainer.strategy.setup_environment() + if not model.state_dict(): + with _strategy_lib.megatron_cpu_init_context(model.config): + model.configure_model() + + trainer.strategy.setup(trainer) + + logging.info(f"loading checkpoint {args.input_path}") + + sharded_state_dict = {"state_dict": trainer.strategy.megatron_parallel.sharded_state_dict()} + + for key in list(sharded_state_dict['state_dict'].keys()): + new_key = key.replace('module', 'model', 1) + sharded_state_dict['state_dict'][new_key] = sharded_state_dict['state_dict'].pop(key) + sharded_state_dict['state_dict'][new_key].key = sharded_state_dict['state_dict'][new_key].key.replace( + 'module', 'model', 1 + ) + + def skip_fp8_load(x): + if isinstance(x, ShardedObject) and 'core_attention' in x.key and '_extra_state' in x.key: + x = LocalNonpersistentObject(x.data) # use the FP8 state from initialization, not from ckpt + return x + + dict_list_map_inplace(skip_fp8_load, sharded_state_dict) + if not Path(args.input_path).is_dir(): + with tempfile.TemporaryDirectory() as tmp_dir: + NLPSaveRestoreConnector._unpack_nemo_file(args.input_path, tmp_dir) + model_weight_dir = f"{tmp_dir}/model_weights" + model_ckpt = trainer.strategy.checkpoint_io.load_checkpoint(model_weight_dir, sharded_state_dict, None) + else: + model_ckpt = trainer.strategy.checkpoint_io.load_checkpoint(args.input_path, sharded_state_dict, None) + + logging.info(f"Saving checkpoint to {args.output_path}") + model_ckpt['state_dict'] = {k.replace('model', 'module', 1): v for k, v in model_ckpt['state_dict'].items()} + trainer.model.module.load_state_dict(model_ckpt['state_dict']) + trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path, is_saving=False)) + if getattr(trainer.strategy, "async_save", False): + trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) + + # Corresponding to Connector: on_import_ckpt + if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'): + trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__ + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(args.output_path), yaml_attrs=["model"]) + + # remove tmp dir + if os.path.isdir(tokenizer_tmp_dir): + shutil.rmtree(tokenizer_tmp_dir) + + logging.info(f"NeMo 2.0 checkpoint saved at {args.output_path}") + + +if __name__ == '__main__': + args = get_args() + main() diff --git a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py index c6a218020c213..17682c5cc1abf 100644 --- a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py @@ -141,7 +141,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py index 043d1fd35261e..9d70544ee401d 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py @@ -140,7 +140,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/tests/collections/llm/bitexact/mixtral/run.sh b/tests/collections/llm/bitexact/mixtral/run.sh index c32dbbc95b981..87bf7c382b996 100644 --- a/tests/collections/llm/bitexact/mixtral/run.sh +++ b/tests/collections/llm/bitexact/mixtral/run.sh @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral. # Compare outputs python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \ - "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/" "$MCORE_OUTPUT_PATH/iter_0000010/" + "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0-consumed_samples=20.0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" diff --git a/tests/collections/llm/megatron_mixtral_pretraining.py b/tests/collections/llm/megatron_mixtral_pretraining.py index 82188f75351ee..4123c7b37987c 100644 --- a/tests/collections/llm/megatron_mixtral_pretraining.py +++ b/tests/collections/llm/megatron_mixtral_pretraining.py @@ -158,7 +158,7 @@ def main(args): ) # Confirm checkpoint directory structure - output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/" + output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0-consumed_samples=8.0/weights" assert output_path.exists(), f"Expected {output_path} to exist" assert output_path.is_dir(), f"Expected {output_path} to be a directory" output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata'] diff --git a/tests/lightning/test_nemo_run.py b/tests/lightning/test_nemo_run.py index 934eaa853bf0b..c7e8c6a921975 100644 --- a/tests/lightning/test_nemo_run.py +++ b/tests/lightning/test_nemo_run.py @@ -16,7 +16,12 @@ ("llama3_70b", "finetune_recipe", "llama3_70b_finetune"), ("llama3_70b_16k", "pretrain_recipe", "llama3_70b_16k_pretrain"), ("llama3_70b_64k", "pretrain_recipe", "llama3_70b_64k_pretrain"), + ("llama31_8b", "pretrain_recipe", "llama31_8b_pretrain"), + ("llama31_8b", "finetune_recipe", "llama31_8b_finetune"), + ("llama31_70b", "pretrain_recipe", "llama31_70b_pretrain"), + ("llama31_70b", "finetune_recipe", "llama31_70b_finetune"), ("llama31_405b", "pretrain_recipe", "llama31_405b_pretrain"), + ("llama31_405b", "finetune_recipe", "llama31_405b_finetune"), ("mistral_7b", "pretrain_recipe", "mistral_pretrain"), ("mistral_7b", "finetune_recipe", "mistral_finetune"), ("mixtral_8x7b", "pretrain_recipe", "mixtral_8x7b_pretrain"), diff --git a/tests/lightning/test_state_restoration.py b/tests/lightning/test_state_restoration.py index 076a2f931f57a..03b01a00e759b 100644 --- a/tests/lightning/test_state_restoration.py +++ b/tests/lightning/test_state_restoration.py @@ -225,7 +225,7 @@ def run_resume_train(mbs, gbs, num_dev): resume=AutoResume( resume_if_exists=True, resume_ignore_no_checkpoint=False, - resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0/', + resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0-consumed_samples=20.0/', ), ) trainer._teardown() diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama-3/README.rst index bb6171e6f5824..5b2d66ed5b017 100755 --- a/tutorials/llm/llama-3/README.rst +++ b/tutorials/llm/llama-3/README.rst @@ -2,7 +2,7 @@ Getting Started with Llama 3 and Llama 3.1 ========================================== -This repository contains jupyter notebook tutorials using NeMo Framework for Llama-3 and Llama-3.1 models by Meta. +This repository contains Jupyter Notebook tutorials using the NeMo Framework for Llama-3 and Llama-3.1 models by Meta. .. list-table:: :widths: 100 25 100 @@ -16,7 +16,7 @@ This repository contains jupyter notebook tutorials using NeMo Framework for Lla - Perform LoRA PEFT on Llama 3 8B Instruct using a dataset for bio-medical domain question answering. Deploy multiple LoRA adapters with NVIDIA NIM. * - `Llama 3.1 Law-Domain LoRA Fine-Tuning and Deployment with NeMo Framework and NVIDIA NIM <./sdg-law-title-generation>`_ - `Law StackExchange `_ - - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a pre-requisite, follow the tutorial for `data curation using NeMo Curator `__. - * - `Llama 3.1 WikiText Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ + - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a prerequisite, follow the tutorial for `data curation using NeMo Curator `_. + * - `Llama 3.1 Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ - `WikiText-103-v1 `_ - Perform pruning and distillation on Llama 3.1 8B Instruct using the WikiText-103-v1 dataset with NeMo Framework. diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb new file mode 100644 index 0000000000000..8548c0cfb1d0e --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ab9e2e97-7f10-4353-859e-693842bde465", + "metadata": {}, + "source": [ + "### Step 1: Prepare the dataset\n", + "\n", + "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your train, test, and validation data files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6505c00b-9eb4-4087-9e49-423f6228e690", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-train.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_train \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb1aa80f-70bc-4dff-8b08-3bff48d9a1c3", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-test.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_test \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42bec54a-94f6-4c87-8e14-2726ef6c2625", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-val.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_val \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "markdown", + "id": "5d77ee8a-e0dc-44f7-b5e8-3b6025d979d7", + "metadata": {}, + "source": [ + "After running the above scripts, you will see the preprocesed `wikitext_tokenized_{train/val/test}_text_document.{idx/bin}`files. These output files will be used in the next step." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb new file mode 100644 index 0000000000000..7d58ac4779aac --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "84b146ba-08b6-4adb-a858-8e4294c5e781", + "metadata": {}, + "source": [ + "\n", + "### Step 2: Fine-tune the teacher on the dataset\n", + "\n", + "NeMo Framework includes a standard Python script, [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py), for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", + "\n", + "We fine-tune the unpruned model on our dataset to correct the distribution shift from the original dataset the model was trained on. According to the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that without correcting for this distribution shift, the teacher provides suboptimal guidance on the dataset during distillation.\n", + "\n", + "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as the path to the teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12007ac8-2fd5-4de8-8964-97821c2198c0", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Set path(s) if different:\n", + "\n", + "MODEL=\"/workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo\"\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_ft\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", + " --config-path /opt/NeMo/examples/nlp/language_modeling/conf/ \\\n", + " --config-name megatron_llama_distill.yaml \\\n", + " \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${MODEL} \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " \\\n", + " ~model.tokenizer \\\n", + " +model.tokenizer='{library: huggingface, type: meta-llama/Meta-Llama-3.1-8B, use_fast: True}' \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.encoder_seq_length=8192 \\\n", + " model.num_layers=32 \\\n", + " model.hidden_size=4096 \\\n", + " model.ffn_hidden_size=14336 \\\n", + " model.num_attention_heads=32 \\\n", + " model.hidden_dropout=0.0 \\\n", + " model.attention_dropout=0.0 \\\n", + " model.apply_query_key_layer_scaling=True \\\n", + " model.normalization='rmsnorm' \\\n", + " model.bias=False \\\n", + " model.activation='fast-swiglu' \\\n", + " model.position_embedding_type='rope' \\\n", + " model.share_embeddings_and_output_weights=False \\\n", + " model.num_query_groups=8 \\\n", + " ++model.scale_positional_embedding=True \\\n", + " ++model.rotary_base=500000.0 \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "3040a993-8423-475f-8bc6-d1dd1ce16a83", + "metadata": {}, + "source": [ + "This will create a fine-tuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the fine-tuned teacher model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb new file mode 100644 index 0000000000000..d64f8c15bd006 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Prune the fine-tuned teacher model to create a student\n", + "In this step, we will explore two methods to prune the fine-tuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", + "\n", + "In the first method, depth-pruning, we trim the layers of the model." + ] + }, + { + "cell_type": "markdown", + "id": "72fa494e-6268-4044-a1d6-c0518d450cfd", + "metadata": {}, + "source": [ + "#### Step 3.a.: Using depth-pruning \n", + "To depth-prune, we will trim the last 16 layers in the finetined teacher model. For depth-pruning, we would be using the [megatron_gpt_drop_layers](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_drop_layers.py) script. \n", + "\n", + "Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60cae073-a192-4d47-b220-b09736d39a93", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python -m torch.distributed.launch --nproc_per_node=8 \\\n", + " /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \\\n", + " --path_to_nemo \"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " --path_to_save \"/workspace/4b_depth_pruned_model.nemo\" \\\n", + " --tensor_model_parallel_size 8 \\\n", + " --pipeline_model_parallel_size 1 \\\n", + " --gpus_per_node 8 \\\n", + " --drop_layers 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31" + ] + }, + { + "cell_type": "markdown", + "id": "375f298a-0363-4f44-b40c-2c8e9bab7d76", + "metadata": {}, + "source": [ + "Running this script will save the depth-pruned model `4b_depth_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb new file mode 100644 index 0000000000000..5c4a47872afbe --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Step 3: Prune the fine-tuned teacher model to create a student\n", + "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads, and embedding channels.\n", + "\n", + "Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore." + ] + }, + { + "cell_type": "markdown", + "id": "9207ed14-2f37-4712-88f3-543a128663ac", + "metadata": { + "tags": [] + }, + "source": [ + "#### Step 3.b.: Using width-pruning\n", + "To width-prune the model, we do the following:\n", + "- Prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", + "- Prune the hidden size from 4096 to 3072.\n", + "- Retrain the attention headcount and number of layers\n", + "\n", + "For width-pruning, we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", + "\n", + "We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model.\n", + "\n", + "> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "571d1483-dd4c-403e-b321-293342e7a62a", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!torchrun --nproc-per-node=8 /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_prune.py \\\n", + " model.restore_from_path=\"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " model.tensor_model_parallel_size=1 \\\n", + " model.pipeline_model_parallel_size=8 \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " inference.batch_size=64 \\\n", + " trainer.num_nodes=1 \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=8 \\\n", + " prune.ffn_hidden_size=9216 \\\n", + " prune.num_attention_heads=null \\\n", + " prune.num_query_groups=null \\\n", + " prune.hidden_size=3072 \\\n", + " export.save_path=\"/workspace/4b_width_pruned_model.nemo\"" + ] + }, + { + "cell_type": "markdown", + "id": "e9fb0977-5c02-4ecc-b602-54d74b2e2184", + "metadata": {}, + "source": [ + "Running this script will save the width-pruned model `4b_width_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb new file mode 100644 index 0000000000000..4882258377311 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "09d30e35-8e9d-4d2e-bd14-738c627a3963", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model.\n", + "\n", + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "c33cf641-0d27-417f-b3ee-c06701698184", + "metadata": {}, + "source": [ + "#### Step 4.a.: Using depth-pruned student\n", + "While distilling knowledge from the teacher to depth-pruned model, the `STUDENT` model would be `4b_depth_pruned_model.nemo` as produced by the [depth-pruning](./03_a_depth_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d23a01e-4912-47cb-bf21-b4fd72007ec1", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_depth_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_depth_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/depth_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "42d910d9-14dd-44ba-bf2c-0064737c70fa", + "metadata": {}, + "source": [ + "This will create the final distilled model named `depth_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_depth_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 35 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb new file mode 100644 index 0000000000000..95110dd19dd92 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5062f23-c604-479b-9a4e-69989598b131", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", + "In this notebook, we will explore distillation with the width-pruned model as the `STUDENT` model.\n", + "\n", + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "be7de691-dd1d-4719-9872-98501a22e3c9", + "metadata": {}, + "source": [ + "#### Step 4.b.: Using width-pruned student\n", + "While distilling knowledge from the teacher to width-pruned model, the `STUDENT` model would be `4b_width_pruned_model.nemo` as produced by the [width-pruning](./03_b_width_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0070b526-771a-4a8d-b0ba-ab218b382bd9", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_width_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_width_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/width_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS} \\\n", + " +model.dist_ckpt_load_strictness=log_all" + ] + }, + { + "cell_type": "markdown", + "id": "d9dbc377-e19a-49e0-b245-fa828cca415a", + "metadata": {}, + "source": [ + "This will create the final width-pruned distilled model named `width_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_width_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb new file mode 100644 index 0000000000000..dcb483c55ab66 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6c91263b-b312-4ab2-b13f-0ee4b6e8bd0f", + "metadata": {}, + "source": [ + "### Step 5: Display the validation loss\n", + "\n", + "Now that the results are in, let's visualize the validation loss of the two distilled models using the `tensorboard` library. \n", + "\n", + "> `NOTE:` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." + ] + }, + { + "cell_type": "markdown", + "id": "b5822d62-8131-4046-8c22-0bf0fce81df7", + "metadata": {}, + "source": [ + "#### Validation Loss Using Depth-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the depth-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a665fe1-df45-4126-8694-f182af113133", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_depth_pruned_student/\" --port=6007" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "db6fcf26-8ae8-40e1-875a-0a10bf85be81", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Depth-Pruned Model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Depth-Pruned Model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png\", width=400))" + ] + }, + { + "cell_type": "markdown", + "id": "f10041ae-6533-47de-9f76-f97d4469c27a", + "metadata": {}, + "source": [ + "#### Validation Loss Using Width-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the width-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b0c3118-4987-4df3-88bd-fcffdb521c5d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_width_pruned_student/\" --port=6008" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ecd79583-f662-40c6-a690-9f4bb847de4e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Width-Pruned Model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Width-Pruned Model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png\", width=400))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama-3/pruning-distillation/README.rst index 9d4207a5c9688..51e2a7cede47a 100644 --- a/tutorials/llm/llama-3/pruning-distillation/README.rst +++ b/tutorials/llm/llama-3/pruning-distillation/README.rst @@ -1,36 +1,48 @@ Llama 3.1 WikiText Pruning and Distillation with NeMo Framework ======================================================================================= -`Llama 3.1 `_ are open-source large language models by Meta that deliver state-of-the-art performance on popular industry benchmarks. They have been pretrained on over 15 trillion tokens, and support a 128K token context length. They are available in three sizes, 8B, 70B, and 405B, and each size has two variants—base pretrained and instruction tuned. +`Llama 3.1 `_ models, developed by Meta, are open-source large language models that deliver state-of-the-art performance on popular industry benchmarks. Pretrained on over 15 trillion tokens, they support a 128K token context length. These models are available in three sizes: 8B, 70B, and 405B. Each size offers two variants: base pretrained and instruction tuned. -`NVIDIA NeMo Framework `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 to fit your use case. +`NVIDIA NeMo Framework `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 to fit your use case. -`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 as described in the `tech report `_. +`NVIDIA TensorRT Model Optimizer `_ is a library (referred to as **Model Optimizer**, or **ModelOpt**) comprising state-of-the-art model optimization techniques including `quantization `_, `sparsity `_, `distillation `_, and `pruning `_ to compress models. + +`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 as described in the `tech report `_. + +`How to Prune and Distill Llama-3.1 8B to an NVIDIA Llama-3.1-Minitron 4B Model `_ provides practical and effective structured compression best practices for LLMs that combine depth, width, attention, and MLP pruning with knowledge distillation-based retraining. These strategies are presented in the `Compact Language Models via Pruning and Knowledge Distillation `_ paper. + +`Mistral-NeMo-Minitron 8B Model Delivers Unparalleled Accuracy `_ introduces the Mistral-NeMo-Minitron 8B, a state-of-the-art 8 billion parameter language model created by pruning and distilling the larger Mistral NeMo 12B model. Objectives ---------- -This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B Instruct** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform a light finetuning procedure on the ``Meta Llama 3.1 8B Instruct`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then depth-pruned to create a trimmed model ``4b_trimmed_model.nemo``. These models will serve as a starting point for distillation to create a final distilled 4B model. +This tutorial demonstrates how to perform depth-pruning, width-pruning, teacher fine-tuning, and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 _ dataset with the NeMo Framework. The WikiText-103-v1 `_ language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia. + +For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the ``Meta LLama 3.1 8B`` teacher model to generate a fine-tuned teacher model, ``megatron_llama_ft.nemo``, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo``, respectively. These models will serve as starting points for distillation to create the final distilled 4B models. + We are using models utilizing the ``meta-llama/Meta-Llama-3.1-8B`` tokenizer for this demonstration. +``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable, but will be supported in future releases. + Requirements ------------- * System Configuration - * Access to at least 8 NVIDIA GPU with an individual memory of at least 80GB, for example: 8 x H100-80GB or 8 x A100-80GB. + * Access to at least 8 NVIDIA GPUs, each with a memory of at least 80GB (e.g., 8 x H100-80GB or 8 x A100-80GB). * A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. -* `Authenticate with NVIDIA NGC `_, and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. +* `Authenticate with NVIDIA NGC `_ and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. * Get your Hugging Face `access token `_, which will be used to obtain the tokenizer required during training. -``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs but you can potentially reduce Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher finetuning and distillation scripts to accommodate lower resource availability. +``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs. However, you can potentially reduce the Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher fine-tuning and distillation scripts to accommodate lower resource availability. -Create a pruned and distilled model with NeMo Framework +Create a Pruned and Distilled Model with NeMo Framework ------------------------------------------------------------------------------ -For pruning and distilling the model, you will use the NeMo Framework which is available as a `docker container `_. +For pruning and distilling the model, you will use the NeMo Framework, which is available as a `Docker container `_. +``NOTE:`` These notebooks use the `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. 1. Download the `Llama 3.1 8B Instruct .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-instruct-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. @@ -63,17 +75,38 @@ For pruning and distilling the model, you will use the NeMo Framework which is a jupyter lab --ip 0.0.0.0 --port=8888 --allow-root -4. Then, navigate to `this notebook <./llama3-pruning-distillation-nemofw.ipynb>`_. +4. Then, navigate to `this notebook <./introduction.ipynb>`_ to get started. +This directory contains a list of notebooks that cover all the steps to create a distilled 4B model. + +:: + + <$pruning_distillation> + └── introduction.ipynb + └── 01_data_preparation.ipynb + └── 02_teacher_finetuning.ipynb + └── 03_a_depth_pruning.ipynb + └── 03_b_width_pruning.ipynb + └── 04_a_distilling_depth_pruned_student.ipynb + └── 04_b_distilling_width_pruned_student.ipynb + └── 05_display_results.ipynb + Results ------------------------------------------------------------------------------ -``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. +``NOTE:`` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. + +Here are the validation loss plots over 30 steps of running the training step in the distillation script (at the end of the `notebook <./05_display_results.ipynb>`_). -Here is the validation loss over 30 steps of running the training step in the distillation script (at the end of the `notebook <./llama3-pruning-distillation-nemofw.ipynb>`_). +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png + :width: 400px + :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the depth-pruned model as the student + :align: center -.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_distillation.png + Figure 1: Validation Loss Plot When Using the Depth-Pruned Model as the Student + +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png :width: 400px :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script :align: center - Figure 1: Validation Loss Plot \ No newline at end of file + Figure 2: Validation Loss Plot When Using the Width-Pruned Model as the Student \ No newline at end of file diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb new file mode 100644 index 0000000000000..71a5a6cfb03ce --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "411e6711-60fc-4488-8aa1-c6463cac8695", + "metadata": { + "tags": [] + }, + "source": [ + "# Efficient Model Reduction with Pruning and Distillation of Llama 3.1 Using NeMo Framework" + ] + }, + { + "cell_type": "markdown", + "id": "03fd1cf4-c67a-4b8d-a5e5-46531be0f991", + "metadata": {}, + "source": [ + "This tutorial demonstrates how to perform depth-pruning, teacher fine-tuning, and distillation on **Llama 3.1-8B** using the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset with NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia.\n", + "\n", + "For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a fine-tuned teacher model, `megatron_llama_ft.nemo`, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo`, respectively. These models will serve as starting points for distillation to create the final distilled 4B models.\n", + "\n", + "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", + "\n", + "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", + "\n", + "**Instructions for downloading the model and the container are available in the [README](./README.rst).**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5026ce-39f1-43e3-93af-4c4f1e9da1f2", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install --upgrade ipywidgets notebook\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "afe59b07-bb48-4913-90cc-bb416b48196c", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Prerequisites\n", + "Ensure you meet the prerequisites listed in this section.\n", + "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo Framework container." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d48b81-e978-4894-8ba4-4f183f698bb1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!ls /workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo" + ] + }, + { + "cell_type": "markdown", + "id": "7129d44e-0536-4e62-bdbc-0f1ad44dc84a", + "metadata": {}, + "source": [ + "2. **Set the Hugging Face Access Token**: You can obtain this from your [Hugging Face account](https://huggingface.co/docs/hub/en/security-tokens). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "481417ed-1456-4962-8f67-4350bde1aabd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "245eda8d-c999-431e-9ebc-5c92c4f21f3b", + "metadata": {}, + "source": [ + "3. **Obtain the dataset**: Generate the `wikitext-{train/val/test}.jsonl` splits after loading the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaef2c7d-41f7-41ad-a76a-2d714e9c35de", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# Split into train, test and val files\n", + "\n", + "import json\n", + "import os\n", + "from datasets import load_dataset\n", + "\n", + "# Load the WikiText-103 dataset\n", + "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\")\n", + "\n", + "# Define the destination folder\n", + "data_folder = 'wikitext-data'\n", + "os.makedirs(data_folder, exist_ok=True)\n", + "\n", + "# Define file paths and destination paths\n", + "file_paths = {\n", + " 'train': os.path.join(data_folder, 'wikitext-train.jsonl'),\n", + " 'validation': os.path.join(data_folder, 'wikitext-val.jsonl'),\n", + " 'test': os.path.join(data_folder, 'wikitext-test.jsonl')\n", + "}\n", + "\n", + "# Function to save dataset split to a JSONL file\n", + "def save_to_jsonl(file_path, data):\n", + " with open(file_path, 'w') as file:\n", + " for item in data:\n", + " file.write(json.dumps(item) + '\\n')\n", + "\n", + "# Define splits\n", + "splits = [\"train\", \"validation\", \"test\"]\n", + "\n", + "# Save splits to JSONL files and calculate their sizes\n", + "for split in splits:\n", + " if split in dataset:\n", + " save_to_jsonl(file_paths[split], dataset[split])\n", + " else:\n", + " print(f\"Split {split} not found in the dataset.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "2d0cc359-0598-40aa-af80-9503ecd4dac1", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Step-by-Step Instructions\n", + "\n", + "This workflow is structured into seven notebooks:\n", + "1. [Prepare the dataset](./01_data_preparation.ipynb)\n", + "2. [Fine-tune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", + "3. Prune the fine-tuned teacher model to create a student \n", + " - 3.a. [Using depth-pruning](./03_a_depth_pruning.ipynb)\n", + " - 3.b. [Using width-pruning](./03_b_width_pruning.ipynb)\n", + "4. Distill knowledge from teacher into student\n", + " - 4.a. [Using depth-pruned student](./04_a_distilling_depth_pruned_student.ipynb)\n", + " - 4.b. [Using width-pruned student](./04_b_distilling_width_pruned_student.ipynb)\n", + "5. [Display the validation loss](./05_display_results.ipynb)\n", + "\n", + "> `NOTE:` We are exploring two methods to prune the fine-tuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}