Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 1, 2024
1 parent d1c028b commit 035ce7d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
29 changes: 22 additions & 7 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6008,17 +6008,32 @@ def _items_list(
vals = [source.get(key, default) for key in new_keys]
return new_keys, vals

@cache # noqa: B019
def _grad(self):
result = self._fast_apply(
lambda x: x.grad, propagate_lock=True, filter_empty=True
# We can't cache this because zero_grad can be called outside (eg from optimizer) and we want the tensors
# to clear out when that is done.
keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
grads = [val.grad for val in vals]
items = dict(zip(keys, grads))
return self._fast_apply(
lambda name, val: items[name],
named=True,
nested_keys=True,
propagate_lock=True,
filter_empty=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
)
return result

@cache # noqa: B019
def _data(self):
result = self._fast_apply(lambda x: x.data, propagate_lock=True)
return result
keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
data = [val.data for val in vals]
items = dict(zip(keys, data))
return self._fast_apply(
lambda name, val: items.get(name),
named=True,
nested_keys=True,
propagate_lock=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
)

@abc.abstractmethod
def keys(
Expand Down
11 changes: 11 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2863,6 +2863,17 @@ def test_update_nested_dict(self):
assert t["a", "b"].shape == torch.Size([2, 3, 1])
t.update({"a": {"d": [[[1]] * 3] * 2}})

def test_zero_grad_module(self):
x = torch.randn(3, 3)
linear = nn.Linear(3, 4)
y = linear(x)
y.sum().backward()
p = TensorDict.from_module(linear).lock_()
assert not p.grad.is_empty()
linear.zero_grad(set_to_none=True)
assert p.grad is None
assert linear.weight.grad is None


class TestPointwiseOps:
@property
Expand Down

0 comments on commit 035ce7d

Please sign in to comment.