Skip to content

Commit

Permalink
fix dice
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Oct 10, 2019
1 parent 3e963ae commit 8f29eca
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import codecs, json
import time
import pickle
import random
import numpy as np
from models.model import Model
from utils.cal_dice_iou import Meter
from utils.cal_dice_iou import Meter, compute_dice_class
from datasets.steel_dataset import provider
from utils.set_seed import seed_torch
from config import get_seg_config
Expand All @@ -34,6 +34,7 @@ def __init__(self, config, fold):
self.weight_decay = config.weight_decay
self.epoch = config.epoch
self.fold = fold
self.class_num = config.class_num

# 创建保存权重的路径
self.model_path = os.path.join(config.save_path, config.model_name)
Expand All @@ -49,7 +50,7 @@ def __init__(self, config, fold):

# 加载损失函数
# self.criterion = torch.nn.BCEWithLogitsLoss()
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[1.0, 1.0])
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=self.class_num, size_average=True, weight=[1.0, 1.0])

# 保存json文件和初始化tensorboard
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}".format(datetime.datetime.now(), fold)
Expand Down Expand Up @@ -116,7 +117,7 @@ def train(self, train_loader, valid_loader):
break

# 验证模型
loss_valid, dice_valid, iou_valid = self.validation(valid_loader)
loss_valid, dice_valid = self.validation(valid_loader)
if dice_valid > self.max_dice_valid:
is_best = True
self.max_dice_valid = dice_valid
Expand All @@ -137,11 +138,13 @@ def validation(self, valid_loader):
Args:
valid_loader: 验证数据的Dataloader
:return loss_mean: 验证集上的loss平均值
:return dice_mean: 验证集上的各类dice平均值
'''
self.model.eval()
meter = Meter()
tbar = tqdm.tqdm(valid_loader)
loss_sum = 0
loss_sum, dice_sum = 0, 0

with torch.no_grad():
for i, samples in enumerate(tbar):
Expand All @@ -154,17 +157,18 @@ def validation(self, valid_loader):
loss_sum += loss.item()

# 注意,损失函数中包含sigmoid函数,meter.update中也包含了sigmoid函数
# masks_predict_binary = torch.sigmoid(masks_predict) > 0.5
meter.update(masks, masks_predict.detach().cpu())

masks_predict_binary = torch.sigmoid(masks_predict) > 0.5
for each_class in range(self.class_num):
masks_predict_oneclass = masks_predict_binary[:, each_class, ...]
masks_oneclasses = masks[:, each_class, ...]
dice = compute_dice_class(masks_predict_oneclass.float(), masks_oneclasses)
dice_sum += dice
descript = "Val Loss: {:.7f}".format(loss.item())
tbar.set_description(desc=descript)
loss_mean = loss_sum/len(tbar)

dices, iou = meter.get_metrics()
dice, dice_neg, dice_pos = dices
print("IoU: %0.4f | dice: %0.4f | dice_neg: %0.4f | dice_pos: %0.4f" % (iou, dice, dice_neg, dice_pos))
return loss_mean, dice, iou
dice_mean = dice_sum/len(tbar)/self.class_num
print("loss_mean: %0.4f, dice_mean: %0.4f" % (loss_mean, dice_mean))
return loss_mean, dice_mean

def load_weight(self, weight_path):
"""加载权重
Expand Down

0 comments on commit 8f29eca

Please sign in to comment.