Skip to content

Commit

Permalink
different fold between classify and segment
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 8, 2019
1 parent a34e54c commit 2486206
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 17 deletions.
5 changes: 3 additions & 2 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def get_model(model_name, load_path):
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
mask_only = True

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits, mask_only)
results = {}
# 存放权重的路径
model_path = os.path.join(config.save_path, config.model_name)
Expand Down
59 changes: 59 additions & 0 deletions classify_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,65 @@ def classify_segment_folds(self, images):
return results


class Classify_Segment_Folds_Split():
def __init__(self, model_name, classify_folds, segment_folds, model_path, class_num=4, tta_flag=False):
''' 首先的到分类模型的集成结果,再得到分割模型的集成结果,最后将两个结果进行融合
:param model_name: 当前的模型名称
:param classify_folds: 参与集成的分类模型的折序号,为list列表
:param segment_folds: 参与集成的分割模型的折序号,为list列表
:param model_path: 存放所有模型的路径
:param class_num: 类别总数
'''
self.model_name = model_name
self.classify_folds = classify_folds
self.segment_folds = segment_folds
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag

self.classify_models, self.segment_models = list(), list()
self.get_classify_segment_models()

def get_classify_segment_models(self):
''' 加载所有折的分割模型和分类模型
'''
for fold in self.classify_folds:
self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))
for fold in self.segment_folds:
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))

def classify_segment_folds(self, images):
''' 使用投票法处理所有fold一个batch的分割结果和分类结果
:param images: 一个batch的数据,维度为[batch, channels, height, width]
:return: results,使用投票法处理所有fold一个batch的分割结果和分类结果,维度为[batch, class_num, height, width]
'''
classify_results = torch.zeros(images.shape[0], self.class_num)
segment_results = torch.zeros(images.shape[0], self.class_num, images.shape[2], images.shape[3])
# 得到分类结果
for classify_index, classify_model in enumerate(self.classify_models):
classify_result_fold = classify_model.get_classify_results(images)
classify_results += classify_result_fold.detach().cpu().squeeze().float()
classify_vote_model_num = len(self.classify_folds)
classify_vote_ticket = round(classify_vote_model_num / 2.0)
classify_results = classify_results > classify_vote_ticket

# 得到分割结果
for segment_index, segment_model in enumerate(self.segment_models):
segment_result_fold = segment_model.get_segment_results(images)
segment_results += segment_result_fold.detach().cpu()
segment_vote_model_num = len(self.segment_folds)
segment_vote_ticket = round(segment_vote_model_num / 2.0)
segment_results = segment_results > segment_vote_ticket

# 将分类结果和分割结果进行融合
for batch_index, classify_result in enumerate(classify_results):
segment_results[batch_index, 1-classify_result, ...] = 0

return segment_results


class Segment_Folds():
def __init__(self, model_name, n_splits, model_path, class_num=4, tta_flag=False):
''' 使用投票法处理所有fold一个batch的分割结果
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_seg_config():
hwp: 6
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
Expand Down
26 changes: 15 additions & 11 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package_path = '/kaggle/input/sources' # add unet script dataset
import sys
sys.path.append(package_path)
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Classify_Segment_Folds_Split


class TestDataset(Dataset):
Expand Down Expand Up @@ -55,10 +55,11 @@ def mask2rle(img):
return ' '.join(str(x) for x in runs)


def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False):
def create_submission(classify_splits, seg_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False):
'''
:param n_splits: 折数,类型为list
:param classify_splits: 分类模型的折数,类型为list
:param seg_splits: 分割模型的折数,类型为list
:param model_name: 当前模型的名称
:param batch_size: batch的大小
:param num_workers: 加载数据的线程
Expand All @@ -79,10 +80,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
num_workers=num_workers,
pin_memory=True
)
if len(n_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, n_splits[0], model_path, tta_flag=tta_flag).classify_segment
else:
classify_segment = Classify_Segment_Folds(model_name, n_splits, model_path, tta_flag=tta_flag).classify_segment_folds
if len(classify_splits) == 1 and len(seg_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment
elif len(classify_splits) == len(seg_splits):
classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds
elif len(classify_splits) != len(seg_splits):
classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds

# start prediction
predictions = []
Expand All @@ -104,11 +107,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
# 设置超参数
model_name = 'unet_resnet34'
num_workers = 12
batch_size = 6
batch_size = 4
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
n_splits = [1] # [0, 1, 2, 3, 4]
tta_flag = False
classify_splits = [1] # [0, 1, 2, 3, 4]
segment_splits = [0, 1, 2, 3, 4]
tta_flag = True

if kaggle:
sample_submission_path = '/kaggle/input/severstal-steel-defect-detection/sample_submission.csv'
Expand All @@ -119,5 +123,5 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
test_data_folder = 'datasets/Steel_data/test_images'
model_path = './checkpoints/' + model_name

create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
create_submission(classify_splits, segment_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
sample_submission_path, model_path, tta_flag=tta_flag)
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def pred_show(images, preds, mean, std, targets=None, flag=False, auto_flag=Fals
# 是否显示真实的mask
show_truemask_flag = True
# 加载哪几折的模型进行测试,若list中有多个值,则使用投票法
n_splits = [1] # [0, 1, 2, 3, 4]
n_splits = [0, 1, 2, 3, 4] # [0, 1, 2, 3, 4]
# 是否只使用分割模型
use_segment_only = True
# 是否使用自动显示
Expand Down
5 changes: 3 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash
python train_classify.py
python train_segment.py
# python train_classify.py
python train_segment.py
python choose_thre_area.py

0 comments on commit 2486206

Please sign in to comment.