diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 208e26707..8a22ed846 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -20,6 +20,7 @@ TensorDict, TensorDictBase, ) +from tensordict.memmap_refact import MemoryMappedTensor from tensordict.utils import is_tensorclass try: diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 552344c6c..480d19e73 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -31,9 +31,8 @@ from tensordict import ( is_tensorclass, LazyStackedTensorDict, - MemmapTensor, tensorclass, - TensorDict, + TensorDict,MemoryMappedTensor, ) from tensordict.tensordict import ( _PermutedTensorDict, @@ -379,16 +378,16 @@ def test_setitem_memmap(): @tensorclass class MyDataMemMap1: x: torch.Tensor - y: MemmapTensor + y: MemoryMappedTensor data1 = MyDataMemMap1( x=torch.zeros(3, 4, 5), - y=MemmapTensor.from_tensor(torch.zeros(3, 4, 5)), + y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), batch_size=[3, 4], ) data2 = MyDataMemMap1( - x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)), + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), y=torch.ones(3, 4, 5), batch_size=[3, 4], ) @@ -407,22 +406,22 @@ def test_setitem_other_cls(): @tensorclass class MyData1: x: torch.Tensor - y: MemmapTensor + y: MemoryMappedTensor data1 = MyData1( x=torch.zeros(3, 4, 5), - y=MemmapTensor.from_tensor(torch.zeros(3, 4, 5)), + y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), batch_size=[3, 4], ) # Set Item should work for other tensorclass @tensorclass class MyData2: - x: MemmapTensor + x: MemoryMappedTensor y: torch.Tensor data_other_cls = MyData2( - x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)), + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), y=torch.ones(3, 4, 5), batch_size=[3, 4], ) @@ -432,11 +431,11 @@ class MyData2: # Set Item should raise if other tensorclass with different members @tensorclass class MyData3: - x: MemmapTensor + x: MemoryMappedTensor z: torch.Tensor data_wrong_cls = MyData3( - x=MemmapTensor.from_tensor(torch.ones(3, 4, 5)), + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), z=torch.ones(3, 4, 5), batch_size=[3, 4], ) @@ -476,7 +475,7 @@ class MyDataNested: elif broadcast_type == "tensordict": val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4]) elif broadcast_type == "maptensor": - val = MemmapTensor.from_tensor(torch.zeros(4, 5)) + val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5)) data[:2] = val assert (data[:2] == 0).all() @@ -1349,7 +1348,7 @@ class MyClass: batch_size=[], ) tc.memmap_() - assert isinstance(tc.y.x, MemmapTensor) + assert isinstance(tc.y.x, MemoryMappedTensor) assert tc.z == z app_state = { @@ -1364,7 +1363,7 @@ class MyClass: batch_size=[], ) tc_dest.memmap_() - assert isinstance(tc_dest.y.x, MemmapTensor) + assert isinstance(tc_dest.y.x, MemoryMappedTensor) app_state = { "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True)) } @@ -1372,7 +1371,7 @@ class MyClass: assert (tc_dest == tc).all() assert tc_dest.y.batch_size == tc.y.batch_size - assert isinstance(tc_dest.y.x, MemmapTensor) + assert isinstance(tc_dest.y.x, MemoryMappedTensor) # torchsnapshot does not support updating strings and such assert tc_dest.z != z @@ -1386,7 +1385,7 @@ class MyClass: tc_dest.load_state_dict(tc.state_dict()) assert (tc_dest == tc).all() assert tc_dest.y.batch_size == tc.y.batch_size - assert isinstance(tc_dest.y.x, MemmapTensor) + assert isinstance(tc_dest.y.x, MemoryMappedTensor) # load_state_dict outperforms snapshot in this case assert tc_dest.z == z @@ -1610,8 +1609,8 @@ class MyClass: cmemmap = c.memmap_() assert cmemmap is c - assert isinstance(c.x, MemmapTensor) - assert isinstance(c.y.x, MemmapTensor) + assert isinstance(c.x, MemoryMappedTensor) + assert isinstance(c.y.x, MemoryMappedTensor) assert c.z == "foo" @@ -1633,8 +1632,8 @@ class MyClass: assert cmemmap is not c assert cmemmap.y is not c.y assert (cmemmap == 0).all() - assert isinstance(cmemmap.x, MemmapTensor) - assert isinstance(cmemmap.y.x, MemmapTensor) + assert isinstance(cmemmap.x, MemoryMappedTensor) + assert isinstance(cmemmap.y.x, MemoryMappedTensor) assert cmemmap.z == "foo"