Skip to content

Commit

Permalink
[BugFix] resilient _exclude_td_from_pytree
Browse files Browse the repository at this point in the history
ghstack-source-id: 7b3ee829689779777d301f0cfff119e48567f9bb
Pull Request resolved: #1038
  • Loading branch information
vmoens committed Oct 11, 2024
1 parent 1e32195 commit 49d226c
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def __init__(self):

def __enter__(self):
for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS:
self.tdnodes[tdtype] = SUPPORTED_NODES.pop(tdtype)
node = SUPPORTED_NODES.pop(tdtype, None)
if node is None:
continue
self.tdnodes[tdtype] = node

def __exit__(self, exc_type, exc_val, exc_tb):
for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS:
SUPPORTED_NODES[tdtype] = self.tdnodes[tdtype]
for tdtype, node in self.tdnodes.items():
SUPPORTED_NODES[tdtype] = node

def set(self):
self.__enter__()
Expand Down

0 comments on commit 49d226c

Please sign in to comment.