diff --git a/inference/models/grasp_model.py b/inference/models/grasp_model.py index 5b4a021e..078b340d 100644 --- a/inference/models/grasp_model.py +++ b/inference/models/grasp_model.py @@ -53,20 +53,15 @@ class ResidualBlock(nn.Module): A residual block with dropout option """ - def __init__(self, in_channels, out_channels, kernel_size=3, dropout=False, prob=0.0): + def __init__(self, in_channels, out_channels, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) self.bn1 = nn.BatchNorm2d(in_channels) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) self.bn2 = nn.BatchNorm2d(in_channels) - self.dropout = dropout - self.dropout1 = nn.Dropout(p=prob) - def forward(self, x_in): x = self.bn1(self.conv1(x_in)) x = F.relu(x) - if self.dropout: - x = self.dropout1(x) x = self.bn2(self.conv2(x)) return x + x_in diff --git a/inference/models/grconvnet2.py b/inference/models/grconvnet2.py index 3e7a3f8f..a08b1d8c 100644 --- a/inference/models/grconvnet2.py +++ b/inference/models/grconvnet2.py @@ -17,11 +17,11 @@ def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout self.conv3 = nn.Conv2d(channel_size * 2, channel_size * 4, kernel_size=4, stride=2, padding=1) self.bn3 = nn.BatchNorm2d(channel_size * 4) - self.res1 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob) - self.res2 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob) - self.res3 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob) - self.res4 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob) - self.res5 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob) + self.res1 = ResidualBlock(channel_size * 4, channel_size * 4) + self.res2 = ResidualBlock(channel_size * 4, channel_size * 4) + self.res3 = ResidualBlock(channel_size * 4, channel_size * 4) + self.res4 = ResidualBlock(channel_size * 4, channel_size * 4) + self.res5 = ResidualBlock(channel_size * 4, channel_size * 4) self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1, output_padding=1)