From 111974b14908b45aea1868f35bc0574f55e0c176 Mon Sep 17 00:00:00 2001 From: QinbinLi Date: Wed, 22 Jun 2022 16:37:21 +0800 Subject: [PATCH] add moon --- experiments.py | 529 ++++++++++++++++++++++++++++++++++++------------- model.py | 191 ++++++++++++++++++ 2 files changed, 585 insertions(+), 135 deletions(-) diff --git a/experiments.py b/experiments.py index 2a20136..c21eef7 100644 --- a/experiments.py +++ b/experiments.py @@ -33,7 +33,8 @@ def get_args(): parser.add_argument('--epochs', type=int, default=5, help='number of local epochs') parser.add_argument('--n_parties', type=int, default=2, help='number of workers in a distributed cluster') parser.add_argument('--alg', type=str, default='fedavg', - help='communication strategy: fedavg/fedprox') + help='fl algorithms: fedavg/fedprox/scaffold/fednova/moon') + parser.add_argument('--use_projection_head', type=bool, default=False, help='whether add an additional header to model or not (see MOON)') parser.add_argument('--comm_round', type=int, default=50, help='number of maximum communication roun') parser.add_argument('--is_same_initial', type=int, default=1, help='Whether initial all the models with the same parameters in fedavg') parser.add_argument('--init_seed', type=int, default=0, help="Random seed") @@ -58,59 +59,80 @@ def init_nets(net_configs, dropout_p, n_parties, args): nets = {net_i: None for net_i in range(n_parties)} - for net_i in range(n_parties): - if args.dataset == "generated": - net = PerceptronModel() - elif args.model == "mlp": - if args.dataset == 'covtype': - input_size = 54 - output_size = 2 - hidden_sizes = [32,16,8] - elif args.dataset == 'a9a': - input_size = 123 - output_size = 2 - hidden_sizes = [32,16,8] - elif args.dataset == 'rcv1': - input_size = 47236 - output_size = 2 - hidden_sizes = [32,16,8] - elif args.dataset == 'SUSY': - input_size = 18 - output_size = 2 - hidden_sizes = [16,8] - net = FcNet(input_size, hidden_sizes, output_size, dropout_p) - elif args.model == "vgg": - net = vgg11() - elif args.model == "simple-cnn": - if args.dataset in ("cifar10", "cinic10", "svhn"): - net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10) - elif args.dataset in ("mnist", 'femnist', 'fmnist'): - net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10) - elif args.dataset == 'celeba': - net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2) - elif args.model == "vgg-9": - if args.dataset in ("mnist", 'femnist'): - net = ModerateCNNMNIST() - elif args.dataset in ("cifar10", "cinic10", "svhn"): - # print("in moderate cnn") - net = ModerateCNN() - elif args.dataset == 'celeba': - net = ModerateCNN(output_dim=2) - elif args.model == "resnet": - net = ResNet50_cifar10() - elif args.model == "vgg16": - net = vgg16() + if args.dataset in {'mnist', 'cifar10', 'svhn', 'fmnist'}: + n_classes = 10 + elif args.dataset == 'celeba': + n_classes = 2 + elif args.dataset == 'cifar100': + n_classes = 100 + elif args.dataset == 'tinyimagenet': + n_classes = 200 + elif args.dataset == 'femnist': + n_classes = 62 + elif args.dataset == 'emnist': + n_classes = 47 + if args.use_projection_head: + for net_i in range(n_parties): + net = ModelFedCon(args.model, args.out_dim, n_classes, net_configs) + nets[net_i] = net + else: + if args.alg == 'moon': + for net_i in range(n_parties): + net = ModelFedCon_noheader(args.model, args.out_dim, n_classes, net_configs) + nets[net_i] = net else: - print("not supported yet") - exit(1) - nets[net_i] = net - - model_meta_data = [] - layer_type = [] - for (k, v) in nets[0].state_dict().items(): - model_meta_data.append(v.shape) - layer_type.append(k) - + for net_i in range(n_parties): + if args.dataset == "generated": + net = PerceptronModel() + elif args.model == "mlp": + if args.dataset == 'covtype': + input_size = 54 + output_size = 2 + hidden_sizes = [32,16,8] + elif args.dataset == 'a9a': + input_size = 123 + output_size = 2 + hidden_sizes = [32,16,8] + elif args.dataset == 'rcv1': + input_size = 47236 + output_size = 2 + hidden_sizes = [32,16,8] + elif args.dataset == 'SUSY': + input_size = 18 + output_size = 2 + hidden_sizes = [16,8] + net = FcNet(input_size, hidden_sizes, output_size, dropout_p) + elif args.model == "vgg": + net = vgg11() + elif args.model == "simple-cnn": + if args.dataset in ("cifar10", "cinic10", "svhn"): + net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10) + elif args.dataset in ("mnist", 'femnist', 'fmnist'): + net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10) + elif args.dataset == 'celeba': + net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2) + elif args.model == "vgg-9": + if args.dataset in ("mnist", 'femnist'): + net = ModerateCNNMNIST() + elif args.dataset in ("cifar10", "cinic10", "svhn"): + # print("in moderate cnn") + net = ModerateCNN() + elif args.dataset == 'celeba': + net = ModerateCNN(output_dim=2) + elif args.model == "resnet": + net = ResNet50_cifar10() + elif args.model == "vgg16": + net = vgg16() + else: + print("not supported yet") + exit(1) + nets[net_i] = net + + model_meta_data = [] + layer_type = [] + for (k, v) in nets[0].state_dict().items(): + model_meta_data.append(v.shape) + layer_type.append(k) return nets, model_meta_data, layer_type @@ -261,88 +283,6 @@ def train_net_fedprox(net_id, net, global_net, train_dataloader, test_dataloader logger.info(' ** Training complete **') return train_acc, test_acc -def view_image(train_dataloader): - for (x, target) in train_dataloader: - np.save("img.npy", x) - print(x.shape) - exit(0) - - -def local_train_net(nets, selected, args, net_dataidx_map, test_dl = None, device="cpu"): - avg_acc = 0.0 - - for net_id, net in nets.items(): - if net_id not in selected: - continue - dataidxs = net_dataidx_map[net_id] - - logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) - # move the model to cuda device: - net.to(device) - - noise_level = args.noise - if net_id == args.n_parties - 1: - noise_level = 0 - - if args.noise_type == 'space': - train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) - else: - noise_level = args.noise / (args.n_parties - 1) * net_id - train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) - train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) - n_epoch = args.epochs - - - trainacc, testacc = train_net(net_id, net, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device) - logger.info("net %d final test acc %f" % (net_id, testacc)) - avg_acc += testacc - # saving the trained models here - # save_model(net, net_id, args) - # else: - # load_model(net, net_id, device=device) - avg_acc /= len(selected) - if args.alg == 'local_training': - logger.info("avg test acc %f" % avg_acc) - - nets_list = list(nets.values()) - return nets_list - - -def local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"): - avg_acc = 0.0 - - for net_id, net in nets.items(): - if net_id not in selected: - continue - dataidxs = net_dataidx_map[net_id] - - logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) - # move the model to cuda device: - net.to(device) - - noise_level = args.noise - if net_id == args.n_parties - 1: - noise_level = 0 - - if args.noise_type == 'space': - train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) - else: - noise_level = args.noise / (args.n_parties - 1) * net_id - train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) - train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) - n_epoch = args.epochs - - trainacc, testacc = train_net_fedprox(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, args.mu, device=device) - logger.info("net %d final test acc %f" % (net_id, testacc)) - avg_acc += testacc - avg_acc /= len(selected) - if args.alg == 'local_training': - logger.info("avg test acc %f" % avg_acc) - - nets_list = list(nets.values()) - return nets_list - - def train_net_scaffold(net_id, net, global_model, c_local, c_global, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"): logger.info('Training network %s' % str(net_id)) @@ -487,6 +427,210 @@ def train_net_fednova(net_id, net, global_model, train_dataloader, test_dataload return train_acc, test_acc, a_i, norm_grad +def train_net_moon(net_id, net, global_net, previous_nets, train_dataloader, test_dataloader, epochs, lr, args_optimizer, mu, temperature, args, + round, device="cpu"): + + logger.info('Training network %s' % str(net_id)) + + train_acc = compute_accuracy(net, train_dataloader, device=device) + test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) + + logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) + logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) + + # conloss = ContrastiveLoss(temperature) + + if args_optimizer == 'adam': + optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) + elif args_optimizer == 'amsgrad': + optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, + amsgrad=True) + elif args_optimizer == 'sgd': + optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, + weight_decay=args.reg) + + criterion = nn.CrossEntropyLoss().cuda() + # global_net.to(device) + + if args.loss != 'l2norm': + for previous_net in previous_nets: + previous_net.to(device) + global_w = global_net.state_dict() + # oppsi_nets = copy.deepcopy(previous_nets) + # for net_id, oppsi_net in enumerate(oppsi_nets): + # oppsi_w = oppsi_net.state_dict() + # prev_w = previous_nets[net_id].state_dict() + # for key in oppsi_w: + # oppsi_w[key] = 2*global_w[key] - prev_w[key] + # oppsi_nets.load_state_dict(oppsi_w) + cnt = 0 + cos=torch.nn.CosineSimilarity(dim=-1) + # mu = 0.001 + if args.apply_cosine_lambda: + mu = mu * math.cos(math.pi / 2 * round / args.comm_round) + elif args.apply_sin_lambda: + mu = mu * math.sin(math.pi / 2 * round / args.comm_round) + + for epoch in range(epochs): + epoch_loss_collector = [] + epoch_loss1_collector = [] + epoch_loss2_collector = [] + for batch_idx, (x, target) in enumerate(train_dataloader): + x, target = x.cuda(), target.cuda() + + optimizer.zero_grad() + x.requires_grad = True + target.requires_grad = False + target = target.long() + + _, pro1, out = net(x) + _, pro2, _ = global_net(x) + if args.loss == 'l2norm': + loss2 = mu * torch.mean(torch.norm(pro2-pro1, dim=1)) + + elif args.loss == 'only_contrastive' or args.loss == 'contrastive': + posi = cos(pro1, pro2) + logits = posi.reshape(-1,1) + + for previous_net in previous_nets: + previous_net.cuda() + _, pro3, _ = previous_net(x) + nega = cos(pro1, pro3) + logits = torch.cat((logits, nega.reshape(-1,1)), dim=1) + + if args.extend_nega: + oppsi_net = copy.deepcopy(previous_net) + oppsi_w = oppsi_net.state_dict() + prev_w = previous_net.state_dict() + for key in oppsi_w: + oppsi_w[key] = 2 * global_w[key] - prev_w[key] + oppsi_net.load_state_dict(oppsi_w) + _, pro4, _ = oppsi_net(x) + nega = cos(pro1, pro4) + logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1) + previous_net.to('cpu') + + logits /= temperature + labels = torch.zeros(x.size(0)).cuda().long() + + # loss = criterion(out, target) + mu * ContraLoss(pro1, pro2, pro3) + + loss2 = mu * criterion(logits, labels) + + if args.loss == 'only_contrastive': + loss = loss2 + else: + loss1 = criterion(out, target) + loss = loss1 + loss2 + + loss.backward() + optimizer.step() + + cnt += 1 + epoch_loss_collector.append(loss.item()) + epoch_loss1_collector.append(loss1.item()) + epoch_loss2_collector.append(loss2.item()) + + epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) + epoch_loss1 = sum(epoch_loss1_collector) / len(epoch_loss1_collector) + epoch_loss2 = sum(epoch_loss2_collector) / len(epoch_loss2_collector) + logger.info('Epoch: %d Loss: %f Loss1: %f Loss2: %f' % (epoch, epoch_loss, epoch_loss1, epoch_loss2)) + + + if args.loss != 'l2norm': + for previous_net in previous_nets: + previous_net.to('cpu') + train_acc, _ = compute_accuracy(net, train_dataloader, device=device) + test_acc, conf_matrix, _ = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) + + logger.info('>> Training accuracy: %f' % train_acc) + logger.info('>> Test accuracy: %f' % test_acc) + net.to('cpu') + logger.info(' ** Training complete **') + return train_acc, test_acc + + +def view_image(train_dataloader): + for (x, target) in train_dataloader: + np.save("img.npy", x) + print(x.shape) + exit(0) + + +def local_train_net(nets, selected, args, net_dataidx_map, test_dl = None, device="cpu"): + avg_acc = 0.0 + + for net_id, net in nets.items(): + if net_id not in selected: + continue + dataidxs = net_dataidx_map[net_id] + + logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) + # move the model to cuda device: + net.to(device) + + noise_level = args.noise + if net_id == args.n_parties - 1: + noise_level = 0 + + if args.noise_type == 'space': + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) + else: + noise_level = args.noise / (args.n_parties - 1) * net_id + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) + train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) + n_epoch = args.epochs + + + trainacc, testacc = train_net(net_id, net, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device) + logger.info("net %d final test acc %f" % (net_id, testacc)) + avg_acc += testacc + # saving the trained models here + # save_model(net, net_id, args) + # else: + # load_model(net, net_id, device=device) + avg_acc /= len(selected) + if args.alg == 'local_training': + logger.info("avg test acc %f" % avg_acc) + + nets_list = list(nets.values()) + return nets_list + + +def local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"): + avg_acc = 0.0 + + for net_id, net in nets.items(): + if net_id not in selected: + continue + dataidxs = net_dataidx_map[net_id] + + logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) + # move the model to cuda device: + net.to(device) + + noise_level = args.noise + if net_id == args.n_parties - 1: + noise_level = 0 + + if args.noise_type == 'space': + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) + else: + noise_level = args.noise / (args.n_parties - 1) * net_id + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) + train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) + n_epoch = args.epochs + + trainacc, testacc = train_net_fedprox(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, args.mu, device=device) + logger.info("net %d final test acc %f" % (net_id, testacc)) + avg_acc += testacc + avg_acc /= len(selected) + if args.alg == 'local_training': + logger.info("avg test acc %f" % avg_acc) + + nets_list = list(nets.values()) + return nets_list + def local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, args, net_dataidx_map, test_dl = None, device="cpu"): avg_acc = 0.0 @@ -594,6 +738,46 @@ def local_train_net_fednova(nets, selected, global_model, args, net_dataidx_map, nets_list = list(nets.values()) return nets_list, a_list, d_list, n_list +def local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl=None, global_model = None, prev_model_pool = None, round=None, device="cpu"): + avg_acc = 0.0 + global_model.to(device) + for net_id, net in nets.items(): + if net_id not in selected: + continue + dataidxs = net_dataidx_map[net_id] + + logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) + net.to(device) + + noise_level = args.noise + if net_id == args.n_parties - 1: + noise_level = 0 + + if args.noise_type == 'space': + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) + else: + noise_level = args.noise / (args.n_parties - 1) * net_id + train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) + train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) + n_epoch = args.epochs + + prev_models=[] + for i in range(len(prev_model_pool)): + prev_models.append(prev_model_pool[i][net_id]) + trainacc, testacc = train_net_moon(net_id, net, global_model, prev_models, train_dl_local, test_dl, n_epoch, args.lr, + args.optimizer, args.mu, args.temperature, args, round, device=device) + logger.info("net %d final test acc %f" % (net_id, testacc)) + avg_acc += testacc + + avg_acc /= len(selected) + if args.alg == 'local_training': + logger.info("avg test acc %f" % avg_acc) + global_model.to('cpu') + nets_list = list(nets.values()) + return nets_list + + + def get_partition_dict(dataset, partition, n_parties, init_seed=0, datadir='./data', logdir='./logs', beta=0.5): seed = init_seed np.random.seed(seed) @@ -951,6 +1135,81 @@ def get_partition_dict(dataset, partition, n_parties, init_seed=0, datadir='./da logger.info('>> Global Model Train accuracy: %f' % train_acc) logger.info('>> Global Model Test accuracy: %f' % test_acc) + elif args.alg == 'moon': + logger.info("Initializing nets") + nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) + global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) + global_model = global_models[0] + + global_para = global_model.state_dict() + if args.is_same_initial: + for net_id, net in nets.items(): + net.load_state_dict(global_para) + + old_nets_pool = [] + old_nets = copy.deepcopy(nets) + for _, net in old_nets.items(): + net.eval() + for param in net.parameters(): + param.requires_grad = False + + for round in range(args.comm_round): + logger.info("in comm round:" + str(round)) + + arr = np.arange(args.n_parties) + np.random.shuffle(arr) + selected = arr[:int(args.n_parties * args.sample)] + + global_para = global_model.state_dict() + if round == 0: + if args.is_same_initial: + for idx in selected: + nets[idx].load_state_dict(global_para) + else: + for idx in selected: + nets[idx].load_state_dict(global_para) + + local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl = test_dl_global, global_model=global_model, + prev_model_pool=old_nets_pool, round=round, device=device) + # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device) + + # update global model + total_data_points = sum([len(net_dataidx_map[r]) for r in selected]) + fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected] + + for idx in range(len(selected)): + net_para = nets[selected[idx]].cpu().state_dict() + if idx == 0: + for key in net_para: + global_para[key] = net_para[key] * fed_avg_freqs[idx] + else: + for key in net_para: + global_para[key] += net_para[key] * fed_avg_freqs[idx] + global_model.load_state_dict(global_para) + + logger.info('global n_training: %d' % len(train_dl_global)) + logger.info('global n_test: %d' % len(test_dl_global)) + + + train_acc = compute_accuracy(global_model, train_dl_global) + test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True) + + + logger.info('>> Global Model Train accuracy: %f' % train_acc) + logger.info('>> Global Model Test accuracy: %f' % test_acc) + + old_nets = copy.deepcopy(nets) + for _, net in old_nets.items(): + net.eval() + for param in net.parameters(): + param.requires_grad = False + if len(old_nets_pool) < args.model_buffer_size: + old_nets_pool.append(old_nets) + else: + for i in range(args.model_buffer_size - 2, -1, -1): + old_nets_pool[i] = old_nets_pool[i + 1] + old_nets_pool[0] = old_nets + elif args.alg == 'local_training': logger.info("Initializing nets") nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) diff --git a/model.py b/model.py index aab9e08..a6efe15 100644 --- a/model.py +++ b/model.py @@ -2,6 +2,26 @@ import torch.nn as nn import torch.nn.functional as F import math +import torchvision.models as models +from resnetcifar import ResNet18_cifar10, ResNet50_cifar10 + + +class MLP_header(nn.Module): + def __init__(self,): + super(MLP_header, self).__init__() + self.fc1 = nn.Linear(28*28, 512) + self.fc2 = nn.Linear(512, 512) + self.relu = nn.ReLU() + #projection + # self.fc3 = nn.Linear(512, 10) + + def forward(self, x): + x = x.view(-1, 28*28) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + return x class FcNet(nn.Module): @@ -120,6 +140,30 @@ def forward(self, x): x = self.fc3(x) return x +class SimpleCNN_header(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim=10): + super(SimpleCNN_header, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + + # for now, we hard coded this network + # i.e. we fix the number of hidden layers i.e. 2 layers + self.fc1 = nn.Linear(input_dim, hidden_dims[0]) + self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) + #self.fc3 = nn.Linear(hidden_dims[1], output_dim) + + def forward(self, x): + + x = self.pool(self.relu(self.conv1(x))) + x = self.pool(self.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + # x = self.fc3(x) + return x class SimpleCNN(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim=10): @@ -157,6 +201,30 @@ def forward(self, x): x = self.fc1(x) return x +class SimpleCNNMNIST_header(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim=10): + super(SimpleCNNMNIST_header, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + + # for now, we hard coded this network + # i.e. we fix the number of hidden layers i.e. 2 layers + self.fc1 = nn.Linear(input_dim, hidden_dims[0]) + self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) + #self.fc3 = nn.Linear(hidden_dims[1], output_dim) + + def forward(self, x): + + x = self.pool(self.relu(self.conv1(x))) + x = self.pool(self.relu(self.conv2(x))) + x = x.view(-1, 16 * 4 * 4) + + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + # x = self.fc3(x) + return x class SimpleCNNMNIST(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim=10): @@ -451,3 +519,126 @@ def forward_conv(self, x): return x +class ModelFedCon(nn.Module): + + def __init__(self, base_model, out_dim, n_classes, net_configs=None): + super(ModelFedCon, self).__init__() + + # if base_model == "resnet50": + # basemodel = models.resnet50(pretrained=False) + # self.features = nn.Sequential(*list(basemodel.children())[:-1]) + # num_ftrs = basemodel.fc.in_features + # elif base_model == "resnet18": + # basemodel = models.resnet18(pretrained=False) + # self.features = nn.Sequential(*list(basemodel.children())[:-1]) + # num_ftrs = basemodel.fc.in_features + if base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel" or base_model == "resnet50": + basemodel = ResNet50_cifar10() + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "resnet18-cifar10" or base_model == "resnet18": + basemodel = ResNet18_cifar10() + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "mlp": + self.features = MLP_header() + num_ftrs = 512 + elif base_model == 'simple-cnn': + self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) + num_ftrs = 84 + elif base_model == 'simple-cnn-mnist': + self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) + num_ftrs = 84 + + #summary(self.features.to('cuda:0'), (3,32,32)) + #print("features:", self.features) + # projection MLP + self.l1 = nn.Linear(num_ftrs, num_ftrs) + self.l2 = nn.Linear(num_ftrs, out_dim) + + # last layer + self.l3 = nn.Linear(out_dim, n_classes) + + def _get_basemodel(self, model_name): + try: + model = self.model_dict[model_name] + #print("Feature extractor:", model_name) + return model + except: + raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") + + def forward(self, x): + h = self.features(x) + #print("h before:", h) + #print("h size:", h.size()) + h = h.squeeze() + #print("h after:", h) + x = self.l1(h) + x = F.relu(x) + x = self.l2(x) + + y = self.l3(x) + return h, x, y + + +class ModelFedCon_noheader(nn.Module): + + def __init__(self, base_model, out_dim, n_classes, net_configs=None): + super(ModelFedCon_noheader, self).__init__() + + if base_model == "resnet50": + basemodel = models.resnet50(pretrained=False) + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "resnet18": + basemodel = models.resnet18(pretrained=False) + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel": + basemodel = ResNet50_cifar10() + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "resnet18-cifar10": + basemodel = ResNet18_cifar10() + self.features = nn.Sequential(*list(basemodel.children())[:-1]) + num_ftrs = basemodel.fc.in_features + elif base_model == "mlp": + self.features = MLP_header() + num_ftrs = 512 + elif base_model == 'simple-cnn': + self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) + num_ftrs = 84 + elif base_model == 'simple-cnn-mnist': + self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) + num_ftrs = 84 + + #summary(self.features.to('cuda:0'), (3,32,32)) + #print("features:", self.features) + # projection MLP + # self.l1 = nn.Linear(num_ftrs, num_ftrs) + # self.l2 = nn.Linear(num_ftrs, out_dim) + + # last layer + self.l3 = nn.Linear(num_ftrs, n_classes) + + def _get_basemodel(self, model_name): + try: + model = self.model_dict[model_name] + #print("Feature extractor:", model_name) + return model + except: + raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") + + def forward(self, x): + h = self.features(x) + #print("h before:", h) + #print("h size:", h.size()) + h = h.squeeze() + #print("h after:", h) + # x = self.l1(h) + # x = F.relu(x) + # x = self.l2(x) + + y = self.l3(h) + return h, h, y +