From 46127f4f56c9d48ddc316d15debc962fa7bbfa6f Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 9 Nov 2023 15:41:44 +0000 Subject: [PATCH] Revert RF ModuleList cleanup Actually, our ModuleList is the base of Sequential, and our ModuleList is already more like the PyTorch Sequential. This reverts commit 389c844efd5c77cf8ed37e7f40a8b0e8927d702a and commit 40c7aa4468d182f38fa3127c8fd2136a3da78bdc. --- returnn/frontend/container.py | 57 ++++++----------------------------- tests/test_rf_container.py | 35 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/returnn/frontend/container.py b/returnn/frontend/container.py index d4470efd6f..fc7fefdba1 100644 --- a/returnn/frontend/container.py +++ b/returnn/frontend/container.py @@ -3,7 +3,6 @@ """ from __future__ import annotations -import operator import returnn.frontend as rf from returnn.tensor import Tensor from typing import Optional, TypeVar, Generic, Iterable, Iterator, Union, Tuple, Dict, Callable @@ -22,9 +21,12 @@ class ModuleList(rf.Module, Generic[__ModT]): Module list, getting passed an Iterable of Modules and creates a list of Modules in that order """ - def __init__(self, *modules: Union[__ModT, Iterable[__ModT], ModuleList]): + def __init__(self, *modules: Union[__ModT, Iterable[__ModT], Dict[str, __ModT], ModuleList]): super().__init__() - if len(modules) == 1 and isinstance(modules[0], ModuleList): + if len(modules) == 1 and isinstance(modules[0], dict): + for key, module in modules[0].items(): + setattr(self, key, _convert_to_module(module)) + elif len(modules) == 1 and isinstance(modules[0], ModuleList): for key, module in modules[0]._get_modules().items(): setattr(self, key, _convert_to_module(module)) elif len(modules) == 1 and _is_iterable(modules[0]): @@ -35,29 +37,7 @@ def __init__(self, *modules: Union[__ModT, Iterable[__ModT], ModuleList]): setattr(self, str(idx), _convert_to_module(module)) def _get_modules(self) -> Dict[str, __ModT]: - # Note: Insertion order is relevant here. We use it in __getitem__ for slicing, etc. - res = {} - i = 0 - while True: - try: - res[str(i)] = getattr(self, str(i)) - i += 1 - except AttributeError: - break - return res - - def _get_abs_index(self, idx: int) -> int: - """Get the absolute index for the list of modules""" - idx = operator.index(idx) - if not (-len(self) <= idx < len(self)): - raise IndexError("index {} is out of range".format(idx)) - if idx < 0: - idx += len(self) - return idx - - def _get_abs_string_index(self, idx: int) -> str: - """Get the absolute index for the list of modules""" - return str(self._get_abs_index(idx)) + return {key: value for (key, value) in vars(self).items() if isinstance(value, rf.Module)} def append(self, module: __ModT) -> ModuleList[__ModT]: """ @@ -88,33 +68,14 @@ def __getitem__(self, idx) -> Union[ModuleList[__ModT], __ModT]: from builtins import slice if isinstance(idx, slice): - return self.__class__(list(self._get_modules().values())[idx]) + return self.__class__(dict(list(self._get_modules().items())[idx])) else: - key = self._get_abs_string_index(idx) - if not hasattr(self, key): - raise IndexError("index {} is out of range".format(idx)) - return getattr(self, key) + return list(self._get_modules().values())[idx] def __setitem__(self, idx: int, module: __ModT) -> None: - key = self._get_abs_string_index(idx) - if not hasattr(self, key): - raise IndexError("index {} is out of range".format(idx)) + key = list(self._get_modules().keys())[idx] return setattr(self, key, _convert_to_module(module)) - def __delitem__(self, idx: Union[int, slice]) -> None: - # To preserve numbering, we reconstruct the list of modules after deletion. - modules = list(self._get_modules().values()) - old_len = len(modules) - del modules[idx] - if isinstance(idx, slice): - min_idx = self._get_abs_index(idx.start) - else: - min_idx = self._get_abs_index(idx) - for i in range(min_idx, len(modules)): - setattr(self, str(i), modules[i]) - for i in range(len(modules), old_len): - delattr(self, str(i)) - __call__ = rf.Module.__call__ # stays abstract diff --git a/tests/test_rf_container.py b/tests/test_rf_container.py index f5604b21bb..52317f6ee9 100644 --- a/tests/test_rf_container.py +++ b/tests/test_rf_container.py @@ -76,6 +76,41 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_sequential_named_case(): + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) + in_dim = Dim(7, name="in") + extern_data = TensorDict( + { + "data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"), + } + ) + + class _Net(rf.Module): + def __init__(self): + super().__init__() + dims = [Dim(1, name="feat1"), Dim(2, name="feat2"), Dim(3, name="feat3")] + self.out_dim = dims[-1] + x = OrderedDict() + x["one"] = rf.Linear(in_dim, dims[0]) + x["two"] = rf.Linear(dims[0], dims[1]) + x["three"] = rf.Linear(dims[1], dims[2]) + self.seq = rf.Sequential(x) + + def __call__(self, data: Tensor) -> Tensor: + """ + Forward + """ + seq = self.seq(data) + return seq + + # noinspection PyShadowingNames + def _forward_step(*, model: _Net, extern_data: TensorDict): + out = model(extern_data["data"]) + out.mark_as_default_output(shape=(batch_dim, time_dim, model.out_dim)) + + run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) + + def test_parameter_list(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(7, name="in")