diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index aa92ec53e4..2d1dd2d470 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -8,13 +8,9 @@ import os import socket -from contextlib import contextmanager import torch -from torch.distributed.algorithms.join import Join -from torch.nn.parallel import DistributedDataParallel from returnn.config import Config -import returnn.frontend as rf class DistributedContext: diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 0889cfb36a..30fd3a5d9d 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -1057,6 +1057,11 @@ def __init__(self, *, module: torch.nn.Module, engine: Engine): self.engine = engine def forward(self, *args, **kwargs): + """ + Call run step function (train_step or forward_step). + + :return: all produced raw tensors via the run context (:func:`rf.get_run_ctx`). + """ # noinspection PyProtectedMember self.engine._run_step(*args, **kwargs, _inside_wrapped=True)