diff --git a/tests/test_rf_container.py b/tests/test_rf_container.py index 52317f6ee9..716fae218b 100644 --- a/tests/test_rf_container.py +++ b/tests/test_rf_container.py @@ -44,6 +44,23 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_module_slice_set_del(): + rf.select_backend_torch() + base_dim = Dim(3, name="linear-out") + dims = [base_dim + i for i in range(4)] + in_dim = Dim(7, name="in") + in_dims = [in_dim] + dims[:-1] + layers = rf.ModuleList([rf.Linear(in_dim_, out_dim_) for in_dim_, out_dim_ in zip(in_dims, dims)]) + assert len(layers) == 4 and [k for k, v in layers.items()] == ["0", "1", "2", "3"] + orig_layers = layers[:] + assert isinstance(orig_layers, rf.ModuleList) + assert len(orig_layers) == 4 and [k for k, v in orig_layers.items()] == ["0", "1", "2", "3"] + del layers[2:] + assert len(layers) == 2 and [k for k, v in layers.items()] == ["0", "1"] + layers[:] = orig_layers + assert len(layers) == 4 and [k for k, v in layers.items()] == ["0", "1", "2", "3"] + + def test_sequential_base_case(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(7, name="in")