diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index 5d1f33a..522f2d3 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -340,7 +340,7 @@ def mps(self): else: logger.error('Unable to move model to Apple Silicon ("mps")') - def forward(self, inputs: dict[str, torch.Tensor]): + def forward(self, inputs: dict[str, dict[str, Sequence[int | float]]]): # function for forwarding data through the network # inputs is a dictionary of tensors # outputs is a dictionary of tensors diff --git a/src/leibnetz/nets/unet.py b/src/leibnetz/nets/unet.py index 63460a5..4b97294 100644 --- a/src/leibnetz/nets/unet.py +++ b/src/leibnetz/nets/unet.py @@ -12,6 +12,8 @@ def build_unet( output_nc=1, base_nc=12, nc_increase_factor=2, + norm_layer=None, + residual=False, ): # define downsample nodes downsample_factors = np.array(downsample_factors) @@ -28,6 +30,8 @@ def build_unet( base_nc * nc_increase_factor**i, kernel_sizes, identifier=output_key, + norm_layer=norm_layer, + residual=residual, ), ) c += 1 @@ -53,6 +57,8 @@ def build_unet( base_nc * nc_increase_factor ** (i + 1), kernel_sizes, identifier=output_key, + norm_layer=norm_layer, + residual=residual, ) ) input_key = output_key @@ -80,6 +86,8 @@ def build_unet( base_nc * nc_increase_factor**i, kernel_sizes, identifier=output_key, + norm_layer=norm_layer, + residual=residual, ) ) input_key = output_key @@ -94,6 +102,8 @@ def build_unet( # kernel_sizes, [(1,) * len(top_resolution)], identifier="output", + norm_layer=norm_layer, # TODO: remove? + residual=residual, ) ) diff --git a/src/leibnetz/nodes/node_ops.py b/src/leibnetz/nodes/node_ops.py index 32213c7..10215d6 100644 --- a/src/leibnetz/nodes/node_ops.py +++ b/src/leibnetz/nodes/node_ops.py @@ -73,7 +73,7 @@ def __init__( for i, kernel_size in enumerate(kernel_sizes): if norm_layer is not None: - layers.append(norm_layer(output_nc)) + layers.append(norm_layer(input_nc)) layers.append(self.activation) @@ -115,7 +115,7 @@ def __init__( def crop(self, x, shape): """Center-crop x to match spatial dimensions given by shape.""" - x_target_size = x.size()[: -self.dims] + np.array(shape) + x_target_size = x.shape[: -self.dims] + tuple(shape) offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size)) @@ -132,3 +132,4 @@ def forward(self, x): init_x = self.crop(self.x_init_map(x), res.size()[-self.dims :]) else: init_x = self.x_init_map(x) + return res + init_x