Skip to content

Commit

Permalink
Update fsdp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead authored Oct 5, 2024
1 parent 1025875 commit 9b45b99
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def setup_environment(self) -> None:
@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple[Module, List[Optimizer]]:
) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]:
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
Expand All @@ -281,7 +281,7 @@ def setup_module_and_optimizers(
" call `setup_optimizer`."
)
module = self.setup_module(module)
return module, optimizers
return module, optimizers, scheduler

@override
def setup_module(self, module: Module) -> Module:
Expand Down

0 comments on commit 9b45b99

Please sign in to comment.