Skip to content

Commit

Permalink
batch norms fix for 3d core
Browse files Browse the repository at this point in the history
  • Loading branch information
pollytur committed May 28, 2024
1 parent 504a5a6 commit 0b21f01
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion neuralpredictors/layers/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict

from torch import nn
import warnings


class Core(ABC):
Expand Down Expand Up @@ -77,7 +78,8 @@ def add_bn_layer(self, layer: OrderedDict, layer_idx: int):
raise NotImplementedError(f"Subclasses must have a `{attr}` attribute.")
for attr in ["batch_norm", "hidden_channels", "bias", "batch_norm_scale"]:
if not isinstance(getattr(self, attr), list):
raise ValueError(f"`{attr}` must be a list.")
setattr(self, attr, [getattr(self, attr)]*self.layers)
warnings.warn(f"The {attr} is applied to all layers", UserWarning)

if self.batch_norm[layer_idx]:
hidden_channels = self.hidden_channels[layer_idx]
Expand Down
8 changes: 4 additions & 4 deletions neuralpredictors/layers/cores/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
padding=(0, input_kernel[1] // 2, input_kernel[2] // 2) if self.padding else 0,
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
self.add_bn_layer(layer=layer, layer_idx=0)

if layers > 1 or self.final_nonlinearity:
if hidden_nonlinearities == "adaptive_elu":
Expand All @@ -185,7 +185,7 @@ def __init__(
padding=(0, self.hidden_kernel[l][1] // 2, self.hidden_kernel[l][2] // 2) if self.padding else 0,
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
self.add_bn_layer(layer=layer, layer_idx=l+1)

if self.final_nonlinearity or l < self.layers:
if hidden_nonlinearities == "adaptive_elu":
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(
dilation=(self.temporal_dilation, 1, 1),
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
self.add_bn_layer(layer=layer, layer_idx=0,)

if layers > 1 or final_nonlin:
if hidden_nonlinearities == "adaptive_elu":
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
dilation=(self.hidden_temporal_dilation[l], 1, 1),
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
self.add_bn_layer(layer=layer, layer_idx=l+1)

if final_nonlin or l < self.layers:
if hidden_nonlinearities == "adaptive_elu":
Expand Down

0 comments on commit 0b21f01

Please sign in to comment.