Skip to content

Commit

Permalink
First update for FSDP pytorch librispeech conformer
Browse files Browse the repository at this point in the history
  • Loading branch information
davidtweedle committed Dec 5, 2024
1 parent 10a32d4 commit 11676d6
Showing 1 changed file with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
Expand Down Expand Up @@ -101,7 +107,15 @@ def init_model_fn(
if N_GPUS > 1:
if USE_PYTORCH_DDP:
self.requires_sync_before_eval = True
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -122,6 +136,7 @@ def model_fn(

model = params
if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
Expand Down

0 comments on commit 11676d6

Please sign in to comment.