-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cv.py
194 lines (155 loc) · 9.11 KB
/
train_cv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import torch
import argparse
import os
import sys
import time
from loguru import logger
import numpy as np
from data_prep.dataset import Dataset
from data_prep.dataset_loader import LoadData
# importing utilities
from utils.utils import seeding
from networks.NetworkController import getNetwork
from networks.VGG16 import VGG16_BN_Attention
from experiments.ClassifierController import getExperiment
from sklearn.model_selection import StratifiedKFold
from utils.loggers import log_to_file
# Custom log format
fmt = "{message}"
config = {
"handlers": [
{"sink": sys.stderr, "format": fmt},
],
}
logger.configure(**config)
# importing experiments
# from experiments.ClassifierExperiment import ClassifierExperiment
if __name__ == "__main__":
# optional arguments from the command line
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', type=str, default='datasets/train', help='root dir for training data')
parser.add_argument('--valid_path', type=str, default='datasets/val', help='root dir for validation data')
parser.add_argument('--train_masks_path', type=str, default=None, help='(Optional) root dir for training masks data. Default = None')
parser.add_argument('--valid_masks_path', type=str, default=None, help='(Optional) root dir for validation masks data. Must be passed when the train masks are used. Default is None.')
parser.add_argument('--output', type=str, default='outputs', help="output dir for saving results")
parser.add_argument('--experiment_name', type=str, default='exp0001', help='experiment name')
parser.add_argument('--network_name', type=str, default='DenseNet', help='network name')
parser.add_argument('--max_epochs', type=int, default=50, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=32, help='batch_size per gpu') # increase the batch size
parser.add_argument('--base_lr', type=float, default=0.0001, help='network learning rate') # 0.0001 add a zero to decrease
parser.add_argument('--patience', type=int, default=5, help='patience for lr and early stopping scheduler')
parser.add_argument('--img_size', type=int, default=224, help='input image size of network input')
parser.add_argument('--seed', type=int, default=42, help='random seed value')
parser.add_argument('--verbose', type=int, default=1, help='verbose value [0:2]')
parser.add_argument("--normalize_attn", action='store_true', help='if True, attention map is normalized by softmax; otherwise use sigmoid. This is only for certain networks.')
parser.add_argument('--num_folds', type=int, default=5, help='number of folds for cross-validation. This will build a model for each fold.')
parser.add_argument("--focal_loss", action='store_true', help='if True, focal loss is used; otherwise use cross entropy loss. This is only for multi-class classification.')
parser.add_argument("--multi", action='store_true', help='if True, we use the 3 class labels for loading the data.')
# get cmd args from the parser
args = parser.parse_args()
logger.info(f"Excuting training pipeline with {args.max_epochs} epochs and {args.num_folds} folds.")
# set paths and dirs
args.exp = args.experiment_name + '_' + str(args.img_size)
output_path = os.path.join(os.getcwd(), args.output, "{}".format(args.exp)) # 'outputs/exp0001_224'
snapshot_path = output_path + '_epo' + str(args.max_epochs)
snapshot_path = snapshot_path + '_bs' + str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr)
snapshot_path = snapshot_path + '_s' + str(args.seed) # 'outputs/exp0001_224_epo50_bs32_lr0.0001_s42'
checkpoint_file = args.exp + '_' + args.network_name + '_epo' + str(args.max_epochs)
checkpoint_file = checkpoint_file + '_bs' + str(args.batch_size)
checkpoint_file = checkpoint_file + '_lr' + str(args.base_lr)
checkpoint_file = checkpoint_file + '_seed' + str(args.seed)
output_path = os.path.join(snapshot_path, f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{args.network_name}')
# set seed value
seeding(args.seed)
# load the data from the disk
# ch1 -> class_labels = {'nevus': 0, 'others': 1}
# ch2 -> class_labels = {'mel': 0, 'bcc': 1, 'scc': 2}
if args.multi:
logger.info(f"Loading data with 3 class labels...")
_labels = {'mel': 0, 'bcc': 1, 'scc': 2}
else:
logger.info(f"Loading data with 2 class labels...")
_labels = {'nevus': 0, 'others': 1}
# dataset_df, images, masks, labels, n_classes
_, train_images, train_masks, train_labels, n_classes = LoadData(
dataset_path= args.train_path,
masks_path= args.train_masks_path,
class_labels = _labels)
_, val_images, val_masks, val_labels, n_classes = LoadData(
dataset_path= args.valid_path,
masks_path = args.valid_masks_path,
class_labels = _labels)
if args.train_masks_path and args.valid_masks_path:
logger.info("Using segmentation masks for training...")
logger.info(f"train_images: {len(train_images)}, train_masks: {len(train_masks)}, train_labels: {len(train_labels)}")
logger.info(f"val_images: {len(val_images)}, val_masks: {len(val_masks)}, val_labels: {len(val_labels)}")
# asserting
assert len(train_images) == len(train_masks), "Number of training images and masks should be the same."
assert len(val_images) == len(val_masks), "Number of valid images and masks should be the same."
# Use StratifiedKFold for cross-validation
skf = StratifiedKFold(n_splits=args.num_folds, shuffle=True, random_state=args.seed)
# Concatenate the training and validation datasets
all_images = train_images + val_images
all_labels = train_labels + val_labels
all_masks = train_masks + val_masks # we need the val masks as we combine the train and val when we split!
# Create a new instance of the experiment
experiment = getExperiment(args.experiment_name)
network = getNetwork(args.network_name)
# args.normalize_attn is only possible when the network is VGG16_BN_Attention
assert not (args.normalize_attn and network != VGG16_BN_Attention), "normalize_att is expected to be used with args.network_name='VGG16_BN_Attention' only."
# args.focal_loss is only possible when n_classes > 2
assert not (args.focal_loss and n_classes == 2), "focal_loss is expected to be used with multi-class classification only."
# Initialize lists to store predictions from each fold
all_fold_metrics = []
all_val_targets = []
# Iterate over folds
for fold, (train_index, val_index) in enumerate(skf.split(all_images, all_labels)):
# Create dataset and loaders for current fold
fold_train_dataset = Dataset(
images_path=[all_images[i] for i in train_index],
labels=[all_labels[i] for i in train_index],
masks_path = [all_masks[i] for i in train_index] if args.train_masks_path and args.valid_masks_path else None,
transform=True,
split="train",
input_size=(args.img_size,args.img_size))
fold_val_dataset = Dataset(
images_path=[all_images[i] for i in val_index],
labels=[all_labels[i] for i in val_index],
transform=True,
split="val",
input_size=(args.img_size,args.img_size))
fold_train_loader = torch.utils.data.DataLoader(
fold_train_dataset, batch_size=args.batch_size, shuffle=True)
fold_val_loader = torch.utils.data.DataLoader(
fold_val_dataset, batch_size=args.batch_size, shuffle=True)
# Add fold number to the checkpoint file name
fold_ckp_file = checkpoint_file + f'_fold{fold+1}' + '.pth'
# Create a new directory for the current fold
fold_path = os.path.join(output_path, f'fold_{fold+1}')
# get the class weights
class_weights = fold_train_dataset.get_class_weight()
logger.info(f"Class weights: {class_weights}")
# Initialize the experiment
exp = experiment(
args,
fold_train_loader,
fold_val_loader,
n_classes,
checkpoint_file = fold_ckp_file,
Network = network,
fold_path = fold_path,
fold_num = fold+1,
c_weights = class_weights)
# run training and validation for current fold
exp.run()
# run the val/test report generation
metrics_dict = exp.run_test(test_loader=fold_val_loader, save_path=fold_path)
# append the metrics to the list
all_fold_metrics.append(metrics_dict)
# compute the average metrics across all folds 'accuracy', 'auc', 'kappa'
all_fold_metrics = np.array(all_fold_metrics)
avg_accuracy = np.mean([fmet['accuracy'] for fmet in all_fold_metrics])
avg_auc = np.mean([fmet['auc'] for fmet in all_fold_metrics])
avg_kappa = np.mean([fmet['kappa'] for fmet in all_fold_metrics])
logger.info(f"Average accuracy: {avg_accuracy:.5f}. Average AUC: {avg_auc:.5f}. Average Kappa: {avg_kappa:.5f}")