From 0b69f4bdb4985ef08cdc2b8382f19b2d855118c0 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Wed, 12 Jun 2024 09:29:28 -0700 Subject: [PATCH] added extra optionally wrapper to FSDP --- .../nlp/models/language_modeling/megatron_base_model.py | 2 ++ nemo/collections/nlp/parts/nlp_overrides.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 0828d88a8133..aca526b894e3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -1255,6 +1255,8 @@ def find_frozen_submodules(model): # TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be # extended to support lower precision main parameters. frozen_submodule_names, frozen_submodules = find_frozen_submodules(self.model) + for submodule in frozen_submodule_names: + logging.debug(f"Ignoring state {submodule} in FSDP.") self.trainer.strategy.kwargs['ignored_states'] = frozen_submodules # FSDP requires uniform status of require_grads # Diffusion models like SD has frozen parts and needs to be added to 'ignored_states' from sharding for FSDP to work diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2fdb1906c31f..7382fa380de1 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -650,6 +650,7 @@ def __init__( nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, set_buffer_dtype: Optional[str] = None, + extra_fsdp_wrap_module: Optional[set] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -679,6 +680,11 @@ def __init__( ParallelTransformerLayer, BasicTransformerBlock, } + + # if extra wrap modules are provided, use them + if extra_fsdp_wrap_module is not None: + self.fsdp_wrap_module.update(extra_fsdp_wrap_module) + kwargs['auto_wrap_policy'] = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=self.fsdp_wrap_module )