Skip to content

Commit

Permalink
Merge pull request #164 from IBM/improve/cnn_autogen
Browse files Browse the repository at this point in the history
Special decorator for dealing with different CNN types
  • Loading branch information
Joao-L-S-Almeida authored Sep 27, 2023
2 parents a549beb + 52bf903 commit d3f9404
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions simulai/regression/_pytorch/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@

from simulai.templates import ConvNetworkTemplate, as_tensor


def channels_dim(method):
def inside(self, input_data=None):
# Any input must be at least 3D to allow the creation
# of a 'channels dim'. Otherwise use a Unflatten operation
# at the bottom of the model.
if 3 <= len(input_data.shape) < self.n_dimensions:
return method(self, input_data=input_data[:, None, ...])
if self.case == '1d':
if len(input_data.shape) < self.n_dimensions:
return method(self, input_data=input_data[:, None, ...])
else:
return method(self, input_data=input_data)
else:
return method(self, input_data=input_data)

# Any input must be at least 3D to allow the creation
# of a 'channels dim'. Otherwise use a Unflatten operation
# at the bottom of the model.
if 3 <= len(input_data.shape) < self.n_dimensions:
return method(self, input_data=input_data[:, None, ...])
else:
return method(self, input_data=input_data)
return inside


# High-level class for assembling different kinds of convolutional networks
class ConvolutionalNetwork(ConvNetworkTemplate):
name = "conv"
Expand Down Expand Up @@ -125,6 +128,12 @@ def __init__(
) -> None:
super(ResConvolutionalNetwork, self).__init__(name=name)

if case == '1d':
self.channels_dim = channels_dim
else:
self.channels_dim = channels_dim_higher


self.args = ["in_channels", "out_channels", "kernel_size"]

# The operation coming in the sequence of each convolution layer can be
Expand Down

0 comments on commit d3f9404

Please sign in to comment.