diff --git a/tensordict/base.py b/tensordict/base.py index 33a8957e0..40aa60f36 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6008,17 +6008,32 @@ def _items_list( vals = [source.get(key, default) for key in new_keys] return new_keys, vals - @cache # noqa: B019 def _grad(self): - result = self._fast_apply( - lambda x: x.grad, propagate_lock=True, filter_empty=True + # We can't cache this because zero_grad can be called outside (eg from optimizer) and we want the tensors + # to clear out when that is done. + keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) + grads = [val.grad for val in vals] + items = dict(zip(keys, grads)) + return self._fast_apply( + lambda name, val: items[name], + named=True, + nested_keys=True, + propagate_lock=True, + filter_empty=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, ) - return result - @cache # noqa: B019 def _data(self): - result = self._fast_apply(lambda x: x.data, propagate_lock=True) - return result + keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) + data = [val.data for val in vals] + items = dict(zip(keys, data)) + return self._fast_apply( + lambda name, val: items.get(name), + named=True, + nested_keys=True, + propagate_lock=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + ) @abc.abstractmethod def keys( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a100590bc..be0ff5fef 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2863,6 +2863,17 @@ def test_update_nested_dict(self): assert t["a", "b"].shape == torch.Size([2, 3, 1]) t.update({"a": {"d": [[[1]] * 3] * 2}}) + def test_zero_grad_module(self): + x = torch.randn(3, 3) + linear = nn.Linear(3, 4) + y = linear(x) + y.sum().backward() + p = TensorDict.from_module(linear).lock_() + assert not p.grad.is_empty() + linear.zero_grad(set_to_none=True) + assert p.grad is None + assert linear.weight.grad is None + class TestPointwiseOps: @property