From bb20577cf907bbbac991646b056f9cbdd24ba345 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 3 Dec 2024 01:04:23 -0800 Subject: [PATCH] enable use_checkpoint flag for Attention Block (#153) --- torchcfm/models/unet/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchcfm/models/unet/unet.py b/torchcfm/models/unet/unet.py index e29df92..205ecab 100644 --- a/torchcfm/models/unet/unet.py +++ b/torchcfm/models/unet/unet.py @@ -270,7 +270,7 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) def _forward(self, x): b, c, *spatial = x.shape