Skip to content

Commit

Permalink
a little milestone
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 9, 2019
1 parent cf0247e commit 4e9f025
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 19 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ When run `SUBMISSION=/path/to/csv/file.csv make release-csv`, If you encounter t
## TODO
- [x] finish classify + segment model
- [x] finish create_submission.py
- [ ] finish demo.py
- [ ] finish loss.py
- [ ] finish choose_threshold
- [ ] finish data enhancement
- [x] finish demo.py
- [x] finish loss.py
- [x] finish choose_threshold
- [x] finish data augmentation
- [ ] EfficientB4( w/ ASPP)
- [ ] code review(validation dice, threshold dice)
- [ ] choose fold
- [ ] ensemble
- [ ] early stopping automaticly
8 changes: 4 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def get_seg_config():
hwp: 6 MXQ: 12
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=12, help='batch size')
parser.add_argument('--epoch', type=int, default=50, help='epoch')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
Expand All @@ -30,14 +30,14 @@ def get_seg_config():
parser.add_argument('--width', type=int, default=None, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_se_resnext50_32x4d', \
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
parser.add_argument('--resume', type=str, default=0, help='Resuming from specified weight')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=4e-5, help='init lr')
parser.add_argument('--lr', type=float, default=5e-5, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

# dataset
Expand Down
10 changes: 8 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ def demo(n_splits, use_segment_only, model_name, mean, std, show_truemask_flag,

# start prediction
if show_truemask_flag:
for images, masks in tqdm(dataloader):
for samples in tqdm(dataloader):
if len(samples) == 0:
continue
images, masks = samples[0], samples[1]
results = model(images).detach().cpu().numpy()
pred_show(images, results, mean, std, targets=masks, flag=show_truemask_flag, auto_flag=auto_flag)
else:
for fnames, images in tqdm(dataloader):
for fnames, samples in tqdm(dataloader):
if len(samples) == 0:
continue
images, masks = samples[0], samples[1]
results = model(images).detach().cpu().numpy()
pred_show(images, results, mean, std, targets=None, flag=show_truemask_flag, auto_flag=auto_flag)

Expand Down
Binary file modified models/segmentation_models.pytorch.tar.gz
Binary file not shown.
7 changes: 4 additions & 3 deletions train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def train(self, train_loader, valid_loader):

# 保存到tensorboard,每一步存储一个
self.writer.add_scalar('train_loss', loss.item(), global_step+i)

descript = "Fold: %d, Train Loss: %.7f, lr: %.7f" % (self.fold, loss.item(), self.lr)
params_groups_lr = str()
for group_ind, param_group in enumerate(optimizer.param_groups):
params_groups_lr = params_groups_lr + 'params_group_%d' % (group_ind) + ': %.12f, ' % (param_group['lr'])
descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (self.fold, loss.item(), params_groups_lr)
tbar.set_description(desc=descript)

# 每一个epoch完毕之后,执行学习率衰减
lr_scheduler.step()
self.lr = lr_scheduler.get_lr()
global_step += len(train_loader)

# Print the log info
Expand Down
11 changes: 6 additions & 5 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self, config, fold):
self.solver = Solver(self.model)

# 加载损失函数
self.criterion = torch.nn.BCEWithLogitsLoss()
# self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25])
# self.criterion = torch.nn.BCEWithLogitsLoss()
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[1.0, 1.0])

# 保存json文件和初始化tensorboard
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}".format(datetime.datetime.now(), fold)
Expand Down Expand Up @@ -94,13 +94,14 @@ def train(self, train_loader, valid_loader):

# 保存到tensorboard,每一步存储一个
self.writer.add_scalar('train_loss', loss.item(), global_step+i)

descript = "Fold: %d, Train Loss: %.7f, lr: %.7f" % (self.fold, loss.item(), self.lr)
params_groups_lr = str()
for group_ind, param_group in enumerate(optimizer.param_groups):
params_groups_lr = params_groups_lr + 'params_group_%d' % (group_ind) + ': %.12f, ' % (param_group['lr'])
descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (self.fold, loss.item(), params_groups_lr)
tbar.set_description(desc=descript)

# 每一个epoch完毕之后,执行学习率衰减
lr_scheduler.step()
self.lr = lr_scheduler.get_lr()
global_step += len(train_loader)

# Print the log info
Expand Down
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, size_average=True, weight=[0.2, 0.8]):
def forward(self, input, target):
soft_bce_loss = self.bce_loss(input, target)
soft_dice_loss = self.softdiceloss(input, target)
loss = 5.0 * soft_bce_loss + soft_dice_loss
loss = 0.7 * soft_bce_loss + 0.3 * soft_dice_loss

return loss

Expand Down

0 comments on commit 4e9f025

Please sign in to comment.