Skip to content

Commit

Permalink
add test_clasify.py
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Sep 27, 2019
1 parent 9601ca6 commit bf5b6eb
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 26 deletions.
28 changes: 2 additions & 26 deletions create_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torch.utils.data import DataLoader, Dataset
from albumentations import (Normalize, Compose)
from albumentations.pytorch import ToTensor
from datasets.steel_dataset import TestDataset


# https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
def mask2rle(img):
Expand All @@ -31,32 +33,6 @@ def mask2rle(img):
return ' '.join(str(x) for x in runs)


class TestDataset(Dataset):
'''Dataset for test prediction'''

def __init__(self, root, df, mean, std):
self.root = root
df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
self.fnames = df['ImageId'].unique().tolist()
self.num_samples = len(self.fnames)
self.transform = Compose(
[
Normalize(mean=mean, std=std, p=1),
ToTensor(),
]
)

def __getitem__(self, idx):
fname = self.fnames[idx]
path = os.path.join(self.root, fname)
image = cv2.imread(path)
images = self.transform(image=image)["image"]
return fname, images

def __len__(self):
return self.num_samples


def post_process(probability, threshold, min_size):
'''Post processing of each predicted mask, components with lesser number of pixels
than `min_size` are ignored'''
Expand Down
26 changes: 26 additions & 0 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ def __len__(self):
return len(self.fnames)


class TestDataset(Dataset):
'''Dataset for test prediction'''

def __init__(self, root, df, mean, std):
self.root = root
df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
self.fnames = df['ImageId'].unique().tolist()
self.num_samples = len(self.fnames)
self.transform = Compose(
[
Normalize(mean=mean, std=std, p=1),
ToTensor(),
]
)

def __getitem__(self, idx):
fname = self.fnames[idx]
path = os.path.join(self.root, fname)
image = cv2.imread(path)
images = self.transform(image=image)["image"]
return fname, images

def __len__(self):
return self.num_samples


def get_transforms(phase, mean, std):
list_transforms = []
if phase == "train":
Expand Down
144 changes: 144 additions & 0 deletions test_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
from torch import nn
from albumentations import HorizontalFlip
import pandas as pd
import numpy as np
from tqdm import tqdm
import cv2

from datasets.steel_dataset import TestDataset, classify_provider
from models.classify import ClassifyResNet


class ClassifyTest():
def __init__(self, model, threshold=[0.5, 0.5, 0.5, 0.5], tta=False):
self.threshold = threshold
self.tta = tta

self.model = model
self.model.eval()

def predict_dataloader(self, dataloader):
"""对测试集进行测试
Return:
test_image_id: 测试样本的名称
predict_label: 各个样本对应的预测类标
"""
tbar = tqdm(dataloader)
test_probility = list()
test_image_id = list()
for (fnames, images) in tbar:
images = images.cuda()
probility = self.tta_pred(images)
probility = probility.data.cpu().numpy()
test_probility.append(probility)
test_image_id.extend([fname for fname in fnames])
test_probility = np.concatenate(test_probility)
predict_label = test_probility > np.array(self.threshold).reshape(1, 4, 1, 1)

return test_image_id, predict_label

def predict_image(self, images):
"""对一个batch的样本进行测试
Return:
predict_label: 各个样本对应的预测类标
"""
probility = self.tta_pred(images)
probility = probility.data.cpu().numpy()
predict_label = probility > np.array(self.threshold).reshape(1, 4, 1, 1)

return predict_label

def tta_pred(self, images):
# 水平翻转
probility_tta = 0
logit = self.model(torch.flip(images, dims=[3]))
probility = torch.sigmoid(logit)
probility_tta += probility

# 原始
logit = self.model(images)
probility = torch.sigmoid(logit)
probility_tta += probility

probility_tta /= 2

return probility_tta


if __name__ == "__main__":
data_folder = "/home/apple/program/MXQ/Competition/Kaggle/Steal-Defect/Kaggle-Steel-Defect-Detection/datasets/Steel_data"
df_path = "/home/apple/program/MXQ/Competition/Kaggle/Steal-Defect/Kaggle-Steel-Defect-Detection/datasets/Steel_data/train.csv"
test_df = pd.read_csv('./datasets/Steel_data/sample_submission.csv')
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
test_dataset = TestDataset('./datasets/Steel_data/test_images', test_df, mean, std)
dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=20,
shuffle=True,
num_workers=8,
pin_memory=True
)

model = ClassifyResNet('unet_resnet34', 4, training=False)
model = torch.nn.DataParallel(model)
model = model.cuda()
pth_path = "checkpoints/unet_resnet34/unet_resnet34_classify_fold2.pth"
checkpoint = torch.load(pth_path)
model.module.load_state_dict(checkpoint['state_dict'])

class_test = ClassifyTest(model, [0.5, 0.5, 0.5, 0.5], True)
# 直接对一整个数据集进行预测
# image_id, predict_label = class_test.predict(dataloader)
# 按照mini-batch的方式进行预测
class_dataloader = classify_provider(data_folder, df_path, mean, std, 20, 8, 5)
for fold_index, [train_dataloader, val_dataloader] in enumerate(class_dataloader):
train_bar = tqdm(val_dataloader)
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (139, 0, 139)]
number_sample = 0
num_true = 0
for (images, targets) in train_bar:
images = images.cuda()
# 预测并计算指标
predicts = class_test.predict_image(images).squeeze().astype(int)
targets_numpy = targets.data.cpu().numpy()
num_true += (predicts == targets_numpy).sum()
number_sample += targets_numpy.size
descript = 'True / Num: %d / %d' % (num_true, number_sample)
train_bar.set_description(desc=descript)

image = images[0]
for i in range(3):
image[i] = image[i] * std[i]
image[i] = image[i] + mean[i]
image = image.permute(1, 2, 0).cpu().numpy()
target = targets[0]
# 真实类别标签
position_x = 10
for i in range(target.size(0)):
color = class_color[i]
position_x += 50
position = (position_x, 50)
if target[i] != 0:
font = cv2.FONT_HERSHEY_SIMPLEX
image = cv2.putText(image, str(i), position, font, 1.2, color, 2)
# 预测类别标签
predict = predicts[0]
position_x = 10
for i in range(predict.shape[0]):
color = class_color[i]
position_x += 50
position = (position_x, 100)
if predict[i] != 0:
font = cv2.FONT_HERSHEY_SIMPLEX
image = cv2.putText(image, str(i), position, font, 1.2, color, 2)
cv2.imshow('win', image)
cv2.waitKey(30)
print("Accuracy: %.4f" % (num_true / number_sample))
pass



0 comments on commit bf5b6eb

Please sign in to comment.