Skip to content

Commit

Permalink
cleanup, fix warnings, doc
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 28, 2023
1 parent 90c7548 commit 88d53f3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 0 additions & 4 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
import os
import socket

from contextlib import contextmanager
import torch
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel

from returnn.config import Config
import returnn.frontend as rf


class DistributedContext:
Expand Down
5 changes: 5 additions & 0 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,11 @@ def __init__(self, *, module: torch.nn.Module, engine: Engine):
self.engine = engine

def forward(self, *args, **kwargs):
"""
Call run step function (train_step or forward_step).
:return: all produced raw tensors via the run context (:func:`rf.get_run_ctx`).
"""
# noinspection PyProtectedMember
self.engine._run_step(*args, **kwargs, _inside_wrapped=True)

Expand Down

0 comments on commit 88d53f3

Please sign in to comment.