diff --git a/models/res16unet.py b/models/res16unet.py index 765fab2..8d672ac 100644 --- a/models/res16unet.py +++ b/models/res16unet.py @@ -224,7 +224,7 @@ def forward(self, x): out = self.bntr4(out) out = self.relu(out) - out = me.cat((out, out_b3p8)) + out = me.cat(out, out_b3p8) out = self.block5(out) # pixel_dist=4 @@ -232,7 +232,7 @@ def forward(self, x): out = self.bntr5(out) out = self.relu(out) - out = me.cat((out, out_b2p4)) + out = me.cat(out, out_b2p4) out = self.block6(out) # pixel_dist=2 @@ -240,7 +240,7 @@ def forward(self, x): out = self.bntr6(out) out = self.relu(out) - out = me.cat((out, out_b1p2)) + out = me.cat(out, out_b1p2) out = self.block7(out) # pixel_dist=1 @@ -248,7 +248,7 @@ def forward(self, x): out = self.bntr7(out) out = self.relu(out) - out = me.cat((out, out_p1)) + out = me.cat(out, out_p1) out = self.block8(out) return self.final(out) diff --git a/models/resunet.py b/models/resunet.py index 6efdef0..6861d72 100644 --- a/models/resunet.py +++ b/models/resunet.py @@ -197,21 +197,21 @@ def forward(self, x): out = self.bntr4(out) out = self.relu(out) - out = me.cat((out, out_b3p4)) + out = me.cat(out, out_b3p4) out = self.block5(out) out = self.convtr5p4s2(out) out = self.bntr5(out) out = self.relu(out) - out = me.cat((out, out_b2p2)) + out = me.cat(out, out_b2p2) out = self.block6(out) out = self.convtr6p2s2(out) out = self.bntr6(out) out = self.relu(out) - out = me.cat((out, out_b1p1)) + out = me.cat(out, out_b1p1) return self.final(out) @@ -460,7 +460,7 @@ def forward(self, x): out = self.bntr4(out) out = self.relu(out) - out = me.cat((out, out_b3p4)) + out = me.cat(out, out_b3p4) out = self.block5(out) out_5 = self.pool_tr5(out) @@ -468,7 +468,7 @@ def forward(self, x): out = self.bntr5(out) out = self.relu(out) - out = me.cat((out, out_b2p2)) + out = me.cat(out, out_b2p2) out = self.block6(out) out_6 = self.pool_tr6(out) @@ -476,7 +476,7 @@ def forward(self, x): out = self.bntr6(out) out = self.relu(out) - out = me.cat((out, out_b1p1, out_6, out_5)) + out = me.cat(out, out_b1p1, out_6, out_5) return self.final(out)