Skip to content

Commit

Permalink
different model different fold ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 21, 2019
1 parent 7627b66 commit ec975d1
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 54 deletions.
4 changes: 2 additions & 2 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import codecs
import json
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from solver import Solver
from models.model import Model
from datasets.steel_dataset import provider
Expand Down Expand Up @@ -213,9 +214,8 @@ def get_model(model_name, load_path):
best_thresholds_sum, best_minareas_sum, max_dices_sum = [0 for x in range(len(dataloaders))], \
[0 for x in range(len(dataloaders))], [0 for x in range(len(dataloaders))]
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
if fold_index != 2:
continue

# 存放权重的路径+文件名
load_path = os.path.join(model_path, '%s_fold%d_best.pth' % (config.model_name, fold_index))
# 加载模型
Expand Down
51 changes: 34 additions & 17 deletions classify_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,31 @@ def get_segment_results(self, images, process_flag=True):


class Classify_Segment_Fold():
def __init__(self, model_name, fold, model_path, class_num=4, tta_flag=False):
def __init__(self, classify_fold, seg_fold, model_path, class_num=4, tta_flag=False, kaggle=0):
''' 处理当前fold一个batch的分割结果和分类结果
:param model_name: 当前的模型名称
:param fold: 当前的折数
:param classify_fold: 字典,分类模型 {'model_name': fold_index}
:param seg_fold: 字典,分割模型 {'model_name': fold_index}
:param model_path: 存放所有模型的路径
:param class_num: 类别总数
'''
self.model_name = model_name
self.fold = fold
self.classify_fold = classify_fold
self.seg_fold = seg_fold
self.model_path = model_path
self.class_num = class_num

self.classify_model = Get_Classify_Results(self.model_name, self.fold, self.model_path, self.class_num, tta_flag=tta_flag)
self.segment_model = Get_Segment_Results(self.model_name, self.fold, self.model_path, self.class_num, tta_flag=tta_flag)
for (model_name, fold) in self.classify_fold.items():
if kaggle == 0:
pth_path = self.model_path
else:
pth_path = os.path.join(self.model_path, model_name)
self.classify_model = Get_Classify_Results(model_name, fold, pth_path, self.class_num, tta_flag=tta_flag)
for (model_name, fold) in self.classify_fold.items():
if kaggle == 0:
pth_path = self.model_path
else:
pth_path = os.path.join(self.model_path, model_name)
self.segment_model = Get_Segment_Results(model_name, fold, pth_path, self.class_num, tta_flag=tta_flag)

def classify_segment(self, images):
''' 处理当前fold一个batch的分割结果和分类结果
Expand Down Expand Up @@ -218,32 +228,39 @@ def classify_segment_folds(self, images):


class Classify_Segment_Folds_Split():
def __init__(self, model_name, classify_folds, segment_folds, model_path, class_num=4, tta_flag=False):
def __init__(self, classify_folds, segment_folds, model_path, class_num=4, tta_flag=False, kaggle=0):
''' 首先得到分类模型的集成结果,再得到分割模型的集成结果,最后将两个结果进行融合
:param model_name: 当前的模型名称
:param classify_folds: 参与集成的分类模型的折序号,为list列表
:param segment_folds: 参与集成的分割模型的折序号,为list列表
:param model_path: 存放所有模型的路径
:param classify_folds: 字典,{'model_name': fold_index}
:param segment_folds: 字典,{'model_name': fold_index}
:param model_path: 存放所有模型的路径, checkpoints/
:param class_num: 类别总数
'''
self.model_name = model_name
self.classify_folds = classify_folds
self.segment_folds = segment_folds
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag
self.kaggle = kaggle

self.classify_models, self.segment_models = list(), list()
self.get_classify_segment_models()

def get_classify_segment_models(self):
''' 加载所有折的分割模型和分类模型
'''
for fold in self.classify_folds:
self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))
for fold in self.segment_folds:
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))
for (model_name, fold) in self.classify_folds.items():
if self.kaggle == 0:
pth_path = self.model_path
else:
pth_path = os.path.join(self.model_path, model_name)
self.classify_models.append(Get_Classify_Results(model_name, fold, pth_path, self.class_num, tta_flag=self.tta_flag))
for (model_name, fold) in self.segment_folds.items():
if self.kaggle == 0:
pth_path = self.model_path
else:
pth_path = os.path.join(self.model_path, model_name)
self.segment_models.append(Get_Segment_Results(model_name, fold, pth_path, self.class_num, tta_flag=self.tta_flag))

def classify_segment_folds(self, images, average_strategy=False):
''' 使用投票法或者平均法处理所有fold一个batch的分割结果和分类结果
Expand Down
20 changes: 12 additions & 8 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ def get_seg_config():
zdaiot:12 z840:16 mxq:24
unet_se_renext50
hwp: 6 MXQ: 12
unet_resnet50
MXQ: 16
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--epoch', type=int, default=50, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
Expand All @@ -30,8 +32,8 @@ def get_seg_config():
parser.add_argument('--width', type=int, default=None, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50')
parser.add_argument('--model_name', type=str, default='unet_resnet50', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50/unet_efficientnet_b3')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
Expand Down Expand Up @@ -64,9 +66,11 @@ def get_classify_config():
zdaiot:12 z840:16 mxq:48
unet_se_renext50
hwp: 8
unet_resnet50:
MXQ: 24
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=30, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
Expand All @@ -76,13 +80,13 @@ def get_classify_config():
parser.add_argument('--width', type=int, default=None, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50')
parser.add_argument('--model_name', type=str, default='unet_resnet50', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50/unet_efficientnet_b4')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=5e-4, help='init lr')
parser.add_argument('--lr', type=float, default=5e-5, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

# dataset
Expand Down
34 changes: 18 additions & 16 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
kaggle = 0
if kaggle:
os.system('pip install /kaggle/input/segmentation_models/pretrainedmodels-0.7.4/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/EfficientNet-PyTorch/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/segmentation_models.pytorch/ > /dev/null')
package_path = '/kaggle/input/sources' # add unet script dataset
import sys
Expand Down Expand Up @@ -55,12 +56,11 @@ def mask2rle(img):
return ' '.join(str(x) for x in runs)


def create_submission(classify_splits, seg_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False, average_strategy=False):
def create_submission(classify_splits, seg_splits, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False, average_strategy=False, kaggle=0):
'''
:param classify_splits: 分类模型的折数,类型为list
:param seg_splits: 分割模型的折数,类型为list
:param model_name: 当前模型的名称
:param classify_splits: 分类模型的折数,类型为字典
:param seg_splits: 分割模型的折数,类型为字典
:param batch_size: batch的大小
:param num_workers: 加载数据的线程
:param mean: 均值
Expand All @@ -70,6 +70,7 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w
:param model_path: 当前模型权重存放的目录
:param tta_flag: 是否使用tta
:param average_strategy: 是否使用平均策略
:param kaggle: 线上或线下
:return: None
'''
# 加载数据集
Expand All @@ -82,16 +83,17 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w
pin_memory=True
)
if len(classify_splits) == 1 and len(seg_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment
elif len(classify_splits) == len(seg_splits):
classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds
elif len(classify_splits) != len(seg_splits):
classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds
classify_segment = Classify_Segment_Fold(classify_splits, seg_splits, model_path, tta_flag=tta_flag, kaggle=kaggle).classify_segment
else:
classify_segment = Classify_Segment_Folds_Split(classify_splits, seg_splits, model_path, tta_flag=tta_flag, kaggle=kaggle).classify_segment_folds

# start prediction
predictions = []
for i, (fnames, images) in enumerate(tqdm(test_loader)):
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()
if len(classify_splits) == 1 and len(seg_splits) == 1:
results = classify_segment(images).detach().cpu().numpy()
else:
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()

for fname, preds in zip(fnames, results):
for cls, pred in enumerate(preds):
Expand All @@ -106,13 +108,13 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w

if __name__ == "__main__":
# 设置超参数
model_name = 'unet_resnet34'
model_name = 'unet_efficientnet_b4'
num_workers = 12
batch_size = 4
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
classify_splits = [1] # [0, 1, 2, 3, 4]
segment_splits = [0, 1, 2, 3, 4]
classify_splits = {'unet_resnet34': 1} # [0, 1, 2, 3, 4]
segment_splits = {'unet_resnet34': 1}
tta_flag = True
average_strategy = False

Expand All @@ -123,7 +125,7 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w
else:
sample_submission_path = 'datasets/Steel_data/sample_submission.csv'
test_data_folder = 'datasets/Steel_data/test_images'
model_path = './checkpoints/' + model_name
model_path = './checkpoints/'

create_submission(classify_splits, segment_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
sample_submission_path, model_path, tta_flag=tta_flag, average_strategy=average_strategy)
create_submission(classify_splits, segment_splits, batch_size, num_workers, mean, std, test_data_folder,
sample_submission_path, model_path, tta_flag=tta_flag, average_strategy=average_strategy, kaggle=kaggle)
Binary file added models/EfficientNet-PyTorch.tar.gz
Binary file not shown.
15 changes: 9 additions & 6 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def create_model_cpu(self):
# Unet Efficient 系列
elif self.model_name == 'unet_efficientnet_b4':
model = smp.Unet('efficientnet-b4', encoder_weights=self.encoder_weights, classes=self.class_num, activation=None)
elif self.model_name == 'unet_efficientnet_b3':
model = smp.Unet('efficientnet-b3', encoder_weights=self.encoder_weights, classes=self.class_num, activation=None)

return model

Expand Down Expand Up @@ -95,6 +97,8 @@ def __init__(self, model_name, class_num=4, training=True, encoder_weights='imag
nn.ReLU(),
nn.Conv2d(160, 32, kernel_size=1)
)
elif model_name == 'unet_efficientnet_b3':
self.feature = nn.Conv2d(384, 32, kernel_size=1)

self.logit = nn.Conv2d(32, self.class_num, kernel_size=1)

Expand All @@ -111,15 +115,14 @@ def forward(self, x):


if __name__ == "__main__":
x = torch.Tensor(1, 3, 256, 1600)
y = torch.ones(1, 4)
# test segment 模型
model_name = 'unet_se_resnext50_32x4d'
model = Model(model_name).create_model()
print(model)
model_name = 'unet_efficientnet_b3'
model = Model(model_name, encoder_weights=None).create_model()

# test classify 模型
class_net = ClassifyResNet(model_name, 4)
x = torch.Tensor(8, 3, 256, 1600)
y = torch.ones(8, 4)
class_net = ClassifyResNet(model_name, 4, encoder_weights=None)
seg_output = model(x)
print(seg_output.size())
output = class_net(x)
Expand Down
Binary file modified models/segmentation_models.pytorch.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
# python train_classify.py
python train_segment.py
python choose_thre_area.py
python choose_thre_area.py
4 changes: 2 additions & 2 deletions train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, config, fold):

self.max_accuracy_valid = 0
# 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
self.seed = int(time.time())
self.seed = 1570421136
seed_torch(self.seed)
with open(self.model_path + '/'+ TIMESTAMP + '.pkl','wb') as f:
pickle.dump({'seed': self.seed}, f, -1)
Expand Down Expand Up @@ -169,7 +169,7 @@ def validation(self, valid_loader):
width=config.width
)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
if fold_index != 2:
continue
train_val = TrainVal(config, fold_index)
train_val.train(train_loader, valid_loader)
Expand Down
4 changes: 2 additions & 2 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, config, fold):
self.max_dice_valid = 0

# 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
self.seed = int(time.time())
self.seed = 1570421136
seed_torch(self.seed)
with open(self.model_path + '/'+ TIMESTAMP + '.pkl','wb') as f:
pickle.dump({'seed': self.seed}, f, -1)
Expand Down Expand Up @@ -186,7 +186,7 @@ def load_weight(self, weight_path):
width=config.width
)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
if fold_index != 2:
continue
train_val = TrainVal(config, fold_index)
train_val.train(train_loader, valid_loader)

0 comments on commit ec975d1

Please sign in to comment.