Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zdaiot committed Oct 9, 2019
2 parents 0b543d9 + 2486206 commit f42593b
Show file tree
Hide file tree
Showing 20 changed files with 265 additions and 42 deletions.
5 changes: 3 additions & 2 deletions choose_thre_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def get_model(model_name, load_path):
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
mask_only = True

dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits, mask_only)
results = {}
# 存放权重的路径
model_path = os.path.join(config.save_path, config.model_name)
Expand Down
59 changes: 59 additions & 0 deletions classify_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,65 @@ def classify_segment_folds(self, images):
return results


class Classify_Segment_Folds_Split():
def __init__(self, model_name, classify_folds, segment_folds, model_path, class_num=4, tta_flag=False):
''' 首先的到分类模型的集成结果,再得到分割模型的集成结果,最后将两个结果进行融合
:param model_name: 当前的模型名称
:param classify_folds: 参与集成的分类模型的折序号,为list列表
:param segment_folds: 参与集成的分割模型的折序号,为list列表
:param model_path: 存放所有模型的路径
:param class_num: 类别总数
'''
self.model_name = model_name
self.classify_folds = classify_folds
self.segment_folds = segment_folds
self.model_path = model_path
self.class_num = class_num
self.tta_flag = tta_flag

self.classify_models, self.segment_models = list(), list()
self.get_classify_segment_models()

def get_classify_segment_models(self):
''' 加载所有折的分割模型和分类模型
'''
for fold in self.classify_folds:
self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))
for fold in self.segment_folds:
self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag))

def classify_segment_folds(self, images):
''' 使用投票法处理所有fold一个batch的分割结果和分类结果
:param images: 一个batch的数据,维度为[batch, channels, height, width]
:return: results,使用投票法处理所有fold一个batch的分割结果和分类结果,维度为[batch, class_num, height, width]
'''
classify_results = torch.zeros(images.shape[0], self.class_num)
segment_results = torch.zeros(images.shape[0], self.class_num, images.shape[2], images.shape[3])
# 得到分类结果
for classify_index, classify_model in enumerate(self.classify_models):
classify_result_fold = classify_model.get_classify_results(images)
classify_results += classify_result_fold.detach().cpu().squeeze().float()
classify_vote_model_num = len(self.classify_folds)
classify_vote_ticket = round(classify_vote_model_num / 2.0)
classify_results = classify_results > classify_vote_ticket

# 得到分割结果
for segment_index, segment_model in enumerate(self.segment_models):
segment_result_fold = segment_model.get_segment_results(images)
segment_results += segment_result_fold.detach().cpu()
segment_vote_model_num = len(self.segment_folds)
segment_vote_ticket = round(segment_vote_model_num / 2.0)
segment_results = segment_results > segment_vote_ticket

# 将分类结果和分割结果进行融合
for batch_index, classify_result in enumerate(classify_results):
segment_results[batch_index, 1-classify_result, ...] = 0

return segment_results


class Segment_Folds():
def __init__(self, model_name, n_splits, model_path, class_num=4, tta_flag=False):
''' 使用投票法处理所有fold一个batch的分割结果
Expand Down
14 changes: 10 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ def get_seg_config():
hwp: 6
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=60, help='epoch')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.')
parser.add_argument('--crop', type=bool, default=True, help='if true, crop image to [height, width].')
parser.add_argument('--height', type=int, default=256, help='the height of cropped image')
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
Expand All @@ -33,7 +36,7 @@ def get_seg_config():
# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=4e-4, help='init lr')
parser.add_argument('--lr', type=float, default=1e-4, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

# dataset
Expand Down Expand Up @@ -67,6 +70,9 @@ def get_classify_config():

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
parser.add_argument('--crop', type=bool, default=False, help='if true, crop image to [height, width].')
parser.add_argument('--height', type=int, default=256, help='the height of cropped image')
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
Expand All @@ -89,4 +95,4 @@ def get_classify_config():


if __name__ == '__main__':
config = get_seg_config()
config = get_seg_config()
26 changes: 15 additions & 11 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package_path = '/kaggle/input/sources' # add unet script dataset
import sys
sys.path.append(package_path)
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold
from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Classify_Segment_Folds_Split


class TestDataset(Dataset):
Expand Down Expand Up @@ -55,10 +55,11 @@ def mask2rle(img):
return ' '.join(str(x) for x in runs)


def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False):
def create_submission(classify_splits, seg_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False):
'''
:param n_splits: 折数,类型为list
:param classify_splits: 分类模型的折数,类型为list
:param seg_splits: 分割模型的折数,类型为list
:param model_name: 当前模型的名称
:param batch_size: batch的大小
:param num_workers: 加载数据的线程
Expand All @@ -79,10 +80,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
num_workers=num_workers,
pin_memory=True
)
if len(n_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, n_splits[0], model_path, tta_flag=tta_flag).classify_segment
else:
classify_segment = Classify_Segment_Folds(model_name, n_splits, model_path, tta_flag=tta_flag).classify_segment_folds
if len(classify_splits) == 1 and len(seg_splits) == 1:
classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment
elif len(classify_splits) == len(seg_splits):
classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds
elif len(classify_splits) != len(seg_splits):
classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds

# start prediction
predictions = []
Expand All @@ -104,11 +107,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
# 设置超参数
model_name = 'unet_resnet34'
num_workers = 12
batch_size = 6
batch_size = 4
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
n_splits = [1] # [0, 1, 2, 3, 4]
tta_flag = False
classify_splits = [1] # [0, 1, 2, 3, 4]
segment_splits = [0, 1, 2, 3, 4]
tta_flag = True

if kaggle:
sample_submission_path = '/kaggle/input/severstal-steel-defect-detection/sample_submission.csv'
Expand All @@ -119,5 +123,5 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std,
test_data_folder = 'datasets/Steel_data/test_images'
model_path = './checkpoints/' + model_name

create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
create_submission(classify_splits, segment_splits, model_name, batch_size, num_workers, mean, std, test_data_folder,
sample_submission_path, model_path, tta_flag=tta_flag)
46 changes: 31 additions & 15 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@

# Dataset Segmentation
class SteelDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None):
super(SteelDataset, self).__init__()
self.df = df
self.root = data_folder
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms
self.crop = crop
self.height = height
self.width = width
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width)
mask = mask.permute(2, 0, 1)
return img, mask

Expand All @@ -47,21 +50,24 @@ def __len__(self):

# Dataset Classification
class SteelClassDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None):
super(SteelClassDataset, self).__init__()
self.df = df
self.root = data_folder
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms
self.crop = crop
self.height = height
self.width = width
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width)
mask = mask.permute(2, 0, 1) # 4x256x1600
mask = mask.view(mask.size(0), -1)
mask = torch.sum(mask, dim=1)
Expand Down Expand Up @@ -99,7 +105,7 @@ def __len__(self):
return self.num_samples


def augmentation(image, mask):
def augmentation(image, mask, crop=False, height=None, width=None):
"""进行数据增强
Args:
image: 原始图像
Expand All @@ -108,16 +114,16 @@ def augmentation(image, mask):
image_aug: 增强后的图像,Image图像
mask: 增强后的掩膜,Image图像
"""
image_aug, mask_aug = data_augmentation(image, mask)
image_aug, mask_aug = data_augmentation(image, mask, crop=crop, height=height, width=width)
image_aug = Image.fromarray(image_aug)

return image_aug, mask_aug


def get_transforms(phase, image, mask, mean, std):
def get_transforms(phase, image, mask, mean, std, crop=False, height=None, width=None):

if phase == 'train':
image, mask = augmentation(image, mask)
image, mask = augmentation(image, mask, crop=crop, height=height, width=width)

to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean, std)
Expand Down Expand Up @@ -152,7 +158,10 @@ def provider(
batch_size=8,
num_workers=4,
n_splits=0,
mask_only=False
mask_only=False,
crop=False,
height=None,
width=None
):
"""返回数据加载器,用于分割模型
Expand Down Expand Up @@ -192,7 +201,7 @@ def provider(
# 生成dataloader
dataloaders = list()
for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)):
train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train')
train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width)
val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val')
if mask_only:
# 只在有掩膜的样本上训练
Expand Down Expand Up @@ -242,6 +251,9 @@ def classify_provider(
batch_size=8,
num_workers=4,
n_splits=0,
crop=False,
height=None,
width=False
):
"""返回数据加载器,用于分类模型
Expand Down Expand Up @@ -281,7 +293,7 @@ def classify_provider(
# 生成dataloader
dataloaders = list()
for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)):
train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train')
train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width)
val_dataset = SteelClassDataset(val_df, data_folder, mean, std, 'val')
train_dataloader = DataLoader(
train_dataset,
Expand All @@ -303,15 +315,19 @@ def classify_provider(


if __name__ == "__main__":
data_folder = "Steel_data"
df_path = "Steel_data/train.csv"
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 8
batch_size = 12
num_workers = 4
n_splits = 1
mask_only = False
crop = True
height = 256
width = 512
# 测试分割数据集
dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits, mask_only=mask_only, crop=crop, height=height, width=width)
for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader):
train_bar = tqdm(train_dataloader)
class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def pred_show(images, preds, mean, std, targets=None, flag=False, auto_flag=Fals
# 是否显示真实的mask
show_truemask_flag = True
# 加载哪几折的模型进行测试,若list中有多个值,则使用投票法
n_splits = [1] # [0, 1, 2, 3, 4]
n_splits = [0, 1, 2, 3, 4] # [0, 1, 2, 3, 4]
# 是否只使用分割模型
use_segment_only = True
# 是否使用自动显示
Expand Down
Loading

0 comments on commit f42593b

Please sign in to comment.