diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 96df6ec1a97fa..6d0dc2dd4073f 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -209,9 +209,16 @@ def run(self, *args: Any, **kwargs: Any) -> Any: """ - def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True,) -> Any: # no specific return because the way we want our API to look does not play well with mypy + def setup( + self, + module: nn.Module, + *optimizers: Optimizer, + scheduler: Optional[_LRScheduler] = None, + move_to_device: bool = True, + _reapply_compile: bool = True, + ) -> Any: # no specific return because the way we want our API to look does not play well with mypy r"""Set up a model and its optimizers for accelerated training. - + Args: module: A :class:`torch.nn.Module` to set up *optimizers: The optimizer(s) to set up (no optimizers is also possible) @@ -222,20 +229,20 @@ def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_ corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues. - + Returns: The tuple containing wrapped module and the optimizers, in the same order they were passed in. - + """ self._validate_setup(module, optimizers) module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None) original_module = module - + module = self._precision.convert_module(module) - + if move_to_device: module = self._move_model_to_device(model=module, optimizers=list(optimizers)) - + # Let accelerator/plugin wrap and connect the models and optimizers if optimizers: module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] @@ -243,29 +250,29 @@ def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_ ) else: module = self._strategy.setup_module(module) - + if compile_kwargs is not None: module = _to_compiled(module, compile_kwargs) module = _FabricModule(module, self._strategy, original_module=original_module) - + # Update the _DeviceDtypeModuleMixin's device parameter # NOTE: for sharded strategies or manual device placement, there's no single root device _update_properties( module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device ) - + optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers] - + self._models_setup += 1 - + if hasattr(original_module, "_fabric"): # this is probably a LightningModule original_module._fabric = self original_module._fabric_optimizers = optimizers if original_module not in self._callbacks: self._callbacks.append(original_module) - + self.call("on_after_setup", fabric=self, module=module) - + if optimizers: # join both types in a tuple for API convenience return (module, *optimizers, scheduler) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 1ee846e079fda..eca16e9b9a4ef 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -311,22 +311,24 @@ def model(self) -> "DeepSpeedEngine": return self._deepspeed_engine @override - def setup_module_and_optimizers(self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]: + def setup_module_and_optimizers( + self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]: """Set up a model and multiple optimizers together along with an optional learning rate scheduler. - + Currently, only a single optimizer is supported. - + Return: The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single deepspeed optimizer. - + """ if len(optimizers) != 1: raise ValueError( f"Currently only one optimizer is supported with DeepSpeed." f" Got {len(optimizers)} optimizers instead." ) - + self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler) self._set_deepspeed_activation_checkpointing() return self._deepspeed_engine, [optimizer], scheduler @@ -590,14 +592,16 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _initialize_engine(self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]: + def _initialize_engine( + self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None + ) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]: """Initialize one model and one optimizer with an optional learning rate scheduler. - + This calls :func:`deepspeed.initialize` internally. - + """ import deepspeed - + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index),