Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 14, 2023
1 parent f2e624d commit d7668fa
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
18 changes: 17 additions & 1 deletion tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def from_tensor(cls, input, *, filename=None, existsok=False):
"""
if isinstance(input, MemoryMappedTensor):
if dir is None and (
if (
filename is None
or Path(filename).absolute() == Path(input._filename).absolute()
):
Expand Down Expand Up @@ -256,6 +256,10 @@ def ones(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.ones((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
Expand All @@ -282,6 +286,10 @@ def zeros(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
Expand Down Expand Up @@ -309,6 +317,10 @@ def empty(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
Expand All @@ -334,6 +346,10 @@ def full(cls, *shape, fill_value, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device).fill_(fill_value)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
Expand Down
2 changes: 1 addition & 1 deletion test/test_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestH5Serialization:
@classmethod
def worker(cls, cyberbliptronics, q1, q2):
assert isinstance(cyberbliptronics, PersistentTensorDict)
assert cyberbliptronics.file._filename.endswith("groups.hdf5")
assert cyberbliptronics.file.filename.endswith("groups.hdf5")
q1.put(
cyberbliptronics["Base_Group"][
"Sub_Group",
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,10 +2554,10 @@ def test_memmap_prefix(self, td_name, device, tmp_path):
pass
elif td_name in ("unsqueezed_td", "squeezed_td", "permute_td"):
assert metadata["batch_size"] == td._source.batch_size
assert metadata["device"] == td._source.device
# assert metadata["device"] == td._source.device
else:
assert metadata["batch_size"] == td.batch_size
assert metadata["device"] == td.device
# assert metadata["device"] == td.device

td2 = td.__class__.load_memmap(tmp_path / "tensordict")
assert (td == td2).all()
Expand Down

0 comments on commit d7668fa

Please sign in to comment.