diff --git a/models/networks.py b/models/networks.py index d84ad63..976e9db 100644 --- a/models/networks.py +++ b/models/networks.py @@ -53,6 +53,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def gated(self, mask): #return torch.clamp(mask, -1, 1) return self.sigmoid(mask) + def forward(self, input): x = self.conv2d(input) mask = self.mask_conv2d(input) @@ -80,7 +81,7 @@ def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride= def forward(self, input): #print(input.size()) - x = F.interpolate(input, scale_factor=2) + x = F.interpolate(input, scale_factor=self.scale_factor) return self.conv2d(x) class SNGatedConv2dWithActivation(torch.nn.Module):