From f24e3d84ac094c54269f69ea64790cddb9f43d4e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 Nov 2024 17:40:02 +0000 Subject: [PATCH] [Refactor] Make _set_dispatch_td_nn_modules compatible with compile ghstack-source-id: 85a78cd6086233b414fcfe221dd8129e2e38f71c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1084 (cherry picked from commit 853b7d982bf58ccaa9154485dfaf030a16d3bce7) --- tensordict/nn/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index b503eb4a0..954511c7b 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -406,16 +406,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager): def __init__(self, mode): self.mode = mode + self._saved_mode = None def clone(self): return type(self)(self.mode) def __enter__(self): global DISPATCH_TDNN_MODULES - self._saved_mode = DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self.mode + # We want to avoid changing global variables because compile puts guards on them + if DISPATCH_TDNN_MODULES != self.mode: + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode def __exit__(self, exc_type, exc_val, exc_tb): + if self._saved_mode is None: + return global DISPATCH_TDNN_MODULES DISPATCH_TDNN_MODULES = self._saved_mode