Skip to content

Commit

Permalink
back dice
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 15, 2019
1 parent 05c3fc9 commit 181c6da
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 145 deletions.
85 changes: 40 additions & 45 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import codecs
import json
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from solver import Solver
from models.model import Model
from datasets.steel_dataset import provider
Expand All @@ -20,9 +19,9 @@
class ChooseThresholdMinArea():
''' 选择每一类的像素阈值和最小连通域
'''

def __init__(self, model, model_name, valid_loader, fold, save_path, class_num=4):
''' 模型初始化
Args:
model: 使用的模型
model_name: 当前模型的名称
Expand All @@ -43,7 +42,7 @@ def __init__(self, model, model_name, valid_loader, fold, save_path, class_num=4

def choose_threshold_minarea(self):
''' 采用网格法搜索各个类别最优像素阈值和最优最小连通域,并画出各个类别搜索过程中的热力图
Return:
best_thresholds_little: 每一个类别的最优阈值
best_minareas_little: 每一个类别的最优最小连通取余
Expand All @@ -52,28 +51,36 @@ def choose_threshold_minarea(self):
init_thresholds_range, init_minarea_range = np.arange(0.50, 0.71, 0.03), np.arange(768, 2305, 256)

# 阈值列表和最小连通域列表,大小为 Nx4
thresholds_table_big = np.array([init_thresholds_range, init_thresholds_range, \
thresholds_table_big = np.array([init_thresholds_range, init_thresholds_range,
init_thresholds_range, init_thresholds_range]) # 阈值列表
minareas_table_big = np.array([init_minarea_range, init_minarea_range, \
minareas_table_big = np.array([init_minarea_range, init_minarea_range,
init_minarea_range, init_minarea_range]) # 最小连通域列表

f, axes = plt.subplots(figsize=(28.8, 18.4), nrows=2, ncols=self.class_num)
cmap = sns.cubehelix_palette(start=1.5, rot=3, gamma=0.8, as_cmap=True)

best_thresholds_big, best_minareas_big, max_dices_big = self.grid_search(thresholds_table_big, minareas_table_big, axes[0,:], cmap)
print('best_thresholds_big:{}, best_minareas_big:{}, max_dices_big:{}'.format(best_thresholds_big, best_minareas_big, max_dices_big))
best_thresholds_big, best_minareas_big, max_dices_big = self.grid_search(thresholds_table_big,
minareas_table_big, axes[0, :], cmap)
print('best_thresholds_big:{}, best_minareas_big:{}, max_dices_big:{}'.format(best_thresholds_big,
best_minareas_big, max_dices_big))

# 开始细分类
thresholds_table_little, minareas_table_little = list(), list()
for best_threshold_big, best_minarea_big in zip(best_thresholds_big, best_minareas_big):
thresholds_table_little.append(np.arange(best_threshold_big-0.03, best_threshold_big+0.03, 0.015)) # 阈值列表
minareas_table_little.append(np.arange(best_minarea_big-256, best_minarea_big+257, 128)) # 像素阈值列表
thresholds_table_little, minareas_table_little = np.array(thresholds_table_little), np.array(minareas_table_little)

best_thresholds_little, best_minareas_little, max_dices_little = self.grid_search(thresholds_table_little, minareas_table_little, axes[1,:], cmap)
print('best_thresholds_little:{}, best_minareas_little:{}, max_dices_little:{}'.format(best_thresholds_little, best_minareas_little, max_dices_little))

f.savefig(os.path.join(self.save_path, self.model_name + '_fold'+str(self.fold)), bbox_inches='tight')
thresholds_table_little.append(
np.arange(best_threshold_big - 0.03, best_threshold_big + 0.03, 0.015)) # 阈值列表
minareas_table_little.append(np.arange(best_minarea_big - 256, best_minarea_big + 257, 128)) # 像素阈值列表
thresholds_table_little, minareas_table_little = np.array(thresholds_table_little), np.array(
minareas_table_little)

best_thresholds_little, best_minareas_little, max_dices_little = self.grid_search(thresholds_table_little,
minareas_table_little,
axes[1, :], cmap)
print('best_thresholds_little:{}, best_minareas_little:{}, max_dices_little:{}'.format(best_thresholds_little,
best_minareas_little,
max_dices_little))

f.savefig(os.path.join(self.save_path, self.model_name + '_fold' + str(self.fold)), bbox_inches='tight')
# plt.show()
plt.close()

Expand All @@ -82,13 +89,11 @@ def choose_threshold_minarea(self):
def grid_search(self, thresholds_table, minareas_table, axes, cmap):
''' 给定包含各个类别搜索区间的thresholds_table和minareas_table,求的各个类别的最优像素阈值,最优最小连通域,最高dice;
并画出各个类别搜索过程中的热力图
Args:
thresholds_table: 待搜索的阈值范围,维度为[4, N],numpy类型
minareas_table: 待搜索的最小连通域范围,维度为[4, N],numpy类型
axes: 画各个类别搜索热力图时所需要的画柄,尺寸为[class_num]
cmap: 画图时所需要的cmap
return:
best_thresholds: 各个类别的最优像素阈值,尺寸为[class_num]
best_minareas: 各个类别的最优最小连通域,尺寸为[class_num]
Expand All @@ -97,15 +102,12 @@ def grid_search(self, thresholds_table, minareas_table, axes, cmap):
dices_table = np.zeros((self.class_num, np.shape(thresholds_table)[1], np.shape(minareas_table)[1]))
tbar = tqdm.tqdm(self.valid_loader)
with torch.no_grad():
for i, samples in enumerate(tbar):
if len(samples) == 0:
continue
images, masks = samples[0], samples[1]
for i, (images, masks) in enumerate(tbar):
# 完成网络的前向传播
masks_predict_allclasses = self.solver.forward(images)
dices_table += self.grid_search_batch(thresholds_table, minareas_table, masks_predict_allclasses, masks)

dices_table = dices_table/len(tbar)
dices_table = dices_table / len(tbar)
best_thresholds, best_minareas, max_dices = list(), list(), list()
# 处理每一类的预测结果
for each_class, dices_oneclass_table in enumerate(dices_table):
Expand All @@ -116,21 +118,21 @@ def grid_search(self, thresholds_table, minareas_table, axes, cmap):
best_minareas.append(minareas_table[each_class, max_location[1]])
max_dices.append(max_dice)

data = pd.DataFrame(data=dices_oneclass_table, index=np.around(thresholds_table[each_class,:], 3), columns=minareas_table[each_class,:])
sns.heatmap(data, linewidths=0.05, ax=axes[each_class], vmax=np.max(dices_oneclass_table), vmin=np.min(dices_oneclass_table), cmap=cmap,
data = pd.DataFrame(data=dices_oneclass_table, index=np.around(thresholds_table[each_class, :], 3),
columns=minareas_table[each_class, :])
sns.heatmap(data, linewidths=0.05, ax=axes[each_class], vmax=np.max(dices_oneclass_table),
vmin=np.min(dices_oneclass_table), cmap=cmap,
annot=True, fmt='.4f')
axes[each_class].set_title('search result')
return best_thresholds, best_minareas, max_dices

def grid_search_batch(self, thresholds_table, minareas_table, masks_predict_allclasses, masks_allclasses):
'''给定thresholds、minareas矩阵、一个batch的预测结果和真实标签,遍历每个类的每一个组合得到对应的dice值
Args:
thresholds_table: 待搜索的阈值范围,维度为[4, N],numpy类型
minareas_table: 待搜索的最小连通域范围,维度为[4, N],numpy类型
masks_predict_allclasses: 所有类别的某个batch的预测结果且未经过sigmoid,维度为[batch_size, class_num, height, width]
masks_allclasses: 所有类别的某个batch的真实类标,维度为[batch_size, class_num, height, width]
Return:
dices_table: 各个类别在其各自的所有搜索组合中所得到的dice值,维度为[4, M, N]
'''
Expand All @@ -149,13 +151,13 @@ def grid_search_batch(self, thresholds_table, minareas_table, masks_predict_allc

def post_process(self, thresholds_range, minareas_range, masks_predict_oneclass, masks_oneclasses):
'''给定某个类别的某个batch的数据,遍历所有搜索组合,得到每个组合的dice值
Args:
thresholds_range: 具体某个类别的像素阈值搜索区间,尺寸为[M]
minareas_range: 具体某个类别的最小连通域搜索区间,尺寸为[N]
masks_predict_oneclass: 预测出的某个类别的该batch的tensor向量且未经过sigmoid,维度为[batch_size, height, width]
masks_oneclasses: 某个类别的该batch的真实类标,维度为[batch_size, height, width]
Return:
dices_range: 某个类别的该batch的所有搜索组合得到dice矩阵,维度为[M, N]
'''
Expand Down Expand Up @@ -188,11 +190,10 @@ def post_process(self, thresholds_range, minareas_range, masks_predict_oneclass,

def get_model(model_name, load_path):
''' 加载网络模型并加载对应的权重
Args:
Args:
model_name: 当前模型的名称
load_path: 当前模型的权重路径
Return:
model: 加载出来的模型
'''
Expand All @@ -203,16 +204,15 @@ def get_model(model_name, load_path):

if __name__ == "__main__":
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mask_only = True

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits, mask_only)
dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std,
config.batch_size, config.num_workers, config.n_splits, mask_only)
results = {}
# 存放权重的路径
model_path = os.path.join(config.save_path, config.model_name)
best_thresholds_sum, best_minareas_sum, max_dices_sum = [0 for x in range(len(dataloaders))], \
[0 for x in range(len(dataloaders))], [0 for x in range(len(dataloaders))]
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
continue
Expand All @@ -221,18 +221,13 @@ def get_model(model_name, load_path):
load_path = os.path.join(model_path, '%s_fold%d_best.pth' % (config.model_name, fold_index))
# 加载模型
model = get_model(config.model_name, load_path)
mychoose_threshold_minarea = ChooseThresholdMinArea(model, config.model_name, valid_loader, fold_index, model_path)
mychoose_threshold_minarea = ChooseThresholdMinArea(model, config.model_name, valid_loader, fold_index,
model_path)
best_thresholds, best_minareas, max_dices = mychoose_threshold_minarea.choose_threshold_minarea()
result = {'best_thresholds': best_thresholds, 'best_minareas': best_minareas, 'max_dices': max_dices}
results[str(fold_index)] = result

best_thresholds_sum = [x+y for x,y in zip(best_thresholds_sum, best_thresholds)]
best_minareas_sum = [x+y for x,y in zip(best_minareas_sum, best_minareas)]
max_dices_sum = [x+y for x,y in zip(max_dices_sum, max_dices)]
best_thresholds_average, best_minareas_average, max_dices_average = [x/len(dataloaders) for x in best_thresholds_sum], \
[x/len(dataloaders) for x in best_minareas_sum], [x/len(dataloaders) for x in max_dices_sum]
results['mean'] = {'best_thresholds': best_thresholds_average, 'best_minareas': best_minareas_average, 'max_dices': max_dices_average}
results[str(fold_index)] = result
with codecs.open(model_path + '/result.json', 'w', "utf-8") as json_file:
json.dump(results, json_file, ensure_ascii=False)

print('save the result')
print('save the result')
27 changes: 15 additions & 12 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.system('pip install /kaggle/input/segmentation_models/pretrainedmodels-0.7.4/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/EfficientNet-PyTorch/ > /dev/null')
os.system('pip install /kaggle/input/segmentation_models/segmentation_models.pytorch/ > /dev/null')
package_path = '/kaggle/input/sources' # add unet script dataset
package_path = '/kaggle/input/sources' # add unet script dataset
import sys
sys.path.append(package_path)
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Classify_Segment_Folds_Split
Expand Down Expand Up @@ -82,17 +82,20 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w
num_workers=num_workers,
pin_memory=True
)
# if len(classify_splits) == 1 and len(seg_splits) == 1:
# classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment
# elif len(classify_splits) == len(seg_splits):
# classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds
# elif len(classify_splits) != len(seg_splits):
classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds
if len(classify_splits) == 1 and len(seg_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment
elif len(classify_splits) == len(seg_splits):
classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds
elif len(classify_splits) != len(seg_splits):
classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds

# start prediction
predictions = []
for i, (fnames, images) in enumerate(tqdm(test_loader)):
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()
if len(classify_splits) != len(seg_splits):
results = classify_segment(images, average_strategy=average_strategy).detach().cpu().numpy()
else:
results = classify_segment(images).detach().cpu().numpy()

for fname, preds in zip(fnames, results):
for cls, pred in enumerate(preds):
Expand All @@ -107,13 +110,13 @@ def create_submission(classify_splits, seg_splits, model_name, batch_size, num_w

if __name__ == "__main__":
# 设置超参数
model_name = 'unet_resnet34'
model_name = 'unet_efficientnet_b4'
num_workers = 12
batch_size = 4
batch_size = 1
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
classify_splits = [1] # [0, 1, 2, 3, 4]
segment_splits = [0, 1, 2, 3, 4]
classify_splits = [1]# [0, 1, 2, 3, 4]
segment_splits = [1]
tta_flag = True
average_strategy = False

Expand Down
8 changes: 4 additions & 4 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ def classify_provider(


if __name__ == "__main__":
data_folder = "./Steel_data"
df_path = "./Steel_data/train.csv"
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 12
num_workers = 4
n_splits = 1
mask_only = False
crop = True
mask_only = True
crop = False
height = 256
width = 512
# 测试分割数据集
Expand Down
9 changes: 2 additions & 7 deletions train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def train(self, train_loader, valid_loader):
optimizer = optim.Adam(self.model.module.parameters(), self.lr, weight_decay=self.weight_decay)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10)
global_step = 0
es = EarlyStopping(mode='min', patience=10)

for epoch in range(self.epoch):
epoch += 1
Expand Down Expand Up @@ -98,10 +97,6 @@ def train(self, train_loader, valid_loader):
average_loss = epoch_loss / len(tbar)
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, average_loss))

# 提前终止
if es.step(average_loss):
break

# 验证模型
class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
self.validation(valid_loader)
Expand Down Expand Up @@ -160,8 +155,8 @@ def validation(self, valid_loader):

if __name__ == "__main__":
config = get_classify_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dataloaders = classify_provider(
config.dataset_root,
os.path.join(config.dataset_root, 'train.csv'),
Expand Down
Loading

0 comments on commit 181c6da

Please sign in to comment.