From bad816ff63d4fb37d9e67e2a3a2a7bc59044d567 Mon Sep 17 00:00:00 2001 From: robogast Date: Thu, 26 Sep 2024 15:54:13 +0200 Subject: [PATCH] added pr curve plot script --- neural_networks/inference.py | 158 ++++++++++++++++------- neural_networks/scripts/plot_pr_curve.py | 89 +++++++++++++ neural_networks/train_nn.py | 16 +-- 3 files changed, 208 insertions(+), 55 deletions(-) create mode 100644 neural_networks/scripts/plot_pr_curve.py diff --git a/neural_networks/inference.py b/neural_networks/inference.py index f663b247..5c5a566d 100644 --- a/neural_networks/inference.py +++ b/neural_networks/inference.py @@ -1,4 +1,7 @@ +import argparse import os +from functools import partial +from typing import Callable, Any, Sequence import matplotlib.image import torch @@ -9,77 +12,138 @@ from pre_processing_for_ml import FitsDataset -@torch.no_grad() -def variational_dropout(model, dataloader, variational_iters=25): +def variational_dropout(model, batch, variational_iters: int): + model.feature_extractor.eval() - model.classifier.train() + + if variational_iters == 0: + # Disable dropout + model.classifier.eval() + variational_iters = 1 + else: + # Enable dropout + model.classifier.train() model.cuda() - for sample in tqdm(dataloader): - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - batch, labels = sample[0].cuda(non_blocking=True), sample[1].cuda(non_blocking=True) - preds = torch.sigmoid(torch.concat([model(batch).clone() for _ in range(variational_iters)], dim=1)) + with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): + sample = batch[0].cuda(non_blocking=True).expand(-1, 3, -1, -1) - means = preds.mean(dim=1) - stds = preds.std(dim=1) + means, stds = variational_dropout_step(model, sample, variational_iters) - yield ( - means.cpu(), - stds.cpu() - ) + return means, stds + + +@torch.no_grad() +def variational_dropout_step(model: torch.nn.Module, sample, variational_iters: int): + preds = torch.sigmoid(torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1)) + + means = preds.mean(dim=1) + + if preds.shape[1] == 1: + stds = torch.zeros_like(means) + else: + stds = preds.std(dim=1) + + return means, stds + + +def save_images(dataset, out_path, preds, stds): + os.makedirs(out_path, exist_ok=True) -def save_images(dataset, out_paths, preds, stds, mode): for elem, in_path, pred, std in tqdm(zip(dataset, dataset.data_paths, preds, stds), total=len(dataset)): batch, label = elem name = in_path.strip('.npz').split('/')[-1] matplotlib.image.imsave( - fname=f'{out_paths}/{std:.3f}_{pred:.3f}_{label}_{name}.png', + fname=f'{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png', arr=batch.to(torch.float).movedim(0, -1).numpy() ) -def main(dataset_root, checkpoint_path): +def gen_output_from_dataset( + dataset: torch.utils.data.Dataset, + inference_f: Any, + **dataloader_kwargs + ): + + num_workers = dataloader_kwargs.pop('num_workers', min(18, len(os.sched_getaffinity(0)))) + batch_size = dataloader_kwargs.pop('batch_size', 64) + + dataloader = DataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + shuffle=False, + drop_last=False, + **dataloader_kwargs + ) + + outputs = [inference_f(batch=batch) for batch in dataloader] + + return tuple(map(lambda x: torch.concat(x).to(torch.device('cpu'), dtype=torch.float).numpy(), zip(*outputs))) + +def dataset_inference_vi( + dataset: torch.utils.data.Dataset, + checkpoint_path: str, + variational_iters: int, + ): torch.set_float32_matmul_precision('high') - ckpt_dict = load_checkpoint(checkpoint_path) - model = ckpt_dict['model'] - breakpoint() + model = load_checkpoint(checkpoint_path)['model'] - num_workers = min(18, len(os.sched_getaffinity(0))) - prefetch_factor, persistent_workers = ( - (2, True) if num_workers > 0 else - (None, False) - ) - batch_size = 64 - - def gen_and_save(mode): - dataset = FitsDataset(dataset_root, mode=mode) - - dataloader = DataLoader( - dataset=dataset, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - persistent_workers=False, - pin_memory=True, - shuffle=False, - drop_last=False, - ) + inference_f = partial(variational_dropout, model=model, variational_iters=variational_iters) + + preds, stds = gen_output_from_dataset(dataset=dataset, inference_f=inference_f) + + return preds, stds - out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/' - preds, stds = map(torch.concat, zip(*[elem for elem in variational_dropout(model, dataloader)])) - save_images(dataset, out_path, preds, stds, mode=mode) +def is_dir(path): + if not os.path.isdir(path): + raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory.") + return path - for mode in ('train', 'val'): - gen_and_save(mode) +def is_file(path): + if not os.path.isfile(path): + raise argparse.ArgumentTypeError(f"'{path}' is not a valid file.") + return path +def positive_int(value): + def err(value): + raise argparse.ArgumentTypeError(f"'{value}' is not a valid positive integer.") + try: + ivalue = int(value) + except: + err(value) + if ivalue < 1: + err(value) + + return ivalue + +def get_args(): + parser = argparse.ArgumentParser(description="Argument parser for dataset and variational iterations") + + parser.add_argument('--dataset_root', type=is_dir, required=True, help="Path to the dataset root directory") + parser.add_argument('--checkpoint_path', type=is_file, required=True, help="Path to the checkpoint root directory") + parser.add_argument('--save_images_path', type=is_dir, default=None, help="Path to save images (optional)") + parser.add_argument('--variational_iters', type=positive_int, default=5, help="Number of variational iterations (must be >= 1)") + + return parser.parse_args() if __name__ == '__main__': + args = get_args() + + dataset = FitsDataset(dataset_root, mode=mode) + dataset_inference(dataset, args.variational_iters, args.save_images_path) + + if save_images_path is not None: + out_path = save_images_path + f'/preds_{mode}' + save_images(dataset, out_path, preds, stds) + # root = f'{os.environ["TMPDIR"]}/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' - root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' - checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth' + # root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' + # checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth' + # out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/' - main(root, checkpoint_path) diff --git a/neural_networks/scripts/plot_pr_curve.py b/neural_networks/scripts/plot_pr_curve.py new file mode 100644 index 00000000..3a3b85fb --- /dev/null +++ b/neural_networks/scripts/plot_pr_curve.py @@ -0,0 +1,89 @@ +from collections import defaultdict +from pickle import Pickler +import sys +import os + +import tensorboard +import numpy as np +from sklearn.metrics import precision_recall_curve +import matplotlib.pyplot as plt + +sys.path.append('..') # yes this is mega ugly, but otherwise I need to restructure the whole project... + +from inference import dataset_inference_vi, is_file +from train_nn import ImagenetTransferLearning +from pre_processing_for_ml import FitsDataset + +DATASET_ROOT = '/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' +CHECKPOINT = '../grid_search/version_7944362_0__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25__use_compile_True/ckpt_step=1643.pth' +PICKLE_DICT_FNAME = 'precision_recall_curves.npy' + + +if os.path.isfile(PICKLE_DICT_FNAME): + pickle_dict = np.load(PICKLE_DICT_FNAME, allow_pickle=True)[()] # Don't ask me why this is the syntax +else: + # Yes I'm aware of the existence of collections.defaultdict, but that doesn't work with np.save... + pickle_dict = {} + +variational_iter_vals = (0, 1, 2, 4, 16) +modes = ('val', 'train') + +# variational_iter_vals = (0, 1,) +# modes = ('val',) + +new_vals = False + + +for mode in modes: + if mode not in pickle_dict: + pickle_dict[mode] = {} + + for variational_iters in variational_iter_vals: + variational_iters_str = f'variational_iters_{variational_iters}' + + if variational_iters_str in pickle_dict[mode]: + print(f"{mode}/{variational_iters_str} already in saved pickle; skipping.") + + else: + print(f"{mode}/{variational_iters_str} not found; calculating") + new_vals = True + + pickle_dict[mode][variational_iters_str] = {} + + dataset = FitsDataset(DATASET_ROOT, mode=mode) + preds, stds = dataset_inference_vi(dataset, CHECKPOINT, variational_iters=variational_iters) + + precision, recall, thresholds = precision_recall_curve(dataset.labels, preds) + + for name, value in ( + ('precision', precision), + ('recall', recall), + ('thresholds', thresholds), + ('preds', preds), + ('stds', stds), + ('labels', dataset.labels), + ('sources', dataset.data_paths) + ): + pickle_dict[mode][variational_iters_str][name] = value + + plt.plot( + pickle_dict[mode][variational_iters_str]['recall'], + pickle_dict[mode][variational_iters_str]['precision'], + label=f'VI Iters: {variational_iters}' + ) + + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('PR Curves for varying Variational Inference iteration counts') + plt.legend() + plt.grid(True, linestyle='--', alpha=0.7) + + plt.tight_layout() + plt.savefig(f'precision_recall_curves_{mode}.png', dpi=300) + plt.clf() + +if new_vals: + print("saving new/updated pickle_dict") + + # pylance or w/e can complain, but this is valid syntax. + np.save(PICKLE_DICT_FNAME, pickle_dict) \ No newline at end of file diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index 28af900b..f7d2ce1a 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -173,8 +173,8 @@ def train(self): self.classifier.train() def get_dataloaders(dataset_root, batch_size, normalize): - # num_workers = min(12, len(os.sched_getaffinity(0))) - num_workers = 0 + num_workers = min(12, len(os.sched_getaffinity(0))) + # num_workers = 0 prefetch_factor, persistent_workers = ( (2, True) if num_workers > 0 else (None, False) @@ -527,13 +527,13 @@ def load_checkpoint(ckpt_path): model = ckpt_dict['model'](model_name=model_name, dropout_p=dropout_p) model.load_state_dict(ckpt_dict['model_state_dict']) - # FIXME: add optim class and args to state dict - optim = ckpt_dict.get('optimizer', torch.optim.AdamW)( - lr=lr, - params=model.classifier.parameters() - ).load_state_dict(ckpt_dict['optimizer_state_dict']) + # # FIXME: add optim class and args to state dict + # optim = ckpt_dict.get('optimizer', torch.optim.AdamW)( + # lr=lr, + # params=model.classifier.parameters() + # ).load_state_dict(ckpt_dict['optimizer_state_dict']) - return {'model': model, 'optim': optim, 'normalize': normalize} + return {'model': model, 'normalize': normalize} def get_argparser(): """