diff --git a/neural_networks/__init__.py b/neural_networks/__init__.py index 9704ca81..42f110b9 100644 --- a/neural_networks/__init__.py +++ b/neural_networks/__init__.py @@ -82,6 +82,8 @@ def predict(self, data: torch.Tensor): with torch.autocast(dtype=self.dtype, device_type=self.device): if self.variational_dropout > 0: self.model.train() + else: + self.model.eval() predictions = torch.concat( [ diff --git a/neural_networks/inference.py b/neural_networks/inference.py index 5c5a566d..ff209964 100644 --- a/neural_networks/inference.py +++ b/neural_networks/inference.py @@ -26,18 +26,19 @@ def variational_dropout(model, batch, variational_iters: int): model.cuda() - with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): sample = batch[0].cuda(non_blocking=True).expand(-1, 3, -1, -1) means, stds = variational_dropout_step(model, sample, variational_iters) - return means, stds @torch.no_grad() def variational_dropout_step(model: torch.nn.Module, sample, variational_iters: int): - preds = torch.sigmoid(torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1)) + preds = torch.sigmoid( + torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1) + ) means = preds.mean(dim=1) @@ -52,22 +53,25 @@ def variational_dropout_step(model: torch.nn.Module, sample, variational_iters: def save_images(dataset, out_path, preds, stds): os.makedirs(out_path, exist_ok=True) - for elem, in_path, pred, std in tqdm(zip(dataset, dataset.data_paths, preds, stds), total=len(dataset)): + for elem, in_path, pred, std in tqdm( + zip(dataset, dataset.data_paths, preds, stds), total=len(dataset) + ): batch, label = elem - name = in_path.strip('.npz').split('/')[-1] + name = in_path.strip(".npz").split("/")[-1] matplotlib.image.imsave( - fname=f'{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png', - arr=batch.to(torch.float).movedim(0, -1).numpy() + fname=f"{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png", + arr=batch.to(torch.float).movedim(0, -1).numpy(), ) + def gen_output_from_dataset( - dataset: torch.utils.data.Dataset, - inference_f: Any, - **dataloader_kwargs - ): + dataset: torch.utils.data.Dataset, inference_f: Any, **dataloader_kwargs +): - num_workers = dataloader_kwargs.pop('num_workers', min(18, len(os.sched_getaffinity(0)))) - batch_size = dataloader_kwargs.pop('batch_size', 64) + num_workers = dataloader_kwargs.pop( + "num_workers", min(18, len(os.sched_getaffinity(0))) + ) + batch_size = dataloader_kwargs.pop("batch_size", 64) dataloader = DataLoader( dataset=dataset, @@ -76,23 +80,33 @@ def gen_output_from_dataset( pin_memory=True, shuffle=False, drop_last=False, - **dataloader_kwargs + **dataloader_kwargs, ) outputs = [inference_f(batch=batch) for batch in dataloader] - return tuple(map(lambda x: torch.concat(x).to(torch.device('cpu'), dtype=torch.float).numpy(), zip(*outputs))) + return tuple( + map( + lambda x: torch.concat(x) + .to(torch.device("cpu"), dtype=torch.float) + .numpy(), + zip(*outputs), + ) + ) + def dataset_inference_vi( - dataset: torch.utils.data.Dataset, - checkpoint_path: str, - variational_iters: int, - ): - torch.set_float32_matmul_precision('high') + dataset: torch.utils.data.Dataset, + checkpoint_path: str, + variational_iters: int, +): + torch.set_float32_matmul_precision("high") - model = load_checkpoint(checkpoint_path)['model'] + model = load_checkpoint(checkpoint_path)["model"] - inference_f = partial(variational_dropout, model=model, variational_iters=variational_iters) + inference_f = partial( + variational_dropout, model=model, variational_iters=variational_iters + ) preds, stds = gen_output_from_dataset(dataset=dataset, inference_f=inference_f) @@ -104,14 +118,17 @@ def is_dir(path): raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory.") return path + def is_file(path): if not os.path.isfile(path): raise argparse.ArgumentTypeError(f"'{path}' is not a valid file.") return path + def positive_int(value): def err(value): raise argparse.ArgumentTypeError(f"'{value}' is not a valid positive integer.") + try: ivalue = int(value) except: @@ -122,28 +139,51 @@ def err(value): return ivalue + def get_args(): - parser = argparse.ArgumentParser(description="Argument parser for dataset and variational iterations") + parser = argparse.ArgumentParser( + description="Argument parser for dataset and variational iterations" + ) - parser.add_argument('--dataset_root', type=is_dir, required=True, help="Path to the dataset root directory") - parser.add_argument('--checkpoint_path', type=is_file, required=True, help="Path to the checkpoint root directory") - parser.add_argument('--save_images_path', type=is_dir, default=None, help="Path to save images (optional)") - parser.add_argument('--variational_iters', type=positive_int, default=5, help="Number of variational iterations (must be >= 1)") + parser.add_argument( + "--dataset_root", + type=is_dir, + required=True, + help="Path to the dataset root directory", + ) + parser.add_argument( + "--checkpoint_path", + type=is_file, + required=True, + help="Path to the checkpoint root directory", + ) + parser.add_argument( + "--save_images_path", + type=is_dir, + default=None, + help="Path to save images (optional)", + ) + parser.add_argument( + "--variational_iters", + type=positive_int, + default=5, + help="Number of variational iterations (must be >= 1)", + ) return parser.parse_args() -if __name__ == '__main__': + +if __name__ == "__main__": args = get_args() dataset = FitsDataset(dataset_root, mode=mode) dataset_inference(dataset, args.variational_iters, args.save_images_path) - + if save_images_path is not None: - out_path = save_images_path + f'/preds_{mode}' + out_path = save_images_path + f"/preds_{mode}" save_images(dataset, out_path, preds, stds) # root = f'{os.environ["TMPDIR"]}/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' # root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' # checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth' # out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/' - diff --git a/neural_networks/plots/dinov2_09739_rotations/pr_curve_dict.npydict b/neural_networks/plots/dinov2_09739_rotations/pr_curve_dict.npydict index f2232079..a89702ae 100644 Binary files a/neural_networks/plots/dinov2_09739_rotations/pr_curve_dict.npydict and b/neural_networks/plots/dinov2_09739_rotations/pr_curve_dict.npydict differ diff --git a/neural_networks/plots/dinov2_097_rotations_dropout_01/pr_curve_dict.npydict b/neural_networks/plots/dinov2_097_rotations_dropout_01/pr_curve_dict.npydict new file mode 100644 index 00000000..3db49b66 Binary files /dev/null and b/neural_networks/plots/dinov2_097_rotations_dropout_01/pr_curve_dict.npydict differ diff --git a/neural_networks/plots/variational_dropout.py b/neural_networks/plots/variational_dropout.py index 3c5ae1ae..1f6d382f 100644 --- a/neural_networks/plots/variational_dropout.py +++ b/neural_networks/plots/variational_dropout.py @@ -13,6 +13,7 @@ from torch.utils.data import Dataset, DataLoader from sklearn.metrics import precision_recall_curve from astropy.io import fits +import torcheval.metrics.functional as tef class RawFitsDataset(Dataset): @@ -113,33 +114,41 @@ def load_model(architecture_name, model_name, device="cpu"): @torch.no_grad() def get_dropout_output(predictor, dataloader, mean, std, vi_iters_list): labels = [] - vi_dict = {vi_iters: {"std": [], "pred": []} for vi_iters in vi_iters_list} - for i, (img, label) in enumerate(dataloader): - data = predictor.prepare_batch(img, mean=mean, std=std) - labels += label.numpy().tolist() - for vi_iters in vi_iters_list: + vi_dict = { + vi_iters: {"std": [], "pred": [], "labels": []} for vi_iters in vi_iters_list + } + for i, vi_iters in enumerate(vi_iters_list): + for img, label in dataloader: + data = predictor.prepare_batch(img, mean=mean, std=std) + if not i: + labels += label.numpy().tolist() predictor.variational_dropout = vi_iters pred, stds = predictor.predict(data.clone()) vi_dict[vi_iters]["pred"] += pred.cpu().to(torch.float32).numpy().tolist() vi_dict[vi_iters]["std"] += stds.cpu().to(torch.float32).numpy().tolist() - labels = np.asarray(labels) - for vi_iters in vi_iters_list: vi_dict[vi_iters]["pred"] = torch.asarray(vi_dict[vi_iters]["pred"]) vi_dict[vi_iters]["std"] = torch.asarray(vi_dict[vi_iters]["std"]) - return vi_dict, labels + vi_dict[vi_iters]["labels"] = torch.asarray(labels) + return vi_dict -def plot_pr_curves(savedir, vi_dict, labels): +def plot_pr_curves(savedir, vi_dict, vi_iters_list): os.makedirs(savedir, exist_ok=True) - for vi_iter, pred_dict in vi_dict.items(): - preds = pred_dict["pred"] + # for vi_iter, pred_dict in vi_dict.items(): + for vi_iter in sorted(vi_iters_list): + pred_dict = vi_dict[vi_iter] + preds, labels = pred_dict["pred"], pred_dict["labels"] # Reverse labels to compute pr curve for predicting "stop" precision, recall, thresholds = precision_recall_curve( - -np.asarray(labels) + 1, -np.asarray(preds) + 1 + np.asarray(labels), np.asarray(preds) ) + auprc = tef.binary_auprc(torch.asarray(preds), torch.asarray(labels)) + print(f"auprc vi_iters {vi_iter}", auprc) - plt.plot(recall, precision, label=f"VI Iters: {vi_iter}") + plt.plot( + recall, precision, label=f"VI Iters: {vi_iter} auprc: {auprc.item():.3f}" + ) plt.xlabel("Recall") plt.ylabel("Precision") plt.title("PR Curves for varying Variational Inference iteration counts") @@ -147,6 +156,7 @@ def plot_pr_curves(savedir, vi_dict, labels): plt.grid(True, linestyle="--", alpha=0.7) plt.tight_layout() + plt.ylim(0, 1) plt.savefig(f"{savedir}/precision_recall_curves.png", dpi=300) plt.clf() @@ -168,7 +178,7 @@ def get_dataloader(data_root, mode="val", batch_size=32): dataloader = DataLoader( dataset, batch_size=32, - shuffle=True, + shuffle=False, num_workers=num_workers, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, @@ -181,21 +191,27 @@ def get_dataloader(data_root, mode="val", batch_size=32): if __name__ == "__main__": # Latest model model_name = "surf/dinov2_09739_rotations" + # model_name = "surf/dinov2_097_rotations_dropout_01" TESTING = True architecture_name = "surf/TransferLearning" # Set Device here DEVICE = "cuda" # Thresholds to consider for classification - vi_iters_list = [0, 1, 2, 4, 8, 16, 32] + vi_iters_list = [0, 1, 2, 4, 8, 16, 32, 64, 128] + # vi_iters_list = [0] # Change to directory of files. Should have subfolders 'continue_val' and 'stop_val' data_root = "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data" - # Uses cached confusion matrix for testing the plotting functionalities + # Uses cached results for testing the plotting functionalities savedir = model_name.split("/")[-1] dict_fname = f"{savedir}/pr_curve_dict.npydict" + VI_DICT, LABELS = None, None + # Load cached results if Path(dict_fname).exists() and TESTING: - pr_dict = np.load(dict_fname, allow_pickle=True)[()] - vi_dict, labels = pr_dict["vi_dict"], pr_dict["labels"] - else: + VI_DICT = np.load(dict_fname, allow_pickle=True)[()] + existing_vi_iters = list(VI_DICT.keys()) + vi_iters_list = list(set(vi_iters_list) - set(existing_vi_iters)) + + if vi_iters_list != []: dataloader = get_dataloader(data_root, mode="val") @@ -203,11 +219,15 @@ def get_dataloader(data_root, mode="val", batch_size=32): mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"] - vi_dict, labels = get_dropout_output( - predictor, dataloader, mean, std, vi_iters_list - ) + vi_dict = get_dropout_output(predictor, dataloader, mean, std, vi_iters_list) + # Add new results to cached results and save + if VI_DICT is not None: + VI_DICT = vi_dict | VI_DICT + else: + VI_DICT = vi_dict + os.makedirs(savedir, exist_ok=True) with open(dict_fname, "wb") as f: - np.save(f, {"vi_dict": vi_dict, "labels": labels}) + np.save(f, VI_DICT) - plot_pr_curves(savedir, vi_dict, labels) + plot_pr_curves(savedir, VI_DICT, vi_iters_list) diff --git a/neural_networks/scripts/plot_pr_curve.py b/neural_networks/scripts/plot_pr_curve.py index 3a3b85fb..7d6b06a5 100644 --- a/neural_networks/scripts/plot_pr_curve.py +++ b/neural_networks/scripts/plot_pr_curve.py @@ -8,25 +8,29 @@ from sklearn.metrics import precision_recall_curve import matplotlib.pyplot as plt -sys.path.append('..') # yes this is mega ugly, but otherwise I need to restructure the whole project... +sys.path.append( + ".." +) # yes this is mega ugly, but otherwise I need to restructure the whole project... from inference import dataset_inference_vi, is_file from train_nn import ImagenetTransferLearning from pre_processing_for_ml import FitsDataset -DATASET_ROOT = '/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/' -CHECKPOINT = '../grid_search/version_7944362_0__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25__use_compile_True/ckpt_step=1643.pth' -PICKLE_DICT_FNAME = 'precision_recall_curves.npy' +DATASET_ROOT = "/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/" +CHECKPOINT = "../grid_search/version_7944362_0__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25__use_compile_True/ckpt_step=1643.pth" +PICKLE_DICT_FNAME = "precision_recall_curves.npy" if os.path.isfile(PICKLE_DICT_FNAME): - pickle_dict = np.load(PICKLE_DICT_FNAME, allow_pickle=True)[()] # Don't ask me why this is the syntax + pickle_dict = np.load(PICKLE_DICT_FNAME, allow_pickle=True)[ + () + ] # Don't ask me why this is the syntax else: # Yes I'm aware of the existence of collections.defaultdict, but that doesn't work with np.save... pickle_dict = {} variational_iter_vals = (0, 1, 2, 4, 16) -modes = ('val', 'train') +modes = ("val", "train") # variational_iter_vals = (0, 1,) # modes = ('val',) @@ -39,7 +43,7 @@ pickle_dict[mode] = {} for variational_iters in variational_iter_vals: - variational_iters_str = f'variational_iters_{variational_iters}' + variational_iters_str = f"variational_iters_{variational_iters}" if variational_iters_str in pickle_dict[mode]: print(f"{mode}/{variational_iters_str} already in saved pickle; skipping.") @@ -51,39 +55,43 @@ pickle_dict[mode][variational_iters_str] = {} dataset = FitsDataset(DATASET_ROOT, mode=mode) - preds, stds = dataset_inference_vi(dataset, CHECKPOINT, variational_iters=variational_iters) + preds, stds = dataset_inference_vi( + dataset, CHECKPOINT, variational_iters=variational_iters + ) + + precision, recall, thresholds = precision_recall_curve( + dataset.labels, preds + ) - precision, recall, thresholds = precision_recall_curve(dataset.labels, preds) - for name, value in ( - ('precision', precision), - ('recall', recall), - ('thresholds', thresholds), - ('preds', preds), - ('stds', stds), - ('labels', dataset.labels), - ('sources', dataset.data_paths) + ("precision", precision), + ("recall", recall), + ("thresholds", thresholds), + ("preds", preds), + ("stds", stds), + ("labels", dataset.labels), + ("sources", dataset.data_paths), ): pickle_dict[mode][variational_iters_str][name] = value plt.plot( - pickle_dict[mode][variational_iters_str]['recall'], - pickle_dict[mode][variational_iters_str]['precision'], - label=f'VI Iters: {variational_iters}' + pickle_dict[mode][variational_iters_str]["recall"], + pickle_dict[mode][variational_iters_str]["precision"], + label=f"VI Iters: {variational_iters}", ) - plt.xlabel('Recall') - plt.ylabel('Precision') - plt.title('PR Curves for varying Variational Inference iteration counts') + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.title("PR Curves for varying Variational Inference iteration counts") plt.legend() - plt.grid(True, linestyle='--', alpha=0.7) + plt.grid(True, linestyle="--", alpha=0.7) plt.tight_layout() - plt.savefig(f'precision_recall_curves_{mode}.png', dpi=300) + plt.savefig(f"precision_recall_curves_{mode}.png", dpi=300) plt.clf() if new_vals: print("saving new/updated pickle_dict") # pylance or w/e can complain, but this is valid syntax. - np.save(PICKLE_DICT_FNAME, pickle_dict) \ No newline at end of file + np.save(PICKLE_DICT_FNAME, pickle_dict) diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index 027ae589..bf38385a 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -13,6 +13,8 @@ from torchvision import models from torchvision.transforms import v2 from tqdm import tqdm +import torchvision.transforms.functional as TF +from torchvision.transforms.functional import InterpolationMode import numpy as np import random @@ -235,7 +237,10 @@ def eval(self): if self.kwargs["model_name"] == "vit_l_16": self.vit.heads.eval() elif "dinov2" in self.kwargs["model_name"]: - self.dino.eval() + if self.kwargs["use_lora"]: + self.dino.eval() + else: + self.dino.decoder.eval() else: self.classifier.eval() @@ -243,7 +248,10 @@ def train(self): if self.kwargs["model_name"] == "vit_l_16": self.vit.heads.train() elif "dinov2" in self.kwargs["model_name"]: - self.dino.train() + if self.kwargs["use_lora"]: + self.dino.train() + else: + self.dino.decoder.train() else: self.classifier.train() @@ -676,10 +684,6 @@ def __iter__(self): yield from iter(self.sampler) -import torchvision.transforms.functional as TF -from torchvision.transforms.functional import InterpolationMode - - class Rotate90Transform: def __init__(self, angles=[0, 90, 180, 270]): self.angles = angles @@ -722,7 +726,7 @@ def save_checkpoint(logging_dir, model, optimizer, global_step, **kwargs): ) -def load_checkpoint(ckpt_path, device="gpu"): +def load_checkpoint(ckpt_path, device="cuda"): if os.path.isfile(ckpt_path): ckpt_dict = torch.load(ckpt_path, weights_only=False) else: