Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 19, 2023
1 parent de0c028 commit 6b40de8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
1 change: 1 addition & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TensorDict,
TensorDictBase,
)
from tensordict.memmap_refact import MemoryMappedTensor
from tensordict.utils import is_tensorclass

try:
Expand Down
39 changes: 19 additions & 20 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from tensordict import (
is_tensorclass,
LazyStackedTensorDict,
MemmapTensor,
tensorclass,
TensorDict,
TensorDict,MemoryMappedTensor,
)
from tensordict.tensordict import (
_PermutedTensorDict,
Expand Down Expand Up @@ -379,16 +378,16 @@ def test_setitem_memmap():
@tensorclass
class MyDataMemMap1:
x: torch.Tensor
y: MemmapTensor
y: MemoryMappedTensor

data1 = MyDataMemMap1(
x=torch.zeros(3, 4, 5),
y=MemmapTensor.from_tensor(torch.zeros(3, 4, 5)),
y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)),
batch_size=[3, 4],
)

data2 = MyDataMemMap1(
x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)),
x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)),
y=torch.ones(3, 4, 5),
batch_size=[3, 4],
)
Expand All @@ -407,22 +406,22 @@ def test_setitem_other_cls():
@tensorclass
class MyData1:
x: torch.Tensor
y: MemmapTensor
y: MemoryMappedTensor

data1 = MyData1(
x=torch.zeros(3, 4, 5),
y=MemmapTensor.from_tensor(torch.zeros(3, 4, 5)),
y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)),
batch_size=[3, 4],
)

# Set Item should work for other tensorclass
@tensorclass
class MyData2:
x: MemmapTensor
x: MemoryMappedTensor
y: torch.Tensor

data_other_cls = MyData2(
x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)),
x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)),
y=torch.ones(3, 4, 5),
batch_size=[3, 4],
)
Expand All @@ -432,11 +431,11 @@ class MyData2:
# Set Item should raise if other tensorclass with different members
@tensorclass
class MyData3:
x: MemmapTensor
x: MemoryMappedTensor
z: torch.Tensor

data_wrong_cls = MyData3(
x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)),
x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)),
z=torch.ones(3, 4, 5),
batch_size=[3, 4],
)
Expand Down Expand Up @@ -476,7 +475,7 @@ class MyDataNested:
elif broadcast_type == "tensordict":
val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4])
elif broadcast_type == "maptensor":
val = MemmapTensor.from_tensor(torch.zeros(4, 5))
val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5))

data[:2] = val
assert (data[:2] == 0).all()
Expand Down Expand Up @@ -1349,7 +1348,7 @@ class MyClass:
batch_size=[],
)
tc.memmap_()
assert isinstance(tc.y.x, MemmapTensor)
assert isinstance(tc.y.x, MemoryMappedTensor)
assert tc.z == z

app_state = {
Expand All @@ -1364,15 +1363,15 @@ class MyClass:
batch_size=[],
)
tc_dest.memmap_()
assert isinstance(tc_dest.y.x, MemmapTensor)
assert isinstance(tc_dest.y.x, MemoryMappedTensor)
app_state = {
"state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
}
snapshot.restore(app_state=app_state)

assert (tc_dest == tc).all()
assert tc_dest.y.batch_size == tc.y.batch_size
assert isinstance(tc_dest.y.x, MemmapTensor)
assert isinstance(tc_dest.y.x, MemoryMappedTensor)
# torchsnapshot does not support updating strings and such
assert tc_dest.z != z

Expand All @@ -1386,7 +1385,7 @@ class MyClass:
tc_dest.load_state_dict(tc.state_dict())
assert (tc_dest == tc).all()
assert tc_dest.y.batch_size == tc.y.batch_size
assert isinstance(tc_dest.y.x, MemmapTensor)
assert isinstance(tc_dest.y.x, MemoryMappedTensor)
# load_state_dict outperforms snapshot in this case
assert tc_dest.z == z

Expand Down Expand Up @@ -1610,8 +1609,8 @@ class MyClass:

cmemmap = c.memmap_()
assert cmemmap is c
assert isinstance(c.x, MemmapTensor)
assert isinstance(c.y.x, MemmapTensor)
assert isinstance(c.x, MemoryMappedTensor)
assert isinstance(c.y.x, MemoryMappedTensor)
assert c.z == "foo"


Expand All @@ -1633,8 +1632,8 @@ class MyClass:
assert cmemmap is not c
assert cmemmap.y is not c.y
assert (cmemmap == 0).all()
assert isinstance(cmemmap.x, MemmapTensor)
assert isinstance(cmemmap.y.x, MemmapTensor)
assert isinstance(cmemmap.x, MemoryMappedTensor)
assert isinstance(cmemmap.y.x, MemoryMappedTensor)
assert cmemmap.z == "foo"


Expand Down

0 comments on commit 6b40de8

Please sign in to comment.