Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sjtudyq authored Jul 25, 2023
1 parent 66270c6 commit 67c7c39
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,15 +626,15 @@ class ModelFedCon_noheader(nn.Module):
def __init__(self, base_model, out_dim, n_classes, net_configs=None):
super(ModelFedCon_noheader, self).__init__()

if base_model == "resnet50" or base_model == "resnet":
basemodel = models.resnet50(pretrained=False)
self.features = nn.Sequential(*list(basemodel.children())[:-1])
num_ftrs = basemodel.fc.in_features
elif base_model == "resnet18":
# if base_model == "resnet":
# basemodel = models.resnet50(pretrained=False)
# self.features = nn.Sequential(*list(basemodel.children())[:-1])
# num_ftrs = basemodel.fc.in_features
if base_model == "resnet18":
basemodel = models.resnet18(pretrained=False)
self.features = nn.Sequential(*list(basemodel.children())[:-1])
num_ftrs = basemodel.fc.in_features
elif base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel":
elif base_model == "resnet" or base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel":
basemodel = ResNet50_cifar10()
self.features = nn.Sequential(*list(basemodel.children())[:-1])
num_ftrs = basemodel.fc.in_features
Expand Down

0 comments on commit 67c7c39

Please sign in to comment.