diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 341d1503d0..9d74615a39 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -90,13 +90,14 @@ def __init__(self, config: Config): print("Using device:", self._device, f"({dev_.reason or '?'})", file=log.v2) self._use_torch_distributed = False - self._torch_distributed_class = None # type: Optional[Callable] + self._torch_distributed_class = None # type: Optional[Union[DistributedDataParallel,Callable]] self._torch_distributed_options = None # type: Optional[dict] self._ddp_pt_model = None # type: Optional[torch.nn.Module] self._accum_grad_multiple_step = config.int("accum_grad_multiple_step", 1) torch_distributed = config.typed_value("torch_distributed") if torch_distributed is not None: + assert isinstance(torch_distributed, dict) self._use_torch_distributed = True self._torch_distributed_class = torch_distributed.get("class", DistributedDataParallel) self._torch_distributed_options = torch_distributed.get("options", {})