From b03ed3c1a23ec8b7886455c391559492c37ce718 Mon Sep 17 00:00:00 2001 From: LVeefkind Date: Tue, 29 Oct 2024 13:22:53 +0100 Subject: [PATCH] Added (very hacky) script to generate confusion matrices --- neural_networks/plots/confusion_matrix.py | 202 +++++++++++++++++----- 1 file changed, 155 insertions(+), 47 deletions(-) diff --git a/neural_networks/plots/confusion_matrix.py b/neural_networks/plots/confusion_matrix.py index 235c61c..ac5a9ea 100644 --- a/neural_networks/plots/confusion_matrix.py +++ b/neural_networks/plots/confusion_matrix.py @@ -5,56 +5,116 @@ SCRIPT_DIR = Path(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(SCRIPT_DIR)) -from train_nn import MultiEpochsDataLoader -from pre_processing_for_ml import FitsDataset +from pre_processing_for_ml import normalize_fits import matplotlib.pyplot as plt import numpy as np import torch +from functools import lru_cache +from torch.utils.data import Dataset, DataLoader from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +from astropy.io import fits + + +class RawFitsDataset(Dataset): + def __init__(self, root_dir, mode="train"): + """ + Args: + root_dir (string): Directory with good/bad folders in it. + """ + + modes = ("train", "val") + assert mode in modes + + classes = {"stop": 0, "continue": 1} + + root_dir = Path(root_dir) + assert root_dir.exists(), f"'{root_dir}' doesn't exist!" + + ext = ".fits" + glob_ext = "*" + ext + + self.root_dir = root_dir + + for folder in ( + root_dir / (cls + ("" if mode == "train" else "_val")) for cls in classes + ): + assert ( + folder.exists() + ), f"root folder doesn't exist, got: '{str(folder.resolve())}'" + assert ( + len(list(folder.glob(glob_ext))) > 0 + ), f"no '{ext}' files were found in '{str(folder.resolve())}'" + + # Yes this code is way overengineered. Yes I also derive pleasure from writing it :) - RJS + # + # Actual documentation: + # You want all 'self.x' variables to be non-python objects such as numpy arrays, + # otherwise you get memory leaks in the PyTorch dataloader + self.data_paths, self.labels = map( + np.asarray, + list( + zip( + *( + (str(file), val) + for cls, val in classes.items() + for file in ( + root_dir / (cls + ("" if mode == "train" else "_val")) + ).glob(glob_ext) + ) + ) + ), + ) + assert len(self.data_paths) > 0 + self.sources = ", ".join( + sorted([str(elem).split("/")[-1].strip(ext) for elem in self.data_paths]) + ) + self.mode = mode + _, counts = np.unique(self.labels, return_counts=True) + self.label_ratio = counts[0] / counts[1] + # print(f'{mode}: using the following sources: {sources}') -def load_model(architecture_name, model_name): - StopPredictor: type(Architecture) = get_architecture(architecture_name) - predictor = StopPredictor(device="cuda", model_name=model_name) - return predictor + @staticmethod + def transform_data(image_data): + """ + Transform data for preprocessing + """ + # FIXME: this should really be a parameter + image_data = torch.from_numpy(image_data).to(torch.bfloat16) + image_data = torch.movedim(image_data, -1, 0) -def get_dataloader(data_root, mode, batch_size): - num_workers = min(12, len(os.sched_getaffinity(0))) + return image_data - prefetch_factor, persistent_workers = ( - (2, True) if num_workers > 0 else (None, False) - ) + @lru_cache(maxsize=1) + def __len__(self): + return len(self.data_paths) - return MultiEpochsDataLoader( - dataset=FitsDataset( - data_root, - mode=mode, - ), - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers, - pin_memory=True, - shuffle=False, - drop_last=False, - ) + def __getitem__(self, idx): + fits_path = self.data_paths[idx] + label = self.labels[idx] -def get_statistics(data_root, mode): - return FitsDataset( - data_root, - mode=mode, - ).compute_statistics(1) + image_data = process_fits(fits_path) + # there is always only one array + + # Pre-processing + image_data = self.transform_data(image_data) + + return image_data, label + + +def load_model(architecture_name, model_name, device="cpu"): + StopPredictor: type(Architecture) = get_architecture(architecture_name) + predictor = StopPredictor(device=device, model_name=model_name) + return predictor @torch.no_grad() -def get_confusion_matrix( - model_name, predictor, dataloader, mean, std, thresholds=[0.2, 0.3, 0.4, 0.5] -): +def get_confusion_matrix(predictor, dataloader, mean, std, thresholds): confusion_matrices = np.zeros((len(thresholds), 2, 2)) thresholds = torch.tensor(thresholds) - for img, label in dataloader: + for i, (img, label) in enumerate(dataloader): data = predictor.prepare_batch(img, mean=mean, std=std) pred = torch.sigmoid(predictor.model(data)).to("cpu") preds_thres = pred >= thresholds @@ -62,6 +122,11 @@ def get_confusion_matrix( confusion_matrices[i] += confusion_matrix( label, preds_thres[:, i], labels=[0, 1] ) + + return confusion_matrices + + +def plot_conf_matrices(savedir, confusion_matrices, thresholds): savedir = model_name.split("/")[-1] os.makedirs(savedir, exist_ok=True) for i, conf_matrix in enumerate(confusion_matrices): @@ -69,29 +134,72 @@ def get_confusion_matrix( disp = ConfusionMatrixDisplay( # Normalization conf_matrix / np.sum(conf_matrix, axis=1, keepdims=True), - display_labels=["continue", "stop"], + display_labels=["stop", "continue"], ) - print(conf_matrix) + # print(conf_matrix) disp.plot() plt.savefig(f"{savedir}/confusion_thres_{thresholds[i]:.3f}.png") +def process_fits(fits_path): + with fits.open(fits_path) as hdul: + image_data = hdul[0].data + + return normalize_fits(image_data) + + +def get_dataloader(data_root, mode="val", batch_size=32): + dataset = RawFitsDataset(data_root, mode="val") + num_workers = min(12, len(os.sched_getaffinity(0))) + + prefetch_factor, persistent_workers = ( + (2, True) if num_workers > 0 else (None, False) + ) + dataloader = DataLoader( + dataset, + batch_size=32, + shuffle=True, + num_workers=num_workers, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + drop_last=False, + ) + + return dataloader + + if __name__ == "__main__": + # Latest model model_name = "surf/dinov2_09739_rotations" + TESTING = True architecture_name = "surf/TransferLearning" - predictor = load_model(architecture_name, model_name) - if hasattr(predictor, "args") and "dataset_mean" in predictor.args: - mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"] + # Set Device here + DEVICE = "cuda" + thresholds = [0.2, 0.3, 0.4, 0.5] + data_root = "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data" + # Uses cached confusion matrix for testing the plotting functionalities + if model_name == "surf/dinov2_09739_rotations" and TESTING: + confusion_matrices = np.asarray( + [ + [[149, 56], [2, 116]], + [[178, 27], [4, 114]], + [[190, 15], [6, 112]], + [[191, 14], [7, 111]], + ] + ) else: - mean, std = get_statistics( - "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/", - mode="train", + + dataloader = get_dataloader(data_root, mode="val") + + predictor = load_model(architecture_name, model_name, device=DEVICE) + + mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"] + + confusion_matrices = get_confusion_matrix( + predictor, dataloader, mean, std, thresholds ) - dataloader = get_dataloader( - "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/", - mode="val", - batch_size=32, - ) - get_confusion_matrix(model_name, predictor, dataloader, mean, std) + print(confusion_matrices) + + plot_conf_matrices(model_name.split("/")[-1], confusion_matrices, thresholds)