diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index b33a4a187..df6d41f68 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1006,7 +1006,10 @@ def test_pickle( with open(tempdir / "test.pkl", "rb") as f: data2 = pickle.load(f) - assert_allclose_td(data.to_tensordict(), data2.to_tensordict()) + assert_allclose_td( + data.to_tensordict(retain_none=False), + data2.to_tensordict(retain_none=False), + ) assert isinstance(data2, MyData) assert data2.z == data.z @@ -1524,14 +1527,14 @@ def get_data(cls, shift): assert (data0.y.X == 1).all() assert data0.y.z == "test_tensorclass1" data0 = MyDataNested.get_data(0) - data0.update(data1.to_dict()) + data0.update(data1.to_dict(retain_none=False)) assert (data0.X == 1).all() assert data0.z == "test_tensorclass1", data0.z assert (data0.y.X == 1).all() assert data0.y.z == "test_tensorclass1" data0 = MyDataNested.get_data(0) - data0.update(data1.to_tensordict()) + data0.update(data1.to_tensordict(retain_none=False)) assert (data0.X == 1).all() assert data0.z == "test_tensorclass1" assert (data0.y.X == 1).all() @@ -1745,7 +1748,7 @@ class MyClass: batch_size=[3, 4], ) - ctd = c.to_tensordict() + ctd = c.to_tensordict(retain_none=False) assert isinstance(ctd, TensorDictBase) assert "x" in ctd.keys() assert "z" in ctd.keys()