diff --git a/network/segmentator_3d_asymm_spconv.py b/network/segmentator_3d_asymm_spconv.py index 08c317c..9cfa559 100644 --- a/network/segmentator_3d_asymm_spconv.py +++ b/network/segmentator_3d_asymm_spconv.py @@ -3,7 +3,8 @@ # @file: segmentator_3d_asymm_spconv.py import numpy as np -import spconv +#import spconv +import spconv.pytorch as spconv import torch from torch import nn @@ -49,8 +50,10 @@ def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, ind self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef") self.bn0 = nn.BatchNorm1d(out_filters) self.act1 = nn.LeakyReLU() + + #elf.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") + self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") - self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") self.bn0_2 = nn.BatchNorm1d(out_filters) self.act1_2 = nn.LeakyReLU() @@ -58,7 +61,8 @@ def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, ind self.act2 = nn.LeakyReLU() self.bn1 = nn.BatchNorm1d(out_filters) - self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") + #self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") + self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") self.act3 = nn.LeakyReLU() self.bn2 = nn.BatchNorm1d(out_filters) @@ -72,21 +76,21 @@ def weight_initialization(self): def forward(self, x): shortcut = self.conv1(x) - shortcut.features = self.act1(shortcut.features) - shortcut.features = self.bn0(shortcut.features) + shortcut = shortcut.replace_feature(self.act1(shortcut.features)) + shortcut = shortcut.replace_feature(self.bn0(shortcut.features)) shortcut = self.conv1_2(shortcut) - shortcut.features = self.act1_2(shortcut.features) - shortcut.features = self.bn0_2(shortcut.features) + shortcut = shortcut.replace_feature(self.act1_2(shortcut.features)) + shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features)) resA = self.conv2(x) - resA.features = self.act2(resA.features) - resA.features = self.bn1(resA.features) + resA = resA.replace_feature(self.act2(resA.features)) + reaA = resA.replace_feature(self.bn1(resA.features)) resA = self.conv3(resA) - resA.features = self.act3(resA.features) - resA.features = self.bn2(resA.features) - resA.features = resA.features + shortcut.features + resA = resA.replace_feature(self.act3(resA.features)) + resA = resA.replace_feature(self.bn2(resA.features)) + resA = resA.replace_feature(resA.features + shortcut.features) return resA @@ -102,7 +106,8 @@ def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), self.act1 = nn.LeakyReLU() self.bn0 = nn.BatchNorm1d(out_filters) - self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") + #self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") + self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") self.act1_2 = nn.LeakyReLU() self.bn0_2 = nn.BatchNorm1d(out_filters) @@ -110,7 +115,8 @@ def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), self.act2 = nn.LeakyReLU() self.bn1 = nn.BatchNorm1d(out_filters) - self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") + #self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") + self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") self.act3 = nn.LeakyReLU() self.bn2 = nn.BatchNorm1d(out_filters) @@ -131,22 +137,22 @@ def weight_initialization(self): def forward(self, x): shortcut = self.conv1(x) - shortcut.features = self.act1(shortcut.features) - shortcut.features = self.bn0(shortcut.features) + shortcut = shortcut.replace_feature(self.act1(shortcut.features)) + shortcut = shortcut.replace_feature(self.bn0(shortcut.features)) shortcut = self.conv1_2(shortcut) - shortcut.features = self.act1_2(shortcut.features) - shortcut.features = self.bn0_2(shortcut.features) + shortcut = shortcut.replace_feature(self.act1_2(shortcut.features)) + shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features)) resA = self.conv2(x) - resA.features = self.act2(resA.features) - resA.features = self.bn1(resA.features) + resA = resA.replace_feature(self.act2(resA.features)) + resA = resA.replace_feature(self.bn1(resA.features)) resA = self.conv3(resA) - resA.features = self.act3(resA.features) - resA.features = self.bn2(resA.features) + resA = resA.replace_feature(self.act3(resA.features)) + resA = resA.replace_feature(self.bn2(resA.features)) - resA.features = resA.features + shortcut.features + resA = resA.replace_feature(resA.features + shortcut.features) if self.pooling: resB = self.pool(resA) @@ -167,11 +173,13 @@ def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=No self.act1 = nn.LeakyReLU() self.bn1 = nn.BatchNorm1d(out_filters) - self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key) + #self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") + self.conv2 = conv1x3(out_filters, out_filters, indice_key=indice_key) self.act2 = nn.LeakyReLU() self.bn2 = nn.BatchNorm1d(out_filters) - self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key) + #self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key) + self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key) self.act3 = nn.LeakyReLU() self.bn3 = nn.BatchNorm1d(out_filters) # self.dropout3 = nn.Dropout3d(p=dropout_rate) @@ -189,25 +197,25 @@ def weight_initialization(self): def forward(self, x, skip): upA = self.trans_dilao(x) - upA.features = self.trans_act(upA.features) - upA.features = self.trans_bn(upA.features) + upA = upA.replace_feature(self.trans_act(upA.features)) + upA = upA.replace_feature(self.trans_bn(upA.features)) ## upsample upA = self.up_subm(upA) - upA.features = upA.features + skip.features + upA = upA.replace_feature(upA.features + skip.features) upE = self.conv1(upA) - upE.features = self.act1(upE.features) - upE.features = self.bn1(upE.features) + upE = upE.replace_feature(self.act1(upE.features)) + upE = upE.replace_feature(self.bn1(upE.features)) upE = self.conv2(upE) - upE.features = self.act2(upE.features) - upE.features = self.bn2(upE.features) + upE = upE.replace_feature(self.act2(upE.features)) + upE = upE.replace_feature(self.bn2(upE.features)) upE = self.conv3(upE) - upE.features = self.act3(upE.features) - upE.features = self.bn3(upE.features) + upE = upE.replace_feature(self.act3(upE.features)) + upE = upE.replace_feature(self.bn3(upE.features)) return upE @@ -229,19 +237,19 @@ def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, ind def forward(self, x): shortcut = self.conv1(x) - shortcut.features = self.bn0(shortcut.features) - shortcut.features = self.act1(shortcut.features) + shortcut = shortcut.replace_feature(self.bn0(shortcut.features)) + shortcut = shortcut.replace_feature(self.act1(shortcut.features)) shortcut2 = self.conv1_2(x) - shortcut2.features = self.bn0_2(shortcut2.features) - shortcut2.features = self.act1_2(shortcut2.features) + shortcut2 = shortcut2.replace_feature(self.bn0_2(shortcut2.features)) + shortcut2 = shortcut2.replace_feature(self.act1_2(shortcut2.features)) shortcut3 = self.conv1_3(x) - shortcut3.features = self.bn0_3(shortcut3.features) - shortcut3.features = self.act1_3(shortcut3.features) - shortcut.features = shortcut.features + shortcut2.features + shortcut3.features + shortcut3 = shortcut.replace_feature(self.bn0_3(shortcut3.features)) + shortcut3 = shortcut3.replace_feature(self.act1_3(shortcut3.features)) + shortcut = shortcut.replace_feature(shortcut.features + shortcut2.features + shortcut3.features) - shortcut.features = shortcut.features * x.features + shortcut = shortcut.replace_feature(shortcut.features * x.features) return shortcut @@ -300,7 +308,7 @@ def forward(self, voxel_features, coors, batch_size): up0e = self.ReconNet(up1e) - up0e.features = torch.cat((up0e.features, up1e.features), 1) + up0e = up0e.replace_feature(torch.cat((up0e.features, up1e.features), 1)) logits = self.logits(up0e) y = logits.dense()