Skip to content

Commit

Permalink
Added (very hacky) script to generate confusion matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Oct 29, 2024
1 parent 7298817 commit b03ed3c
Showing 1 changed file with 155 additions and 47 deletions.
202 changes: 155 additions & 47 deletions neural_networks/plots/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,201 @@

SCRIPT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from train_nn import MultiEpochsDataLoader
from pre_processing_for_ml import FitsDataset
from pre_processing_for_ml import normalize_fits
import matplotlib.pyplot as plt
import numpy as np
import torch
from functools import lru_cache
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from astropy.io import fits


class RawFitsDataset(Dataset):
def __init__(self, root_dir, mode="train"):
"""
Args:
root_dir (string): Directory with good/bad folders in it.
"""

modes = ("train", "val")
assert mode in modes

classes = {"stop": 0, "continue": 1}

root_dir = Path(root_dir)
assert root_dir.exists(), f"'{root_dir}' doesn't exist!"

ext = ".fits"
glob_ext = "*" + ext

self.root_dir = root_dir

for folder in (
root_dir / (cls + ("" if mode == "train" else "_val")) for cls in classes
):
assert (
folder.exists()
), f"root folder doesn't exist, got: '{str(folder.resolve())}'"
assert (
len(list(folder.glob(glob_ext))) > 0
), f"no '{ext}' files were found in '{str(folder.resolve())}'"

# Yes this code is way overengineered. Yes I also derive pleasure from writing it :) - RJS
#
# Actual documentation:
# You want all 'self.x' variables to be non-python objects such as numpy arrays,
# otherwise you get memory leaks in the PyTorch dataloader
self.data_paths, self.labels = map(
np.asarray,
list(
zip(
*(
(str(file), val)
for cls, val in classes.items()
for file in (
root_dir / (cls + ("" if mode == "train" else "_val"))
).glob(glob_ext)
)
)
),
)

assert len(self.data_paths) > 0
self.sources = ", ".join(
sorted([str(elem).split("/")[-1].strip(ext) for elem in self.data_paths])
)
self.mode = mode
_, counts = np.unique(self.labels, return_counts=True)
self.label_ratio = counts[0] / counts[1]
# print(f'{mode}: using the following sources: {sources}')

def load_model(architecture_name, model_name):
StopPredictor: type(Architecture) = get_architecture(architecture_name)
predictor = StopPredictor(device="cuda", model_name=model_name)
return predictor
@staticmethod
def transform_data(image_data):
"""
Transform data for preprocessing
"""

# FIXME: this should really be a parameter
image_data = torch.from_numpy(image_data).to(torch.bfloat16)
image_data = torch.movedim(image_data, -1, 0)

def get_dataloader(data_root, mode, batch_size):
num_workers = min(12, len(os.sched_getaffinity(0)))
return image_data

prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else (None, False)
)
@lru_cache(maxsize=1)
def __len__(self):
return len(self.data_paths)

return MultiEpochsDataLoader(
dataset=FitsDataset(
data_root,
mode=mode,
),
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=True,
shuffle=False,
drop_last=False,
)
def __getitem__(self, idx):

fits_path = self.data_paths[idx]
label = self.labels[idx]

def get_statistics(data_root, mode):
return FitsDataset(
data_root,
mode=mode,
).compute_statistics(1)
image_data = process_fits(fits_path)
# there is always only one array

# Pre-processing
image_data = self.transform_data(image_data)

return image_data, label


def load_model(architecture_name, model_name, device="cpu"):
StopPredictor: type(Architecture) = get_architecture(architecture_name)
predictor = StopPredictor(device=device, model_name=model_name)
return predictor


@torch.no_grad()
def get_confusion_matrix(
model_name, predictor, dataloader, mean, std, thresholds=[0.2, 0.3, 0.4, 0.5]
):
def get_confusion_matrix(predictor, dataloader, mean, std, thresholds):
confusion_matrices = np.zeros((len(thresholds), 2, 2))
thresholds = torch.tensor(thresholds)
for img, label in dataloader:
for i, (img, label) in enumerate(dataloader):
data = predictor.prepare_batch(img, mean=mean, std=std)
pred = torch.sigmoid(predictor.model(data)).to("cpu")
preds_thres = pred >= thresholds
for i, _ in enumerate(thresholds):
confusion_matrices[i] += confusion_matrix(
label, preds_thres[:, i], labels=[0, 1]
)

return confusion_matrices


def plot_conf_matrices(savedir, confusion_matrices, thresholds):
savedir = model_name.split("/")[-1]
os.makedirs(savedir, exist_ok=True)
for i, conf_matrix in enumerate(confusion_matrices):

disp = ConfusionMatrixDisplay(
# Normalization
conf_matrix / np.sum(conf_matrix, axis=1, keepdims=True),
display_labels=["continue", "stop"],
display_labels=["stop", "continue"],
)
print(conf_matrix)
# print(conf_matrix)
disp.plot()

plt.savefig(f"{savedir}/confusion_thres_{thresholds[i]:.3f}.png")


def process_fits(fits_path):
with fits.open(fits_path) as hdul:
image_data = hdul[0].data

return normalize_fits(image_data)


def get_dataloader(data_root, mode="val", batch_size=32):
dataset = RawFitsDataset(data_root, mode="val")
num_workers = min(12, len(os.sched_getaffinity(0)))

prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else (None, False)
)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=num_workers,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
drop_last=False,
)

return dataloader


if __name__ == "__main__":
# Latest model
model_name = "surf/dinov2_09739_rotations"
TESTING = True
architecture_name = "surf/TransferLearning"
predictor = load_model(architecture_name, model_name)
if hasattr(predictor, "args") and "dataset_mean" in predictor.args:
mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"]
# Set Device here
DEVICE = "cuda"
thresholds = [0.2, 0.3, 0.4, 0.5]
data_root = "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data"
# Uses cached confusion matrix for testing the plotting functionalities
if model_name == "surf/dinov2_09739_rotations" and TESTING:
confusion_matrices = np.asarray(
[
[[149, 56], [2, 116]],
[[178, 27], [4, 114]],
[[190, 15], [6, 112]],
[[191, 14], [7, 111]],
]
)
else:
mean, std = get_statistics(
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/",
mode="train",

dataloader = get_dataloader(data_root, mode="val")

predictor = load_model(architecture_name, model_name, device=DEVICE)

mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"]

confusion_matrices = get_confusion_matrix(
predictor, dataloader, mean, std, thresholds
)

dataloader = get_dataloader(
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/",
mode="val",
batch_size=32,
)
get_confusion_matrix(model_name, predictor, dataloader, mean, std)
print(confusion_matrices)

plot_conf_matrices(model_name.split("/")[-1], confusion_matrices, thresholds)

0 comments on commit b03ed3c

Please sign in to comment.