From bea8e36dad9dc75b98fddf7369c1b8eb90607e43 Mon Sep 17 00:00:00 2001 From: heisenbug237 Date: Sun, 14 Aug 2022 17:05:48 +0530 Subject: [PATCH] fixed bugs --- trailmet/models/resnet.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/trailmet/models/resnet.py b/trailmet/models/resnet.py index d18b3b6..ea995b0 100644 --- a/trailmet/models/resnet.py +++ b/trailmet/models/resnet.py @@ -179,16 +179,19 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal for b in l.children(): downs = next(b.downsample.children()) if b.downsample is not None else None - assert block is Bottleneck prev = self.bn1 - for l_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + l_blocks = [self.layer1, self.layer2, self.layer3] + if block is Bottleneck: + l_blocks.append(self.layer4) + for l_block in l_blocks: for b in l_block: self.prev_module[b.bn1] = prev self.prev_module[b.bn2] = b.bn1 - self.prev_module[b.bn3] = b.bn2 + if block is Bottleneck: + self.prev_module[b.bn3] = b.bn2 if b.downsample is not None: self.prev_module[b.downsample[1]] = prev - prev = b.bn3 + prev = b.bn3 if block is Bottleneck else b.bn2 def _make_layer(self, block, planes, blocks, stride=1): downsample = None