Skip to content

Commit

Permalink
[Feature] from_any with UserDict
Browse files Browse the repository at this point in the history
ghstack-source-id: 420464209cff29c3a1c58ec521fbf4ed69d1355f
Pull Request resolved: #1106
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent d5fcace commit 3485c2c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import uuid
import warnings
import weakref
from collections import UserDict
from collections.abc import MutableMapping

from concurrent.futures import Future, ThreadPoolExecutor, wait
Expand Down Expand Up @@ -9871,6 +9872,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
return obj
if isinstance(obj, dict):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, UserDict):
return cls.from_dict(dict(obj), auto_batch_size=auto_batch_size)
if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"):
return cls.from_struct_array(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, tuple):
Expand Down
8 changes: 8 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import uuid
import warnings
from collections import UserDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -996,6 +997,13 @@ class MyClass:
td.keys(True, True)
).symmetric_difference(expected)

def test_from_any_userdict(self):
class D(UserDict): ...

d = D(a=0)
assert TensorDict.from_any(d)["a"] == 0
assert isinstance(TensorDict.from_any(d)["a"], torch.Tensor)

def test_from_dataclass(self):
@dataclass
class MyClass:
Expand Down

0 comments on commit 3485c2c

Please sign in to comment.