From 7d174a444f155eb4a62b3d622131639b72e55c0d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 18 Oct 2023 16:56:33 -0700 Subject: [PATCH] amend --- tensordict/tensordict.py | 2 +- test/test_tensordict.py | 20 +++----------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index a4f5d8062..45098cdcf 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -4742,7 +4742,7 @@ def load_memmap(cls, prefix: str) -> T: prefix = Path(prefix) metadata = torch.load(prefix / "meta.pt") # TODO: remove this - assert metadata["device"] == torch.device("cpu") + assert metadata["device"] == torch.device("cpu"), metadata out = cls({}, batch_size=metadata["batch_size"], device=metadata["device"]) for path in prefix.glob("**/*meta.pt"): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5394adce2..8e9d739b0 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2272,20 +2272,6 @@ def test_chunk(self, td_name, device, dim, chunks): assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] assert (torch.cat(td_chunks, dim) == td).all() - # def test_as_tensor(self, td_name, device): - # td = getattr(self, td_name)(device) - # if "memmap" in td_name and device == torch.device("cpu"): - # tdt = td.as_tensor() - # assert (tdt == td).all() - # elif "memmap" in td_name: - # with pytest.raises( - # RuntimeError, match="can only be called with MemmapTensors stored" - # ): - # td.as_tensor() - # else: - # with pytest.raises(AttributeError): - # td.as_tensor() - def test_items_values_keys(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) @@ -5240,11 +5226,11 @@ def test_memmap_as_tensor(device): td_memmap = td.clone().memmap_() assert (td == td_memmap).all() - assert (td == td_memmap.apply(lambda x: x.as_tensor())).all() + assert (td == td_memmap.apply(lambda x: x.clone())).all() if device.type == "cuda": td = td.pin_memory() td_memmap = td.clone().memmap_() - td_memmap_pm = td_memmap.apply(lambda x: x.as_tensor()).pin_memory() + td_memmap_pm = td_memmap.apply(lambda x: x.clone()).pin_memory() assert (td.pin_memory().to(device) == td_memmap_pm.to(device)).all() @@ -5764,7 +5750,7 @@ def test_memmap_td(self): assert td.names == list("abcd") td.rename_(c="g") assert td.names == list("abgd") - assert td.as_tensor().names == list("abgd") + assert td.clone().names == list("abgd") def test_h5_td(self): td = self.td_h5("cpu")