Skip to content

Commit

Permalink
add classify unet-serenext50
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 2, 2019
1 parent f7e2d2e commit 2213fb0
Show file tree
Hide file tree
Showing 9 changed files with 724 additions and 673 deletions.
6 changes: 3 additions & 3 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datasets.steel_dataset import provider
from utils.set_seed import seed_torch
from utils.cal_dice_iou import compute_dice_class
from config import get_config
from config import get_seg_config


class ChooseThresholdMinArea():
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_model(model_name, load_path):


if __name__ == "__main__":
config = get_config()
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)

Expand All @@ -207,7 +207,7 @@ def get_model(model_name, load_path):
# 存放权重的路径
model_path = os.path.join(config.save_path, config.model_name)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 0:
if fold_index != 1:
continue

# 存放权重的路径+文件名
Expand Down
52 changes: 48 additions & 4 deletions config.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from argparse import Namespace


def get_config():
def get_seg_config():
use_paras = False
if use_paras:
with open('./checkpoints/unet_resnet34/' + "params.json", 'r', encoding='utf-8') as json_file:
Expand All @@ -15,10 +15,12 @@ def get_config():
'''
unet_resnet34时各个电脑可以设置的最大batch size
zdaiot:12 z840:16 mxq:24
unet_se_renext50
hwp: 6
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=12, help='batch size')
parser.add_argument('--epoch', type=int, default=40, help='epoch')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=60, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
Expand All @@ -43,5 +45,47 @@ def get_config():
return config


def get_classify_config():
use_paras = False
if use_paras:
with open('./checkpoints/unet_resnet34/' + "params.json", 'r', encoding='utf-8') as json_file:
config = json.load(json_file)
# dict to namespace
config = Namespace(**config)
else:
parser = argparse.ArgumentParser()
'''
unet_resnet34时各个电脑可以设置的最大batch size
zdaiot:12 z840:16 mxq:48
unet_se_renext50
hwp: 8
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--epoch', type=int, default=30, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=5e-4, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

# dataset
parser.add_argument('--save_path', type=str, default='./checkpoints')
parser.add_argument('--dataset_root', type=str, default='./datasets/Steel_data')

config = parser.parse_args()
# config = {k: v for k, v in args._get_kwargs()}

return config


if __name__ == '__main__':
config = get_config()
config = get_seg_config()
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm
import cv2
import os
from config import get_config
from config import get_seg_config
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Segment_Folds, Get_Segment_Results
from datasets.steel_dataset import TestDataset, provider

Expand Down Expand Up @@ -94,7 +94,7 @@ def pred_show(images, preds, mean, std, targets=None, flag=False, auto_flag=Fals


if __name__ == "__main__":
config = get_config()
config = get_seg_config()
# 设置超参数
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
Expand All @@ -108,7 +108,7 @@ def pred_show(images, preds, mean, std, targets=None, flag=False, auto_flag=Fals
# 是否只使用分割模型
use_segment_only = True
# 是否使用自动显示
auto_flag = True
auto_flag = False

# 测试数据集的dataloader
sample_submission_path = 'datasets/Steel_data/sample_submission.csv'
Expand Down
11 changes: 9 additions & 2 deletions models/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ def __init__(self, model_name, class_num=4, training=True, encoder_weights='imag
model = Model(model_name, encoder_weights=encoder_weights, class_num=class_num).create_model_cpu()
# 注意模型里面必须包含 encoder 模块
self.encoder = model.encoder
self.feature = nn.Conv2d(512, 32, kernel_size=1)
if model_name == 'unet_resnet34':
self.feature = nn.Conv2d(512, 32, kernel_size=1)
elif model_name == 'unet_se_resnext50_32x4d':
self.feature = nn.Sequential(
nn.Conv2d(2048, 512, kernel_size=1),
nn.ReLU(),
nn.Conv2d(512, 32, kernel_size=1)
)
self.logit = nn.Conv2d(32, self.class_num, kernel_size=1)

self.training = training
Expand All @@ -89,7 +96,7 @@ def forward(self, x):

if __name__ == "__main__":
# test segment 模型
model_name = 'unet_resnet34'
model_name = 'unet_se_resnext50_32x4d'
model = Model(model_name).create_model()
print(model)

Expand Down
1,298 changes: 649 additions & 649 deletions submission_tta.csv

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def tta_pred(self, images):
model = ClassifyResNet('unet_resnet34', 4, training=False)
model = torch.nn.DataParallel(model)
model = model.cuda()
pth_path = "checkpoints/unet_resnet34/unet_resnet34_classify_fold2.pth"
pth_path = "checkpoints/unet_resnet34/unet_resnet34_classify_fold1.pth"
checkpoint = torch.load(pth_path)
model.module.load_state_dict(checkpoint['state_dict'])

Expand Down Expand Up @@ -136,6 +136,6 @@ def tta_pred(self, images):
font = cv2.FONT_HERSHEY_SIMPLEX
image = cv2.putText(image, str(i), position, font, 1.2, color, 2)
cv2.imshow('win', image)
cv2.waitKey(30)
cv2.waitKey(240)
print("Accuracy: %.4f" % (num_true / number_sample))
pass
4 changes: 2 additions & 2 deletions train_classify.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import optim
import torch
import tqdm
from config import get_config
from config import get_classify_config
from solver import Solver
from torch.utils.tensorboard import SummaryWriter
import datetime
Expand Down Expand Up @@ -151,7 +151,7 @@ def validation(self, valid_loader):


if __name__ == "__main__":
config = get_config()
config = get_classify_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
dataloaders = classify_provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
Expand Down
4 changes: 2 additions & 2 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from utils.cal_dice_iou import Meter
from datasets.steel_dataset import provider
from utils.set_seed import seed_torch
from config import get_config
from config import get_seg_config
from solver import Solver

class TrainVal():
Expand Down Expand Up @@ -144,7 +144,7 @@ def validation(self, valid_loader):


if __name__ == "__main__":
config = get_config()
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
Expand Down
12 changes: 6 additions & 6 deletions utils/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)

sys.path.append('.')
from .visualize import image_with_mask_torch, image_with_mask_numpy
from .rle_parse import make_mask
from utils.visualize import image_with_mask_torch, image_with_mask_numpy
from utils.rle_parse import make_mask


def visualize(image, mask, original_image=None, original_mask=None):
Expand Down Expand Up @@ -57,9 +57,9 @@ def data_augmentation(original_image, original_mask):
augmentations = Compose([
HorizontalFlip(p=0.4),
VerticalFlip(p=0.4),
ShiftScaleRotate(shift_limit=0, rotate_limit=8, p=0.4),
ShiftScaleRotate(shift_limit=0.07, rotate_limit=0, p=0.4),
# 直方图均衡化
CLAHE(p=0.4),
CLAHE(p=0.3),

# 亮度、对比度
RandomGamma(gamma_limit=(80, 120), p=0.1),
Expand All @@ -86,8 +86,8 @@ def data_augmentation(original_image, original_mask):


if __name__ == "__main__":
data_folder = "../datasets/Steel_data"
df_path = "../datasets/Steel_data/train.csv"
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
df = pd.read_csv(df_path)
Expand Down

0 comments on commit 2213fb0

Please sign in to comment.