From 6429a7266f3b7f5bbe80b6c98b5eb1af7716d24a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 19:32:07 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 3 +++ test/test_tensordict.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 6c600b11f..f3034fd96 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -18,6 +18,7 @@ import uuid import warnings import weakref +from collections import UserDict from collections.abc import MutableMapping from concurrent.futures import Future, ThreadPoolExecutor, wait @@ -9871,6 +9872,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, UserDict): + return cls.from_dict(dict(obj), auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) if isinstance(obj, tuple): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 73d401c03..50fc7170e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -18,6 +18,7 @@ import sys import uuid import warnings +from collections import UserDict from dataclasses import dataclass from pathlib import Path from typing import Any @@ -996,6 +997,13 @@ class MyClass: td.keys(True, True) ).symmetric_difference(expected) + def test_from_any_userdict(self): + class D(UserDict): ... + + d = D(a=0) + assert TensorDict.from_any(d)["a"] == 0 + assert isinstance(TensorDict.from_any(d)["a"], torch.Tensor) + def test_from_dataclass(self): @dataclass class MyClass: