diff --git a/tensordict/base.py b/tensordict/base.py index 007189df7..3c3ebe320 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8679,7 +8679,7 @@ def _clone_recurse(self) -> TensorDictBase: # noqa: D417 nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, propagate_lock=False, - filter_empty=True, + filter_empty=False, default=None, ) if items: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index edf9eb580..a815ee79c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -340,6 +340,17 @@ def test_cat_from_tensordict(self): assert (tensor[:, :4] == 0).all() assert (tensor[:, 4:] == 1).all() + @pytest.mark.parametrize("recurse", [True, False]) + def test_clone_empty(self, recurse): + td = TensorDict() + assert td.clone(recurse=recurse) is not None + td = TensorDict(device="cpu") + assert td.clone(recurse=recurse) is not None + td = TensorDict(batch_size=[2]) + assert td.clone(recurse=recurse) is not None + td = TensorDict(device="cpu", batch_size=[2]) + assert td.clone(recurse=recurse) is not None + @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("num_threads", [0, 1, 2])