Skip to content

Commit

Permalink
Fixes #1797 by adding an init_weights keyword argument to Inception3 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
os-gabe authored Jan 30, 2020
1 parent f2600c2 commit 791c172
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class Inception3(nn.Module):

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None):
inception_blocks=None, init_weights=True):
super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
Expand Down Expand Up @@ -102,19 +102,19 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.fc = nn.Linear(2048, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _transform_input(self, x):
if self.transform_input:
Expand Down

0 comments on commit 791c172

Please sign in to comment.