Skip to content

Commit

Permalink
doc, extra check
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 5, 2023
1 parent 82924a1 commit ac75d8c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def __init__(self, config: Config):
print("Using device:", self._device, f"({dev_.reason or '?'})", file=log.v2)

self._use_torch_distributed = False
self._torch_distributed_class = None # type: Optional[Callable]
self._torch_distributed_class = None # type: Optional[Union[DistributedDataParallel,Callable]]
self._torch_distributed_options = None # type: Optional[dict]
self._ddp_pt_model = None # type: Optional[torch.nn.Module]
self._accum_grad_multiple_step = config.int("accum_grad_multiple_step", 1)

torch_distributed = config.typed_value("torch_distributed")
if torch_distributed is not None:
assert isinstance(torch_distributed, dict)
self._use_torch_distributed = True
self._torch_distributed_class = torch_distributed.get("class", DistributedDataParallel)
self._torch_distributed_options = torch_distributed.get("options", {})
Expand Down

0 comments on commit ac75d8c

Please sign in to comment.