From 3485c2c890ea4672a7775ff0cdc5c87ee41729c4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:03 +0000 Subject: [PATCH] [Feature] from_any with UserDict ghstack-source-id: 420464209cff29c3a1c58ec521fbf4ed69d1355f Pull Request resolved: https://github.com/pytorch/tensordict/pull/1106 --- tensordict/base.py | 3 +++ test/test_tensordict.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 6c600b11f..f3034fd96 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -18,6 +18,7 @@ import uuid import warnings import weakref +from collections import UserDict from collections.abc import MutableMapping from concurrent.futures import Future, ThreadPoolExecutor, wait @@ -9871,6 +9872,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, UserDict): + return cls.from_dict(dict(obj), auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) if isinstance(obj, tuple): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 73d401c03..50fc7170e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -18,6 +18,7 @@ import sys import uuid import warnings +from collections import UserDict from dataclasses import dataclass from pathlib import Path from typing import Any @@ -996,6 +997,13 @@ class MyClass: td.keys(True, True) ).symmetric_difference(expected) + def test_from_any_userdict(self): + class D(UserDict): ... + + d = D(a=0) + assert TensorDict.from_any(d)["a"] == 0 + assert isinstance(TensorDict.from_any(d)["a"], torch.Tensor) + def test_from_dataclass(self): @dataclass class MyClass: