Skip to content

Commit

Permalink
feat: ✨ Add final activation specification for nodes and networks
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 5, 2024
1 parent 1f22595 commit 1c80eaf
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 32 deletions.
2 changes: 2 additions & 0 deletions src/leibnetz/nets/attentive_scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def build_subnet(
residual=False,
dropout_prob=None,
activation="ReLU",
final_activation="Sigmoid",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand Down Expand Up @@ -143,6 +144,7 @@ def build_subnet(
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
final_activation=final_activation,
)
)
outputs = {
Expand Down
2 changes: 2 additions & 0 deletions src/leibnetz/nets/scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build_subnet(
residual=False,
dropout_prob=None,
activation="ReLU",
final_activation="Sigmoid",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand Down Expand Up @@ -123,6 +124,7 @@ def build_subnet(
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
final_activation=final_activation,
)
)
outputs = {
Expand Down
2 changes: 2 additions & 0 deletions src/leibnetz/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def build_unet(
residual=False,
dropout_prob=None,
activation="ReLU",
final_activation="Sigmoid",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand Down Expand Up @@ -114,6 +115,7 @@ def build_unet(
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
final_activation=final_activation,
)
)

Expand Down
3 changes: 3 additions & 0 deletions src/leibnetz/nodes/conv_pass_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
kernel_sizes,
output_key_channels=None,
activation="ReLU",
final_activation=None,
padding="valid",
residual=False,
padding_mode="reflect",
Expand All @@ -30,6 +31,7 @@ def __init__(
self.output_nc = output_nc
self.kernel_sizes = kernel_sizes
self.activation = activation
self.final_activation = final_activation
self.padding = padding
self.residual = residual
self.padding_mode = padding_mode
Expand All @@ -40,6 +42,7 @@ def __init__(
output_nc,
kernel_sizes,
activation=activation,
final_activation=final_activation,
padding=padding,
residual=residual,
padding_mode=padding_mode,
Expand Down
76 changes: 44 additions & 32 deletions src/leibnetz/nodes/node_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
output_nc,
kernel_sizes,
activation="ReLU",
final_activation=None,
padding="valid",
residual=False,
padding_mode="reflect",
Expand All @@ -23,6 +24,7 @@ def __init__(
output_nc (int): Number of output channels
kernel_sizes (list(int) or array_like): Kernel sizes for convolution layers.
activation (str or callable): Name of activation function in 'nn' or the function itself.
final_activation (str or callable, optional): Name of activation function in 'nn' or the function itself, to be applied to final output values only. Defaults to the same as activation.
padding (str, optional): What type of padding to use in convolutions. Defaults to 'valid'.
residual (bool, optional): Whether to make the blocks calculate the residual. Defaults to False.
padding_mode (str, optional): What values to use in padding (i.e. 'zeros', 'reflect', 'wrap', etc.). Defaults to 'reflect'.
Expand All @@ -42,6 +44,14 @@ def __init__(
else:
self.activation = nn.Identity()

if final_activation is not None:
if isinstance(final_activation, str):
self.final_activation = getattr(nn, final_activation)()
else:
self.final_activation = final_activation() # assume is function
else:
self.final_activation = self.activation

self.residual = residual
self.padding = padding
self.padding_mode = padding_mode
Expand Down Expand Up @@ -84,39 +94,40 @@ def __init__(

self.dims = len(kernel_size)

conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]

try:
layers.append(
conv(
input_nc,
output_nc,
kernel_size,
padding=padding,
padding_mode=padding_mode,
)
conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]
except KeyError:
raise ValueError(
f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D"
)
if residual and i == 0:
if input_nc < output_nc and output_nc % input_nc == 0:
groups = input_nc
elif input_nc % output_nc == 0:
groups = output_nc
else:
groups = 1
self.x_init_map = conv(
input_nc,
output_nc,
np.ones(self.dims, dtype=int),
padding=padding,
padding_mode=padding_mode,
bias=False,
groups=groups,
)
else:
layers.append(self.activation)

except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)
layers.append(
conv(
input_nc,
output_nc,
kernel_size,
padding=padding,
padding_mode=padding_mode,
)
)
if residual and i == 0:
if input_nc < output_nc and output_nc % input_nc == 0:
groups = input_nc
elif input_nc % output_nc == 0:
groups = output_nc
else:
groups = 1
self.x_init_map = conv(
input_nc,
output_nc,
np.ones(self.dims, dtype=int),
padding=padding,
padding_mode=padding_mode,
bias=False,
groups=groups,
)
elif i < len(kernel_sizes) - 1:
layers.append(self.activation)

input_nc = output_nc

Expand All @@ -135,11 +146,12 @@ def crop(self, x, shape):

def forward(self, x):
if not self.residual:
return self.conv_pass(x)
x = self.conv_pass(x)
return self.final_activation(x)
else:
res = self.conv_pass(x)
if self.padding.lower() == "valid":
init_x = self.crop(self.x_init_map(x), res.size()[-self.dims :])
else:
init_x = self.x_init_map(x)
return self.activation(res + init_x)
return self.final_activation(res + init_x)

0 comments on commit 1c80eaf

Please sign in to comment.