Skip to content

Commit

Permalink
added extra optionally wrapper to FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Jena authored and rohitrango committed Jun 25, 2024
1 parent 398a18e commit 0b69f4b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,8 @@ def find_frozen_submodules(model):
# TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be
# extended to support lower precision main parameters.
frozen_submodule_names, frozen_submodules = find_frozen_submodules(self.model)
for submodule in frozen_submodule_names:
logging.debug(f"Ignoring state {submodule} in FSDP.")
self.trainer.strategy.kwargs['ignored_states'] = frozen_submodules
# FSDP requires uniform status of require_grads
# Diffusion models like SD has frozen parts and needs to be added to 'ignored_states' from sharding for FSDP to work
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def __init__(
nccl_communicator_config_path: Optional[str] = None,
sharp: bool = False,
set_buffer_dtype: Optional[str] = None,
extra_fsdp_wrap_module: Optional[set] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not HAVE_APEX:
Expand Down Expand Up @@ -679,6 +680,11 @@ def __init__(
ParallelTransformerLayer,
BasicTransformerBlock,
}

# if extra wrap modules are provided, use them
if extra_fsdp_wrap_module is not None:
self.fsdp_wrap_module.update(extra_fsdp_wrap_module)

kwargs['auto_wrap_policy'] = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=self.fsdp_wrap_module
)
Expand Down

0 comments on commit 0b69f4b

Please sign in to comment.