Skip to content

Commit

Permalink
RF test_module_slice_set_del
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 9, 2023
1 parent 17915f8 commit 5db7847
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/test_rf_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_module_slice_set_del():
rf.select_backend_torch()
base_dim = Dim(3, name="linear-out")
dims = [base_dim + i for i in range(4)]
in_dim = Dim(7, name="in")
in_dims = [in_dim] + dims[:-1]
layers = rf.ModuleList([rf.Linear(in_dim_, out_dim_) for in_dim_, out_dim_ in zip(in_dims, dims)])
assert len(layers) == 4 and [k for k, v in layers.items()] == ["0", "1", "2", "3"]
orig_layers = layers[:]
assert isinstance(orig_layers, rf.ModuleList)
assert len(orig_layers) == 4 and [k for k, v in orig_layers.items()] == ["0", "1", "2", "3"]
del layers[2:]
assert len(layers) == 2 and [k for k, v in layers.items()] == ["0", "1"]
layers[:] = orig_layers
assert len(layers) == 4 and [k for k, v in layers.items()] == ["0", "1", "2", "3"]


def test_sequential_base_case():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
Expand Down

0 comments on commit 5db7847

Please sign in to comment.