From 67c7c39325126fe2a88c72ddffb94422847bf643 Mon Sep 17 00:00:00 2001 From: sjtudyq <48618508+sjtudyq@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:31:36 +0800 Subject: [PATCH] Update model.py --- model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index 57aae24..b56af2b 100644 --- a/model.py +++ b/model.py @@ -626,15 +626,15 @@ class ModelFedCon_noheader(nn.Module): def __init__(self, base_model, out_dim, n_classes, net_configs=None): super(ModelFedCon_noheader, self).__init__() - if base_model == "resnet50" or base_model == "resnet": - basemodel = models.resnet50(pretrained=False) - self.features = nn.Sequential(*list(basemodel.children())[:-1]) - num_ftrs = basemodel.fc.in_features - elif base_model == "resnet18": + # if base_model == "resnet": + # basemodel = models.resnet50(pretrained=False) + # self.features = nn.Sequential(*list(basemodel.children())[:-1]) + # num_ftrs = basemodel.fc.in_features + if base_model == "resnet18": basemodel = models.resnet18(pretrained=False) self.features = nn.Sequential(*list(basemodel.children())[:-1]) num_ftrs = basemodel.fc.in_features - elif base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel": + elif base_model == "resnet" or base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel": basemodel = ResNet50_cifar10() self.features = nn.Sequential(*list(basemodel.children())[:-1]) num_ftrs = basemodel.fc.in_features