Skip to content

Commit

Permalink
fix: 🐛 Fix norm_layer and residual block bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 12, 2024
1 parent 565614a commit da4875a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def mps(self):
else:
logger.error('Unable to move model to Apple Silicon ("mps")')

def forward(self, inputs: dict[str, torch.Tensor]):
def forward(self, inputs: dict[str, dict[str, Sequence[int | float]]]):
# function for forwarding data through the network
# inputs is a dictionary of tensors
# outputs is a dictionary of tensors
Expand Down
10 changes: 10 additions & 0 deletions src/leibnetz/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def build_unet(
output_nc=1,
base_nc=12,
nc_increase_factor=2,
norm_layer=None,
residual=False,
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand All @@ -28,6 +30,8 @@ def build_unet(
base_nc * nc_increase_factor**i,
kernel_sizes,
identifier=output_key,
norm_layer=norm_layer,
residual=residual,
),
)
c += 1
Expand All @@ -53,6 +57,8 @@ def build_unet(
base_nc * nc_increase_factor ** (i + 1),
kernel_sizes,
identifier=output_key,
norm_layer=norm_layer,
residual=residual,
)
)
input_key = output_key
Expand Down Expand Up @@ -80,6 +86,8 @@ def build_unet(
base_nc * nc_increase_factor**i,
kernel_sizes,
identifier=output_key,
norm_layer=norm_layer,
residual=residual,
)
)
input_key = output_key
Expand All @@ -94,6 +102,8 @@ def build_unet(
# kernel_sizes,
[(1,) * len(top_resolution)],
identifier="output",
norm_layer=norm_layer, # TODO: remove?
residual=residual,
)
)

Expand Down
5 changes: 3 additions & 2 deletions src/leibnetz/nodes/node_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(

for i, kernel_size in enumerate(kernel_sizes):
if norm_layer is not None:
layers.append(norm_layer(output_nc))
layers.append(norm_layer(input_nc))

layers.append(self.activation)

Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
def crop(self, x, shape):
"""Center-crop x to match spatial dimensions given by shape."""

x_target_size = x.size()[: -self.dims] + np.array(shape)
x_target_size = x.shape[: -self.dims] + tuple(shape)

offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))

Expand All @@ -132,3 +132,4 @@ def forward(self, x):
init_x = self.crop(self.x_init_map(x), res.size()[-self.dims :])
else:
init_x = self.x_init_map(x)
return res + init_x

0 comments on commit da4875a

Please sign in to comment.