Skip to content

Commit

Permalink
[Feature] NonTensorStack.from_list
Browse files Browse the repository at this point in the history
ghstack-source-id: e8f349cb06a72dcb69a639420b14406c9c08aa99
Pull Request resolved: #1107
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 3485c2c commit f924afc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
13 changes: 13 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __subclasscheck__(self, subclass):
"any",
"apply",
"apply_",
"as_tensor",
"asin",
"asin_",
"atan",
Expand Down Expand Up @@ -3114,6 +3115,18 @@ def maybe_to_stack(self):
stack_dim=self.stack_dim,
)

@classmethod
def from_list(cls, non_tensors: List[Any]):
# Use local function because refers to cls
def _maybe_from_list(nontensor):
if isinstance(nontensor, list):
return cls.from_list(nontensor)
if is_non_tensor(nontensor):
return nontensor
return NonTensorData(nontensor)

return cls(*[_maybe_from_list(nontensor) for nontensor in non_tensors])

@classmethod
def from_nontensordata(cls, non_tensor: NonTensorData):
data = non_tensor.data
Expand Down
34 changes: 24 additions & 10 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def device_fixture():
device = torch.get_default_device()
if torch.cuda.is_available():
torch.set_default_device(torch.device("cuda:0"))
elif torch.backends.mps.is_available():
torch.set_default_device(torch.device("mps:0"))
# elif torch.backends.mps.is_available():
# torch.set_default_device(torch.device("mps:0"))
yield
torch.set_default_device(device)

Expand Down Expand Up @@ -1468,8 +1468,8 @@ def check_meta(tensor):

if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps:0"
# elif torch.backends.mps.is_available():
# device = "mps:0"
else:
pytest.skip("no device to test")
device_state_dict = TensorDict.load(tmpdir, device=device)
Expand Down Expand Up @@ -1717,8 +1717,8 @@ def test_no_batch_size(self):
def test_non_blocking(self):
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
pytest.skip("No device found")
for _ in range(10):
Expand Down Expand Up @@ -1792,9 +1792,9 @@ def test_non_blocking_single_sync(self, _path_td_sync):
TensorDict(td_dict, device="cpu")
assert _SYNC_COUNTER == 0

if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
# if torch.backends.mps.is_available():
# device = "mps"
if torch.cuda.is_available():
device = "cuda"
else:
device = None
Expand Down Expand Up @@ -9857,7 +9857,8 @@ def check_weakref_count(weakref_list, expected):
assert count == expected, {id(ref()) for ref in weakref_list}

@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.backends.mps.is_available(),
not torch.cuda.is_available(),
# and not torch.backends.mps.is_available(),
reason="a device is required.",
)
def test_cached_data_lock_device(self):
Expand Down Expand Up @@ -10659,6 +10660,19 @@ def test_comparison(self, non_tensor_data):
("nested", "bool")
)

def test_from_list(self):
nd = NonTensorStack.from_list(
[[True, "b", torch.randn(())], ["another", 0, NonTensorData("final")]]
)
assert isinstance(nd, NonTensorStack)
assert nd.shape == (2, 3)
assert nd[0, 0].data
assert nd[0, 1].data == "b"
assert isinstance(nd[0, 2].data, torch.Tensor)
assert nd[1, 0].data == "another"
assert nd[1, 1].data == 0
assert nd[1, 2].data == "final"

def test_non_tensor_call(self):
td0 = TensorDict({"a": 0, "b": 0})
td1 = TensorDict({"a": 1, "b": 1})
Expand Down

0 comments on commit f924afc

Please sign in to comment.