From 3316bdcc69ef6bbccfc34c6b17d9d83de3ecdb2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=9Cnal=20Ege=20Gaznepo=C4=9Flu?= Date: Thu, 24 Oct 2024 16:59:28 +0200 Subject: [PATCH] [BugFix] Reference cycle in TensorDict._cast_reduction (#1056) --- tensordict/_td.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 4387839b5..2785b7da4 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -978,24 +978,30 @@ def _cast_reduction( agglomerate, keepdim=keepdim, dim=dim ) - def proc_dim(dim, tuple_ok=True): + # IMPORTANT: do not directly access batch_dims (or any other property) + # via self.batch_dims otherwise a reference cycle is introduced + def proc_dim(dim, batch_dims, tuple_ok=True): if dim is None: return dim if isinstance(dim, tuple): if tuple_ok: - return tuple(_d for d in dim for _d in proc_dim(d, tuple_ok=False)) + return tuple( + _d + for d in dim + for _d in proc_dim(d, batch_dims, tuple_ok=False) + ) return dim - if dim >= self.batch_dims or dim < -self.batch_dims: + if dim >= batch_dims or dim < -batch_dims: raise RuntimeError( "dim must be greater than or equal to -tensordict.batch_dims and " "smaller than tensordict.batch_dims" ) if dim < 0: - return (self.batch_dims + dim,) + return (batch_dims + dim,) return (dim,) if dim is not NO_DEFAULT: - dim = proc_dim(dim, tuple_ok=tuple_ok) + dim = proc_dim(dim, self.batch_dims, tuple_ok=tuple_ok) if not tuple_ok: dim = dim[0] if dim is not NO_DEFAULT or keepdim: