diff --git a/LION/models/iterative_unrolled/LPD.py b/LION/models/iterative_unrolled/LPD.py index 33158a5..9f143d1 100644 --- a/LION/models/iterative_unrolled/LPD.py +++ b/LION/models/iterative_unrolled/LPD.py @@ -32,7 +32,7 @@ class dataProximal(nn.Module): CNN block of the dual variable """ - def __init__(self, layers, channels, instance_norm=False): + def __init__(self, layers, channels, conv_bias, instance_norm=False): super().__init__() # imput parsing @@ -50,12 +50,12 @@ def __init__(self, layers, channels, instance_norm=False): # PReLUs and 3x3 kernels all the way except the last if ii < layers - 1: layer_list.append( - nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=False) + nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=conv_bias) ) layer_list.append(nn.PReLU()) else: layer_list.append( - nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=False) + nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=conv_bias) ) self.block = nn.Sequential(*layer_list) @@ -68,7 +68,7 @@ class RegProximal(nn.Module): CNN block of the primal variable """ - def __init__(self, layers, channels, instance_norm=False): + def __init__(self, layers, channels, conv_bias, instance_norm=False): super().__init__() if len(channels) != layers + 1: raise ValueError( @@ -84,12 +84,12 @@ def __init__(self, layers, channels, instance_norm=False): # PReLUs and 3x3 kernels all the way except the last if ii < layers - 1: layer_list.append( - nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=False) + nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=conv_bias) ) layer_list.append(nn.PReLU()) else: layer_list.append( - nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=False) + nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=conv_bias) ) self.block = nn.Sequential(*layer_list) @@ -122,6 +122,7 @@ def __init__( RegProximal( layers=len(self.model_parameters.reg_channels) - 1, channels=self.model_parameters.reg_channels, + conv_bias=self.model_parameters.conv_bias, instance_norm=self.model_parameters.instance_norm, ), ) @@ -130,6 +131,7 @@ def __init__( dataProximal( layers=len(self.model_parameters.data_channels) - 1, channels=self.model_parameters.data_channels, + conv_bias=self.model_parameters.conv_bias, instance_norm=self.model_parameters.instance_norm, ), ) @@ -199,6 +201,7 @@ def default_parameters(): LPD_params.step_positive = False LPD_params.mode = "ct" LPD_params.instance_norm = False + LPD_params.conv_bias = True return LPD_params @staticmethod @@ -212,6 +215,7 @@ def continous_LPD_paper(): LPD_params.step_positive = True LPD_params.mode = "ct" LPD_params.instance_norm = True + LPD_params.conv_bias = False return LPD_params @staticmethod