Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent 8929877 commit ff4e4e9
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,16 +405,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager):

def __init__(self, mode):
self.mode = mode
self._saved_mode = None

def clone(self):
return type(self)(self.mode)

def __enter__(self):
global DISPATCH_TDNN_MODULES
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode
# We want to avoid changing global variables because compile puts guards on them
if DISPATCH_TDNN_MODULES != self.mode:
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode

def __exit__(self, exc_type, exc_val, exc_tb):
if self._saved_mode is None:
return
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode

Expand Down

0 comments on commit ff4e4e9

Please sign in to comment.