Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 18, 2023
1 parent 0403add commit 7d174a4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4742,7 +4742,7 @@ def load_memmap(cls, prefix: str) -> T:
prefix = Path(prefix)
metadata = torch.load(prefix / "meta.pt")
# TODO: remove this
assert metadata["device"] == torch.device("cpu")
assert metadata["device"] == torch.device("cpu"), metadata
out = cls({}, batch_size=metadata["batch_size"], device=metadata["device"])

for path in prefix.glob("**/*meta.pt"):
Expand Down
20 changes: 3 additions & 17 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,20 +2272,6 @@ def test_chunk(self, td_name, device, dim, chunks):
assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim]
assert (torch.cat(td_chunks, dim) == td).all()

# def test_as_tensor(self, td_name, device):
# td = getattr(self, td_name)(device)
# if "memmap" in td_name and device == torch.device("cpu"):
# tdt = td.as_tensor()
# assert (tdt == td).all()
# elif "memmap" in td_name:
# with pytest.raises(
# RuntimeError, match="can only be called with MemmapTensors stored"
# ):
# td.as_tensor()
# else:
# with pytest.raises(AttributeError):
# td.as_tensor()

def test_items_values_keys(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down Expand Up @@ -5240,11 +5226,11 @@ def test_memmap_as_tensor(device):
td_memmap = td.clone().memmap_()
assert (td == td_memmap).all()

assert (td == td_memmap.apply(lambda x: x.as_tensor())).all()
assert (td == td_memmap.apply(lambda x: x.clone())).all()
if device.type == "cuda":
td = td.pin_memory()
td_memmap = td.clone().memmap_()
td_memmap_pm = td_memmap.apply(lambda x: x.as_tensor()).pin_memory()
td_memmap_pm = td_memmap.apply(lambda x: x.clone()).pin_memory()
assert (td.pin_memory().to(device) == td_memmap_pm.to(device)).all()


Expand Down Expand Up @@ -5764,7 +5750,7 @@ def test_memmap_td(self):
assert td.names == list("abcd")
td.rename_(c="g")
assert td.names == list("abgd")
assert td.as_tensor().names == list("abgd")
assert td.clone().names == list("abgd")

def test_h5_td(self):
td = self.td_h5("cpu")
Expand Down

0 comments on commit 7d174a4

Please sign in to comment.