diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index bc5c35b27c..629e00d0a1 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -98,8 +98,8 @@ def __init__(self, config: Config): torch_distributed = config.typed_value("torch_distributed") if torch_distributed is not None: self._use_torch_distributed = True - self._torch_distributed_class = torch_distributed.get("class", None) - self._torch_distributed_options = torch_distributed.get("options", None) + self._torch_distributed_class = torch_distributed.get("class", DistributedDataParallel) + self._torch_distributed_options = torch_distributed.get("options", {}) import returnn.torch.distributed