Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent e871b7d commit 02b3c3a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8679,7 +8679,7 @@ def _clone_recurse(self) -> TensorDictBase: # noqa: D417
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=False,
filter_empty=True,
filter_empty=False,
default=None,
)
if items:
Expand Down
11 changes: 11 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ def test_cat_from_tensordict(self):
assert (tensor[:, :4] == 0).all()
assert (tensor[:, 4:] == 1).all()

@pytest.mark.parametrize("recurse", [True, False])
def test_clone_empty(self, recurse):
td = TensorDict()
assert td.clone(recurse=recurse) is not None
td = TensorDict(device="cpu")
assert td.clone(recurse=recurse) is not None
td = TensorDict(batch_size=[2])
assert td.clone(recurse=recurse) is not None
td = TensorDict(device="cpu", batch_size=[2])
assert td.clone(recurse=recurse) is not None

@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("num_threads", [0, 1, 2])
Expand Down

0 comments on commit 02b3c3a

Please sign in to comment.