diff --git a/clip/model.py b/clip/model.py index 9bfd1f079..ed4121fe7 100644 --- a/clip/model.py +++ b/clip/model.py @@ -136,15 +136,11 @@ def _make_layer(self, planes, blocks, stride=1): return nn.Sequential(*layers) def forward(self, x): - def stem(x): - x = self.relu1(self.bn1(self.conv1(x))) - x = self.relu2(self.bn2(self.conv2(x))) - x = self.relu3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - x = x.type(self.conv1.weight.dtype) - x = stem(x) + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x)