Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Oct 12, 2019
2 parents 66e3c94 + 222af63 commit 64dff89
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ When run `SUBMISSION=/path/to/csv/file.csv make release-csv`, If you encounter t
- [x] finish choose_threshold
- [x] finish data augmentation
- [ ] EfficientB4( w/ ASPP)
- [ ] code review(validation dice, threshold dice)
- [x] ResNet50
- [x] code review(validation dice, threshold dice)
- [ ] choose fold
- [ ] ensemble
- [ ] early stopping automaticly
- [x] early stopping automaticly
- [ ] GN
24 changes: 16 additions & 8 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ def get_seg_config():
parser = argparse.ArgumentParser()
'''
unet_resnet34时各个电脑可以设置的最大batch size
zdaiot:12 z840:16 mxq:24
zdaiot:12
z840:16
mxq:24
unet_se_renext50
hwp: 6 MXQ: 12
hwp: 6
MXQ: 12
unet_efficientnet_b4
MXQ:
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--batch_size', type=int, default=6, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
Expand All @@ -30,8 +36,8 @@ def get_seg_config():
parser.add_argument('--width', type=int, default=None, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d')
parser.add_argument('--model_name', type=str, default='unet_efficientnet_b4', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
Expand Down Expand Up @@ -64,9 +70,11 @@ def get_classify_config():
zdaiot:12 z840:16 mxq:48
unet_se_renext50
hwp: 8
unet_efficientnet_b4
MXQ: 8
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--epoch', type=int, default=30, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
Expand All @@ -76,8 +84,8 @@ def get_classify_config():
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
help='unet_resnet34/unet_se_resnext50_32x4d')
parser.add_argument('--model_name', type=str, default='unet_efficientnet_b4', \
help='unet_resnet34/unet_se_resnext50_32x4d/unet_efficientnet_b4')

# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
Expand Down
8 changes: 7 additions & 1 deletion models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def __init__(self, model_name, class_num=4, training=True, encoder_weights='imag
self.encoder = model.encoder
if model_name == 'unet_resnet34':
self.feature = nn.Conv2d(512, 32, kernel_size=1)
elif model_name == 'unet_resnet50':
self.feature = nn.Sequential(
nn.Conv2d(2048, 512, kernel_size=1),
nn.ReLU(),
nn.Conv2d(512, 32, kernel_size=1)
)
elif model_name == 'unet_se_resnext50_32x4d':
self.feature = nn.Sequential(
nn.Conv2d(2048, 512, kernel_size=1),
Expand Down Expand Up @@ -106,7 +112,7 @@ def forward(self, x):

if __name__ == "__main__":
# test segment 模型
model_name = 'unet_efficientnet_b4'
model_name = 'unet_resnet50'
model = Model(model_name, class_num=4).create_model_cpu()
x = torch.Tensor(5, 3, 256, 1600)
x = model.encoder(x)
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
# python train_classify.py
python train_classify.py
python train_segment.py
python choose_thre_area.py
4 changes: 2 additions & 2 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(self, config, fold):
self.solver = Solver(self.model)

# 加载损失函数
# self.criterion = torch.nn.BCEWithLogitsLoss()
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=self.class_num, size_average=True, weight=[1.0, 1.0])
self.criterion = torch.nn.BCEWithLogitsLoss()
# self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=self.class_num, size_average=True, weight=[1.0, 1.0])

# 保存json文件和初始化tensorboard
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}".format(datetime.datetime.now(), fold)
Expand Down
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, size_average=True, weight=[0.2, 0.8]):
def forward(self, input, target):
soft_bce_loss = self.bce_loss(input, target)
soft_dice_loss = self.softdiceloss(input, target)
loss = 0.7 * soft_bce_loss + 0.3 * soft_dice_loss
loss = 0.85 * soft_bce_loss + 0.15 * soft_dice_loss

return loss

Expand Down

0 comments on commit 64dff89

Please sign in to comment.