-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
125 lines (108 loc) · 5.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from dataset import ECGDataset
from resnet import resnet34
from utils import cal_f1s, cal_aucs, split_data
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, default='data/CPSC', help='Directory for data dir')
parser.add_argument('--leads', type=str, default='all', help='ECG leads to use')
parser.add_argument('--seed', type=int, default=42, help='Seed to split data')
parser.add_argument('--num-classes', type=int, default=int, help='Num of diagnostic classes')
parser.add_argument('--lr', '--learning-rate', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
parser.add_argument('--num-workers', type=int, default=4, help='Num of workers to load data')
parser.add_argument('--phase', type=str, default='train', help='Phase: train or test')
parser.add_argument('--epochs', type=int, default=40, help='Training epochs')
parser.add_argument('--resume', default=False, action='store_true', help='Resume')
parser.add_argument('--use-gpu', default=False, action='store_true', help='Use GPU')
parser.add_argument('--model-path', type=str, default='', help='Path to saved model')
return parser.parse_args()
def train(dataloader, net, args, criterion, epoch, scheduler, optimizer, device):
print('Training epoch %d:' % epoch)
net.train()
running_loss = 0
output_list, labels_list = [], []
for _, (data, labels) in enumerate(tqdm(dataloader)):
data, labels = data.to(device), labels.to(device)
output = net(data)
loss = criterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
output_list.append(output.data.cpu().numpy())
labels_list.append(labels.data.cpu().numpy())
# scheduler.step()
print('Loss: %.4f' % running_loss)
def evaluate(dataloader, net, args, criterion, device):
print('Validating...')
net.eval()
running_loss = 0
output_list, labels_list = [], []
for _, (data, labels) in enumerate(tqdm(dataloader)):
data, labels = data.to(device), labels.to(device)
output = net(data)
loss = criterion(output, labels)
running_loss += loss.item()
output = torch.sigmoid(output)
output_list.append(output.data.cpu().numpy())
labels_list.append(labels.data.cpu().numpy())
print('Loss: %.4f' % running_loss)
y_trues = np.vstack(labels_list)
y_scores = np.vstack(output_list)
f1s = cal_f1s(y_trues, y_scores)
avg_f1 = np.mean(f1s)
print('F1s:', f1s)
print('Avg F1: %.4f' % avg_f1)
if args.phase == 'train' and avg_f1 > args.best_metric:
args.best_metric = avg_f1
torch.save(net.state_dict(), args.model_path)
else:
aucs = cal_aucs(y_trues, y_scores)
avg_auc = np.mean(aucs)
print('AUCs:', aucs)
print('Avg AUC: %.4f' % avg_auc)
if __name__ == "__main__":
args = parse_args()
args.best_metric = 0
data_dir = os.path.normpath(args.data_dir)
database = os.path.basename(data_dir)
if not args.model_path:
args.model_path = f'models/resnet34_{database}_{args.leads}_{args.seed}.pth'
if args.use_gpu and torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = 'cpu'
if args.leads == 'all':
leads = 'all'
nleads = 12
else:
leads = args.leads.split(',')
nleads = len(leads)
label_csv = os.path.join(data_dir, 'labels.csv')
train_folds, val_folds, test_folds = split_data(seed=args.seed)
train_dataset = ECGDataset('train', data_dir, label_csv, train_folds, leads)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
val_dataset = ECGDataset('val', data_dir, label_csv, val_folds, leads)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
test_dataset = ECGDataset('test', data_dir, label_csv, test_folds, leads)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
net = resnet34(input_channels=nleads).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)
criterion = nn.BCEWithLogitsLoss()
if args.phase == 'train':
if args.resume:
net.load_state_dict(torch.load(args.model_path, map_location=device))
for epoch in range(args.epochs):
train(train_loader, net, args, criterion, epoch, scheduler, optimizer, device)
evaluate(val_loader, net, args, criterion, device)
else:
net.load_state_dict(torch.load(args.model_path, map_location=device))
evaluate(test_loader, net, args, criterion, device)