Skip to content

Commit

Permalink
Add conv_bias parameter in LPD (Issue CambridgeCIA#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
D1rk123 committed Oct 16, 2024
1 parent a28be48 commit c5e0539
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions LION/models/iterative_unrolled/LPD.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class dataProximal(nn.Module):
CNN block of the dual variable
"""

def __init__(self, layers, channels, instance_norm=False):
def __init__(self, layers, channels, conv_bias, instance_norm=False):

super().__init__()
# imput parsing
Expand All @@ -50,12 +50,12 @@ def __init__(self, layers, channels, instance_norm=False):
# PReLUs and 3x3 kernels all the way except the last
if ii < layers - 1:
layer_list.append(
nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=False)
nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=conv_bias)
)
layer_list.append(nn.PReLU())
else:
layer_list.append(
nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=False)
nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=conv_bias)
)
self.block = nn.Sequential(*layer_list)

Expand All @@ -68,7 +68,7 @@ class RegProximal(nn.Module):
CNN block of the primal variable
"""

def __init__(self, layers, channels, instance_norm=False):
def __init__(self, layers, channels, conv_bias, instance_norm=False):
super().__init__()
if len(channels) != layers + 1:
raise ValueError(
Expand All @@ -84,12 +84,12 @@ def __init__(self, layers, channels, instance_norm=False):
# PReLUs and 3x3 kernels all the way except the last
if ii < layers - 1:
layer_list.append(
nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=False)
nn.Conv2d(channels[ii], channels[ii + 1], 3, padding=1, bias=conv_bias)
)
layer_list.append(nn.PReLU())
else:
layer_list.append(
nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=False)
nn.Conv2d(channels[ii], channels[ii + 1], 1, padding=0, bias=conv_bias)
)
self.block = nn.Sequential(*layer_list)

Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
RegProximal(
layers=len(self.model_parameters.reg_channels) - 1,
channels=self.model_parameters.reg_channels,
conv_bias=self.model_parameters.conv_bias,
instance_norm=self.model_parameters.instance_norm,
),
)
Expand All @@ -130,6 +131,7 @@ def __init__(
dataProximal(
layers=len(self.model_parameters.data_channels) - 1,
channels=self.model_parameters.data_channels,
conv_bias=self.model_parameters.conv_bias,
instance_norm=self.model_parameters.instance_norm,
),
)
Expand Down Expand Up @@ -199,6 +201,7 @@ def default_parameters():
LPD_params.step_positive = False
LPD_params.mode = "ct"
LPD_params.instance_norm = False
LPD_params.conv_bias = True
return LPD_params

@staticmethod
Expand All @@ -212,6 +215,7 @@ def continous_LPD_paper():
LPD_params.step_positive = True
LPD_params.mode = "ct"
LPD_params.instance_norm = True
LPD_params.conv_bias = False
return LPD_params

@staticmethod
Expand Down

0 comments on commit c5e0539

Please sign in to comment.