Skip to content

Commit

Permalink
Merge branch 'pr-curve-plots' of https://github.com/jurjen93/lofar_he…
Browse files Browse the repository at this point in the history
  • Loading branch information
robogast committed Sep 26, 2024
2 parents d015421 + bad816f commit 6e78896
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 52 deletions.
158 changes: 111 additions & 47 deletions neural_networks/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import argparse
import os
from functools import partial
from typing import Callable, Any, Sequence

import matplotlib.image
import torch
Expand All @@ -9,77 +12,138 @@
from pre_processing_for_ml import FitsDataset


@torch.no_grad()
def variational_dropout(model, dataloader, variational_iters=25):
def variational_dropout(model, batch, variational_iters: int):

model.feature_extractor.eval()
model.classifier.train()

if variational_iters == 0:
# Disable dropout
model.classifier.eval()
variational_iters = 1
else:
# Enable dropout
model.classifier.train()

model.cuda()

for sample in tqdm(dataloader):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
batch, labels = sample[0].cuda(non_blocking=True), sample[1].cuda(non_blocking=True)
preds = torch.sigmoid(torch.concat([model(batch).clone() for _ in range(variational_iters)], dim=1))
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
sample = batch[0].cuda(non_blocking=True).expand(-1, 3, -1, -1)

means = preds.mean(dim=1)
stds = preds.std(dim=1)
means, stds = variational_dropout_step(model, sample, variational_iters)

yield (
means.cpu(),
stds.cpu()
)

return means, stds


@torch.no_grad()
def variational_dropout_step(model: torch.nn.Module, sample, variational_iters: int):
preds = torch.sigmoid(torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1))

means = preds.mean(dim=1)

if preds.shape[1] == 1:
stds = torch.zeros_like(means)
else:
stds = preds.std(dim=1)

return means, stds


def save_images(dataset, out_path, preds, stds):
os.makedirs(out_path, exist_ok=True)

def save_images(dataset, out_paths, preds, stds, mode):
for elem, in_path, pred, std in tqdm(zip(dataset, dataset.data_paths, preds, stds), total=len(dataset)):
batch, label = elem
name = in_path.strip('.npz').split('/')[-1]
matplotlib.image.imsave(
fname=f'{out_paths}/{std:.3f}_{pred:.3f}_{label}_{name}.png',
fname=f'{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png',
arr=batch.to(torch.float).movedim(0, -1).numpy()
)

def main(dataset_root, checkpoint_path):
def gen_output_from_dataset(
dataset: torch.utils.data.Dataset,
inference_f: Any,
**dataloader_kwargs
):

num_workers = dataloader_kwargs.pop('num_workers', min(18, len(os.sched_getaffinity(0))))
batch_size = dataloader_kwargs.pop('batch_size', 64)

dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=False,
drop_last=False,
**dataloader_kwargs
)

outputs = [inference_f(batch=batch) for batch in dataloader]

return tuple(map(lambda x: torch.concat(x).to(torch.device('cpu'), dtype=torch.float).numpy(), zip(*outputs)))

def dataset_inference_vi(
dataset: torch.utils.data.Dataset,
checkpoint_path: str,
variational_iters: int,
):
torch.set_float32_matmul_precision('high')

ckpt_dict = load_checkpoint(checkpoint_path)
model = ckpt_dict['model']
breakpoint()
model = load_checkpoint(checkpoint_path)['model']

num_workers = min(18, len(os.sched_getaffinity(0)))
prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else
(None, False)
)
batch_size = 64

def gen_and_save(mode):
dataset = FitsDataset(dataset_root, mode=mode)

dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=False,
pin_memory=True,
shuffle=False,
drop_last=False,
)
inference_f = partial(variational_dropout, model=model, variational_iters=variational_iters)

preds, stds = gen_output_from_dataset(dataset=dataset, inference_f=inference_f)

return preds, stds

out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/'

preds, stds = map(torch.concat, zip(*[elem for elem in variational_dropout(model, dataloader)]))
save_images(dataset, out_path, preds, stds, mode=mode)
def is_dir(path):
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory.")
return path

for mode in ('train', 'val'):
gen_and_save(mode)
def is_file(path):
if not os.path.isfile(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a valid file.")
return path

def positive_int(value):
def err(value):
raise argparse.ArgumentTypeError(f"'{value}' is not a valid positive integer.")
try:
ivalue = int(value)
except:
err(value)

if ivalue < 1:
err(value)

return ivalue

def get_args():
parser = argparse.ArgumentParser(description="Argument parser for dataset and variational iterations")

parser.add_argument('--dataset_root', type=is_dir, required=True, help="Path to the dataset root directory")
parser.add_argument('--checkpoint_path', type=is_file, required=True, help="Path to the checkpoint root directory")
parser.add_argument('--save_images_path', type=is_dir, default=None, help="Path to save images (optional)")
parser.add_argument('--variational_iters', type=positive_int, default=5, help="Number of variational iterations (must be >= 1)")

return parser.parse_args()

if __name__ == '__main__':
args = get_args()

dataset = FitsDataset(dataset_root, mode=mode)
dataset_inference(dataset, args.variational_iters, args.save_images_path)

if save_images_path is not None:
out_path = save_images_path + f'/preds_{mode}'
save_images(dataset, out_path, preds, stds)

# root = f'{os.environ["TMPDIR"]}/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth'
# root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
# checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth'
# out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/'

main(root, checkpoint_path)
89 changes: 89 additions & 0 deletions neural_networks/scripts/plot_pr_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from collections import defaultdict
from pickle import Pickler
import sys
import os

import tensorboard
import numpy as np
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

sys.path.append('..') # yes this is mega ugly, but otherwise I need to restructure the whole project...

from inference import dataset_inference_vi, is_file
from train_nn import ImagenetTransferLearning
from pre_processing_for_ml import FitsDataset

DATASET_ROOT = '/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
CHECKPOINT = '../grid_search/version_7944362_0__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25__use_compile_True/ckpt_step=1643.pth'
PICKLE_DICT_FNAME = 'precision_recall_curves.npy'


if os.path.isfile(PICKLE_DICT_FNAME):
pickle_dict = np.load(PICKLE_DICT_FNAME, allow_pickle=True)[()] # Don't ask me why this is the syntax
else:
# Yes I'm aware of the existence of collections.defaultdict, but that doesn't work with np.save...
pickle_dict = {}

variational_iter_vals = (0, 1, 2, 4, 16)
modes = ('val', 'train')

# variational_iter_vals = (0, 1,)
# modes = ('val',)

new_vals = False


for mode in modes:
if mode not in pickle_dict:
pickle_dict[mode] = {}

for variational_iters in variational_iter_vals:
variational_iters_str = f'variational_iters_{variational_iters}'

if variational_iters_str in pickle_dict[mode]:
print(f"{mode}/{variational_iters_str} already in saved pickle; skipping.")

else:
print(f"{mode}/{variational_iters_str} not found; calculating")
new_vals = True

pickle_dict[mode][variational_iters_str] = {}

dataset = FitsDataset(DATASET_ROOT, mode=mode)
preds, stds = dataset_inference_vi(dataset, CHECKPOINT, variational_iters=variational_iters)

precision, recall, thresholds = precision_recall_curve(dataset.labels, preds)

for name, value in (
('precision', precision),
('recall', recall),
('thresholds', thresholds),
('preds', preds),
('stds', stds),
('labels', dataset.labels),
('sources', dataset.data_paths)
):
pickle_dict[mode][variational_iters_str][name] = value

plt.plot(
pickle_dict[mode][variational_iters_str]['recall'],
pickle_dict[mode][variational_iters_str]['precision'],
label=f'VI Iters: {variational_iters}'
)

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR Curves for varying Variational Inference iteration counts')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig(f'precision_recall_curves_{mode}.png', dpi=300)
plt.clf()

if new_vals:
print("saving new/updated pickle_dict")

# pylance or w/e can complain, but this is valid syntax.
np.save(PICKLE_DICT_FNAME, pickle_dict)
14 changes: 9 additions & 5 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def train(self):

def get_dataloaders(dataset_root, batch_size):
num_workers = min(12, len(os.sched_getaffinity(0)))
# num_workers = 0

prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else (None, False)
)
Expand Down Expand Up @@ -692,10 +692,14 @@ def load_checkpoint(ckpt_path):
model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p)
model.load_state_dict(ckpt_dict["model_state_dict"])

# FIXME: add optim class and args to state dict
optim = ckpt_dict.get("optimizer", torch.optim.AdamW)(
lr=lr, params=model.classifier.parameters()
).load_state_dict(ckpt_dict["optimizer_state_dict"])
try:
# FIXME: add optim class and args to state dict
optim = ckpt_dict.get("optimizer", torch.optim.AdamW)(
lr=lr, params=model.classifier.parameters()
).load_state_dict(ckpt_dict["optimizer_state_dict"])
except e:
print(f"Could not load optim due to {e}; skipping.")
optim = None

return {"model": model, "optim": optim, "normalize": normalize}

Expand Down

0 comments on commit 6e78896

Please sign in to comment.