Skip to content

Commit

Permalink
Revert RF ModuleList cleanup
Browse files Browse the repository at this point in the history
Actually, our ModuleList is the base of Sequential,
and our ModuleList is already more like the PyTorch Sequential.

This reverts commit 389c844
and commit 40c7aa4.
  • Loading branch information
albertz committed Nov 9, 2023
1 parent 76ab73c commit 46127f4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 48 deletions.
57 changes: 9 additions & 48 deletions returnn/frontend/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from __future__ import annotations
import operator
import returnn.frontend as rf
from returnn.tensor import Tensor
from typing import Optional, TypeVar, Generic, Iterable, Iterator, Union, Tuple, Dict, Callable
Expand All @@ -22,9 +21,12 @@ class ModuleList(rf.Module, Generic[__ModT]):
Module list, getting passed an Iterable of Modules and creates a list of Modules in that order
"""

def __init__(self, *modules: Union[__ModT, Iterable[__ModT], ModuleList]):
def __init__(self, *modules: Union[__ModT, Iterable[__ModT], Dict[str, __ModT], ModuleList]):
super().__init__()
if len(modules) == 1 and isinstance(modules[0], ModuleList):
if len(modules) == 1 and isinstance(modules[0], dict):
for key, module in modules[0].items():
setattr(self, key, _convert_to_module(module))
elif len(modules) == 1 and isinstance(modules[0], ModuleList):
for key, module in modules[0]._get_modules().items():
setattr(self, key, _convert_to_module(module))
elif len(modules) == 1 and _is_iterable(modules[0]):
Expand All @@ -35,29 +37,7 @@ def __init__(self, *modules: Union[__ModT, Iterable[__ModT], ModuleList]):
setattr(self, str(idx), _convert_to_module(module))

def _get_modules(self) -> Dict[str, __ModT]:
# Note: Insertion order is relevant here. We use it in __getitem__ for slicing, etc.
res = {}
i = 0
while True:
try:
res[str(i)] = getattr(self, str(i))
i += 1
except AttributeError:
break
return res

def _get_abs_index(self, idx: int) -> int:
"""Get the absolute index for the list of modules"""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
return idx

def _get_abs_string_index(self, idx: int) -> str:
"""Get the absolute index for the list of modules"""
return str(self._get_abs_index(idx))
return {key: value for (key, value) in vars(self).items() if isinstance(value, rf.Module)}

def append(self, module: __ModT) -> ModuleList[__ModT]:
"""
Expand Down Expand Up @@ -88,33 +68,14 @@ def __getitem__(self, idx) -> Union[ModuleList[__ModT], __ModT]:
from builtins import slice

if isinstance(idx, slice):
return self.__class__(list(self._get_modules().values())[idx])
return self.__class__(dict(list(self._get_modules().items())[idx]))
else:
key = self._get_abs_string_index(idx)
if not hasattr(self, key):
raise IndexError("index {} is out of range".format(idx))
return getattr(self, key)
return list(self._get_modules().values())[idx]

def __setitem__(self, idx: int, module: __ModT) -> None:
key = self._get_abs_string_index(idx)
if not hasattr(self, key):
raise IndexError("index {} is out of range".format(idx))
key = list(self._get_modules().keys())[idx]
return setattr(self, key, _convert_to_module(module))

def __delitem__(self, idx: Union[int, slice]) -> None:
# To preserve numbering, we reconstruct the list of modules after deletion.
modules = list(self._get_modules().values())
old_len = len(modules)
del modules[idx]
if isinstance(idx, slice):
min_idx = self._get_abs_index(idx.start)
else:
min_idx = self._get_abs_index(idx)
for i in range(min_idx, len(modules)):
setattr(self, str(i), modules[i])
for i in range(len(modules), old_len):
delattr(self, str(i))

__call__ = rf.Module.__call__ # stays abstract


Expand Down
35 changes: 35 additions & 0 deletions tests/test_rf_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_sequential_named_case():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data = TensorDict(
{
"data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"),
}
)

class _Net(rf.Module):
def __init__(self):
super().__init__()
dims = [Dim(1, name="feat1"), Dim(2, name="feat2"), Dim(3, name="feat3")]
self.out_dim = dims[-1]
x = OrderedDict()
x["one"] = rf.Linear(in_dim, dims[0])
x["two"] = rf.Linear(dims[0], dims[1])
x["three"] = rf.Linear(dims[1], dims[2])
self.seq = rf.Sequential(x)

def __call__(self, data: Tensor) -> Tensor:
"""
Forward
"""
seq = self.seq(data)
return seq

# noinspection PyShadowingNames
def _forward_step(*, model: _Net, extern_data: TensorDict):
out = model(extern_data["data"])
out.mark_as_default_output(shape=(batch_dim, time_dim, model.out_dim))

run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_parameter_list():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
Expand Down

0 comments on commit 46127f4

Please sign in to comment.