Skip to content

Commit

Permalink
add ConvCore for convolutional models only
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxFBurg committed Mar 7, 2024
1 parent a385198 commit e77ccac
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 31 deletions.
56 changes: 29 additions & 27 deletions neuralpredictors/layers/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ class Core(ABC):
Base class for the core models, taking 2d inputs and computing nonlinear features.
"""

def __init__(self) -> None:
super().__init__()
self.set_batchnorm_type()

def initialize(self):
"""
Initialization applied on the core.
Expand All @@ -34,6 +30,35 @@ def init_conv(m):
if m.bias is not None:
m.bias.data.fill_(0)

@abstractmethod
def regularizer(self):
"""
Regularization applied on the core. Returns a scalar value.
"""

@abstractmethod
def forward(self, x):
"""
Forward function for pytorch nn module.
Args:
x (torch.tensor): input of shape (batch, channels, height, width)
"""

def __repr__(self):
s = super().__repr__()
s += f" [{self.__class__.__name__} regularizers: "
ret = []
for attr in filter(lambda x: "gamma" in x or "skip" in x, dir(self)):
ret.append(f"{attr} = {getattr(self, attr)}")
return s + "|".join(ret) + "]\n"


class ConvCore(Core):
def __init__(self) -> None:
super().__init__()
self.set_batchnorm_type()

@abstractmethod
def set_batchnorm_type(self):
"""
Expand Down Expand Up @@ -63,26 +88,3 @@ def add_bn_layer(self, layer: OrderedDict, layer_idx: int):
layer["bias"] = self.bias_layer_cls(hidden_channels)

Check warning on line 88 in neuralpredictors/layers/cores/base.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/cores/base.py#L88

Added line #L88 was not covered by tests
elif not bias and scale:
layer["scale"] = self.scale_layer_cls(hidden_channels)

Check warning on line 90 in neuralpredictors/layers/cores/base.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/cores/base.py#L90

Added line #L90 was not covered by tests

@abstractmethod
def regularizer(self):
"""
Regularization applied on the core. Returns a scalar value.
"""

@abstractmethod
def forward(self, x):
"""
Forward function for pytorch nn module.
Args:
x (torch.tensor): input of shape (batch, channels, height, width)
"""

def __repr__(self):
s = super().__repr__()
s += f" [{self.__class__.__name__} regularizers: "
ret = []
for attr in filter(lambda x: "gamma" in x or "skip" in x, dir(self)):
ret.append(f"{attr} = {getattr(self, attr)}")
return s + "|".join(ret) + "]\n"
4 changes: 2 additions & 2 deletions neuralpredictors/layers/cores/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
RotationEquivariantScale2DLayer,
)
from ..squeeze_excitation import SqueezeExcitationBlock
from .base import Core
from .base import ConvCore, Core

logger = logging.getLogger(__name__)


class Stacked2dCore(Core, nn.Module):
class Stacked2dCore(ConvCore, nn.Module):
"""
An instantiation of the Core base class. Made up of layers layers of nn.sequential modules.
Allows for the flexible implementations of many different architectures, such as convolutional layers,
Expand Down
4 changes: 2 additions & 2 deletions neuralpredictors/layers/cores/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from ...regularizers import DepthLaplaceL21d
from ..affine import Bias3DLayer, Scale3DLayer
from .base import Core
from .base import ConvCore

Check warning on line 14 in neuralpredictors/layers/cores/conv3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/cores/conv3d.py#L14

Added line #L14 was not covered by tests


class Core3d(Core):
class Core3d(ConvCore):

Check warning on line 17 in neuralpredictors/layers/cores/conv3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/cores/conv3d.py#L17

Added line #L17 was not covered by tests
def initialize(self, cuda=False):
self.apply(self.init_conv)
self.put_to_cuda(cuda=cuda)
Expand Down

0 comments on commit e77ccac

Please sign in to comment.