From 3c87940d87750993798778dae9d9fb054d11d5b0 Mon Sep 17 00:00:00 2001 From: z1069614715 <1069614715@qq.com> Date: Sat, 12 Nov 2022 13:28:44 +0800 Subject: [PATCH] 'v1.1' --- Knowledge_Distillation.md | 80 +++++++ README.md | 28 ++- config/__pycache__/config.cpython-38.pyc | Bin 1093 -> 1099 bytes config/config.py | 4 +- config/sgd_config.py | 2 +- main.py | 88 +++---- metrice.py | 17 +- predict.py | 7 +- processing.py | 53 ++--- requirements.txt | 4 +- utils/__pycache__/utils.cpython-38.pyc | Bin 25731 -> 28307 bytes utils/__pycache__/utils_aug.cpython-38.pyc | Bin 5588 -> 6328 bytes utils/__pycache__/utils_aug.cpython-39.pyc | Bin 5586 -> 6330 bytes utils/__pycache__/utils_fit.cpython-38.pyc | Bin 3141 -> 3414 bytes utils/__pycache__/utils_fit.cpython-39.pyc | Bin 3681 -> 3463 bytes utils/__pycache__/utils_loss.cpython-38.pyc | Bin 2731 -> 3812 bytes utils/__pycache__/utils_model.cpython-38.pyc | Bin 2988 -> 2988 bytes utils/utils.py | 227 ++++++++++++------- utils/utils_aug.py | 16 +- utils/utils_fit.py | 123 ++++++---- utils/utils_loss.py | 35 ++- v1.1-update_log.md | 200 ++++++++++++++++ 22 files changed, 656 insertions(+), 228 deletions(-) create mode 100644 v1.1-update_log.md diff --git a/Knowledge_Distillation.md b/Knowledge_Distillation.md index 75ac60e..c9d4f83 100644 --- a/Knowledge_Distillation.md +++ b/Knowledge_Distillation.md @@ -254,6 +254,86 @@ | ghostnet | 0.77709 | 0.77756 | 0.76367 | 0.76277 | 0.78046 | 0.77958 | | teacher->ghostnet
student->ghostnet
AT | 0.78046 | 0.78080 | 0.77142 | 0.77069 | 0.78820 | 0.78742 | +### 在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错. + +普通训练mobilenetv2: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +计算mobilenetv2指标: + + python metrice.py --task val --save_path runs/mobilenetv2 + python metrice.py --task test --save_path runs/mobilenetv2 + python metrice.py --task test --save_path runs/mobilenetv2 --test_tta + +普通训练efficientnet_v2_s: + + python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +计算efficientnet_v2_s指标: + + python metrice.py --task val --save_path runs/efficientnet_v2_s + python metrice.py --task test --save_path runs/efficientnet_v2_s + python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta + +知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd \ + --kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s + +知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd \ + --kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --ema \ + --kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop \ + --kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \ + --kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s + +计算通过efficientnet_v2_s蒸馏mobilenetv2指标: + + python metrice.py --task val --save_path runs/mobilenetv2_ST + python metrice.py --task test --save_path runs/mobilenetv2_ST + python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta + + python metrice.py --task val --save_path runs/mobilenetv2_MGD + python metrice.py --task test --save_path runs/mobilenetv2_MGD + python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta + + python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA + python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA + python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta + + python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP + python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP + python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta + + python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP + python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP + python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta + +| model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 | +| efficientnet_v2_s | 0.84166 | 0.84191 | 0.84460 | 0.84441 | 0.86483 | 0.86484 | +| teacher->efficientnet_v2_s
student->mobilenetv2
ST | 0.76137 | 0.76209 | 0.75161 | 0.75088 | 0.77830 | 0.77715 | +| teacher->efficientnet_v2_s
student->mobilenetv2
MGD | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 | +| teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 | +| teacher->efficientnet_v2_s
student->mobilenetv2
MGD(RDrop) | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 | +| teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA,RDrop) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 | + ## 关于Knowledge Distillation的一些解释 实验解释: diff --git a/README.md b/README.md index d37c73d..1d49d4b 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,8 @@ image classifier implement in pytoch. 3. 热力图可视化. 4. TSNE可视化. 5. 数据集识别情况可视化.(metrice.py文件中--visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析) - 6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,kappa,accuracy) + 6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr) + 7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa) - **丰富的模型库** 1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#3) @@ -136,7 +137,7 @@ image classifier implement in pytoch. 根据metrice选择的指标来进行保存best.pt. - **patience** type: int, default:30 - 早停法中的patience. + 早停法中的patience.(设置为0即为不使用早停法) - **imagenet_meanstd** default:False 是否采用imagenet的均值和方差,False则使用当前训练集的均值和方差. @@ -161,6 +162,9 @@ image classifier implement in pytoch. - **teacher_path** type: string, default: '' 知识蒸馏中老师模型的路径. + - **rdrop** + default: False + 是否采用R-Drop.(不支持知识蒸馏) - **metrice.py** 实现计算指标的主要程序. 参数解释: @@ -443,7 +447,7 @@ image classifier implement in pytoch. 1. 关于尺寸问题. 假如我设置image_size为224,对于训练集,其会把短边先resize到(224+224\*0.1)的大小,然后随机裁剪224大小区域.对于验证集或者测试集,如果不采用test_tta,其会把短边先resize到224大小,然后再进行中心裁剪,如果采用test_tta,其会把短边先resize到(224+224\*0.1)的大小,然后随机裁剪224大小区域. 2. 关于test_tta的问题. - 这里采用的是随机裁剪10次图像进行预测,最后把预测结果求平均,相当于是一个集成预测的结果. + 这里采用的是torchvision中的TenCrop函数,其内部的原理是先(左上,右上,左下,右下,中心)进行裁剪,然后再把这五张图做翻转,作为最终的10张图返回,最后把预测结果求平均,相当于是一个集成预测的结果. 3. 关于数据增强的问题. 本程序采用的是在线数据增强,自带的是torchvision.transforms中支持的数据增强策略(RandAugment,AutoAugment,TrivialAugmentWide,AugMix),在main.py中还有一个mixup的数据增强参数,其与前面的数据增强策略不冲突,假设我使用了AutoAugment+MixUp,那么程序会先对数据做AutoAugment然后再做MixUp.当然(RandAugment,AutoAugment,TrivialAugmentWide,AugMix)这些数据策略可能对某些数据并不合适,所以我们在utils/config.py中定义了一个名为custom_augment参数,这个参数默认为transforms.Compose([]),以下会有一个示例,如何制定自己的自定义数据增强. @@ -452,6 +456,8 @@ image classifier implement in pytoch. transforms.RandomRotation(degrees=20), ]) + 在v1.1版本中更新了支持使用albumentations库的函数,具体可以看一下[Some explanation中的第十七点](#5) + 其实很简单,也就是当列表里面为空的时候,其会使用main.py中的--Augment参数,如果列表不为空的话,其会代替--Augment参数(无论--Augment设置了什么),也就是说两个参数只会生效一个. 当然自定义的数据增强与mixup也不冲突,也就是先做自定义数据增强,然后再做mixup. @@ -508,7 +514,7 @@ image classifier implement in pytoch. def __str__(self): return 'CutOut' - 实现好后,在config/config.py中进行导入,然后添加到自定义数据增强的list中即可. + 使用者可以在utils/utils_aug.py中实现好后,在config/config.py中进行导入,然后添加到自定义数据增强的list中即可. 15. 关于resume的问题. @@ -518,11 +524,20 @@ image classifier implement in pytoch. 默认保存最后的模型last.pt和验证集上精度最高(可以在main.py中的--metrice参数中进行修改)的模型best.pt. +

+ + 17. 关于如何使用albumentations的数据增强问题. + + 我们可以在[albumentations的github](https://github.com/albumentations-team/albumentations)或者[albumentations的官方网站](https://albumentations.ai/docs/api_reference/augmentations/)中找到自己需要的数据增强的名字,比如[RandomGridShuffle](https://github.com/albumentations-team/albumentations#:~:text=%E2%9C%93-,RandomGridShuffle,-%E2%9C%93)的方法,我们可以在config/config.py中进行创建: + Create_Albumentations_From_Name('RandomGridShuffle') + 还有些使用者可能需要修改其默认参数,参数可以在其api文档中找到,我们的函数也是支持修改参数的,比如这个RandomGridShuffle函数有一个grid的参数,具体方法如下: + Create_Albumentations_From_Name('RandomGridShuffle', grid=(3, 3)) + 不止一个参数的话直接也是在后面加即可,但是需要指定其参数的名字. ## TODO - [x] Knowledge Distillation - [ ] EMA -- [ ] R-Drop +- [x] R-Drop - [ ] SWA - [ ] DDP Mode - [ ] Export Model(onnx, tensorrt, torchscript) @@ -530,7 +545,6 @@ image classifier implement in pytoch. - [ ] Accumulation Gradient - [ ] Model Ensembling - [ ] Freeze Training -- [ ] Customize Evaluation Function - [x] Early Stop ## Reference @@ -549,4 +563,4 @@ image classifier implement in pytoch. https://github.com/clovaai/CutMix-PyTorch https://github.com/AberHu/Knowledge-Distillation-Zoo https://github.com/yoshitomo-matsubara/torchdistill - + https://github.com/albumentations-team/albumentations diff --git a/config/__pycache__/config.cpython-38.pyc b/config/__pycache__/config.cpython-38.pyc index d8b31aa452daf0676175698cb7edf123268bb48e..0f17312ddf438585b273cd79e0c4722352ee45bd 100644 GIT binary patch delta 326 zcmX@gahiiSl$V!_0SGc=vy*R57`5XF_soyxU?9{EF}Y~- z1Ko6s8=@c{Bv&K=6e;2Z5rQB&j`;M{lKA|B5{bzZnJ!C-f+WC75ClkF5$ohSW@!8Ml$V!_0SI*d#w2S@SttS z=@(R%6%GK}(*Pc!TJih?BBAhz9NEhx&&D=ETd;w=uF-29Z%oK!nT Vpsr$&Lpd0EK#++~goBZd5daNBS7iVI diff --git a/config/config.py b/config/config.py index 7c9b27d..c443443 100644 --- a/config/config.py +++ b/config/config.py @@ -1,7 +1,7 @@ import torch import torchvision.transforms as transforms from argparse import Namespace -from utils.utils_aug import CutOut +from utils.utils_aug import CutOut, Create_Albumentations_From_Name class Config: lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR @@ -17,6 +17,8 @@ class Config: # transforms.RandomVerticalFlip(p=0.5), # ]), # transforms.RandomRotation(45), + # Create_Albumentations_From_Name('PixelDropout', p=1.0), + # Create_Albumentations_From_Name('RandomGridShuffle', grid=(16, 16)) ]) def _get_opt(self): diff --git a/config/sgd_config.py b/config/sgd_config.py index a1b5c7a..8b7d741 100644 --- a/config/sgd_config.py +++ b/config/sgd_config.py @@ -1,7 +1,7 @@ import torch import torchvision.transforms as transforms from argparse import Namespace -from utils.utils_aug import CutOut +from utils.utils_aug import CutOut, Create_Albumentations_From_Name class Config: lr_scheduler = torch.optim.lr_scheduler.StepLR diff --git a/main.py b/main.py index 9e8e15d..d8bf934 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from utils.utils_model import select_model from utils import utils_aug from utils.utils import save_model, plot_train_batch, WarmUpLR, show_config, setting_optimizer, check_batch_size, \ - plot_log, update_opt, load_weights, get_channels, dict_to_PrettyTable + plot_log, update_opt, load_weights, get_channels, dict_to_PrettyTable, ModelEMA from utils.utils_distill import * from utils.utils_loss import * @@ -74,6 +74,10 @@ def parse_opt(): parser.add_argument('--kd_method', type=str, choices=['SoftTarget', 'MGD', 'SP', 'AT'], default='SoftTarget', help='Knowledge Distillation Method') parser.add_argument('--teacher_path', type=str, default='', help='teacher model path') + # Tricks parameters + parser.add_argument('--rdrop', action="store_true", help='using R-Drop') + parser.add_argument('--ema', action="store_true", help='using EMA(Exponential Moving Average)') + opt = parser.parse_known_args()[0] if opt.resume: opt.resume = True @@ -103,7 +107,7 @@ def parse_opt(): train_dataset = torchvision.datasets.ImageFolder(opt.train_path, transform=train_transform) test_dataset = torchvision.datasets.ImageFolder(opt.val_path, transform=test_transform) if opt.resume: - model = ckpt['model'].to(DEVICE) + model = ckpt['model'].to(DEVICE).float() else: model = select_model(opt.model_name, CLASS_NUM, (opt.image_size, opt.image_size), opt.image_channel, opt.pretrained) @@ -113,40 +117,40 @@ def parse_opt(): batch_size = opt.batch_size if opt.batch_size != -1 else check_batch_size(model, opt.image_size, amp=opt.amp) if opt.class_balance: - class_weight = np.sqrt( - compute_class_weight('balanced', classes=np.unique(train_dataset.targets), y=train_dataset.targets)) + class_weight = np.sqrt(compute_class_weight('balanced', classes=np.unique(train_dataset.targets), y=train_dataset.targets)) else: class_weight = np.ones_like(np.unique(train_dataset.targets)) print('class weight: {}'.format(class_weight)) - # try: - # with open(opt.label_path) as f: - # label = list(map(lambda x: x.strip(), f.readlines())) - # except Exception as e: - # with open(opt.label_path, encoding='gbk') as f: - # label = list(map(lambda x: x.strip(), f.readlines())) - # print(dict_to_PrettyTable({label[i]:class_weight[i] for i in range(len(label))}, 'Class Weight')) train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=opt.workers) test_dataset = torch.utils.data.DataLoader(test_dataset, max(batch_size // (10 if opt.test_tta else 1), 1), shuffle=False, num_workers=(0 if opt.test_tta else opt.workers)) scaler = torch.cuda.amp.GradScaler(enabled=(opt.amp if torch.cuda.is_available() else False)) + ema = None + if opt.ema: + ema = ModelEMA(model) + optimizer = setting_optimizer(opt, model) + lr_scheduler = WarmUpLR(optimizer, opt) if opt.resume: - optimizer = ckpt['optimizer'] - lr_scheduler = ckpt['lr_scheduler'] - loss = ckpt['loss'] + optimizer.load_state_dict(ckpt['optimizer']) + lr_scheduler.load_state_dict(ckpt['lr_scheduler']) + loss = ckpt['loss'].to(DEVICE) scaler.load_state_dict(ckpt['scaler']) + if opt.ema: + ema.ema = ckpt['ema'].to(DEVICE).float() + ema.updates = ckpt['updates'] else: - optimizer = setting_optimizer(opt, model) - lr_scheduler = WarmUpLR(optimizer, opt) loss = eval(opt.loss)(label_smoothing=opt.label_smoothing, weight=torch.from_numpy(class_weight).to(DEVICE).float()) - return opt, model, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, ( + if opt.rdrop: + loss = RDropLoss(loss) + return opt, model, ema, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, ( ckpt['epoch'] if opt.resume else 0), (ckpt['best_metrice'] if opt.resume else None) if __name__ == '__main__': - opt, model, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, begin_epoch, best_metrice = parse_opt() - + opt, model, ema, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, begin_epoch, best_metrice = parse_opt() + if not opt.resume: save_epoch = 0 with open(os.path.join(opt.save_path, 'train.log'), 'w+') as f: @@ -161,36 +165,39 @@ def parse_opt(): if not os.path.exists(os.path.join(opt.teacher_path, 'best.pt')): raise Exception('teacher best.pt not found. please check your --teacher_path folder') teacher_ckpt = torch.load(os.path.join(opt.teacher_path, 'best.pt')) - teacher_model = teacher_ckpt['model'] + teacher_model = teacher_ckpt['model'].float().to(DEVICE).eval() print('found teacher checkpoint from {}, model type:{}\n{}'.format(opt.teacher_path, teacher_model.name, dict_to_PrettyTable(teacher_ckpt['best_metrice'], 'Best Metrice'))) - if opt.kd_method == 'SoftTarget': - kd_loss = SoftTarget().to(DEVICE) - elif opt.kd_method == 'MGD': - kd_loss = MGD(get_channels(model, opt), get_channels(teacher_model, opt)).to(DEVICE) - optimizer.add_param_group({'params': kd_loss.parameters(), 'weight_decay': opt.weight_decay}) - elif opt.kd_method == 'SP': - kd_loss = SP().to(DEVICE) - elif opt.kd_method == 'AT': - kd_loss = AT().to(DEVICE) + if opt.resume: + kd_loss = torch.load(os.path.join(opt.save_path, 'last.pt'))['kd_loss'].to(DEVICE) + else: + if opt.kd_method == 'SoftTarget': + kd_loss = SoftTarget().to(DEVICE) + elif opt.kd_method == 'MGD': + kd_loss = MGD(get_channels(model, opt), get_channels(teacher_model, opt)).to(DEVICE) + optimizer.add_param_group({'params': kd_loss.parameters(), 'weight_decay': opt.weight_decay}) + elif opt.kd_method == 'SP': + kd_loss = SP().to(DEVICE) + elif opt.kd_method == 'AT': + kd_loss = AT().to(DEVICE) print('{} begin train on {}!'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), DEVICE)) for epoch in range(begin_epoch, opt.epoch): - if epoch > (save_epoch + opt.patience): + if epoch > (save_epoch + opt.patience) and opt.patience != 0: print('No Improve from {} to {}, EarlyStopping.'.format(save_epoch + 1, epoch)) break begin = time.time() if opt.kd: - model, metrice = fitting_distill(teacher_model, model, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, '{}/{}'.format(epoch + 1,opt.epoch), opt) + metrice = fitting_distill(teacher_model, model, ema, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, '{}/{}'.format(epoch + 1,opt.epoch), opt) else: - model, metrice = fitting(model, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler,'{}/{}'.format(epoch + 1, opt.epoch), opt) + metrice = fitting(model, ema, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler,'{}/{}'.format(epoch + 1, opt.epoch), opt) - with open(os.path.join(opt.save_path, 'train.log'), 'a+') as f: f.write( '\n{},{:.10f},{}'.format(epoch + 1, optimizer.param_groups[2]['lr'], metrice[1])) + n_lr = optimizer.param_groups[2]['lr'] lr_scheduler.step() if best_metrice is None: @@ -201,7 +208,7 @@ def parse_opt(): save_model( os.path.join(opt.save_path, 'best.pt'), **{ - 'model': model.to('cpu'), + 'model': (deepcopy(ema.ema).to('cpu').half() if opt.ema else deepcopy(model).to('cpu').half()), 'opt': opt, 'best_metrice': best_metrice, } @@ -211,13 +218,16 @@ def parse_opt(): save_model( os.path.join(opt.save_path, 'last.pt'), **{ - 'model': model.to('cpu'), + 'model': deepcopy(model).to('cpu').half(), + 'ema': (deepcopy(ema.ema).to('cpu').half() if opt.ema else None), + 'updates': (ema.updates if opt.ema else None), 'opt': opt, 'epoch': epoch + 1, - 'optimizer' : optimizer, - 'lr_scheduler': lr_scheduler, + 'optimizer' : optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), 'best_metrice': best_metrice, - 'loss': loss, + 'loss': deepcopy(loss).to('cpu'), + 'kd_loss': (deepcopy(kd_loss).to('cpu') if opt.kd else None), 'scaler': scaler.state_dict(), 'best_epoch': save_epoch, } @@ -225,7 +235,7 @@ def parse_opt(): print(dict_to_PrettyTable(metrice[0], '{} epoch:{}/{}, best_epoch:{}, time:{:.2f}s, lr:{:.8f}'.format( datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - epoch + 1, opt.epoch, save_epoch + 1, time.time() - begin, optimizer.param_groups[2]['lr'], + epoch + 1, opt.epoch, save_epoch + 1, time.time() - begin, n_lr, ))) plot_log(opt) diff --git a/metrice.py b/metrice.py index 2df7f76..5314325 100644 --- a/metrice.py +++ b/metrice.py @@ -21,13 +21,14 @@ def parse_opt(): parser.add_argument('--val_path', type=str, default=r'dataset/val', help='val data path') parser.add_argument('--test_path', type=str, default=r'dataset/test', help='test data path') parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label path') - parser.add_argument('--task', type=str, choices=['train', 'val', 'test', 'fps'], default='test', help='train, val, test, fps') + parser.add_argument('--task', type=str, choices=['train', 'val', 'test', 'fps'], default='val', help='train, val, test, fps') parser.add_argument('--workers', type=int, default=4, help='dataloader workers') parser.add_argument('--batch_size', type=int, default=64, help='batch size') - parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log') + parser.add_argument('--save_path', type=str, default=r'runs/mobilenetv2_ST', help='save path for model and log') parser.add_argument('--test_tta', action="store_true", help='using TTA Tricks') parser.add_argument('--visual', action="store_true", help='visual dataset identification') parser.add_argument('--tsne', action="store_true", help='visual tsne') + parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference') opt = parser.parse_known_args()[0] @@ -35,7 +36,7 @@ def parse_opt(): raise Exception('best.pt not found. please check your --save_path folder') ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = ckpt['model'] + model = (ckpt['model'] if opt.half else ckpt['model'].float()) model.to(DEVICE) model.eval() train_opt = ckpt['opt'] @@ -47,6 +48,8 @@ def parse_opt(): if opt.task == 'fps': inputs = torch.rand((opt.batch_size, train_opt.image_channel, train_opt.image_size, train_opt.image_size)).to(DEVICE) + if opt.half: + inputs = inputs.half() warm_up, test_time = 100, 300 fps_arr = [] for i in tqdm.tqdm(range(test_time + warm_up)): @@ -55,7 +58,7 @@ def parse_opt(): if i > warm_up: fps_arr.append(time.time() - since) fps = np.mean(fps_arr) - print('{:.6f} seconds, {:.4f} fps, @batch_size {}'.format(fps, 1 / fps, opt.batch_size)) + print('{:.6f} seconds, {:.2f} fps, @batch_size {}'.format(fps, 1 / fps, opt.batch_size)) sys.exit(0) else: save_path = os.path.join(opt.save_path, opt.task, datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')) @@ -83,7 +86,7 @@ def parse_opt(): model.eval() with torch.no_grad(): for x, y, path in tqdm.tqdm(test_dataset, desc='Test Stage'): - x = x.to(DEVICE) + x = (x.half().to(DEVICE) if opt.half else x.to(DEVICE)) if opt.test_tta: bs, ncrops, c, h, w = x.size() pred = model(x.view(-1, c, h, w)) @@ -93,10 +96,10 @@ def parse_opt(): pred_feature = model.forward_features(x.view(-1, c, h, w)) pred_feature = pred_feature.view(bs, ncrops, -1).mean(1) else: - pred = model(x.float()) + pred = model(x) if opt.tsne: - pred_feature = model.forward_features(x.float()) + pred_feature = model.forward_features(x) pred = torch.softmax(pred, 1) y_true.extend(list(y.cpu().detach().numpy())) y_pred.extend(list(pred.argmax(-1).cpu().detach().numpy())) diff --git a/predict.py b/predict.py index 059cd6e..f39120a 100644 --- a/predict.py +++ b/predict.py @@ -23,6 +23,7 @@ def parse_opt(): parser.add_argument('--test_tta', action="store_true", help='using TTA Tricks') parser.add_argument('--cam_visual', action="store_true", help='visual cam') parser.add_argument('--cam_type', type=str, choices=['GradCAM', 'HiResCAM', 'ScoreCAM', 'GradCAMPlusPlus', 'AblationCAM', 'XGradCAM', 'EigenCAM', 'FullGrad'], default='FullGrad', help='cam type') + parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference') opt = parser.parse_known_args()[0] @@ -30,7 +31,7 @@ def parse_opt(): raise Exception('best.pt not found. please check your --save_path folder') ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = ckpt['model'] + model = (ckpt['model'] if opt.half else ckpt['model'].float()) model.to(DEVICE) model.eval() train_opt = ckpt['opt'] @@ -59,7 +60,7 @@ def parse_opt(): os.makedirs(os.path.join(save_path)) result = [] for file in tqdm.tqdm(os.listdir(opt.source)): - pred, pred_result = predict_single_image(os.path.join(opt.source, file), model, test_transform, DEVICE) + pred, pred_result = predict_single_image(os.path.join(opt.source, file), model, test_transform, DEVICE, half=opt.half) result.append('{},{},{}'.format(os.path.join(opt.source, file), label[pred], pred_result[pred])) plt.figure(figsize=(6, 6)) @@ -77,7 +78,7 @@ def parse_opt(): f.write('img_path,pred_class,pred_class_probability\n') f.write('\n'.join(result)) elif os.path.isfile(opt.source): - pred, pred_result = predict_single_image(opt.source, model, test_transform, DEVICE) + pred, pred_result = predict_single_image(opt.source, model, test_transform, DEVICE, half=opt.half) plt.figure(figsize=(6, 6)) if opt.cam_visual: diff --git a/processing.py b/processing.py index 09f8aff..90e0641 100644 --- a/processing.py +++ b/processing.py @@ -33,40 +33,31 @@ def parse_opt(): if __name__ == '__main__': opt = parse_opt() - with open(opt.label_path, 'w+', encoding='utf-8') as f: - f.write('\n'.join(os.listdir(opt.data_path))) + # with open(opt.label_path, 'w+', encoding='utf-8') as f: + # f.write('\n'.join(os.listdir(opt.data_path))) str_len = len(str(len(os.listdir(opt.data_path)))) - for idx, i in enumerate(os.listdir(opt.data_path)): - os.rename(r'{}/{}'.format(opt.data_path, i), r'{}/{}'.format(opt.data_path, str(idx).zfill(str_len))) + # for idx, i in enumerate(os.listdir(opt.data_path)): + # os.rename(r'{}/{}'.format(opt.data_path, i), r'{}/{}'.format(opt.data_path, str(idx).zfill(str_len))) os.chdir(opt.data_path) - for i in range(len(os.listdir(os.getcwd()))): - base_path = os.path.join(os.getcwd(), str(i).zfill(str_len)) - end_path = base_path.replace('train', 'test') - if not os.path.exists(end_path): - os.makedirs(end_path) - len_arr = os.listdir(base_path) - need_copy = np.random.choice(np.arange(len(len_arr)), int(len(len_arr) * opt.test_size), replace=False) - for j in need_copy: - a = os.path.join(base_path, len_arr[j]) - b = os.path.join(end_path, len_arr[j]) - shutil.copy(a, b) - for j in need_copy: - os.remove(os.path.join(base_path, len_arr[j])) - - for i in range(len(os.listdir(os.getcwd()))): - base_path = os.path.join(os.getcwd(), str(i).zfill(str_len)) - end_path = base_path.replace('train', 'val') - if not os.path.exists(end_path): - os.makedirs(end_path) - len_arr = os.listdir(base_path) - need_copy = np.random.choice(np.arange(len(len_arr)), int(len(len_arr) * opt.val_size), replace=False) - for j in need_copy: - a = os.path.join(base_path, len_arr[j]) - b = os.path.join(end_path, len_arr[j]) - shutil.copy(a, b) - for j in need_copy: - os.remove(os.path.join(base_path, len_arr[j])) + for i in os.listdir(os.getcwd()): + base_path = os.path.join(os.getcwd(), i) + base_arr = os.listdir(base_path) + np.random.shuffle(base_arr) + + val_path = base_path.replace('train', 'val') + if not os.path.exists(val_path): + os.makedirs(val_path) + val_need_copy = base_arr[int(len(base_arr) * (1 - opt.val_size - opt.test_size)):int(len(base_arr) * (1 - opt.test_size))] + for j in val_need_copy: + shutil.copy(os.path.join(base_path, j), os.path.join(val_path, j)) + + test_path = base_path.replace('train', 'test') + if not os.path.exists(test_path): + os.makedirs(test_path) + test_need_copy = base_arr[int(len(base_arr) * (1 - opt.test_size)):] + for j in test_need_copy: + shutil.move(os.path.join(base_path, j), os.path.join(test_path, j)) diff --git a/requirements.txt b/requirements.txt index 19c84bb..53a9071 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ matplotlib prettytable pillow thop -rfconv \ No newline at end of file +rfconv +albumentations +pycm \ No newline at end of file diff --git a/utils/__pycache__/utils.cpython-38.pyc b/utils/__pycache__/utils.cpython-38.pyc index 411dc2cd10c1bb4a4dcb959d1bd4ae5c1e0b71d3..d6793a359eafa1d3e0c3710e0f3982bf5320917a 100644 GIT binary patch delta 12251 zcmb7K33wdEmG0`fdJf5wJeFlyw&mMnY{?e3!NwT*k}eZ`PRqwreRl{Fr?hje@!K$j5g1_f3E@a{-ZmDi!-?-(R&axim3Rm+nE3U_x zBGU=0O0TM6-j{YlFUl(u;U){ygh^{d{EV_{{u9Ym>f7A|>+|Sey0DOmUB~<$)D^ zwmZV?z+t}CtyfQ!`R%uTK);ZW@KYxheGC5t|0KQ_@mKk$_@_@Yy`A5|KZBC3{51b8 zzT0*w{B!)ylM26cSk<@lyZGHG-NEnS_d<9(dpr1j-9dgoA4S_;`~m)X)LzWLz`uy^ zZnXUp|1#Qs8Er2??m=FD2-$~_-NPT|Ujfoy{_p$|d^`E0{0zQb{5Ad`{Hq{$DSw2aCz09@Vo&j>LF;K#?Pa}}_a5jyI2_Qg;Lq@Hqs^84S^gZp zhxi-(JN&z7aTQv8k3YW`O`b>hI{!ZZ0T8a{9)AJfYxs-&hxlI0U*a$0dmVp;{|MhS z|1tjwzSpDqPx;T#`e$fuaP@K}^`?7oU^k1nZwEG?>&++$Gn2D(#X{Q76ou&}kT>%A zN>1xuTS-+LnM|KBGX3tF;2U)*=FQk%EOhr#%TB`)xzRbHmt*cVYpTjgEP6*7REC+u zj%cEWv&)e~$7i41`+}52h8qTLd z(^jchChm8CbVbtruNfDvkX`xUK2&)z+c;*Xts*z`rzbV|(&+nYS{Q3`zf^lk;{$}K z38TPErm)0=$clLcs@$f!uDQhgWQtbqL|-XQBfHamPu*hmTZ-FsRnzDbbx9)l!puhY zto!?!nQ9^{h(U1=fZLO7tCyr7rlPL^ctPq@e9OHzd3@(%R1hPo7#uf+$8zEfm9ezP zdc8o-%1$3!O(|b=(vC24h4f*=$@Dzx=4UNuzkx6sM(><;H>>$Jb@d#;S;;lN-o9XIl51GeVIJi%d?P%LC6#c$ zw&02C0HeP1B5fHb#Gl>Un){bPpDMD~zv(ZX&J}V_I{g>L-MEl%IIVVxCUlAw88vus zpY5dcLe!%9X(pgkD>EwRmKKVrM_JT;Zs99sY9@Y5fEY-{Wh*g@+RPwOO`whdwGjcr z3=*iJEOF>1N`^43bXFAmN_HyXcPNJwpHwq#45ATDtC-ui_}p>^$?-o*Yxr%)&t3^I zsPxj0;iuvkz%PhjXe4}!$|LmU?6!#9oT-SkKj7DMCfJ;RUvax@SX@44+d8Hw&gQdOuTPA2_Xkm{1QJN&V z9sMiz6##>bvpE=3no_DvF{A=Ar6LuSsQ^+TnX1SE`S2X&6k`>6ARqDboK%+9HD$O{ zMoZ!-%lVxql)u69X3$nmPYlLyrTWuNvY0NP1t^0ZR_tnwO<}R|{~9*XFfN^Zv1SU{ zu9W8XEMD5tKosSgBNn1{_J zOuNrAr5+fF$)(xSYZnV|tH;j~-$lc>*`~MIDr}-~uE5QLgJob68uYZCsuc;K_$nd} zZJp=of?=7q7d&PN+lv^Sr$zC&xC>ZbM3|+#kukkcN#qJ;#|s`Y`)zS2)dh|uMFoDm5bv|FlS0162|32$Jd6w|2)HHCk=>JIAoYIS?kOHmXJE^6 z*%qMKw5t^s#(#`y%>CWcnQYVOTMa*BNz#vANZJ&sGN2%S=KivA>#}WVm!)R-YZu7) zYQYO_|h%}p~Eoak< z7K*|$^0@(1b|JR9w=}h`{RMUTD*{t@t~X2SYm%3={ZP3u{wKCF_cu*z%GZL24C%SC zzX%dl_(_fQ0VJ}@Pyo&WG}5rTL%Fu?8g(c*6dDYSVE>M&gF&v21e~A~;sG8!sW`#G z&@dZThuJB1N~K+Tv|8eZMpVKaq;mW-#R)r+-nklzR31L1^2o6Iq57flV0a|TsWjj| zv$r80x?yP54MSoj*mP&k->{Ud6K3%YKGVmkZO$9kVQy?W(_79*Dy&R>YZe9m?-qee-bOV|@r05@2HdObv?R}mn4sz>0WqzxgkPRNf*>w%1w z78|L4y3pOr7R;J`5H&XKBn9fSj;(^7tY;$YYRmp1E^W~kYWk%6@nwtI!|s`7jbTm& zFS0Y_QOI1s*JQ_F_#AwjY}t$J$98bMnOuCG+B`O0{uN{Vf!g_J2;r7gF(dsb~IOKlJ& z+H!s?Il1;vNBD2G*ismQu+Gwco!kaP6Juf8A8r( z1Y5MgYVco;-y-&pt5~X)5>&>kvKmUm(m+>t zx%YK6cGZFB98CDNu_2iKoZ_lSgO^WcsXUd#ik!(wk#hbB@fZe?u%aredsEA*vL4vC z?+QI|VBdZ{u=U_}J-GECB|Fxv*VT@-__cY_ePf3wFT7tKn+Ad%Yg*Q$!S?Og@8gTw zU+*u1zIZo!@*=WkdYVh$85Ie$>7463(%NOi~|49naasO5;GK?z3*XAmybDc}vlfdW)| zW0b--PBojAqbrX#$x_x$M@qG8Ae6CMZWt=35@J3Lb7JUn)N--De|$vu1H}uP$BewM zK1nFz0{~+pI79`~{76GIRRW%tPRlcD`Zl0RzeT}Pt;4UpMje<%>LcChOSokWwd-_9 zywf7SnR=R_Ke2RgM1@QA$W_c&P<3hNuGnuv5{GW5%Lwr>*~kr2y^z7T*H*y zoGH1+DY>~*a^)t^fd;e`%KZio!pLrnFwb14hT8< zd8Z{~ACvCRK&a2@UVT2OXl5Z(&`K{kML&_An3#cB zCGDJWQuC!sl7|KH9`LGbPLUi+qrWW36&G0oq*o^nU}808r}?MkB8YVaq_%FQ6#2$t zEu~0X>+uRxX_z)~l8VV@>QNZU%J`x%m3~rul+10VP7cYhqolZ%+TI0{sp?5q;sllc zp1LI0SV&VrMnKp66#_}u$iv)PoYg8Wj2`w&q|02XZu6_lbiSvV_h&FQ97`#3`b4HGwL&6mXj0Y9V3=YAY>!r6cb5-OQ#|_Mq!-O8aM< zAt3|254)egVClLEN-`y*-j@)3_y=g*E-Z$^mIbLU-jEj!}g zaM7xWed)3_F&0I@pNt!*FJK`wb0f+Y77eK{}T17^W4 z3UMcuTn=#BCnV00Tqd}dikT=ix=*)nUGDeng$f7`<;649_#JTPy60_8m3>`8DAc_; z0^n5%^JrfVDh(k}gNys9LQ3mtsts`&5sAo<%i0a%8&nv?O+d+uV7YYX5Zj6FJdrW6 zriuk$cS(`aG04*n8<``=4Z;1hTGzuw{4uJTVVPD@^pjdO#W9nMyQnnWS>%17=!dsx z1>16K%606;O{?Ve!|%l0r|QJl+|}FWu`}+TZFAT!U31%l`bmjgeDTwC@JPZn$ zzhb_>JLd>_GDuSS6GV9*k&u-*9HeFm+KO_Bp4N@qCppiL6M--;uDWw4#LLt`%^vm` z&tU4HL`>o25ZwAHRWKnwP82Ap;>G3NhupBX_%-$R3e~+v;8g-=36NMPc_P20td#5{ zNG);yaPe;Th`V+7yT*3in;?WKnN4_)kJO-dv7QztmP} zlzIm|PMy=qEUd(oRQP{)Xbyu40$3#?b`-#i^!qoOegt{k<$&)b`pk5W+hSuoAxpkw z&opAM?pV(5L%0mfj-m$8H}*BegCC$LOB`d{;1fV=$JNomrCsb?+GcQ`69xDrR;gHk zIi+~nn?zlL=V<)M`GF)<{oXyfZ(;ZhmAvYHZr}NAsr%D?`E^oFkT`HWeg{#C98sJq z$A0(ZWos@bMeBv3*Rc%|Wt2g883^}%M!H)RPZOvC@T#yM>6WlycOxY1uTwu#?f%zg z4OeQ^%d3PPrPLJ!UZ2y2w#R5TC%A8ZiqB79%i>a}h;+ zQ5GHyij_EF2Tv&loWOAsr}H%;9$Y=5iknaq!4FZ7Q%ru*CaK6&uu>ALe8Z<$H_AXE zIugSPA2deFacYR4Iut1+Co1SB$}e>igOQP{Qxrmp$UsS?%b#zRhB~4BNma{h6!^sB zd)IT8D^T>~`n|^dyx_m0Cawi(g9ni|&F0 zO%XDoBBZ((9%yQpD`tGzc*)ArhJq|)^VkZOku?WtjnnF&OQ@s_39e%eaBamr)m^)M zru%`_QAF}CE^hz|W`Hi>5LHSu2eZ)|z#D+zAmT_PA;dTk$8f^9bkyk5@q>hgw=n+I zV?Zi;#Fh;K-IUH7{jk1>`PybW)5nceOf-TKF_*vxF$5e7aBWK7fsl^dh0?}I&tbyM zM+rx|i%<|!KTB&eL1_VzdyE7)RTNSK-h@iu7jxe{_!d*gj+X9UuGqbDA+@AK&gqKA zT|xy-1jYp_NS}S!J$mH^|E72mYEw)+htezwRC%S?jw8Arp;3zLVOYEIx%Y?LK=kC) z%Yn}!PIsfZbkKZvM^3-wMw-<)tN|I_0)BEGc!A8ZwwMeull%Kqcg>-O1`6@Qsz{N@ zwefqRj8$>{p*G)-SxP-iO@A}`yV?Epp|;iHL)7(zq$UPBAJ{kIHx?=ji}|B-uiC-b zE$;RDuJB4~_Z#;veQ(*ni@b*lh+lsoQcDHks{40zF`n!0HU&ZsraxNghREY6om0KA ze;SfEnpoU&PUoXSkIA`~H_v+fxS7lLIQ}*C2=zxt1^4||FNjZ!;6C^BS1)8Acb~sH z#adi;&Ckn3ELG(nTE;b+7dCMV=h*&fSTZ6v9Nsg)XC);>A$VMO6BlAGjjILHfH9W5 zJw^qxYEsYCC?Ri!eMaqjNq%nr+J&smz4_W3*UDwwLj*}p9Prh~lSty7fn4Y0_=xev zE?#mwu3OL}7mM5=HPGSY{H7~<*OrMy zT*`HtAmce(9%QNUGsHE1INON}@O3X;(|A?gI9s9`@$|||DW4%9%?o2y<5(g#q0RJ( zigR`|!J}9r+c=OH^zSoj7^39#aV=xQNTPY%yHJeueIn~@PidJ)v;p3$1|G^+Vn=w7x8trEp|}F3shb`<>;TDVxlG1PD0&9;4JWczP3|V+9f)qyo8*un$$Wd z0=lmpsGVT5T;X%d(;*TS-Qh^B(2K+%UJ~^}War_;1LsA$lfC3ns&HXwu5djwBwtQZ ziX0it%bLo&1zgMYM#Q>7IG-C8nk$aWp9XoIGe;?tr;}99L&ec`GQnqkbRZJB+n^o zHA^6tSr09)v4I)mqVi{_3l{4lVxye^h*NjX_)eS+>3E8*l1}@FM4~>aOR4?j>})~B z_sam417a`~(*za50s1l#R>eagd;s5Pk)@s~FWw;V6d|yV)M123kK|256#Unb((GME z7}7L~1#y>q*Y!<1p9eu<5t4M0!X>B-!xll5c(vFJ z^rR)9JB(+e{*y-kE8jcfxo3vdzh%2gW92R(Lo`vMMW>uacOP30-8jDI1zWV!bl$yzash}RRCO`ySj z%vj0hxW6zKmC1SY;OD`n_{UJ`ap15(4`9HCkf$Wt2&^YSR$Q7#A-4uCnW+nGbLG8)wi!kNc5w>`EB1Q< z6RaKnKtir<`o^FTmdFyyCD;m<7+sM`u}xYamei73s(P(9vp%UswQzN7b*pG8*K)m7Y5_tQP+ylIkV zyP0p!+;{H1=bn4+x##ai{+taz!z!<-sE8`?cUb98eJp!-9e7o6brgl8Lz+Px9q;^Gik=*S}{;y?pmBA;YVU<>b(>(&0!eIGa0@z3y)n-ya{|2h9G|J=>Ycn|+){&~P`;P>+f z;P-O=1wIBl-Pm&l|6+HLKgb_~QJeTL_?P$zXl>r9@Q3-yn-zX?ST(lrNBEav)>b~w zABErdLhrxuQ_y<~dbh#TSNK=q>8tQm+RnemzYc?T@W=Q!;P*=YIR7U6?&PoWZ}BI9 z`1|;i{3*cf;%E4G_;&%bo4?MV;m<x{D&}S zAAg?z2!8kT7x-!Ty&A^o@U0Pf`gae~?roXRX@}!&mf_$iUA$wx{N3|>2 zoJQ;n=JGi!4K4UPvw2K?6Yqd{T3%A0s={RiuI~kLvb>|dy@A$x1UtS=9ms@uLjI)w z_>Qk2B8nrk`6Rc5$1>s+;+W~N9xsrw(>_`J<+Bah2ZG%Qx% zP-NA*nd7G$PO$3l;#%JW@FsF18s%M$r}JY-HrSKT2wn!@1(B?$Wv#RY z+`Wn@tv{RWva;eAIOsJ5KL?Od1#unB=F`q}=5%5XdR~w(H0@$+{M!qD!|IR~FVqcN zbOlOqjESGg#}{qAG{O|6EJ$UGfCokV7RPESE9Y85d=@dya#70y7Lhw!cI=v9+6&m3 zoEOMiIZsPE?m)rv0#2VTP|RL9XB|)WXB_bg&iD6{1gWPjL+f`lsb0sBFSRtY)AGMt z=I{C&4w~HKU+@uyCH@D17j%x~k0;`unzFqxsi5VM0KG^qFYIJCGicEgLYIdxUB2>H zIP*6Mrm0jti6)iLb@w@$e9p8%7@7Waiu~54D@qGs@2McnI{&9bqXNF$;OjgJpcoKK zipq$>*(fV2J&db7a698*CqJyZYEPhBZBk|{Lm2N*jS_Iw{o)Vu z%Gm>pz=tSweWnt{VY+E%av9e&XEFJ!*?et6?I0m0R4ENp?P2_|2?WL2Oep~>0vTh?zzlC%tVPMI)a&%`R|Xo9b!0S`LU zV|q}m8oy+Ge9?&$S0cFZB)k|}3==PUMoq7ipI*?yj87y!8l!W;J0YHw|BzVNOr}ap ziZrMPvgjO10zW7P(Tbx^%U2T3ZKSHxR?t$(teLWR!zc$gtFRTU0Y0J|^si2CTso%T z4_Hw2N95Z}*Q|aU9=rX zM9`m~LM#X1X&I-}0%=dER|0kN>{)W9(|2?bCsVB9M0aeGh-*-v9dP;3gyDZ#$Vvi9cfRh33hq^t){< znKN^!Af#6C6bS3lHI*MQr!Q`1y3(1N! zbnJsKX&BtgNK}hPM0&!)om??j9&U|*VSjlA=&Rh6YB+>ugz%6fTM`-~DUM^3jB8o)ZZ+F=HcbUSQB9SOiQ zIM-g4g99E741wz14||% zg0|7iO0XGHn$YnhmawL5Hhg}cB)apPMO8H70}v1PStdCyd@2Zw`%lxQTScaW92~8D z%cLkOSskmJrw-1+Yi#bdAUsJ~ZPUif4^_|VsDwo^78z9-2?$P=?zJjPK#mSWY{k}v zvPUNeS6oR-$?PX%A8z_c85-0|T0}%5BC)>pkgirwp zzCt}35Ax7N@D=XS%fVO34Ua^Mp;2}lL`%2`23fZUZewl?pJ2JbuyUKqbsphS9^-Ld zk=BMZH$GfahQhpZq~artS2QU=rt%}oCVOO~XYmf{RVyRDW&>NW~|R3X&dWX^rB+Z@9I^@GwSzb#8r6!^q5HkXJ(l zI0Lic+XHvj4MmEPkwzc{qj1Itlr_qJPl6>V=Zdk-vLO2)*U)F1QGj)0%wON zfsKF>fhW|&vXP)8xDXe-2re{%0|Qi^t4-iQ)v#SkgC>rucazyogvCPGs?Xl7*cwKF z7FK*10SW>ly@@z_(m;;|4c!lk3rQpHw;-U(3)iTz`oVwtmd5mP;w?NGyrV7Tp2RUA z{4ji-8o+6~dbUypJ42uNU(~ky#9of7ulhB$OAUIR%6 z>JzD?OZlB|K<|bD&UzT2K!B#RID8hdDEQ|}wOZ9!BY2{xEnb%pN=0v=0xtV?_ld6{N%qvdcZUvR&MocLZ3+&PAe zcZ|=z`~+*oDRA@-Qf9p#J2st2WFr&>tp0-7^n`q8IM#~`-Zi9CN&p1NX7=@Zx&*@u~j zg~>mKhZ6fKYkB|{^+K?Zj4l2ivAB*;Fnsmi0|k*Uh;|z+2koYvR)u|&^s@E^xm=2KrDGc)6p<2zvNeHnvRnT#V1nycq zDBs%LvU-vVwMgkAAPviO26Cxakp>1YiW&uV5|84-Pmk}|QW;=|?A&&U-5|fbtyTBW z4l<<~JJz<+B8*~BI15q$bs#g1Ux6z8E|c%uKDZqMJ*5L;uu-H=C!%#)IVUf~KO&|P zz?e_yV);E$Vp7|&b+Ny)7s^3imlfZ{$&>rTtDPQ9~Wyw4k1PG0QHv0}$T8*(PelDZ~Zx1;=CH zb+~m5La7!?^1yL8*eQn_190P{g@EHNjC8**}i+nnm@ra9Itgg)d_B_2C96TCk@pTyJX%G~2XZpntaDtkKBjED;dn&Fyjn7Pok0Z&I z_zoChg2 zSS2rrMi8o*`P?x}xZ;O6v|QvtDS`kqS@R;L=a^|?uIXB~X>OEXc>k)M+mVlHx#fQ# z?oSsoNJ1gydK)^{Kw0{H%pTAK;_dMT`_{2w0=ndit9Q!>_Sdzyz{6OGY?3e6-{4N} z1gysf#qaQ8k|^rjPWkr!J4)|Bi&Ay*XKY<$+*PA6picxyj#e*Sf6v!aA{)ziR?!+nGLCsr98zrcmJjb2-b zZ(}PCz^j0Zhc6&6d+}|87qQT$3%I%vroatGSN>P*rLc^O?X!`$DHV4N zk0wj*{)sHdF?0kmFWhY<-M;hLFiHOqcAt@-y|(SrX%QKw2Yy5rD+}Ira@Jn?=WCah zf^ha1l-a|uGFr-Q@OAD3;3^|fIv-_2LCE2Q!)ydH{ZR%vK4gAFA}j^af4szX|?ZW7NaH$W_; zjoL_jR2gPPtsCE*g*$vTI18NmHFmPaXr`$Qny8q(@Lr@Z=N#>`EC^fBWx_eF!&%tu zLhvZgBBu~`^4%^LsVRIWYnp@+U_MCJ_sibv=BR&U@{a4~Y$D|*K+0WL)@rbixpdY7 ztC37wDZFA>Bv2@@tz@eqoULObs>-7)Hp|2ZTdvvyo%oW%R6x#($%;@}F+t)0A^!## zbVHbF!`4|bm}V|y*&~rrNSbcPk$xpb?Z9YD-ub~xN-^Kmhhb%*BZwpLUsPG(3m}Do zf;LrrDNLC3_7Uu$s1S?+*7gGq_Yq%i&hzN(A>xwIsR3ct<^}$a&XWORJCv zJh8E|b*(`}1i^#^DL)#LC$3*xLD#n$dZ9e_g#5$xvy(9(>t6?ACxkQUENP4N*h>zO z4%H8D5J&5hpNxTxM=zR#FP*LzNF8gZ$|Z)gp0p1(He#_3&YJd42TrC9&ikJZw)>$- zEiR1bEwC1IT>k1{`*w%|FJ}F5z!X7E?lT57Z1|Gx#d@NMsq%%|IpgM`#f*JO-g?7M zZ4pj>O}==;o)Wy&g%ro4OkR8%@nniGH?rtcS5mNR0uaSSNFo&FSx?}?45EGcGBgtIO$~s>} zX0YU@Kx#V|(o9H2T7a#4mw-}J-2wfRe<_&w5fYwM5p?6HwXXhiT0x`}Uf{3M-7UJ9&;hh#HvEIoo;#fhu9e@%fqu^6zx)d|2SSSPsEF9I$6 zeHt)mTJTxCjNmC8a~(Le-mH}rIeNqI3$g*b>A=TnAcAWV)B|--r+;18(3ZO`ru`*T}V-GjNx0-JtE)eS-RVZ zf%$da0eUS(We_iz$`=L%+9N1yi8SQFki_7Dc!7Pp_K|DGvy-_@cRtHST;_Wl z7U2tLfp%M9=}O=$PB8-kmTm-=vP7L6>s_j$<16e0Yme6Jb$X(qE*h(x mtJmsU5C1Q*6{t)A diff --git a/utils/__pycache__/utils_aug.cpython-38.pyc b/utils/__pycache__/utils_aug.cpython-38.pyc index beb69354819a70d4d3a1ab02dea006007ca90b0e..466a95f9141d65bd8e0322122c9e669d95ff232f 100644 GIT binary patch delta 1016 zcma)4&rcIU6rQ)cUA7A?+jd)8APD#aH&r4Wybu!z31UckF~+0^*Oj*rpe<&0iw27o zOEAVjgiSQwAS6a22TVN>kNO`NPraFV@#5d$n}Q~uaFX}EA3Ja6n{VH@?9ov4RwSY` zT5t4)>AuWXv>Oh#vZGaoP|)+Os9+`#_3BO6KGN;ah+ijQ)>)Uu33}wNy0myKYT6A zpaql4C+Monp>2TP`YV-7FdT%#M=-hfISepW|Ds=$zzZO<1YZ4NqzB-3{dKeqPzxeP ziLhcE!sFm_>=$g-yW>Lu`#~=85w-)_yxY4Eyg--o%wOoFsL^8Ep3iN2FQ}OV)bQTS zs#URNHF%lqsPb8gH^@AdKm%RTV$Iq#rT;96vQaYc9D&R7fGjT|phT+3D5FB2^hK3? z2wO-6t#M|YyN>VLV}*O=qFeGEe{Qkl*%N%RXy=@wJ6#p0{#h~w3fUBy*Go`=3bUCH z2s=^*cA*9{tRiiSs*H?okaEk)*x;iTOKB)2Ic`eeDM_W!IgNb$zLcpE?O@U_-qFR*A?)k9oc8| z#CBe(&>J>Ls|nTmv+Vlaf-N^g~&fe$AB!FbZy zY43%#2)4d}g_T5b6_wvqfF~9q5^_*W!F62=dsL3Jom2}FU9wBM+ n5@93qFk2+yA+zBN)G)Bdiec0-NIHTXL!_ga!ww9uK{19AjFKmgF^rQ=U^gZ( zNhe86VVXRt2MjZqU1OMa(yMAQL;FtbqJ39;pGA1W%${Geh=K{Nz@v&i^Yvs&1@0HL zW9Vsy0^0!n&5ycJ;IF50iAj*_fy`iDo)k2z1jJyYfXsxidC>H&QwE2Jzpvr8p$L{J4+qv%wirO>?(EG zgH@Pjv#hRsf=zj%BBLxyp{9)uFD;u|ORG}MmgX$lWf7+!t&k_|syjU|DyDkM&O7PW z=U&-j#wDumof9Ux5*30kB(AK2fey8smo(rHC$sGyBzxoX=`8|}?Sia#89ZfI^;n6j zEopV?V}tn0(RnF=>V~o>XGLp|$t;bDBRb}RCtO9c+$vnm4}TX9*gJx0Ah~qLI3~& delta 359 zcmY+8yGjFL5QX>T`e(Dd*_*i`39(t0KwvFwL=Xi77B-?_Y!pWY3yT#+0xHBbf)>ed zW2;XPHie|KwGwL!^#ObaXYCxA@64Rxz)XVorm?H*s>1blyuBBc2F5h*2hGK;6s7x0 zN2Ls9S@R?)$Es0@@}%|A(MczR9fJyFvKFXJ7TJ6;$sv~=tF6!!75fSm_wDOs$04jM zer5D=lSEb?q$a3Fjr0(xN8?PGKp?%`Guls|Ibb#VQWp{kWnB*uMnk;@Y(^vF6zECM zJj(a@<8Xt~ZOOoF$yf6n7joJ9!Cf?K2Y@fh8Dk(T?hemgw}wag;Cej2xXq=@IDQqQ z@*mZBJ!IT+pJW&UYD!pto{zA>j*m6pmEWZmIbWWWXXTk%oHC(7aBB(Sh_tZPqDTr& H^+e$puK`Zp diff --git a/utils/__pycache__/utils_fit.cpython-38.pyc b/utils/__pycache__/utils_fit.cpython-38.pyc index b13f168673eeeaa692f0c0156c2b5c51bc379087..c00582e604daca74871ad891d6f2cc6ea4e9dd9b 100644 GIT binary patch literal 3414 zcmbVP&5t8T74PZ~yWMtsJRZ-d=X;kB1KC+N2@nWLmYq!q$*w|X0?3-h>g{UJbdTNb ztuAMGLiK47%4M^2=F~n$LWsYBBSJZG;8X{M1QHBqPJ4m(%HA2HNQlrQ|EgZqdsWr1 zs(!zo4_w#L;Q8nKvH#+Nru_wt=|>0S4fv%8092!-r?tl=X%l7ZZ39ZZXZEeOMKo$q zvsdUB+r_@!wv}!5N`0s8Kpo@D?J_OU;?K2qh1#?PrAx`4)^g5a2JLQ(NSg-1z)uDz z=LA|>x^(*Cqv4=KJ?_bwaouMj4_KJQ;P37;FN`~P0?tA|@XI*Pqlg>z~gm{~QO$!giG z&_os1QJ<7^d{*e9F6It(F(YbX_SDK6&{sj9c>P&zQqEoYx~F>9i0QML&j`CdJ-4#| z%j}a`KR+X`OUQM!u{{!mbLEF9{8MA3z<}`fEX=m25t?VYiwiEt}`Xe08Kf`77)j z?w3Wa>HWI*`~PFNx@x!9&UDs>U0xq)2UfO_EoMvEa<(Gop>2Z}#FAJRDWBu!!Ac!%vt(}}cF-|Gd^zVG?R z4?RYuneqU}8av3+M*&MZhs>kW2>PBhgAcu)tcGznU_tB$oqj?gn!~GgI1Fs2;YWcq zKMaG1((DIbERDl}x0l^KI)9qW1Ms#fhV2& z(&%}Rp#ckM+d{k{Ne}1^7uQkmrr%FxA@(7+sU*H6N0L1JRAb+Vg+G0B=P2n1J4aEp zvm2zxJQ?f^PWVw0cLpqp0-vLaTlRwA^U^fzh5_3d^01eVvC|FtwE?UPN)C&*VbJua z02;C3Z=I~`by6Tpq@nBvadd32>8q-xMjXQ?4WmXXdR2Fnj|uexvBA$Ft18|mb?|OL zJNW7iXk$Ns8GIM#nLJmZykO;z;fw+UJY!`W?vSf)c|oyq`GS@8d+)(L2i_D_U}9ig z#mY02`k0lW_kxux$jWEr$tQ5vDp6T6@@Kpfxp{TM$e(c-3wXAo7#Y|&t7mh9h#IwM z;jw;TWQ}Y-TL2EcY*7pWOl|-d%py|*L$5##vUoGThPBT{U$&@d6*~iry#l-o+(L`T z1z^3pFk#GPXkU##p7hwN=OyTA(9*I7>;tGgx)Tpac^oN^vhn~t1x?#m8LkL zCzethXQDDzzM^_`P#Zm87D}%`ZFCs>7ebGk*&3(~L6ynYKyB8h)TT&hL2WM1p;DVK zcnY7O+Ek+TGlSiL)4QS6Wb$j*D(f9+-2KzZdEQ-ZXuU-l}k5L`BOHkjsD_ue^vGM=OGpnY?{Ww3X^ zV0RGSMHu(B%`kn3vI@3T0a{mB4eO{(Y!;!8Fo)1Um`8wbRE;enEFml-;71-?MOZ^v zN4V4?V%)299Uq z8rk>IzJ;)ha2w%Ggtrh-i^lLa)a1-WO*%A8L2G*0nDziPr5>m%4W%h{PzKNt(3m=? zMU_!*74M!IjWnbGTJ5fOwf`K9iJ;;Qfdm2}vGJOa0=ojc0ZL`p)L1>*)vPp< zds?WnC*Gh45W@>)julx1u?l-UB zuRkf5T>{UeFIjK4LdZW+Sbhu;Zon7+5`Ym#hoqa8v`e*YbWLV3b7+ln-Q39T+LVB7 z4fCTyw*d7Vv;Rc84$HFwlrE!(q~lzGHFEcY0(K1Ljs{Pr6OZ}ASJkZE;{y@!!8qzr zRX*bWAoA`6f)9E@PcIn9!_$JV2;VLE;;XP&Nm3&1Ii*;mb3%zQWNt@Xh|Qb#%=jC#hwt`oY`WMKLU; z{D%CN{DDX}t&6H`gpC>1yK~bOn5{lHU`6?~DI0T|uHt?{T9by>$*1eO z#YkINKD-V-+0Z_*)6F?~_Al@bdc|V7uX*L~-_R?e_DcJl^1s2Jw&vtsE?G&M$!fBe ztm~a^Gh438HMu^w&=<0vZp#(fWcJ)lcUXRoZRj1n>sRqccpHXG8AnqwE8b`suA)4U%>VxNHR10E} zX|_){_N`LR$r9rMbFCd%f1#Kt3KVaZLjmg^74zTg(ti_;05lE52|;}}+q2d1!S zI#ap0=MMu_c;NR=AN!msD;5Em%oLRKX~4(cG54V(c;pXNK3a$c0sHU-7_;KRr-8B_ z4T8tY8U=o&%;P|GTCl`13x>ECnAjOl#9#!PT$S~<_0f$3p-S2kt_vmS)~|0K9(s2_ zxufiZ+rRth)@^0$HFD*~C*#MSI2nLzOfaYNqwIJq`b5z)m4D)8PBG6s@Yk95K$$}y z{4(JIQ*@||2eGoF9v@F)MSF^#DEjy_!Vkbp&)?ZU8IOYf>4Rw`ru!#hxIa15?th~< z^y7HY9|U}VDh9(iL$5y&*C#M3lmb>{gOM055$e(cyjyhBXwVjI(h{xFicz9Xqee?c z#lU{ofN!BwpbgLhJ+0HEI|ly+*y)wf+Tyta?P3;bi9v>?4e5HFIJq>zZR7MdFR z$eJ<_N2esL(2_YwV>#)*NMq;5{UbQCkaQs3=y9@=m8q>08T1H9W1*9}vKTsZdQP8x z0kajuBBnApxUMLL<+PF|Gc8wT8In~|r!z=pNgYxamDRKcI*`+!8u!fXki!B0dMh0I zq$TRofFoanL*0OMw;>wRoKcoPwfS>tLQe}4+$LBoqea*mq>t#@3;F+M&{~(RdIoL3 z0#=r6!^*oatw(=e%vw0Lp+`c`3UfAyTtme9qy9p})f%5`jiS~7oEcg!X&G>4tP>dX z)?!RqkNL-|j9HARzzBo4AWG#d%sw^BDZiY!vn^rL|Jp z1Y+3;R~NKU3Q^uf8N5}~iLnHPP?4*eDx7p(Zph7b0(7BCK|1M%v~=t_ND2@YR@&Bd zaRZ_-qtHAibxju-eKxBvN8f#^ zNAAi!O&4w1PIu*&+-45Y#U85yU9{ovx2693Ko?6<^s;@~OWG^iOW>78=9Lh=($+H4 zh4Ju`S5__ zp@+J0`0)jj#s#Kxly4q&O#UGV{38UsT4`1Ri`F3i72}v4cp0Gc8n0sg+XyuTBp6;t zXdu9yjqoPID#9AVIs)E|_y)ozLK^|+=UWJxOw3sDZsC-9+!uo}-^Pv|gk6L^1WZ-@ zHH2>=Tt&Es@EwHL5jqIpMfe^-=L%3l;P*}f?rF;CmSZtxK_oJfzX3`r50fI3@as5k zAK^`e?<2g0@B@S&BD{_84#JNRevI%g!cP!>itrvl*NyQhRbg24WfFs^buV zWE7!kWCv)Z2}IGPYw!Yr0BzI^th=-VQU!V{^cs9ujS77oWuO@(lg6tM&BriXrD{0?PCuthFjkj^rw0ShKYrA>XM$p~1sbE8$6-T?8m1blP zWqVT&P+)V&#qb_;93Ary^w4t;J?6|qQM7=Q1$yksm-hQO6YVY>6mBKxM}Fi-a`^G% z?>BNQ701BO+Cg_TXBhv`#u=bczXHE#2Y?xj42*VKk~YycZJSWifi<+-HZhpVtU+#= zZ|8@Fc0t?rpg45e4%Blj|EbX~u>vbXS!QJ0XgMb^vvI3MRDpTks2h$BPYATMsB(t= z2jfwPxx!U5X}!z)!sGog0RPGkcl$x-rYCs6>vhwO;ADoe1V6N|!7q9Yn~+9g$lQbw ztdWU93_+!RL}Evnz9~&c5=#~)gwcoex^YbI{NXl@OL1A2(s7ABF=RQb=OzXiI0_#0=}*7ko{ zJzMqZ73;Co~eBdEUu(Ta9|01?ISZ{a7Mq%LgP@Eb=&N8Cx8u8PgW+d#v=>rcvcp(&z{wtzy2R5 zHPR=wdQAE6;DpyE#%()p#&hv}ybv$qj+1qklk;*xE>0|cIy2=W|!21gB1APAW&XWqdd67uHPW}xaZbH`%49N8(>}p+5RD{YF zc+GeYI+Z2vvZ2cLLOyhbvU`IN3TqGsd*BxG?!L0RW9BN;9gbAN9gDE*M&QU94@b(@ zcUaM2q;lL1SU6M`mfV85xs3~y#v{1yUS|+S5j;5F{6f}rsO+fg4m?%Zb-M@mUCxvh z2@mE<9aQlXkB6N-?m~xm&mE|0Kj?Yf3%Xus7%~_rU{@Uhb6QdV6Hi(9`rdtI4Lvtd z=AI{7OYn%`IKVBz`o(Z0`a|&Ks-lCSsVMS@z5$>`^DAm?ar-V-Bh{jH-7#8 zwHqp@AC)U7+7It{#C{(lWWsbRKTPRLkq3$#s{Dga8Wi)e17SPt>?(8MLR?1NV{IGp zvLXXTcO#Vxx)8-kk**^9iroLg;6I0$J$Y?wKOA~nr$VJ*DuTuUV%-t>H3fboZlv95G67o*-Gw?6mfsCIe z0k%zBNJ2Fem@*G{u`Dd#k~T25EuAT2(>osn4*;jiN;)n9mX_75p4W^Gy#n~!caoAU z`7jn(<#SlEl(F^Wvy90f16x;sgG-vNfvMx!82NUF*(~>v-ZtX~at5%)X^Z9$;8H5* zWDQvy*n0`aAj2;Qm$3HK=+l;zeQ<0Q*nCOUrFleH{vPlEza&j)UxfCR;Ipi!pnE=o zo(3x}8FC&`d-O67M|*s#JxbaG@Q7-=tZl#}x?sSpm$NYyJ?4+kGA0{Q)g%5XS0IuN zHgkZ_`}eetxM-{xfrYi#ah=68>rif`qp;^&m~QzT)0LB^$jK#tF4L8=kM;`M5Y?Jy z(CjRM*>^#&=i&U8QFnrJX3c38DuV7T$(02Ilql7mVzMMHoqxqtccg{714?uSGBJH) zC?u;>CBhucvDvdZsFmn{%)m#eL?wUi*yJBT=B{fcT8-D@^>`!R#C!p@Dav)ZAvd)W zEz9L(L$1m-Rs<#5WHX>d%aCu&Y0iKWW!ZMx_EcQWh@X$Ej-!q%#Idg9SoU@7w2ewc z@BaB)=K0gPc?iV#LZ(EGTO4_X7ZC8O^Af;Li?j=Y=dn)DbKw=v*$#VR$K^=AM430= zz0ND(26ejq(FwkMo}eh9eRZd0^7p~uA0m8!koL8$e)J*Z$S=GK(0Y#7u>M_ySp?KB zUPovkG!fu!*WmLA3kZt{_}a)<5LOY^5Y`b+15#f6X-$!^E0ED(8Jh!{=aV=DXb|n&H<^>`(`4%{-JS>W`1ZM~1 zUO~8w@I!=G5q^a58p4keeuD5i!W#&0BD{t0Q-rqx+D?QAA;#Q`V9ja0)9bc#QAcZ# zBc{#m%4uD@2<^hJ;C%1kd=@?M+XqCJKC5jSA52?~Jtv`!2{SsDdRnkPj z3{iG49*z!m0>H}*{2SqzhN`FqwBwHV&Lo4a8`JbcjpI1mnO*Ar>a7;uf)|VNzfq=XUq) z?sFP$v6wUP{Ok7(@4wTA@oy@OKPC#-@P_XLe8XoQqn@O!&UD$VTS(20-A&a~%AXJGbpEg=XXXu>!*y{HhU03kd z-eZQAHB;rft-ax(;k&|B6G`6VEfH{-dKP_t^w>xpyw~xDzXnJn#&}08W+NjuM~1N4 zwqz&keABlcnx+vM5rciow?_;%I!$VYBMr$!T4bbozz&$7dT2$Kv@m*SLs{Nn$Y~Uq zg6;0L!pHx^gM!m>w}MW%mQ|+cE5{Fn>+PsiZ`d8|E3-FHj?1^Z?w(5XAlz{W0jEeS zyVDAVN^!Th9jw`PXTKr%Fi=PbJn&Wa`nzv!ZZy90jqKT*+O=87L@x;hDU{OVsC6g67A&C*>6ORfXn_(v&oHB+~J;t}){oo#x`I zn2=NL*#qoCSmwnXMp}`n1M3YV%HO>r%=nbdjaZaJZeHf$b72HKLf@G+7K~{6E8;p6 zVcc^}`31yjd1Ty4MYGXdbSj#U7Id^$ z{ER#$=jFo4rufQAydr1ioSzw4@hKmxg*uA3j^x*haOJXbXoj~>#CYbp*i}!&u6oGs z{^`UjB|bhK_o1UZgYnGh=vI%STOEsT4KvEgX*nZn%@j^_ZsRkiQu;hTDg7o|xy9X9 zZ)H=s+W{JEKL|Z#_xioy3py2s@%fuyy#3d|efs4;A6$J>B>Nx~i8-eOl`z99r=z1% zf}HW9&OZovKUB83>vxsY>~nA)&VIkQ4Kq5q%Jzo7t1P!WP#Je9`kqU?p!rWZpa;L- zRW_xZcu%F8ojzzYrSl=qb)Yg^uD5&N<-W2*5rA-p1CVH-veb8Do4YNCzDn;kl9*Wg4TNjIv8Akzi&1O3WF3`FK5#t`SsJ=_lh#>yF}FhH^gP}l zgo=5J?I?C%C3D0lnqCvA(ooYNPJwY1PzW~w2FtMw7;}lun-w<4rdffNS;;K0m(7w% zZ5gx7D#*>7Rg?;>rY$q90$W(a@&YVgg`atoe*}*nIe+@Rgv>GT{sB6pc^6EIQ+mX^ zq-eygs9~gX#KhnX(z>)vCtPeDv1(KjsWuoIeC!bK9P{xX!4hCjIg$91s5$uc5bMTuK;OmnSse<~1izMwMtPng%CK#%1C)KmE|W zV?|)GXcp)1WUXczF)2-ON<~h|X>j&}s7MRSo4E(3<0WaKZ5}*bgU<=_2_9PPZ9FZ# z(1W%)uGKwg|7G|%>AU^R_Y3U8!rzXwi>&$CX#-=(g8vs}ZPoz4B>bO^XQZuHBpX+? z64>#qP7CoI^k|`d>PQiax?DRhm&81DYEh=OLX_^_5>~t*3nQk};)qn0)X|D-vV{J? z{ZKlfpVMlw^!&=hemPpf&d?eyGV#)}TF_2#{QUS%&}#8b_5rCyzP${!_#7*-qSazG zIvt&f&PM0-daU{dc}AX<=d@ZZ%Efq9uFBJX0cx?}V@A;;)?!iWwSZcTt?0A*7uLcG z3n$iM`MJ2RoQUJfp?UX{uki8yxDOrSWsGN8M|b5Yx+`PRJ&)1RU1Lcu%kxbKdQsZo zb2y@Wo?wB1Zc2O!aBGd#Grb`28_mE)KDXI#id(q!1!DUzf2PO0#+S)@h2WW7oYr!| zzelyJ1f&@JG{G5yvjpb|&Jz$t@RteJ2rdwOi{RS?Bo6!)0xb$X_FPo!||E?-Trh;D-cP0BbLS-veB11K!a5UoVDY=m$NK6!}#csr0U| zS(m>_-(4eki-6dl|A>IDmHY<5O@bc-o(c!ibz|4(KcR-V3Em<2DZ$SO-X(YsP|t-F zYB9ux8sj+5H=17E3HANCo)gE__2RR2{s9C5qv;! zo8Uu&UlGuCd`vGUj`X78<5u75bZFY)7QoQjK>AQS(vBL`fbQN<4`@LNYEh*W>Hw`c zRua;Z>ethfPteO+A=x8(-;l28y;PqFcRPX0dl$RV;FcFAe?7ALA40<&ZYR5jHcYY+ lNR~MH8=Ie=sTRe|6)gR*6I`V~Y$5eh)=F14Gfw8U{{zCgni2p2 diff --git a/utils/__pycache__/utils_loss.cpython-38.pyc b/utils/__pycache__/utils_loss.cpython-38.pyc index 1dedf35f6309fd39f254f1f43d36e4058a917e08..ffa1b61e1ef8722200c6ba76c5a13b54386b10dd 100644 GIT binary patch delta 1948 zcmZuy&2QXP5P$Fa`D=HxX_8H&v`MK@3Yd`4rV=2vq(Y^9HSD1g94sVjwa=SwvTJXj z9SV_Gdq|~nphCenKBVLZLJDx_g7_19sN@SAkWi7haN`6s&rUWhJn4N{g&B?9Z``K{*fmA&d5{djNj%DckY;I8>P5vI`Y5uy2HlWkhNnrnm15*%dQ z&m|?0*g^-1UP#IyDG384Mj@#Px=x((-d8V{GyB%7GVHB|QIvjXocaE*n{*x{kYyka z&v@i_TJl&T$I}K|qt*0~eM#+fOJN-ltmnx`A|7Ed^r_Uof5j?&wq>IMl4VA&y zw!DC%r;t32#6j{b5{BeBH@S$fLr+CUpBw$D0t+~YU<)i4&9`r=n2bCz{n*(skrlMy zqv@ek{wK$uOouL|f0#cGrqPOKccqHv9IE48WxR7b3^wEwz&myx zc@A@W-M%&ZB&x7h=LAeOG+js!?UnOGnB*mtj$-*~X__|oCTasXB+*qMK5?laE)%8u zq(udT&aqw1=bJii!M|WahpynRA?6TvQDJBz@jqISzejE1D%%9#!+Ryj3=u|MsE#dmtjX}OK6oE zbZ7FoSC>7n9kgT5b68%b7eVPLN`(PI#ISI#!;V}aF-hnyjp>#p-%99vtgrQXLiSmr zCG3DHjw3fi1g(+24iXK{_yOx13A}Cc#Hf)O$W7@|)Lkc-tHbIW@BvuATyyknz~c{l zK3sP@iX|qPjW@$?Z#(wAt&WEPsEi%+WDTI2(GEcCQt@1I?QWwTD1*#AvK(G~UL9cf z66^@!CbU8=IyZnvmF-Ln(Q1YAZbOQC6Gz0+LVyrl1<*&p#AK7gW(gHgFAo^-08P#d zP4EW**s)F_S2~4$p`#BHO^_{H-cI;F2VOs5lqC8Ip6}}i1kl&jQ%`^c)Qz(QkTw8M zvo2qNfK(@DJO~3Hn%0)S*lxz{FpwDfV;nJ#p&BzY03OkPKPyA;d4niy#n{NI9V^{N zwB-x0)rPSuv0e+{Guq4O9oOnLgoiq=(NkEv#-03CtrgVYMruC9@MK$3t9mQb{W!BYhD7HKPDn)Vw?? delta 886 zcmZWnO=}ZD7|!g>ekHpb%{C>iNv*_+V@#PM?W6%XbTVs;y?`%z{M7NJ|7>Rkjq2p*&te?agk>`}Z3o;-Nfd3Tp=wR?D&dFFlId7gLnd*f}@x?`Hfg7`lE zyxU*bMphkdk6ITl4RA-jQg9RpKQB1gEo@4g+B4kI1*bYhaKufSqF^*f7L1%RB?sLu zP<^y_o_|qhKh)3~YUtty2KP`BlgRZp*o0vg-<7T-laHlWNaGK2TS?FOYh1T2p^0%c z^kd??y@1NB%*S}HRDp&G*!(wMQWdek$yDAZm#GP7O%R&hZr|%gQQ!5vZa1w@M2Dcn z<}b-krquYnd~F3BCeFk?2CNAH^oyk+U^8MtF)IU^`%ikMR+4J|P7d28C`JD$+HkrE zu+Kbfn|xJitt1KF=31E_qoq}!NDNFQ8*|D=(h;v~!o8_u{V=eKd zq`~A#vpmQdP>}!Q0m}|#{H9WLnj6~Zr&FLpaVZ%3>9F01Re_UY>^}$u^tRcwAO=_8 z@4fi-^~v&&KjNchR)fy707n2IWGsvK^0(5|iL*Me*b4_S-U=VH#7zUr;jLogO^3mN z&51SDQ|IZJ`91yC%rPjDHoSdfvO(J9WBu%WJ|?yV?15PxEmm!6q`tIAA?x2?6hA?W qcdP|e^55e&`H6cNvLY1plvw3{*jeRFtL;SG2!mJEw5~ diff --git a/utils/__pycache__/utils_model.cpython-38.pyc b/utils/__pycache__/utils_model.cpython-38.pyc index 7f25fea22c7383e2254140e87a520ffa8e859e25..3110170af2317adcc3d22e65c166353e174c5eb8 100644 GIT binary patch delta 24 ecmZ1@zDArgl$V!_0SGKMa{6*J-kKcF9SQ(Ix&?Ru delta 24 ecmZ1@zDArgl$V!_0SL@Da{6*JMokXq4g~-{F9i1h diff --git a/utils/utils.py b/utils/utils.py index c8340af..5ed5ee0 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,5 +1,5 @@ from sklearn import utils -import torch, itertools, os, time, thop, json, cv2 +import torch, itertools, os, time, thop, json, cv2, math import torch.nn as nn import torchvision.transforms as transforms import numpy as np @@ -8,7 +8,7 @@ plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False from math import cos, pi -from sklearn.metrics import classification_report, confusion_matrix, cohen_kappa_score +from sklearn.metrics import classification_report, confusion_matrix, cohen_kappa_score, precision_score, recall_score, f1_score, accuracy_score from prettytable import PrettyTable from copy import deepcopy from argparse import Namespace @@ -19,6 +19,7 @@ from pytorch_grad_cam.utils.image import show_cam_on_image from collections import OrderedDict from .utils_aug import rand_bbox +from pycm import ConfusionMatrix cnames = { 'aliceblue': '#F0F8FF', @@ -162,6 +163,9 @@ 'yellow': '#FFFF00', 'yellowgreen': '#9ACD32'} +def str2float(data): + return (0.0 if type(data) is str else data) + def save_model(path, **ckpt): torch.save(ckpt, path) @@ -185,7 +189,6 @@ def mixup_data(x, opt, alpha=1.0): raise 'Unsupported MixUp Methods.' return mixed_x - def plot_train_batch(dataset, opt): dataset.transform.transforms[-1] = transforms.ToTensor() dataloader = iter(torch.utils.data.DataLoader(dataset, 16, shuffle=True)) @@ -247,37 +250,6 @@ def plot_log(opt): plt.savefig(r'{}/learning_rate_curve.png'.format(opt.save_path)) -def plot_confusion_matrix(cm, classes, save_path, normalize=True, title='Confusion matrix', cmap=plt.cm.Blues, name='test'): - plt.figure(figsize=(min(len(classes), 30), min(len(classes), 30))) - if normalize: - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] - trained_classes = classes - plt.imshow(cm, interpolation='nearest', cmap=cmap) - plt.title(name + title, fontsize=min(len(classes), 30)) # title font size - tick_marks = np.arange(len(classes)) - plt.xticks(np.arange(len(trained_classes)), classes, rotation=90, fontsize=min(len(classes), 30)) # X tricks font size - plt.yticks(tick_marks, classes, fontsize=min(len(classes), 30)) # Y tricks font size - thresh = cm.max() / 2. - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - plt.text(j, i, np.round(cm[i, j], 2), horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black", fontsize=min(len(classes), 30)) # confusion_matrix font size - plt.ylabel('True label', fontsize=min(len(classes), 30)) # True label font size - plt.xlabel('Predicted label', fontsize=min(len(classes), 30)) # Predicted label font size - plt.tight_layout() - plt.savefig(os.path.join(save_path, 'confusion_matrix.png'), dpi=150) - plt.show() - -def save_confusion_matrix(cm, classes, save_path, normalize=True): - if normalize: - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] - str_arr = [] - for class_, cm_ in zip(classes, cm): - str_arr.append('{},{}'.format(class_, ','.join(list(map(lambda x:'{:.4f}'.format(x), list(cm_)))))) - str_arr.append(' ,{}'.format(','.join(classes))) - - with open(os.path.join(save_path, 'confusion_matrix.csv'), 'w+') as f: - f.write('\n'.join(str_arr)) - class WarmUpLR: def __init__(self, optimizer, opt): self.optimizer = optimizer @@ -306,7 +278,26 @@ def adjust_lr(self): 1 + cos(pi * (self.current_epoch - self.warmup_epoch) / (self.max_epoch - self.warmup_epoch))) / 2 for param_group in self.optimizer.param_groups: param_group['lr'] = lr - + + def state_dict(self): + return { + 'lr_min': self.lr_min, + 'lr_max': self.lr_max, + 'max_epoch': self.max_epoch, + 'current_epoch': self.current_epoch, + 'warmup_epoch': self.warmup_epoch, + 'lr_scheduler': self.lr_scheduler.state_dict(), + 'optimizer': self.optimizer.state_dict() + } + + def load_state_dict(self, state_dict): + self.lr_min = state_dict['lr_min'] + self.lr_max = state_dict['lr_max'] + self.max_epoch = state_dict['max_epoch'] + self.current_epoch = state_dict['current_epoch'] + self.warmup_epoch = state_dict['warmup_epoch'] + self.optimizer.load_state_dict(state_dict['optimizer']) + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) def show_config(opt): table = PrettyTable() @@ -327,6 +318,35 @@ def show_config(opt): with open(os.path.join(opt['save_path'], 'param.json'), 'w+') as f: f.write(json.dumps(opt, indent=4, separators={':', ','})) +def plot_confusion_matrix(cm, classes, save_path, normalize=True, title='Confusion matrix', cmap=plt.cm.Blues, name='test'): + plt.figure(figsize=(min(len(classes), 30), min(len(classes), 30))) + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + trained_classes = classes + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(name + title, fontsize=min(len(classes), 30)) # title font size + tick_marks = np.arange(len(classes)) + plt.xticks(np.arange(len(trained_classes)), classes, rotation=90, fontsize=min(len(classes), 30)) # X tricks font size + plt.yticks(tick_marks, classes, fontsize=min(len(classes), 30)) # Y tricks font size + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, np.round(cm[i, j], 2), horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black", fontsize=min(len(classes), 30)) # confusion_matrix font size + plt.ylabel('True label', fontsize=min(len(classes), 30)) # True label font size + plt.xlabel('Predicted label', fontsize=min(len(classes), 30)) # Predicted label font size + plt.tight_layout() + plt.savefig(os.path.join(save_path, 'confusion_matrix.png'), dpi=150) + plt.show() + +def save_confusion_matrix(cm, classes, save_path): + str_arr = [] + for class_, cm_ in zip(classes, cm): + str_arr.append('{},{}'.format(class_, ','.join(list(map(lambda x:'{:.4f}'.format(x), list(cm_)))))) + str_arr.append(' ,{}'.format(','.join(classes))) + + with open(os.path.join(save_path, 'confusion_matrix.csv'), 'w+') as f: + f.write('\n'.join(str_arr)) + def cal_cm(y_true, y_pred, CLASS_NUM): y_true, y_pred = y_true.to('cpu').detach().numpy(), np.argmax(y_pred.to('cpu').detach().numpy(), axis=1) y_true, y_pred = y_true.reshape((-1)), y_pred.reshape((-1)) @@ -385,58 +405,63 @@ def __init__(self, y_true, y_pred, class_num): self.y_true = y_true self.y_pred = y_pred self.class_num = class_num - self.result = {str(i):{} for i in range(self.class_num)} - - def cal_class_kappa(self): - for i in range(self.class_num): - y_true_class = np.where(self.y_true == i, 1, 0) - y_pred_class = np.where(self.y_pred == i, 1, 0) - - self.result[str(i)]['kappa'] = cohen_kappa_score(y_true_class, y_pred_class) + self.result = {i:{} for i in range(self.class_num)} + self.metrice = ['PPV', 'TPR', 'AUC', 'AUPR', 'F05', 'F1', 'F2'] + self.metrice_name = ['Precision', 'Recall', 'AUC', 'AUPR', 'F0.5', 'F1', 'F2', 'ACC'] def __call__(self): - self.cal_class_kappa() - return self.result + cm = ConfusionMatrix(self.y_true, self.y_pred) + for j in range(len(self.metrice)): + for i in range(self.class_num): + self.result[i][self.metrice_name[j]] = str2float(eval('cm.{}'.format(self.metrice[j]))[i]) + + return self.result, cm def classification_metrice(y_true, y_pred, class_num, label, save_path): - cm = confusion_matrix(y_true, y_pred, labels=list(range(class_num))) + metrice = Test_Metrice(y_true, y_pred, class_num) + class_report, cm = metrice() + class_pa = np.diag(cm.to_array(normalized=True)) # mean class accuracy if class_num <= 50: - plot_confusion_matrix(cm, label, save_path) - save_confusion_matrix(cm, label, save_path) - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] - class_pa = np.diag(cm) # mean class accuracy - class_report = classification_report(y_true, y_pred, output_dict=True) - extra_class_report = Test_Metrice(y_true, y_pred, class_num)() - - cols_name = ['class', 'precision', 'recall', 'f1-score', 'kappa', 'accuracy'] - - table = PrettyTable() - table.title = 'Accuracy:{:.5f} MPA:{:.5f}'.format(class_report['accuracy'], np.mean(class_pa)) - table.field_names = cols_name - for i in range(class_num): - table.add_row([label[i], - '{:.5f}'.format(class_report[str(i)]['precision']), - '{:.5f}'.format(class_report[str(i)]['recall']), - '{:.5f}'.format(class_report[str(i)]['f1-score']), - '{:.5f}'.format(extra_class_report[str(i)]['kappa']), - '{:.5f}'.format(class_pa[i]) - ]) + plot_confusion_matrix(cm.to_array(), label, save_path) + save_confusion_matrix(cm.to_array(normalized=True), label, save_path) + + table1_cols_name = ['class'] + metrice.metrice_name + table1 = PrettyTable() + table1.title = 'Per Class' + table1.field_names = table1_cols_name + with open(os.path.join(save_path, 'perclass_result.csv'), 'w+', encoding='utf-8') as f: + f.write(','.join(table1_cols_name) + '\n') + for i in range(class_num): + table1.add_row([label[i]] + ['{:.5f}'.format(class_report[i][j]) for j in table1_cols_name[1:-1]] + ['{:.5f}'.format(class_pa[i])]) + f.write(','.join([label[i]] + ['{:.5f}'.format(class_report[i][j]) for j in table1_cols_name[1:-1]] + ['{:.5f}'.format(class_pa[i])]) + '\n') + print(table1) + + table2_cols_name = ['Accuracy', 'MPA', 'Kappa', 'Precision_Micro', 'Recall_Micro', 'F1_Micro', 'Precision_Macro', 'Recall_Macro', 'F1_Macro'] + table2 = PrettyTable() + table2.title = 'Overall' + table2.field_names = table2_cols_name + with open(os.path.join(save_path, 'overall_result.csv'), 'w+', encoding='utf-8') as f: + data = ['{:.5f}'.format(str2float(cm.Overall_ACC)), + '{:.5f}'.format(np.mean(class_pa)), + '{:.5f}'.format(str2float(cm.Kappa)), + '{:.5f}'.format(str2float(cm.PPV_Micro)), + '{:.5f}'.format(str2float(cm.TPR_Micro)), + '{:.5f}'.format(str2float(cm.F1_Micro)), + '{:.5f}'.format(str2float(cm.PPV_Macro)), + '{:.5f}'.format(str2float(cm.TPR_Macro)), + '{:.5f}'.format(str2float(cm.F1_Macro)), + ] + + table2.add_row(data) + + f.write(','.join(table2_cols_name) + '\n') + f.write(','.join(data)) + print(table2) - print(table) with open(os.path.join(save_path, 'result.txt'), 'w+', encoding='utf-8') as f: - f.write(str(table)) - - with open(os.path.join(save_path, 'result.csv'), 'w+', encoding='utf-8') as f: - f.write(','.join(cols_name) + '\n') - f.write('\n'.join(['{},{}'.format(label[i], ','.join( - [ - '{:.5f}'.format(class_report[str(i)]['precision']), - '{:.5f}'.format(class_report[str(i)]['recall']), - '{:.5f}'.format(class_report[str(i)]['f1-score']), - '{:.5f}'.format(extra_class_report[str(i)]['kappa']), - '{:.5f}'.format(class_pa[i]) - ] - )) for i in range(class_num)])) + f.write(str(table1)) + f.write('\n') + f.write(str(table2)) def update_opt(a, b): b = vars(b) @@ -608,9 +633,10 @@ def visual_tsne(feature, y_true, path, labels, save_path): f.write('\n'.join(['{},{},{:.0f},{:.0f}'.format(i, labels[j], k[0], k[1]) for i, j, k in zip(path, y_true, feature_tsne)])) -def predict_single_image(path, model, test_transform, DEVICE): +def predict_single_image(path, model, test_transform, DEVICE, half=False): pil_img = Image.open(path) tensor_img = test_transform(pil_img).unsqueeze(0).to(DEVICE) + tensor_img = (tensor_img.half() if half else tensor_img) if len(tensor_img.shape) == 5: tensor_img = tensor_img.reshape((tensor_img.size(0) * tensor_img.size(1), tensor_img.size(2), tensor_img.size(3), tensor_img.size(4))) pred_result = torch.softmax(model(tensor_img).mean(0), 0) @@ -620,13 +646,10 @@ def predict_single_image(path, model, test_transform, DEVICE): class cam_visual: def __init__(self, model, test_transform, DEVICE, target_layers, opt): - self.model = model self.test_transform = test_transform self.DEVICE = DEVICE - self.target_layers = target_layers - self.opt = opt - self.cam_model = eval(opt.cam_type)(model=model, target_layers=[target_layers], use_cuda=torch.cuda.is_available()) + self.cam_model = eval(opt.cam_type)(model=deepcopy(model).float(), target_layers=[target_layers], use_cuda=torch.cuda.is_available()) def __call__(self, path, label): pil_img = Image.open(path) @@ -693,3 +716,37 @@ def dict_to_PrettyTable(data, name): table.field_names = data_keys table.add_row(['{:.5f}'.format(data[i]) for i in data_keys]) return str(table) + +def is_parallel(model): + # Returns True if model is of type DP or DDP + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + +def de_parallel(model): + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model + +class ModelEMA: + """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models + Keeps a moving average of everything in the model state_dict (parameters and buffers) + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + # Create EMA + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + # Update EMA parameters + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: # true for FP16 and FP32 + v *= d + v += (1 - d) * msd[k].detach() + # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' \ No newline at end of file diff --git a/utils/utils_aug.py b/utils/utils_aug.py index 5564e6d..8e9e6b7 100644 --- a/utils/utils_aug.py +++ b/utils/utils_aug.py @@ -3,6 +3,7 @@ import numpy as np from PIL import Image from copy import deepcopy +import albumentations as A def get_mean_and_std(dataset, opt): '''Compute the mean and std value of dataset.''' @@ -160,4 +161,17 @@ def __call__(self, img): return Image.fromarray(np.array(img * mask, dtype=np.uint8)) def __str__(self): - return 'CutOut' \ No newline at end of file + return 'CutOut' + +class Create_Albumentations_From_Name(object): + # https://albumentations.ai/docs/api_reference/augmentations/transforms/ + def __init__(self, name, **kwargs): + self.name = name + self.transform = eval('A.{}'.format(name))(**kwargs) + + def __call__(self, img): + img = np.array(img) + return Image.fromarray(np.array(self.transform(image=img)['image'], dtype=np.uint8)) + + def __str__(self): + return self.name \ No newline at end of file diff --git a/utils/utils_fit.py b/utils/utils_fit.py index de4e910..6c2ae7f 100644 --- a/utils/utils_fit.py +++ b/utils/utils_fit.py @@ -1,118 +1,141 @@ import torch, tqdm import numpy as np +from copy import deepcopy from .utils_aug import mixup_data, mixup_criterion from .utils import Train_Metrice +import time -def fitting(model, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, show_thing, opt): - model.to(DEVICE) +def fitting(model, ema, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, show_thing, opt): model.train() metrice = Train_Metrice(CLASS_NUM) for x, y in tqdm.tqdm(train_dataset, desc='{} Train Stage'.format(show_thing)): - x, y = x.to(DEVICE), y.to(DEVICE).long() + x, y = x.to(DEVICE).float(), y.to(DEVICE).long() with torch.cuda.amp.autocast(opt.amp): - if opt.mixup != 'none' and np.random.rand() > 0.5: - x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) - pred = model(x_mixup.float()) - l = mixup_criterion(loss, pred, y_a, y_b, lam) - pred = model(x.float()) + if opt.rdrop: + if opt.mixup != 'none' and np.random.rand() > 0.5: + x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) + pred = model(x_mixup) + pred2 = model(x_mixup) + l = mixup_criterion(loss, [pred, pred2], y_a, y_b, lam) + pred = model(x) + else: + pred = model(x) + pred2 = model(x) + l = loss([pred, pred2], y) else: - pred = model(x.float()) - l = loss(pred, y) + if opt.mixup != 'none' and np.random.rand() > 0.5: + x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) + pred = model(x_mixup) + l = mixup_criterion(loss, pred, y_a, y_b, lam) + pred = model(x) + else: + + pred = model(x) + l = loss(pred, y) + metrice.update_loss(float(l.data)) metrice.update_y(y, pred) - + scaler.scale(l).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() - - model.eval() - with torch.no_grad(): + if ema: + ema.update(model) + + if ema: + model_eval = ema.ema + else: + model_eval = model.eval() + with torch.inference_mode(): for x, y in tqdm.tqdm(test_dataset, desc='{} Test Stage'.format(show_thing)): - x, y = x.to(DEVICE), y.to(DEVICE).long() + x, y = x.to(DEVICE).float(), y.to(DEVICE).long() with torch.cuda.amp.autocast(opt.amp): if opt.test_tta: bs, ncrops, c, h, w = x.size() - pred = model(x.view(-1, c, h, w)) + pred = model_eval(x.view(-1, c, h, w)) pred = pred.view(bs, ncrops, -1).mean(1) l = loss(pred, y) else: - pred = model(x.float()) + pred = model_eval(x) l = loss(pred, y) metrice.update_loss(float(l.data), isTest=True) metrice.update_y(y, pred, isTest=True) - return model, metrice.get() + return metrice.get() -def fitting_distill(teacher_model, student_model, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, +def fitting_distill(teacher_model, student_model, ema, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, show_thing, opt): - teacher_model.to(DEVICE) - teacher_model.eval() - student_model.to(DEVICE) student_model.train() metrice = Train_Metrice(CLASS_NUM) for x, y in tqdm.tqdm(train_dataset, desc='{} Train Stage'.format(show_thing)): - x, y = x.to(DEVICE), y.to(DEVICE).long() + x, y = x.to(DEVICE).float(), y.to(DEVICE).long() with torch.cuda.amp.autocast(opt.amp): if opt.mixup != 'none' and np.random.rand() > 0.5: x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) - s_features, s_features_fc, s_pred = student_model(x_mixup.float(), need_fea=True) - t_features, t_features_fc, t_pred = teacher_model(x_mixup.float(), need_fea=True) + s_features, s_features_fc, s_pred = student_model(x_mixup, need_fea=True) + t_features, t_features_fc, t_pred = teacher_model(x_mixup, need_fea=True) l = mixup_criterion(loss, s_pred, y_a, y_b, lam) - if str(kd_loss) in ['SoftTarget']: - kd_l = kd_loss(s_pred, t_pred) - pred = student_model(x.float()) + pred = student_model(x) else: - s_features, s_features_fc, s_pred = student_model(x.float(), need_fea=True) - t_features, t_features_fc, t_pred = teacher_model(x.float(), need_fea=True) + s_features, s_features_fc, s_pred = student_model(x, need_fea=True) + t_features, t_features_fc, t_pred = teacher_model(x, need_fea=True) l = loss(s_pred, y) - if str(kd_loss) in ['SoftTarget']: - kd_l = kd_loss(s_pred, t_pred) - elif str(kd_loss) in ['MGD']: - kd_l = kd_loss(s_features[-1], t_features[-1]) - elif str(kd_loss) in ['SP']: - kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) - elif str(kd_loss) in ['AT']: - kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) + if str(kd_loss) in ['SoftTarget']: + kd_l = kd_loss(s_pred, t_pred) + elif str(kd_loss) in ['MGD']: + kd_l = kd_loss(s_features[-1], t_features[-1]) + elif str(kd_loss) in ['SP']: + kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) + elif str(kd_loss) in ['AT']: + kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) - if str(kd_loss) in ['SoftTarget', 'SP', 'MGD']: - kd_l *= (opt.kd_ratio / (1 - opt.kd_ratio)) if opt.kd_ratio < 1 else opt.kd_ratio - elif str(kd_loss) in ['AT']: - kd_l *= opt.kd_ratio + if str(kd_loss) in ['SoftTarget', 'SP', 'MGD']: + kd_l *= (opt.kd_ratio / (1 - opt.kd_ratio)) if opt.kd_ratio < 1 else opt.kd_ratio + elif str(kd_loss) in ['AT']: + kd_l *= opt.kd_ratio metrice.update_loss(float(l.data)) metrice.update_loss(float(kd_l.data), isKd=True) - metrice.update_y(y, s_pred) + if opt.mixup != 'none': + metrice.update_y(y, pred) + else: + metrice.update_y(y, s_pred) scaler.scale(l + kd_l).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() - - student_model.eval() - with torch.no_grad(): + if ema: + ema.update(student_model) + + if ema: + model_eval = ema.ema + else: + model_eval = student_model.eval() + with torch.inference_mode(): for x, y in tqdm.tqdm(test_dataset, desc='{} Test Stage'.format(show_thing)): - x, y = x.to(DEVICE), y.to(DEVICE).long() + x, y = x.to(DEVICE).float(), y.to(DEVICE).long() with torch.cuda.amp.autocast(opt.amp): if opt.test_tta: bs, ncrops, c, h, w = x.size() - pred = student_model(x.view(-1, c, h, w)) + pred = model_eval(x.view(-1, c, h, w)) pred = pred.view(bs, ncrops, -1).mean(1) l = loss(pred, y) else: - pred = student_model(x.float()) + pred = model_eval(x) l = loss(pred, y) metrice.update_loss(float(l.data), isTest=True) metrice.update_y(y, pred, isTest=True) - return student_model, metrice.get() \ No newline at end of file + return metrice.get() \ No newline at end of file diff --git a/utils/utils_loss.py b/utils/utils_loss.py index f728e8e..655a40b 100644 --- a/utils/utils_loss.py +++ b/utils/utils_loss.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch import Tensor -__all__ = ['PolyLoss', 'CrossEntropyLoss', 'FocalLoss'] +__all__ = ['PolyLoss', 'CrossEntropyLoss', 'FocalLoss', 'RDropLoss'] class PolyLoss(torch.nn.Module): """ @@ -45,4 +45,35 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ce = -1 * input_logsoftmax * target_onehot_labelsmoothing fl = torch.pow((1 - input_softmax), self.gamma) * ce fl = fl.sum(1) * self.weight[target.long()] - return fl.mean() \ No newline at end of file + return fl.mean() + +class RDropLoss(nn.Module): + def __init__(self, loss, a=0.3): + super(RDropLoss, self).__init__() + self.loss = loss + self.a = a + + def forward(self, input, target: torch.Tensor) -> torch.Tensor: + if type(input) is list: + input1, input2 = input + main_loss = (self.loss(input1, target) + self.loss(input2, target)) * 0.5 + kl_loss = self.compute_kl_loss(input1, input2) + return main_loss + self.a * kl_loss + else: + return self.loss(input, target) + + def compute_kl_loss(self, p, q, pad_mask=None): + p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') + q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') + + # pad_mask is for seq-level tasks + if pad_mask is not None: + p_loss.masked_fill_(pad_mask, 0.) + q_loss.masked_fill_(pad_mask, 0.) + + # You can choose whether to use function "sum" and "mean" depending on your task + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss \ No newline at end of file diff --git a/v1.1-update_log.md b/v1.1-update_log.md new file mode 100644 index 0000000..52f4f14 --- /dev/null +++ b/v1.1-update_log.md @@ -0,0 +1,200 @@ +# pytorch-classifier v1.1 更新日志 + +- **2022.11.8** + 1. 修改processing.py的分配数据集逻辑,之前是先分出test_size的数据作为测试集,然后再从剩下的数据里面分val_size的数据作为验证集,这种分数据的方式,当我们的val_size=0.2和test_size=0.2,最后出来的数据集比例不是严格等于6:2:2,现在修改为等比例的划分,也就是现在的逻辑分割数据集后严格等于6:2:2. + 2. 参考yolov5,训练中的模型保存改为FP16保存.(在精度基本保持不变的情况下,模型相比FP32小一半) + 3. metrice.py和predict.py新增支持FP16推理.(在精度基本保持不变的情况下,速度更加快) + +- **2022.11.9** + 1. 支持(albumentations库)[https://github.com/albumentations-team/albumentations]的数据增强. + 2. 训练过程新增[R-Drop](https://github.com/dropreg/R-Drop),具体在main.py中添加--rdrop参数即可. + +- **2022.11.10** + 1. 利用Pycm库进行修改metrice.py中的可视化内容.增加指标种类. + +- **2022.11.11** + 1. 支持EMA(Exponential Moving Average),具体在main.py中添加--ema参数即可. + 2. 修改早停法中的--patience机制,当--patience参数为0时,停止使用早停法. + 3. 知识蒸馏中增加了一些实验数据. + 4. 修复一些bug. + +### FP16推理实验: + +实验环境: + +| System | CPU | GPU | RAM | +| :----: | :----: | :----: | :----: | +| Ubuntu | i9-12900KF | RTX-3090 | 32G | + +训练mobilenetv2: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +训练resnext50: + + python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +训练RepVGG-A0: + + python main.py --model_name RepVGG-A0 --config config/config.py --save_path runs/RepVGG-A0 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +训练densenet121: + + python main.py --model_name densenet121 --config config/config.py --save_path runs/densenet121 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + +计算各个模型的指标: + + python metrice.py --task val --save_path runs/mobilenetv2 + python metrice.py --task val --save_path runs/resnext50 + python metrice.py --task val --save_path runs/RepVGG-A0 + python metrice.py --task val --save_path runs/densenet121 + + python metrice.py --task val --save_path runs/mobilenetv2 --half + python metrice.py --task val --save_path runs/resnext50 --half + python metrice.py --task val --save_path runs/RepVGG-A0 --half + python metrice.py --task val --save_path runs/densenet121 --half + +计算各个模型的fps: + + python metrice.py --task fps --save_path runs/mobilenetv2 + python metrice.py --task fps --save_path runs/resnext50 + python metrice.py --task fps --save_path runs/RepVGG-A0 + python metrice.py --task fps --save_path runs/densenet121 + + python metrice.py --task fps --save_path runs/mobilenetv2 --half + python metrice.py --task fps --save_path runs/resnext50 --half + python metrice.py --task fps --save_path runs/RepVGG-A0 --half + python metrice.py --task fps --save_path runs/densenet121 --half + +| model | val accuracy(train stage) | val accuracy(test stage) | val accuracy half(test stage) | FP32 FPS(batch_size=64) | FP16 FPS(batch_size=64) | +| :----: | :----: | :----: | :----: | :----: | :----: | +| mobilenetv2 | 0.74284 | 0.74340 | 0.74396 | 52.43 | 92.80 | +| resnext50 | 0.80966 | 0.80966 | 0.80966 | 19.48 | 30.28 | +| RepVGG-A0 | 0.73666 | 0.73666 | 0.73666 | 54.74 | 98.87 | +| densenet121 | 0.77035 | 0.77148 | 0.77035 | 18.87 | 32.75 | + +### R-Drop实验: + +训练mobilenetv2: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop + +训练resnext50: + + python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop + +训练ghostnet: + + python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop + +训练efficientnet_v2_s: + + python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --rdrop + +计算各个模型的指标: + + python metrice.py --task val --save_path runs/mobilenetv2 + python metrice.py --task val --save_path runs/mobilenetv2_rdrop + python metrice.py --task val --save_path runs/resnext50 + python metrice.py --task val --save_path runs/resnext50_rdrop + python metrice.py --task val --save_path runs/ghostnet + python metrice.py --task val --save_path runs/ghostnet_rdrop + python metrice.py --task val --save_path runs/efficientnet_v2_s + python metrice.py --task val --save_path runs/efficientnet_v2_s_rdrop + + python metrice.py --task test --save_path runs/mobilenetv2 + python metrice.py --task test --save_path runs/mobilenetv2_rdrop + python metrice.py --task test --save_path runs/resnext50 + python metrice.py --task test --save_path runs/resnext50_rdrop + python metrice.py --task test --save_path runs/ghostnet + python metrice.py --task test --save_path runs/ghostnet_rdrop + python metrice.py --task test --save_path runs/efficientnet_v2_s + python metrice.py --task test --save_path runs/efficientnet_v2_s_rdrop + +| model | val accuracy | val accuracy(r-drop) | test accuracy | test accuracy(r-drop) | +| :----: | :----: | :----: | +| mobilenetv2 | 0.74340 | 0.75126 | 0.73784 | 0.73741 | +| resnext50 | 0.80966 | 0.81134 | 0.82437 | 0.82092 | +| ghostnet | 0.77597 | 0.76698 | 0.76625 | 0.77012 | +| efficientnet_v2_s | 0.84166 | 0.85289 | 0.84460 | 0.85837 | + +### EMA实验: + +训练mobilenetv2: + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --ema + +训练resnext50: + + python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --ema + +训练ghostnet: + + python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --ema + +训练efficientnet_v2_s: + + python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd + + python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \ + --pretrained --amp --warmup --imagenet_meanstd --ema + +计算各个模型的指标: + + python metrice.py --task val --save_path runs/mobilenetv2 + python metrice.py --task val --save_path runs/mobilenetv2_ema + python metrice.py --task val --save_path runs/resnext50 + python metrice.py --task val --save_path runs/resnext50_ema + python metrice.py --task val --save_path runs/ghostnet + python metrice.py --task val --save_path runs/ghostnet_ema + python metrice.py --task val --save_path runs/efficientnet_v2_s + python metrice.py --task val --save_path runs/efficientnet_v2_s_ema + + python metrice.py --task test --save_path runs/mobilenetv2 + python metrice.py --task test --save_path runs/mobilenetv2_ema + python metrice.py --task test --save_path runs/resnext50 + python metrice.py --task test --save_path runs/resnext50_ema + python metrice.py --task test --save_path runs/ghostnet + python metrice.py --task test --save_path runs/ghostnet_ema + python metrice.py --task test --save_path runs/efficientnet_v2_s + python metrice.py --task test --save_path runs/efficientnet_v2_s_ema + +| model | val accuracy | val accuracy(ema) | test accuracy | test accuracy(ema) | +| :----: | :----: | :----: | +| mobilenetv2 | 0.74340 | 0.74958 | 0.73784 | 0.73870 | +| resnext50 | 0.80966 | 0.81246 | 0.82437 | 0.82307 | +| ghostnet | 0.77597 | 0.77765 | 0.76625 | 0.77142 | +| efficientnet_v2_s | 0.84166 | 0.83998 | 0.84460 | 0.83986 | \ No newline at end of file