Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 25, 2024
1 parent eb4d92a commit 54af82d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,10 @@ def test_pickle(
with open(tempdir / "test.pkl", "rb") as f:
data2 = pickle.load(f)

assert_allclose_td(data.to_tensordict(), data2.to_tensordict())
assert_allclose_td(
data.to_tensordict(retain_none=False),
data2.to_tensordict(retain_none=False),
)
assert isinstance(data2, MyData)
assert data2.z == data.z

Expand Down Expand Up @@ -1524,14 +1527,14 @@ def get_data(cls, shift):
assert (data0.y.X == 1).all()
assert data0.y.z == "test_tensorclass1"
data0 = MyDataNested.get_data(0)
data0.update(data1.to_dict())
data0.update(data1.to_dict(retain_none=False))
assert (data0.X == 1).all()
assert data0.z == "test_tensorclass1", data0.z
assert (data0.y.X == 1).all()
assert data0.y.z == "test_tensorclass1"

data0 = MyDataNested.get_data(0)
data0.update(data1.to_tensordict())
data0.update(data1.to_tensordict(retain_none=False))
assert (data0.X == 1).all()
assert data0.z == "test_tensorclass1"
assert (data0.y.X == 1).all()
Expand Down Expand Up @@ -1745,7 +1748,7 @@ class MyClass:
batch_size=[3, 4],
)

ctd = c.to_tensordict()
ctd = c.to_tensordict(retain_none=False)
assert isinstance(ctd, TensorDictBase)
assert "x" in ctd.keys()
assert "z" in ctd.keys()
Expand Down

0 comments on commit 54af82d

Please sign in to comment.