Skip to content

Commit

Permalink
[BugFix] Reference cycle in TensorDict._cast_reduction (#1056)
Browse files Browse the repository at this point in the history
  • Loading branch information
egaznep authored Oct 24, 2024
1 parent 8c65dcb commit 3316bdc
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,24 +978,30 @@ def _cast_reduction(
agglomerate, keepdim=keepdim, dim=dim
)

def proc_dim(dim, tuple_ok=True):
# IMPORTANT: do not directly access batch_dims (or any other property)
# via self.batch_dims otherwise a reference cycle is introduced
def proc_dim(dim, batch_dims, tuple_ok=True):
if dim is None:
return dim
if isinstance(dim, tuple):
if tuple_ok:
return tuple(_d for d in dim for _d in proc_dim(d, tuple_ok=False))
return tuple(
_d
for d in dim
for _d in proc_dim(d, batch_dims, tuple_ok=False)
)
return dim
if dim >= self.batch_dims or dim < -self.batch_dims:
if dim >= batch_dims or dim < -batch_dims:
raise RuntimeError(
"dim must be greater than or equal to -tensordict.batch_dims and "
"smaller than tensordict.batch_dims"
)
if dim < 0:
return (self.batch_dims + dim,)
return (batch_dims + dim,)
return (dim,)

if dim is not NO_DEFAULT:
dim = proc_dim(dim, tuple_ok=tuple_ok)
dim = proc_dim(dim, self.batch_dims, tuple_ok=tuple_ok)
if not tuple_ok:
dim = dim[0]
if dim is not NO_DEFAULT or keepdim:
Expand Down

0 comments on commit 3316bdc

Please sign in to comment.