From b102010358072fb4f2c49d14c2700e5194ee97ba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 13:00:16 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 18 +++++++++++++----- test/test_tensordict.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index c868d1447..8ae783905 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6544,12 +6544,16 @@ def update_( named = True - def inplace_update(name, dest, source): + def inplace_update(name, source, dest): if source is None: return None name = _unravel_key_to_tuple(name) for key in keys_to_update: if key == name[: len(key)]: + if dest is None: + raise KeyError( + f"The key {name} was not found in the dest tensordict." + ) dest.copy_(source, non_blocking=non_blocking) else: @@ -6564,16 +6568,20 @@ def inplace_update(name, dest, source): vals = [vals[k] for k in new_keys] _foreach_copy_(vals, other_val, non_blocking=non_blocking) return self - named = False + named = True - def inplace_update(dest, source): + def inplace_update(name, source, dest): if source is None: return None + if dest is None: + raise KeyError( + f"The key {name} was not found in the dest tensordict." + ) dest.copy_(source, non_blocking=non_blocking) - self._apply_nest( + input_dict_or_td._apply_nest( inplace_update, - input_dict_or_td, + self, nested_keys=True, default=None, filter_empty=True, diff --git a/test/test_tensordict.py b/test/test_tensordict.py index de9fad1e7..fe2e1d0f6 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3033,6 +3033,28 @@ def make(val, todict=False, stack=False): assert (td1.select(("a", "b")) == 2).all() assert (td1.exclude(("a", "b")) == 1).all() + # Any extra key in dest will raise an exception + with pytest.raises(KeyError): + td_dest = TensorDict(a=0) + td_source = TensorDict(b=1) + td_dest.update_(td_source) + with pytest.raises(KeyError): + td_dest = TensorDict(a=0) + td_source = {"b": torch.ones(())} + td_dest.update_(td_source) + with pytest.raises(KeyError): + td_dest = TensorDict(a=0) + td_source = TensorDict(b=1) + td_dest.update_(td_source, keys_to_update="b") + with pytest.raises(KeyError): + td_dest = TensorDict(a=0) + td_source = {"b": torch.ones(())} + td_dest.update_(td_source, keys_to_update="b") + + td_dest = TensorDict(a=0, b=1) + td_source = TensorDict(a=0) + td_dest.update_(td_source) + def test_update_nested_dict(self): t = TensorDict({"a": {"d": [[[0]] * 3] * 2}}, [2, 3]) assert ("a", "d") in t.keys(include_nested=True)