Skip to content

Commit

Permalink
fix: release memory when forward mode is not used
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris authored Aug 2, 2024
1 parent cc950b2 commit f1b89dd
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions torchcomp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
if ctx.needs_input_grad[3]:
grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1)

if hasattr(ctx, "y"):
del ctx.y
if hasattr(ctx, "x"):
del ctx.x
if hasattr(ctx, "zi"):
del ctx.zi
if hasattr(ctx, "at"):
del ctx.at
if hasattr(ctx, "rt"):
del ctx.rt
if hasattr(ctx, "at_mask"):
del ctx.at_mask

return grad_x, grad_zi, grad_at, grad_rt

@staticmethod
Expand Down Expand Up @@ -180,6 +193,7 @@ def jvp(
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
)

del ctx.x, ctx.y, ctx.zi, ctx.at, ctx.rt, ctx.at_mask
return sample_wise_lpc(
fwd_combined,
coeffs.unsqueeze(2) - 1,
Expand Down

0 comments on commit f1b89dd

Please sign in to comment.