diff --git a/trailmet/algorithms/binarize/BNNBN.py b/trailmet/algorithms/binarize/BNNBN.py index 009fdea..da09bbc 100644 --- a/trailmet/algorithms/binarize/BNNBN.py +++ b/trailmet/algorithms/binarize/BNNBN.py @@ -42,7 +42,19 @@ def __init__(self, model, dataloaders, **kwargs): self.weight_decay = self.kwargs.get('weight_decay', '0') self.learning_rate = self.kwargs.get('learning_rate', '0.001') + def prepare_dirs(self): + if not os.path.exists('log'): + print('Creating Logging Directory...') + os.mkdir('log') + log_format = '%(asctime)s %(message)s' + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + fh = logging.FileHandler(os.path.join('log/log.txt')) + fh.setFormatter(logging.Formatter(log_format)) + logging.getLogger().addHandler(fh) + def compress_model(self): + self.prepare_dirs() if not torch.cuda.is_available(): sys.exit(1) start_t = time.time() @@ -108,6 +120,7 @@ def compress_model(self): if self.pretrained: print('* loading pretrained weight {}'.format(self.pretrained)) + logging.info(f'loading pretrained weight {self.pretrained}') pretrain_student = torch.load(args.pretrained) if 'state_dict' in pretrain_student.keys(): pretrain_student = pretrain_student['state_dict'] @@ -122,6 +135,7 @@ def compress_model(self): checkpoint_tar = os.path.join(self.save, 'checkpoint.pth.tar') if os.path.exists(checkpoint_tar): print('loading checkpoint {} ..........'.format(checkpoint_tar)) + logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) checkpoint = torch.load(checkpoint_tar) start_epoch = checkpoint['epoch'] best_top1_acc = checkpoint['best_top1_acc'] @@ -129,13 +143,16 @@ def compress_model(self): optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) print("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) + logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) else: raise ValueError('no checkpoint for resume') if self.loss_type == 'kd': if not classes_in_teacher == self.num_classes: self.validate('teacher', self.val_loader, model_teacher, criterion) - + + logging.info('epoch, train accuracy, train loss, val accuracy, val loss') + # train the model epoch = start_epoch while epoch < self.epochs: @@ -150,7 +167,9 @@ def compress_model(self): raise ValueError('unsupport loss_type') valid_obj, valid_top1_acc, valid_top5_acc = self.validate(epoch, self.val_loader, self.model_student, criterion) - + + logging.info("{}, {}, {}, {}, {}".format(epoch, train_top1_acc, train_obj, valid_top1_acc.item(), valid_obj)) + is_best = False if valid_top1_acc > best_top1_acc: best_top1_acc = valid_top1_acc @@ -168,7 +187,9 @@ def compress_model(self): training_time = (time.time() - start_t) / 3600 print('total training time = {} hours'.format(training_time)) + logging.info('total training time = {} hours'.format(training_time)) print('* best acc = {}'.format(best_top1_acc)) + logging.info('* best acc = {}'.format(best_top1_acc)) def train_kd(self, epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): diff --git a/trailmet/algorithms/binarize/ReActNet.py b/trailmet/algorithms/binarize/ReActNet.py index 5fc8820..65233d5 100644 --- a/trailmet/algorithms/binarize/ReActNet.py +++ b/trailmet/algorithms/binarize/ReActNet.py @@ -1,159 +1,37 @@ -from trailmet.algorithms.binarize.binarize import BaseBinarize -from trailmet.models.resnet import ResNetCifar -from trailmet.models.resnet import ResNet +from trailmet.models.resnet_react import ResNet, BasicBlock1, BasicBlock2 +from trailmet.models.mobilenet_1 import reactnet_1 +from trailmet.models.mobilenet_2 import reactnet_2 import torch -import numpy as np +import numpy as np import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from trailmet.algorithms.binarize.utils import * -# import matplotlib.pyplot as plt +sys.path.append('/workspace/code/kushagrabhushan/TrAIL/trailmet/trailmet/algorithms/binarize') +sys.path.append("../../../") +from utils import * +from trailmet.utils import * -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) -class BinaryActivation(nn.Module): - def __init__(self): - super(BinaryActivation, self).__init__() - - def forward(self, x): - out_forward = torch.sign(x) - #out_e1 = (x^2 + 2*x) - #out_e2 = (-x^2 + 2*x) - out_e_total = 0 - mask1 = x < -1 - mask2 = x < 0 - mask3 = x < 1 - out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) - out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) - out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) - out = out_forward.detach() - out3.detach() + out3 - - return out - -class LearnableBias(nn.Module): - def __init__(self, out_chn): - super(LearnableBias, self).__init__() - self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) - - def forward(self, x): - out = x + self.bias.expand_as(x) - return out - -class HardBinaryConv(nn.Module): - def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): - super(HardBinaryConv, self).__init__() - self.stride = stride - self.padding = padding - self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size - self.shape = (out_chn, in_chn, kernel_size, kernel_size) - #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) - self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True) - - def forward(self, x): - #real_weights = self.weights.view(self.shape) - real_weights = self.weight - scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) - #print(scaling_factor, flush=True) - scaling_factor = scaling_factor.detach() - binary_weights_no_grad = scaling_factor * torch.sign(real_weights) - cliped_weights = torch.clamp(real_weights, -1.0, 1.0) - binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights - #print(binary_weights, flush=True) - y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) - - return y - -class BasicBlock1(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock1, self).__init__() - - self.move0 = LearnableBias(inplanes) - self.binary_activation = BinaryActivation() - self.binary_conv = conv3x3(inplanes, planes, stride=stride) - self.bn1 = nn.BatchNorm2d(planes) - self.move1 = LearnableBias(planes) - self.prelu = nn.PReLU(planes) - self.move2 = LearnableBias(planes) - - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.move0(x) - out = self.binary_activation(out) - out = self.binary_conv(out) - out = self.bn1(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.move1(out) - out = self.prelu(out) - out = self.move2(out) - - return out - -class BasicBlock2(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock2, self).__init__() - - self.move0 = LearnableBias(inplanes) - self.binary_activation = BinaryActivation() - self.binary_conv = HardBinaryConv(inplanes, planes, stride=stride) - self.bn1 = nn.BatchNorm2d(planes) - self.move1 = LearnableBias(planes) - self.prelu = nn.PReLU(planes) - self.move2 = LearnableBias(planes) - - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.move0(x) - out = self.binary_activation(out) - out = self.binary_conv(out) - out = self.bn1(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.move1(out) - out = self.prelu(out) - out = self.move2(out) - - return out - -class ReActNet(BaseBinarize): - def __init__(self,teacher,model,dataloaders,**kwargs): - super(ReActNet, self).__init__(**kwargs) +class ReActNet(): + def __init__(self, teacher, model_name, dataloaders, num_fp, **kwargs): self.teacher = teacher - self.model = model - self.layers = model.layers_size - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model_name = model_name + self.num_fp=num_fp + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dataloaders=dataloaders - self.num_classes = model.num_classes - self.insize = model.insize self.kwargs=kwargs + self.dataset = self.kwargs['GENERAL'].get('DATASET', 'c100') + self.num_classes = self.kwargs['GENERAL'].get('num_classes',100) + self.insize = self.kwargs['GENERAL'].get('insize',32) self.batch_size1 = self.kwargs['ReActNet1_ARGS'].get('batch_size',128) self.epochs1 = self.kwargs['ReActNet1_ARGS'].get('epochs',128) self.learning_rate1 = self.kwargs['ReActNet1_ARGS'].get('learning_rate',2.5e-3) self.momentum1 = self.kwargs['ReActNet1_ARGS'].get('momentum',0.9) self.weight_decay1 = self.kwargs['ReActNet1_ARGS'].get('weight_decay',1e-5) self.label_smooth1 = self.kwargs['ReActNet1_ARGS'].get('label_smooth',0.1) + self.save1 = self.kwargs['ReActNet1_ARGS'].get('save','') self.batch_size2 = self.kwargs['ReActNet2_ARGS'].get('batch_size',128) self.epochs2 = self.kwargs['ReActNet2_ARGS'].get('epochs',128) @@ -161,36 +39,57 @@ def __init__(self,teacher,model,dataloaders,**kwargs): self.momentum2 = self.kwargs['ReActNet2_ARGS'].get('momentum',0.9) self.weight_decay2 = self.kwargs['ReActNet2_ARGS'].get('weight_decay',0) self.label_smooth2 = self.kwargs['ReActNet2_ARGS'].get('label_smooth',0.1) + self.save2 = self.kwargs['ReActNet2_ARGS'].get('save','') - def make_model1(self): - self.layers=[i*2 for i in self.layers] - - if (len(self.layers)==3): - new_model=ResNetCifar(BasicBlock1,self.layers,width=1,num_classes=self.num_classes,insize=self.insize) + def prepare_logs(self): + if not os.path.exists('log'): + print('Creating Logging Directory...') + os.mkdir('log') + log_format = '%(asctime)s %(message)s' + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + fh = logging.FileHandler(os.path.join('log/log.txt')) + fh.setFormatter(logging.Formatter(log_format)) + logging.getLogger().addHandler(fh) + + def make_model_actbin(self): + if self.model_name == 'resnet50': + new_model=ResNet(BasicBlock1, [3, 4, 6, 3], self.num_fp, width=1, num_classes=self.num_classes, insize=self.insize) + elif self.model_name == 'mobilenetv2': + new_model = reactnet_1(num_classes=self.num_classes) else: - new_model=ResNet(BasicBlock1,self.layers,width=1,num_classes=self.num_classes,insize=self.insize) - + print("Model Not Avaliable") return new_model - def make_model2(self): - self.layers=[i*2 for i in self.layers] - - if (len(self.layers)==3): - new_model=ResNetCifar(BasicBlock2,self.layers,width=1,num_classes=self.num_classes,insize=self.insize) + def make_model_fullbin(self): + if self.model_name == 'resnet50': + new_model=ResNet(BasicBlock2, [3, 4, 6, 3], self.num_fp, width=1, num_classes=self.num_classes, insize=self.insize) + elif self.model_name == 'mobilenetv2': + new_model = reactnet_2(num_classes=self.num_classes) else: - new_model=ResNet(BasicBlock2,self.layers,width=1,num_classes=self.num_classes,insize=self.insize) - + print("Model Not Avaliable") return new_model def train_one_epoch(self,model,teacher,scheduler,criterion,optimizer): + batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') - - scheduler.step() - + train_loader = self.dataloaders['train'] + + progress = ProgressMeter( + len(train_loader), + [batch_time, losses, top1, top5], + prefix='Train: ') + scheduler.step() + end = time.time() + + for param_group in optimizer.param_groups: + cur_lr = param_group['lr'] + print('learning_rate:', cur_lr) + for i, (images, target) in enumerate(train_loader): images = images.cuda() target = target.cuda() @@ -199,8 +98,7 @@ def train_one_epoch(self,model,teacher,scheduler,criterion,optimizer): logits_student = model(images) logits_teacher = teacher(images) loss = criterion(logits_student, logits_teacher) - - prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) + prec1, prec5 = accuracy(logits_student, target, topk1=(1, 5)) n = images.size(0) losses.update(loss.item(), n) #accumulated loss top1.update(prec1.item(), n) @@ -210,30 +108,50 @@ def train_one_epoch(self,model,teacher,scheduler,criterion,optimizer): optimizer.zero_grad() loss.backward() optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + progress.display(i) return losses.avg, top1.avg, top5.avg - def base_train(self,model,teacher,epochs,criterion,scheduler,optimizer): + def base_train(self,model,teacher,epochs,criterion,scheduler,optimizer,save): model.train() teacher = teacher.eval() for param in teacher.parameters(): - param.requires_grad = False + param.requires_grad = False epoch=0 + best_top1_acc=0 + logging.info("epoch, train acc, train loss, val acc, val loss") while(epochbest_top1_acc): + best_top1_acc=acc + is_best=True + save_checkpoint({ + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'best_top1_acc': best_top1_acc, + 'optimizer' : optimizer.state_dict(), + }, is_best, save, self.dataset) + + print("Top-1 Train Accuracy-{0}\nTop-5 Train Accuracy-{1}\nValidation Accuracy-{2}".format(train_top1_acc,train_top5_acc,acc)) + logging.info("{}, {}, {}, {}, {}".format(epoch, train_top1_acc, train_obj, acc.item(), v_loss)) epoch+=1 return model - def train1(self): + def train_actbin(self): print("Step-1 Training with activations binarized for {} epochs\n".format(self.epochs1)) - model = self.make_model1() + logging.info("Step-1 Training with activations binarized for {} epochs\n".format(self.epochs1)) + model = self.make_model_actbin() model = model.to(self.device) teacher = self.teacher.to(self.device) @@ -253,13 +171,13 @@ def train1(self): lr=self.learning_rate1,) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/self.epochs1), last_epoch=-1) - model = self.base_train(model,teacher,self.epochs1,criterion,scheduler,optimizer) + model = self.base_train(model,teacher,self.epochs1,criterion,scheduler,optimizer,self.save1) return model - def train2(self,model): + def train_fullbin(self,model): print("Step-2 Training with both activations and weights binarized for {} epochs".format(self.epochs2)) - # model = model.to(self.device) + logging.info("Step-2 Training with both activations and weights binarized for {} epochs".format(self.epochs2)) teacher = self.teacher.to(self.device) all_parameters = model.parameters() @@ -278,27 +196,52 @@ def train2(self,model): lr=self.learning_rate2,) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/self.epochs2), last_epoch=-1) - model = self.base_train(model,teacher,self.epochs2,criterion,scheduler,optimizer) + model = self.base_train(model,teacher,self.epochs2,criterion,scheduler,optimizer,self.save2) return model - - def compress_model(self): - model1 = self.train1() - model2 = self.make_model2() - model2 = model2.to(self.device) - model2.load_state_dict(model1.state_dict(),strict=False) + def test(self, model, test_loader, criterion): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(test_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluation mode + model.eval() + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(test_loader): + images = images.to(device=self.device) + target = target.to(device=self.device) + + # compute output + logits = model(images) + loss = criterion(logits, target) + + # measure accuracy and record loss + pred1, pred5 = accuracy(logits, target, topk1=(1, 5)) + n = images.size(0) + losses.update(loss.item(), n) + top1.update(pred1[0], n) + top5.update(pred5[0], n) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + progress.display(i) - fin_output = self.train2(model2) + return losses.avg, top1.avg, top5.avg + + def compress_model(self): + self.prepare_logs() + model_actbin = self.train_actbin() + model_fullbin = self.make_model_fullbin() + model_fullbin = model_fullbin.to(self.device) + model_fullbin.load_state_dict(model_actbin.state_dict(),strict=False) + fin_output = self.train_fullbin(model_fullbin) return fin_output - - - - - - - - - - - diff --git a/trailmet/algorithms/binarize/birealnet.py b/trailmet/algorithms/binarize/birealnet.py index 6a469d0..26daa2f 100644 --- a/trailmet/algorithms/binarize/birealnet.py +++ b/trailmet/algorithms/binarize/birealnet.py @@ -28,7 +28,7 @@ import torch.utils.data.distributed sys.path.append('/workspace/code/kushagrabhushan/TrAIL/trailmet/trailmet/algorithms/binarize') sys.path.append("../../../") -from utils import * +from .utils import * from trailmet.algorithms.binarize.binarize import BaseBinarize from trailmet.utils import * @@ -39,7 +39,7 @@ def __init__(self, model, dataloaders, **CFG): self.dataloaders = dataloaders self.CFG = CFG self.batch_size = self.CFG['batch_size'] - self.optimizer = self.CFG['optimizer'] + self.optimizer = eval(self.CFG['optimizer']) self.epochs = self.CFG['epochs'] self.lr = self.CFG['lr'] self.momentum = self.CFG['momentum'] @@ -96,7 +96,7 @@ def train(self, epoch, train_loader, model, criterion, optimizer, scheduler): loss = criterion(logits, target) # measure accuracy and record loss - prec1, prec5 = accuracy(logits, target, topk1=(1, 5)) + prec1, prec5 = accuracy(logits, target, topk=(1, 5)) n = images.size(0) losses.update(loss.item(), n) #accumulated loss top1.update(prec1.item(), n) @@ -138,7 +138,7 @@ def validate(self, epoch, val_loader, model, criterion, CFG): loss = criterion(logits, target) # measure accuracy and record loss - pred1, pred5 = accuracy(logits, target, topk1=(1, 5)) + pred1, pred5 = accuracy(logits, target, topk=(1, 5)) n = images.size(0) losses.update(loss.item(), n) top1.update(pred1[0], n) @@ -178,7 +178,7 @@ def test(self, epoch, test_loader, model, criterion, CFG): loss = criterion(logits, target) # measure accuracy and record loss - pred1, pred5 = accuracy(logits, target, topk1=(1, 5)) + pred1, pred5 = accuracy(logits, target, topk=(1, 5)) n = images.size(0) losses.update(loss.item(), n) top1.update(pred1[0], n) @@ -207,7 +207,6 @@ def binarize(self): # load model model = self.model - logging.info(model) # model = nn.DataParallel(model).to(device=self.device) model = model.to(device=self.device) criterion = nn.CrossEntropyLoss() @@ -264,11 +263,11 @@ def binarize(self): 'state_dict': model.state_dict(), 'best_top1_acc': best_top1_acc, 'optimizer' : optimizer.state_dict(), - }, is_best, self.save_path, self.dataset) + }, is_best, self.save_path) epoch += 1 - best = torch.load(f"{self.save_path}/{self.dataset}-model_best.pth.tar") + best = torch.load(f"{self.save_path}/model_best.pth.tar") self.model.load_state_dict(best['state_dict']) self.test(epoch, self.dataloaders['test'], self.model, criterion, self.CFG) training_time = (time.time() - start_t) / 36000 diff --git a/trailmet/models/mobilenet_1.py b/trailmet/models/mobilenet_1.py new file mode 100644 index 0000000..d01e64b --- /dev/null +++ b/trailmet/models/mobilenet_1.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + self.bn1 = nn.BatchNorm2d(oup) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + + return out + +class BinaryActivation(nn.Module): + def __init__(self): + super(BinaryActivation, self).__init__() + + def forward(self, x): + out_forward = torch.sign(x) + mask1 = x < -1 + mask2 = x < 0 + mask3 = x < 1 + out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) + out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) + out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) + out = out_forward.detach() - out3.detach() + out3 + + return out + +class LearnableBias(nn.Module): + def __init__(self, out_chn): + super(LearnableBias, self).__init__() + self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) + + def forward(self, x): + out = x + self.bias.expand_as(x) + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride) + self.bn1 = norm_layer(inplanes) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = conv1x1(inplanes, planes) + self.bn2 = norm_layer(planes) + else: + self.binary_pw_down1 = conv1x1(inplanes, inplanes) + self.binary_pw_down2 = conv1x1(inplanes, inplanes) + self.bn2_1 = norm_layer(inplanes) + self.bn2_2 = norm_layer(inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + out1 = self.bn1(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = self.bn2(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = self.bn2_1(out2_1) + out2_2 = self.bn2_2(out2_2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + + +class reactnet_1(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet_1, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/trailmet/models/mobilenet_2.py b/trailmet/models/mobilenet_2.py new file mode 100644 index 0000000..51c9891 --- /dev/null +++ b/trailmet/models/mobilenet_2.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + self.bn1 = nn.BatchNorm2d(oup) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + + return out + +class BinaryActivation(nn.Module): + def __init__(self): + super(BinaryActivation, self).__init__() + + def forward(self, x): + out_forward = torch.sign(x) + mask1 = x < -1 + mask2 = x < 0 + mask3 = x < 1 + out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) + out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) + out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) + out = out_forward.detach() - out3.detach() + out3 + + return out + +class LearnableBias(nn.Module): + def __init__(self, out_chn): + super(LearnableBias, self).__init__() + self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) + + def forward(self, x): + out = x + self.bias.expand_as(x) + return out + +class HardBinaryConv(nn.Module): + def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): + super(HardBinaryConv, self).__init__() + self.stride = stride + self.padding = padding + self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size + self.shape = (out_chn, in_chn, kernel_size, kernel_size) + self.weights = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) + + def forward(self, x): + real_weights = self.weights.view(self.shape) + scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) + #print(scaling_factor, flush=True) + scaling_factor = scaling_factor.detach() + binary_weights_no_grad = scaling_factor * torch.sign(real_weights) + cliped_weights = torch.clamp(real_weights, -1.0, 1.0) + binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights + #print(binary_weights, flush=True) + y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) + + return y + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride) + self.bn1 = norm_layer(inplanes) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = binaryconv1x1(inplanes, planes) + self.bn2 = norm_layer(planes) + else: + self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes) + self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes) + self.bn2_1 = norm_layer(inplanes) + self.bn2_2 = norm_layer(inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + out1 = self.bn1(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = self.bn2(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = self.bn2_1(out2_1) + out2_2 = self.bn2_2(out2_2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + + +class reactnet_2(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet_2, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/trailmet/models/resnet_react.py b/trailmet/models/resnet_react.py new file mode 100644 index 0000000..b69f231 --- /dev/null +++ b/trailmet/models/resnet_react.py @@ -0,0 +1,449 @@ +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from .base_model import BaseModel +import torch + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.activ = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.activ(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.activ(out) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None,binarize=False): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.activ = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.activ(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.activ(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.activ(out) + + return out + +class ResNetCifar(BaseModel): + def __init__(self, block, layers, width=1, num_classes=1000, insize=32): + super(ResNetCifar, self).__init__() + self.inplanes = 16 + self.insize = insize + self.layers_size = layers + self.num_classes = num_classes + self.width = width + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(16) + self.prev_module[self.bn1]=None + self.activ = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, 16 * width, layers[0]) + self.layer2 = self._make_layer(block, 32 * width, layers[1], stride=2) + self.layer3 = self._make_layer(block, 64 * width, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) + self.fc = nn.Linear(64 * width, num_classes) + self.init_weights() + + assert block is BasicBlock + prev = self.bn1 + for l_block in [self.layer1, self.layer2, self.layer3]: + for b in l_block: + self.prev_module[b.bn1] = prev + self.prev_module[b.bn2] = b.bn1 + if b.downsample is not None: + self.prev_module[b.downsample[1]] = prev + prev = b.bn2 + + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + conv_module = nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) + bn_module = nn.BatchNorm2d(planes * block.expansion) + if hasattr(bn_module, 'is_imp'): + bn_module.is_imp = True + downsample = nn.Sequential(conv_module, bn_module) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.activ(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def get_bn_layers(self): + bn_layers = [] + for l_blocks in [self.layer1, self.layer2, self.layer3]: + for b in l_blocks: + m1, m2 = b.bn1, b.bn2 + bn_layers.append([m1, m2]) + return bn_layers + + +class ResNet(BaseModel): + def __init__(self, block, layers,num_fp=10, width=1, num_classes=1000, produce_vectors=False, init_weights=True, insize=32): + super(ResNet, self).__init__() + self.layers_size = layers + self.num_classes = num_classes + self.insize = insize + self.num_fp=num_fp + self.count_fp=0 + self.produce_vectors = produce_vectors + self.block_type = block.__class__.__name__ + self.inplanes = 64 + if insize<128: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.activ = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64 * width, layers[0]) + self.layer2 = self._make_layer(block, 128 * width, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256 * width, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512 * width, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) # Global Avg Pool + self.fc = nn.Linear(512 * block.expansion * width, num_classes) + + self.init_weights() + + for l in [self.layer1, self.layer2, self.layer3, self.layer4]: + for b in l.children(): + downs = next(b.downsample.children()) if b.downsample is not None else None + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + conv_module = nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) + bn_module = nn.BatchNorm2d(planes * block.expansion) + if hasattr(bn_module, 'is_imp'): + bn_module.is_imp = True + downsample = nn.Sequential(conv_module, bn_module) + + layers = [] + if(self.count_fp