From 9b45b9920a9dd8bdd056e9a153342743980df044 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 4 Oct 2024 22:21:00 -0500 Subject: [PATCH] Update fsdp.py --- src/lightning/fabric/strategies/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 74bfe56395020..6efc372db627b 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -269,7 +269,7 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None - ) -> Tuple[Module, List[Optimizer]]: + ) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -281,7 +281,7 @@ def setup_module_and_optimizers( " call `setup_optimizer`." ) module = self.setup_module(module) - return module, optimizers + return module, optimizers, scheduler @override def setup_module(self, module: Module) -> Module: