Skip to content

Commit

Permalink
Merge pull request #4 from CBICA/for_fets
Browse files Browse the repository at this point in the history
  • Loading branch information
Geeks-Sid authored Apr 29, 2021
2 parents 9acfe15 + e582cd7 commit feeb408
Show file tree
Hide file tree
Showing 24 changed files with 1,843 additions and 1,077 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ celerybeat-schedule
.env

# virtualenv
venv/
venv*/
ENV/

# Spyder project settings
Expand Down
110 changes: 59 additions & 51 deletions BrainMaGe/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ def __init__(self, n_channels, n_classes, base_filters=16):
self.n_channels = n_channels
self.n_classes = n_classes
self.ins = in_conv(self.n_channels, base_filters)
self.ds_0 = DownsamplingModule(base_filters, base_filters*2)
self.en_1 = EncodingModule(base_filters*2, base_filters*2)
self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4)
self.en_2 = EncodingModule(base_filters*4, base_filters*4)
self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8)
self.en_3 = EncodingModule(base_filters*8, base_filters*8)
self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16)
self.en_4 = EncodingModule(base_filters*16, base_filters*16)
self.us_3 = UpsamplingModule(base_filters*16, base_filters*8)
self.de_3 = DecodingModule(base_filters*16, base_filters*8)
self.us_2 = UpsamplingModule(base_filters*8, base_filters*4)
self.de_2 = DecodingModule(base_filters*8, base_filters*4)
self.us_1 = UpsamplingModule(base_filters*4, base_filters*2)
self.de_1 = DecodingModule(base_filters*4, base_filters*2)
self.us_0 = UpsamplingModule(base_filters*2, 16)
self.out = out_conv(base_filters*2, self.n_classes-1)
self.ds_0 = DownsamplingModule(base_filters, base_filters * 2)
self.en_1 = EncodingModule(base_filters * 2, base_filters * 2)
self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4)
self.en_2 = EncodingModule(base_filters * 4, base_filters * 4)
self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8)
self.en_3 = EncodingModule(base_filters * 8, base_filters * 8)
self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16)
self.en_4 = EncodingModule(base_filters * 16, base_filters * 16)
self.us_3 = UpsamplingModule(base_filters * 16, base_filters * 8)
self.de_3 = DecodingModule(base_filters * 16, base_filters * 8)
self.us_2 = UpsamplingModule(base_filters * 8, base_filters * 4)
self.de_2 = DecodingModule(base_filters * 8, base_filters * 4)
self.us_1 = UpsamplingModule(base_filters * 4, base_filters * 2)
self.de_1 = DecodingModule(base_filters * 4, base_filters * 2)
self.us_0 = UpsamplingModule(base_filters * 2, 16)
self.out = out_conv(base_filters * 2, self.n_classes - 1)

def forward(self, x):
x1 = self.ins(x)
Expand Down Expand Up @@ -64,22 +64,22 @@ def __init__(self, n_channels, n_classes, base_filters=16):
self.n_channels = n_channels
self.n_classes = n_classes
self.ins = in_conv(self.n_channels, base_filters, res=True)
self.ds_0 = DownsamplingModule(base_filters, base_filters*2)
self.en_1 = EncodingModule(base_filters*2, base_filters*2, res=True)
self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4)
self.en_2 = EncodingModule(base_filters*4, base_filters*4, res=True)
self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8)
self.en_3 = EncodingModule(base_filters*8, base_filters*8, res=True)
self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16)
self.en_4 = EncodingModule(base_filters*16, base_filters*16, res=True)
self.us_3 = UpsamplingModule(base_filters*16, base_filters*8)
self.de_3 = DecodingModule(base_filters*16, base_filters*8, res=True)
self.us_2 = UpsamplingModule(base_filters*8, base_filters*4)
self.de_2 = DecodingModule(base_filters*8, base_filters*4, res=True)
self.us_1 = UpsamplingModule(base_filters*4, base_filters*2)
self.de_1 = DecodingModule(base_filters*4, base_filters*2, res=True)
self.us_0 = UpsamplingModule(base_filters*2, base_filters)
self.out = out_conv(base_filters*2, self.n_classes-1, res=True)
self.ds_0 = DownsamplingModule(base_filters, base_filters * 2)
self.en_1 = EncodingModule(base_filters * 2, base_filters * 2, res=True)
self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4)
self.en_2 = EncodingModule(base_filters * 4, base_filters * 4, res=True)
self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8)
self.en_3 = EncodingModule(base_filters * 8, base_filters * 8, res=True)
self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16)
self.en_4 = EncodingModule(base_filters * 16, base_filters * 16, res=True)
self.us_3 = UpsamplingModule(base_filters * 16, base_filters * 8)
self.de_3 = DecodingModule(base_filters * 16, base_filters * 8, res=True)
self.us_2 = UpsamplingModule(base_filters * 8, base_filters * 4)
self.de_2 = DecodingModule(base_filters * 8, base_filters * 4, res=True)
self.us_1 = UpsamplingModule(base_filters * 4, base_filters * 2)
self.de_1 = DecodingModule(base_filters * 4, base_filters * 2, res=True)
self.us_0 = UpsamplingModule(base_filters * 2, base_filters)
self.out = out_conv(base_filters * 2, self.n_classes - 1, res=True)

def forward(self, x):
x1 = self.ins(x)
Expand Down Expand Up @@ -109,21 +109,27 @@ def __init__(self, n_channels, n_classes, base_filters=16):
self.n_channels = n_channels
self.n_classes = n_classes
self.ins = in_conv(self.n_channels, base_filters)
self.ds_0 = DownsamplingModule(base_filters, base_filters*2)
self.en_1 = EncodingModule(base_filters*2, base_filters*2)
self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4)
self.en_2 = EncodingModule(base_filters*4, base_filters*4)
self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8)
self.en_3 = EncodingModule(base_filters*8, base_filters*8)
self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16)
self.en_4 = EncodingModule(base_filters*16, base_filters*16)
self.us_4 = FCNUpsamplingModule(base_filters*16, 1, scale_factor=5)
self.us_3 = FCNUpsamplingModule(base_filters*8, 1, scale_factor=4)
self.us_2 = FCNUpsamplingModule(base_filters*4, 1, scale_factor=3)
self.us_1 = FCNUpsamplingModule(base_filters*2, 1, scale_factor=2)
self.ds_0 = DownsamplingModule(base_filters, base_filters * 2)
self.en_1 = EncodingModule(base_filters * 2, base_filters * 2)
self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4)
self.en_2 = EncodingModule(base_filters * 4, base_filters * 4)
self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8)
self.en_3 = EncodingModule(base_filters * 8, base_filters * 8)
self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16)
self.en_4 = EncodingModule(base_filters * 16, base_filters * 16)
self.us_4 = FCNUpsamplingModule(base_filters * 16, 1, scale_factor=5)
self.us_3 = FCNUpsamplingModule(base_filters * 8, 1, scale_factor=4)
self.us_2 = FCNUpsamplingModule(base_filters * 4, 1, scale_factor=3)
self.us_1 = FCNUpsamplingModule(base_filters * 2, 1, scale_factor=2)
self.us_0 = FCNUpsamplingModule(base_filters, 1, scale_factor=1)
self.conv_0 = nn.Conv3d(in_channels=5, out_channels=self.n_classes-1,
kernel_size=1, stride=1, padding=0, bias=True)
self.conv_0 = nn.Conv3d(
in_channels=5,
out_channels=self.n_classes - 1,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)

def forward(self, x):
x1 = self.ins(x)
Expand All @@ -147,13 +153,15 @@ def forward(self, x):


def fetch_model(modelname, num_channels, num_classes, num_filters):
if modelname == 'resunet':
if modelname == "resunet":
model = resunet(num_channels, num_classes, num_filters)
elif modelname == 'unet':
elif modelname == "unet":
model = resunet(num_channels, num_classes, num_filters)
elif modelname == 'fcn':
elif modelname == "fcn":
model = fcn(num_channels, num_classes, num_filters)
else:
raise ValueError('Check Model spelling, should be one of resunet, unet, fcn in the config'+\
'file!')
raise ValueError(
"Check Model spelling, should be one of resunet, unet, fcn in the config"
+ "file!"
)
return model
Loading

0 comments on commit feeb408

Please sign in to comment.