diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 626bac278..f8cc1cc2a 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,7 +1,14 @@ from typing import Optional +import functools import torch from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, + ) from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -58,7 +65,15 @@ def init_model_fn( self.requires_sync_before_eval = False if N_GPUS > 1: if USE_PYTORCH_DDP: - model = DDP(model, device_ids=[RANK], output_device=RANK) + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=2 ** 10 + ) + model = FSDP( + model, + use_orig_params=True, + auto_wrap_policy=auto_wrap_policy, + device_id=RANK + ) else: model = torch.nn.DataParallel(model) return model, None