From ff4e4e9da3e8bf25acca0ee4f29f2be0174fc8ff Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 Nov 2024 17:39:59 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/nn/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 6080d8637..177217b98 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -405,16 +405,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager): def __init__(self, mode): self.mode = mode + self._saved_mode = None def clone(self): return type(self)(self.mode) def __enter__(self): global DISPATCH_TDNN_MODULES - self._saved_mode = DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self.mode + # We want to avoid changing global variables because compile puts guards on them + if DISPATCH_TDNN_MODULES != self.mode: + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode def __exit__(self, exc_type, exc_val, exc_tb): + if self._saved_mode is None: + return global DISPATCH_TDNN_MODULES DISPATCH_TDNN_MODULES = self._saved_mode