Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 5, 2024
1 parent 67089a1 commit 1025875
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
35 changes: 21 additions & 14 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,16 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
"""

def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True,) -> Any: # no specific return because the way we want our API to look does not play well with mypy
def setup(
self,
module: nn.Module,
*optimizers: Optimizer,
scheduler: Optional[_LRScheduler] = None,
move_to_device: bool = True,
_reapply_compile: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
r"""Set up a model and its optimizers for accelerated training.
Args:
module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
Expand All @@ -222,50 +229,50 @@ def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
"""
self._validate_setup(module, optimizers)
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module

module = self._precision.convert_module(module)

if move_to_device:
module = self._move_model_to_device(model=module, optimizers=list(optimizers))

# Let accelerator/plugin wrap and connect the models and optimizers
if optimizers:
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers), scheduler
)
else:
module = self._strategy.setup_module(module)

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._strategy, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
# NOTE: for sharded strategies or manual device placement, there's no single root device
_update_properties(
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
)

optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers]

self._models_setup += 1

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self
original_module._fabric_optimizers = optimizers
if original_module not in self._callbacks:
self._callbacks.append(original_module)

self.call("on_after_setup", fabric=self, module=module)

if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers, scheduler)
Expand Down
22 changes: 13 additions & 9 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,22 +311,24 @@ def model(self) -> "DeepSpeedEngine":
return self._deepspeed_engine

@override
def setup_module_and_optimizers(self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]:
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]:
"""Set up a model and multiple optimizers together along with an optional learning rate scheduler.
Currently, only a single optimizer is supported.
Return:
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
"""
if len(optimizers) != 1:
raise ValueError(
f"Currently only one optimizer is supported with DeepSpeed."
f" Got {len(optimizers)} optimizers instead."
)

self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
self._set_deepspeed_activation_checkpointing()
return self._deepspeed_engine, [optimizer], scheduler
Expand Down Expand Up @@ -590,14 +592,16 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_optimizer_device="nvme",
)

def _initialize_engine(self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]:
def _initialize_engine(
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None
) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.
This calls :func:`deepspeed.initialize` internally.
"""
import deepspeed

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
Expand Down

0 comments on commit 1025875

Please sign in to comment.