Skip to content

Commit

Permalink
back to the milstone
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 15, 2019
1 parent 181c6da commit 010c971
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 214 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ When run `SUBMISSION=/path/to/csv/file.csv make release-csv`, If you encounter t
- [x] finish choose_threshold
- [x] finish data augmentation
- [ ] EfficientB4( w/ ASPP)
- [x] ResNet50
- [x] code review(validation dice, threshold dice)
- [ ] code review(validation dice, threshold dice)
- [ ] choose fold
- [ ] ensemble
- [x] early stopping automaticly
- [ ] GN
- [ ] early stopping automaticly
84 changes: 44 additions & 40 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
class ChooseThresholdMinArea():
''' 选择每一类的像素阈值和最小连通域
'''

def __init__(self, model, model_name, valid_loader, fold, save_path, class_num=4):
''' 模型初始化
Args:
model: 使用的模型
model_name: 当前模型的名称
Expand All @@ -42,7 +42,7 @@ def __init__(self, model, model_name, valid_loader, fold, save_path, class_num=4

def choose_threshold_minarea(self):
''' 采用网格法搜索各个类别最优像素阈值和最优最小连通域,并画出各个类别搜索过程中的热力图
Return:
best_thresholds_little: 每一个类别的最优阈值
best_minareas_little: 每一个类别的最优最小连通取余
Expand All @@ -51,36 +51,28 @@ def choose_threshold_minarea(self):
init_thresholds_range, init_minarea_range = np.arange(0.50, 0.71, 0.03), np.arange(768, 2305, 256)

# 阈值列表和最小连通域列表,大小为 Nx4
thresholds_table_big = np.array([init_thresholds_range, init_thresholds_range,
thresholds_table_big = np.array([init_thresholds_range, init_thresholds_range, \
init_thresholds_range, init_thresholds_range]) # 阈值列表
minareas_table_big = np.array([init_minarea_range, init_minarea_range,
minareas_table_big = np.array([init_minarea_range, init_minarea_range, \
init_minarea_range, init_minarea_range]) # 最小连通域列表

f, axes = plt.subplots(figsize=(28.8, 18.4), nrows=2, ncols=self.class_num)
cmap = sns.cubehelix_palette(start=1.5, rot=3, gamma=0.8, as_cmap=True)

best_thresholds_big, best_minareas_big, max_dices_big = self.grid_search(thresholds_table_big,
minareas_table_big, axes[0, :], cmap)
print('best_thresholds_big:{}, best_minareas_big:{}, max_dices_big:{}'.format(best_thresholds_big,
best_minareas_big, max_dices_big))
best_thresholds_big, best_minareas_big, max_dices_big = self.grid_search(thresholds_table_big, minareas_table_big, axes[0,:], cmap)
print('best_thresholds_big:{}, best_minareas_big:{}, max_dices_big:{}'.format(best_thresholds_big, best_minareas_big, max_dices_big))

# 开始细分类
thresholds_table_little, minareas_table_little = list(), list()
for best_threshold_big, best_minarea_big in zip(best_thresholds_big, best_minareas_big):
thresholds_table_little.append(
np.arange(best_threshold_big - 0.03, best_threshold_big + 0.03, 0.015)) # 阈值列表
minareas_table_little.append(np.arange(best_minarea_big - 256, best_minarea_big + 257, 128)) # 像素阈值列表
thresholds_table_little, minareas_table_little = np.array(thresholds_table_little), np.array(
minareas_table_little)

best_thresholds_little, best_minareas_little, max_dices_little = self.grid_search(thresholds_table_little,
minareas_table_little,
axes[1, :], cmap)
print('best_thresholds_little:{}, best_minareas_little:{}, max_dices_little:{}'.format(best_thresholds_little,
best_minareas_little,
max_dices_little))

f.savefig(os.path.join(self.save_path, self.model_name + '_fold' + str(self.fold)), bbox_inches='tight')
thresholds_table_little.append(np.arange(best_threshold_big-0.03, best_threshold_big+0.03, 0.015)) # 阈值列表
minareas_table_little.append(np.arange(best_minarea_big-256, best_minarea_big+257, 128)) # 像素阈值列表
thresholds_table_little, minareas_table_little = np.array(thresholds_table_little), np.array(minareas_table_little)

best_thresholds_little, best_minareas_little, max_dices_little = self.grid_search(thresholds_table_little, minareas_table_little, axes[1,:], cmap)
print('best_thresholds_little:{}, best_minareas_little:{}, max_dices_little:{}'.format(best_thresholds_little, best_minareas_little, max_dices_little))

f.savefig(os.path.join(self.save_path, self.model_name + '_fold'+str(self.fold)), bbox_inches='tight')
# plt.show()
plt.close()

Expand All @@ -89,11 +81,13 @@ def choose_threshold_minarea(self):
def grid_search(self, thresholds_table, minareas_table, axes, cmap):
''' 给定包含各个类别搜索区间的thresholds_table和minareas_table,求的各个类别的最优像素阈值,最优最小连通域,最高dice;
并画出各个类别搜索过程中的热力图
Args:
thresholds_table: 待搜索的阈值范围,维度为[4, N],numpy类型
minareas_table: 待搜索的最小连通域范围,维度为[4, N],numpy类型
axes: 画各个类别搜索热力图时所需要的画柄,尺寸为[class_num]
cmap: 画图时所需要的cmap
return:
best_thresholds: 各个类别的最优像素阈值,尺寸为[class_num]
best_minareas: 各个类别的最优最小连通域,尺寸为[class_num]
Expand All @@ -102,12 +96,15 @@ def grid_search(self, thresholds_table, minareas_table, axes, cmap):
dices_table = np.zeros((self.class_num, np.shape(thresholds_table)[1], np.shape(minareas_table)[1]))
tbar = tqdm.tqdm(self.valid_loader)
with torch.no_grad():
for i, (images, masks) in enumerate(tbar):
for i, samples in enumerate(tbar):
if len(samples) == 0:
continue
images, masks = samples[0], samples[1]
# 完成网络的前向传播
masks_predict_allclasses = self.solver.forward(images)
dices_table += self.grid_search_batch(thresholds_table, minareas_table, masks_predict_allclasses, masks)

dices_table = dices_table / len(tbar)
dices_table = dices_table/len(tbar)
best_thresholds, best_minareas, max_dices = list(), list(), list()
# 处理每一类的预测结果
for each_class, dices_oneclass_table in enumerate(dices_table):
Expand All @@ -118,21 +115,21 @@ def grid_search(self, thresholds_table, minareas_table, axes, cmap):
best_minareas.append(minareas_table[each_class, max_location[1]])
max_dices.append(max_dice)

data = pd.DataFrame(data=dices_oneclass_table, index=np.around(thresholds_table[each_class, :], 3),
columns=minareas_table[each_class, :])
sns.heatmap(data, linewidths=0.05, ax=axes[each_class], vmax=np.max(dices_oneclass_table),
vmin=np.min(dices_oneclass_table), cmap=cmap,
data = pd.DataFrame(data=dices_oneclass_table, index=np.around(thresholds_table[each_class,:], 3), columns=minareas_table[each_class,:])
sns.heatmap(data, linewidths=0.05, ax=axes[each_class], vmax=np.max(dices_oneclass_table), vmin=np.min(dices_oneclass_table), cmap=cmap,
annot=True, fmt='.4f')
axes[each_class].set_title('search result')
return best_thresholds, best_minareas, max_dices

def grid_search_batch(self, thresholds_table, minareas_table, masks_predict_allclasses, masks_allclasses):
'''给定thresholds、minareas矩阵、一个batch的预测结果和真实标签,遍历每个类的每一个组合得到对应的dice值
Args:
thresholds_table: 待搜索的阈值范围,维度为[4, N],numpy类型
minareas_table: 待搜索的最小连通域范围,维度为[4, N],numpy类型
masks_predict_allclasses: 所有类别的某个batch的预测结果且未经过sigmoid,维度为[batch_size, class_num, height, width]
masks_allclasses: 所有类别的某个batch的真实类标,维度为[batch_size, class_num, height, width]
Return:
dices_table: 各个类别在其各自的所有搜索组合中所得到的dice值,维度为[4, M, N]
'''
Expand All @@ -151,13 +148,13 @@ def grid_search_batch(self, thresholds_table, minareas_table, masks_predict_allc

def post_process(self, thresholds_range, minareas_range, masks_predict_oneclass, masks_oneclasses):
'''给定某个类别的某个batch的数据,遍历所有搜索组合,得到每个组合的dice值
Args:
thresholds_range: 具体某个类别的像素阈值搜索区间,尺寸为[M]
minareas_range: 具体某个类别的最小连通域搜索区间,尺寸为[N]
masks_predict_oneclass: 预测出的某个类别的该batch的tensor向量且未经过sigmoid,维度为[batch_size, height, width]
masks_oneclasses: 某个类别的该batch的真实类标,维度为[batch_size, height, width]
Return:
dices_range: 某个类别的该batch的所有搜索组合得到dice矩阵,维度为[M, N]
'''
Expand Down Expand Up @@ -190,10 +187,11 @@ def post_process(self, thresholds_range, minareas_range, masks_predict_oneclass,

def get_model(model_name, load_path):
''' 加载网络模型并加载对应的权重
Args:
Args:
model_name: 当前模型的名称
load_path: 当前模型的权重路径
Return:
model: 加载出来的模型
'''
Expand All @@ -204,15 +202,16 @@ def get_model(model_name, load_path):

if __name__ == "__main__":
config = get_seg_config()
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
mask_only = True

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std,
config.batch_size, config.num_workers, config.n_splits, mask_only)
dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits, mask_only)
results = {}
# 存放权重的路径
model_path = os.path.join(config.save_path, config.model_name)
best_thresholds_sum, best_minareas_sum, max_dices_sum = [0 for x in range(len(dataloaders))], \
[0 for x in range(len(dataloaders))], [0 for x in range(len(dataloaders))]
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
continue
Expand All @@ -221,13 +220,18 @@ def get_model(model_name, load_path):
load_path = os.path.join(model_path, '%s_fold%d_best.pth' % (config.model_name, fold_index))
# 加载模型
model = get_model(config.model_name, load_path)
mychoose_threshold_minarea = ChooseThresholdMinArea(model, config.model_name, valid_loader, fold_index,
model_path)
mychoose_threshold_minarea = ChooseThresholdMinArea(model, config.model_name, valid_loader, fold_index, model_path)
best_thresholds, best_minareas, max_dices = mychoose_threshold_minarea.choose_threshold_minarea()
result = {'best_thresholds': best_thresholds, 'best_minareas': best_minareas, 'max_dices': max_dices}

results[str(fold_index)] = result

best_thresholds_sum = [x+y for x,y in zip(best_thresholds_sum, best_thresholds)]
best_minareas_sum = [x+y for x,y in zip(best_minareas_sum, best_minareas)]
max_dices_sum = [x+y for x,y in zip(max_dices_sum, max_dices)]
best_thresholds_average, best_minareas_average, max_dices_average = [x/len(dataloaders) for x in best_thresholds_sum], \
[x/len(dataloaders) for x in best_minareas_sum], [x/len(dataloaders) for x in max_dices_sum]
results['mean'] = {'best_thresholds': best_thresholds_average, 'best_minareas': best_minareas_average, 'max_dices': max_dices_average}
with codecs.open(model_path + '/result.json', 'w', "utf-8") as json_file:
json.dump(results, json_file, ensure_ascii=False)

print('save the result')
print('save the result')
24 changes: 8 additions & 16 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@ def get_seg_config():
parser = argparse.ArgumentParser()
'''
unet_resnet34时各个电脑可以设置的最大batch size
zdaiot:12
z840:16
mxq:24
zdaiot:12 z840:16 mxq:24
unet_se_renext50
hwp: 6
MXQ: 12
unet_efficientnet_b4
MXQ: 6
hwp: 6 MXQ: 12
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=6, help='batch size')
parser.add_argument('--epoch', type=int, default=50, help='epoch')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
Expand All @@ -37,7 +31,7 @@ def get_seg_config():

# model set
parser.add_argument('--model_name', type=str, default='unet_efficientnet_b4', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4')
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
Expand Down Expand Up @@ -70,11 +64,9 @@ def get_classify_config():
zdaiot:12 z840:16 mxq:48
unet_se_renext50
hwp: 8
unet_efficientnet_b4
MXQ: 8
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--epoch', type=int, default=30, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
Expand All @@ -84,8 +76,8 @@ def get_classify_config():
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_efficientnet_b4', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4')
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4/unet_resnet50')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
Expand Down
16 changes: 6 additions & 10 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
kaggle = 0
if kaggle:
os.system('pip install /kaggle/input/segmentation_models/pretrainedmodels-0.7.4/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/EfficientNet-PyTorch/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/segmentation_models.pytorch/ > /dev/null')
package_path = '/kaggle/input/sources' # add unet script dataset
package_path = '/kaggle/input/sources' # add unet script dataset
import sys
sys.path.append(package_path)
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Classify_Segment_Folds_Split
Expand Down Expand Up @@ -92,10 +91,7 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w
# start prediction
predictions = []
for i, (fnames, images) in enumerate(tqdm(test_loader)):
if len(classify_splits) != len(seg_splits):
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()
else:
results = classify_segment(images).detach().cpu().numpy()
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()

for fname, preds in zip(fnames, results):
for cls, pred in enumerate(preds):
Expand All @@ -110,13 +106,13 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w

if __name__ == "__main__":
# 设置超参数
model_name = 'unet_efficientnet_b4'
model_name = 'unet_resnet34'
num_workers = 12
batch_size = 1
batch_size = 4
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
classify_splits = [1]# [0, 1, 2, 3, 4]
segment_splits = [1]
classify_splits = [1] # [0, 1, 2, 3, 4]
segment_splits = [0, 1, 2, 3, 4]
tta_flag = True
average_strategy = False

Expand Down
8 changes: 4 additions & 4 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ def classify_provider(


if __name__ == "__main__":
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
data_folder = "./Steel_data"
df_path = "./Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 12
num_workers = 4
n_splits = 1
mask_only = True
crop = False
mask_only = False
crop = True
height = 256
width = 512
# 测试分割数据集
Expand Down
Binary file removed models/EfficientNet-PyTorch.tar.gz
Binary file not shown.
14 changes: 7 additions & 7 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ def create_model_cpu(self):
# Unet dpn 系列
elif self.model_name == 'unet_dpn68':
model = smp.Unet('dpn68', encoder_weights=self.encoder_weights, classes=self.class_num, activation=None)

# Unet Efficient 系列
elif self.model_name == 'unet_efficientnet_b4':
model = smp.Unet('efficientnet-b4', encoder_weights=self.encoder_weights, classes=self.class_num, activation=None)

return model

def create_model(self):
Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(self, model_name, class_num=4, training=True, encoder_weights='imag
nn.ReLU(),
nn.Conv2d(160, 32, kernel_size=1)
)

self.logit = nn.Conv2d(32, self.class_num, kernel_size=1)

self.training = training
Expand All @@ -112,15 +112,15 @@ def forward(self, x):

if __name__ == "__main__":
# test segment 模型
model_name = 'unet_resnet50'
model = Model(model_name, class_num=4).create_model_cpu()
x = torch.Tensor(5, 3, 256, 1600)
x = model.encoder(x)
# print(model)
model_name = 'unet_se_resnext50_32x4d'
model = Model(model_name).create_model()
print(model)

# test classify 模型
class_net = ClassifyResNet(model_name, 4)
x = torch.Tensor(8, 3, 256, 1600)
y = torch.ones(8, 4)
seg_output = model(x)
print(seg_output.size())
output = class_net(x)
print(output.size())
Binary file modified models/segmentation_models.pytorch.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
python train_classify.py
# python train_classify.py
python train_segment.py
python choose_thre_area.py
Loading

0 comments on commit 010c971

Please sign in to comment.