Skip to content

Commit

Permalink
add tta
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Sep 30, 2019
1 parent adfebfb commit f7e2d2e
Show file tree
Hide file tree
Showing 7 changed files with 7,270 additions and 20 deletions.
36 changes: 23 additions & 13 deletions classify_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Get_Classify_Results():
def __init__(self, model_name, fold, model_path, class_num=4):
def __init__(self, model_name, fold, model_path, class_num=4, tta_flag=False):
''' 处理当前fold一个batch的数据分类结果
:param model_name: 当前的模型名称
Expand All @@ -20,7 +20,8 @@ def __init__(self, model_name, fold, model_path, class_num=4):
self.fold = fold
self.model_path = model_path
self.class_num = class_num

self.tta_flag = tta_flag

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型及其权重
self.classify_model = ClassifyResNet(model_name, encoder_weights=None)
Expand All @@ -41,13 +42,16 @@ def get_classify_results(self, images, thrshold=0.5):
:param thrshold: 分类模型的阈值
:return: predict_classes: 一个batch的数据经过分类模型后的结果,维度为[batch, class_num]
'''
predict_classes = self.solver.forward(images)
if self.tta_flag:
predict_classes = self.solver.tta(images, seg=False)
else:
predict_classes = self.solver.forward(images)
predict_classes = predict_classes > thrshold
return predict_classes


class Get_Segment_Results():
def __init__(self, model_name, fold, model_path, class_num=4):
def __init__(self, model_name, fold, model_path, class_num=4, tta_flag=False):
''' 处理当前fold一个batch的数据分割结果
:param model_name: 当前的模型名称
Expand All @@ -59,6 +63,7 @@ def __init__(self, model_name, fold, model_path, class_num=4):
self.fold = fold
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag

# 加载模型及其权重
self.segment_model = Model(self.model_name, encoder_weights=None).create_model()
Expand All @@ -77,7 +82,10 @@ def get_segment_results(self, images):
:param images: 一个batch的数据,维度为[batch, channels, height, width]
:return: predict_masks: 一个batch的数据经过分割网络后得到的预测结果,维度为[batch, class_num, height, width]
'''
predict_masks = self.solver.forward(images)
if self.tta_flag:
predict_masks = self.solver.tta(images)
else:
predict_masks = self.solver.forward(images)
for index, predict_masks_classes in enumerate(predict_masks):
for each_class, pred in enumerate(predict_masks_classes):
pred_binary, _ = self.post_process(pred.detach().cpu().numpy(), self.best_thresholds[each_class], self.best_minareas[each_class])
Expand Down Expand Up @@ -113,7 +121,7 @@ def get_thresholds_minareas(self, json_path, fold):


class Classify_Segment_Fold():
def __init__(self, model_name, fold, model_path, class_num=4):
def __init__(self, model_name, fold, model_path, class_num=4, tta_flag=False):
''' 处理当前fold一个batch的分割结果和分类结果
:param model_name: 当前的模型名称
Expand All @@ -126,8 +134,8 @@ def __init__(self, model_name, fold, model_path, class_num=4):
self.model_path = model_path
self.class_num = class_num

self.classify_model = Get_Classify_Results(self.model_name, self.fold, self.model_path, self.class_num)
self.segment_model = Get_Segment_Results(self.model_name, self.fold, self.model_path, self.class_num)
self.classify_model = Get_Classify_Results(self.model_name, self.fold, self.model_path, self.class_num, tta_flag=tta_flag)
self.segment_model = Get_Segment_Results(self.model_name, self.fold, self.model_path, self.class_num, tta_flag=tta_flag)

def classify_segment(self, images):
''' 处理当前fold一个batch的分割结果和分类结果
Expand All @@ -147,7 +155,7 @@ def classify_segment(self, images):


class Classify_Segment_Folds():
def __init__(self, model_name, n_splits, model_path, class_num=4):
def __init__(self, model_name, n_splits, model_path, class_num=4, tta_flag=False):
''' 使用投票法处理所有fold一个batch的分割结果和分类结果
:param model_name: 当前的模型名称
Expand All @@ -159,6 +167,7 @@ def __init__(self, model_name, n_splits, model_path, class_num=4):
self.n_splits = n_splits
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag

self.classify_models, self.segment_models = list(), list()
self.get_classify_segment_models()
Expand All @@ -168,8 +177,8 @@ def get_classify_segment_models(self):
'''

for fold in self.n_splits:
self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num))
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num))
self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))

def classify_segment_folds(self, images):
''' 使用投票法处理所有fold一个batch的分割结果和分类结果
Expand All @@ -196,7 +205,7 @@ def classify_segment_folds(self, images):


class Segment_Folds():
def __init__(self, model_name, n_splits, model_path, class_num=4):
def __init__(self, model_name, n_splits, model_path, class_num=4, tta_flag=False):
''' 使用投票法处理所有fold一个batch的分割结果
:param model_name: 当前的模型名称
Expand All @@ -208,6 +217,7 @@ def __init__(self, model_name, n_splits, model_path, class_num=4):
self.n_splits = n_splits
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag

self.segment_models = list()
self.get_segment_models()
Expand All @@ -217,7 +227,7 @@ def get_segment_models(self):
'''

for fold in self.n_splits:
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num))
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))

def segment_folds(self, images):
''' 使用投票法处理所有fold一个batch的分割结果
Expand Down
12 changes: 7 additions & 5 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mask2rle(img):
return ' '.join(str(x) for x in runs)


def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path):
def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False):
'''
:param n_splits: 折数,类型为list
Expand All @@ -67,6 +67,7 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
:param test_data_folder: 测试数据存放的路径
:param sample_submission_path: 提交样例csv存放的路径
:param model_path: 当前模型权重存放的目录
:param tta_flag: 是否使用tta
:return: None
'''
# 加载数据集
Expand All @@ -79,9 +80,9 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
pin_memory=True
)
if len(n_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, n_splits[0], model_path).classify_segment
classify_segment = Classify_Segment_Fold(model_name, n_splits[0], model_path, tta_flag=tta_flag).classify_segment
else:
classify_segment = Classify_Segment_Folds(model_name, n_splits, model_path).classify_segment_folds
classify_segment = Classify_Segment_Folds(model_name, n_splits, model_path, tta_flag=tta_flag).classify_segment_folds

# start prediction
predictions = []
Expand All @@ -103,10 +104,11 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
# 设置超参数
model_name = 'unet_resnet34'
num_workers = 12
batch_size = 8
batch_size = 6
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
n_splits = [1] # [0, 1, 2, 3, 4]
tta_flag = False

if kaggle:
sample_submission_path = '/kaggle/input/severstal-steel-defect-detection/sample_submission.csv'
Expand All @@ -118,4 +120,4 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
model_path = './checkpoints/' + model_name

create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
sample_submission_path, model_path)
sample_submission_path, model_path, tta_flag=tta_flag)
3 changes: 3 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
python train_classify.py
python train_segment.py
30 changes: 30 additions & 0 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ def forward(self, images):
outputs = self.model(images)
return outputs

def tta(self, images, seg=True):
"""测试时数据增强
Args:
images: [batch_size, channel, height, width]
seg: 分类还是分割,默认为分割
Return:
"""
images = images.to(self.device)
# 原图
pred_origin = self.model(images)
preds = torch.zeros_like(pred_origin)
# 水平翻转
images_hflp = torch.flip(images, dims=[3])
pred_hflip = self.model(images_hflp)
# 垂直翻转
images_vflip = torch.flip(images, dims=[2])
pred_vflip = self.model(images_vflip)

if seg:
# 分割需要将预测结果翻转回去
pred_hflip = torch.flip(pred_hflip, dims=[3])
pred_vflip = torch.flip(pred_vflip, dims=[2])
preds = preds + pred_origin + pred_hflip + pred_vflip
# 求平均
pred = preds / 3.0

return pred

def cal_loss(self, targets, predicts, criterion):
''' 根据真实类标和预测出的类标计算损失
Expand Down
Loading

0 comments on commit f7e2d2e

Please sign in to comment.