diff --git a/tensordict/_td.py b/tensordict/_td.py index 6a6b2040f..2785b7da4 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -978,6 +978,8 @@ def _cast_reduction( agglomerate, keepdim=keepdim, dim=dim ) + # IMPORTANT: do not directly access batch_dims (or any other property) + # via self.batch_dims otherwise a reference cycle is introduced def proc_dim(dim, batch_dims, tuple_ok=True): if dim is None: return dim