diff --git a/model/network.py b/model/network.py index 88b3a3f..8ee89a4 100644 --- a/model/network.py +++ b/model/network.py @@ -75,13 +75,13 @@ def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256, # size: size / 4 self.sa4 = SelfAttention(channels=channel[1], size=int(self.image_size / 4), act=act) # channel: 256 -> 64 in_channels: up2(256) = sa4(128) + sa1(128) - # size: size / 4 + # size: size / 2 self.up2 = UpBlock(in_channels=channel[2], out_channels=channel[0], act=act) # channel: 128 # size: size / 2 self.sa5 = SelfAttention(channels=channel[0], size=int(self.image_size / 2), act=act) # channel: 128 -> 64 in_channels: up3(128) = sa5(64) + inc(64) - # size: size / 4 + # size: size self.up3 = UpBlock(in_channels=channel[1], out_channels=channel[0], act=act) # channel: 128 # size: size