Skip to content

Commit

Permalink
updated variational dropout plots
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Oct 30, 2024
1 parent 26e6eff commit 480f800
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 88 deletions.
2 changes: 2 additions & 0 deletions neural_networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def predict(self, data: torch.Tensor):
with torch.autocast(dtype=self.dtype, device_type=self.device):
if self.variational_dropout > 0:
self.model.train()
else:
self.model.eval()

predictions = torch.concat(
[
Expand Down
102 changes: 71 additions & 31 deletions neural_networks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ def variational_dropout(model, batch, variational_iters: int):

model.cuda()

with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
sample = batch[0].cuda(non_blocking=True).expand(-1, 3, -1, -1)

means, stds = variational_dropout_step(model, sample, variational_iters)


return means, stds


@torch.no_grad()
def variational_dropout_step(model: torch.nn.Module, sample, variational_iters: int):
preds = torch.sigmoid(torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1))
preds = torch.sigmoid(
torch.concat([model(sample).clone() for _ in range(variational_iters)], dim=1)
)

means = preds.mean(dim=1)

Expand All @@ -52,22 +53,25 @@ def variational_dropout_step(model: torch.nn.Module, sample, variational_iters:
def save_images(dataset, out_path, preds, stds):
os.makedirs(out_path, exist_ok=True)

for elem, in_path, pred, std in tqdm(zip(dataset, dataset.data_paths, preds, stds), total=len(dataset)):
for elem, in_path, pred, std in tqdm(
zip(dataset, dataset.data_paths, preds, stds), total=len(dataset)
):
batch, label = elem
name = in_path.strip('.npz').split('/')[-1]
name = in_path.strip(".npz").split("/")[-1]
matplotlib.image.imsave(
fname=f'{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png',
arr=batch.to(torch.float).movedim(0, -1).numpy()
fname=f"{out_path}/{std:.3f}_{pred:.3f}_{label}_{name}.png",
arr=batch.to(torch.float).movedim(0, -1).numpy(),
)


def gen_output_from_dataset(
dataset: torch.utils.data.Dataset,
inference_f: Any,
**dataloader_kwargs
):
dataset: torch.utils.data.Dataset, inference_f: Any, **dataloader_kwargs
):

num_workers = dataloader_kwargs.pop('num_workers', min(18, len(os.sched_getaffinity(0))))
batch_size = dataloader_kwargs.pop('batch_size', 64)
num_workers = dataloader_kwargs.pop(
"num_workers", min(18, len(os.sched_getaffinity(0)))
)
batch_size = dataloader_kwargs.pop("batch_size", 64)

dataloader = DataLoader(
dataset=dataset,
Expand All @@ -76,23 +80,33 @@ def gen_output_from_dataset(
pin_memory=True,
shuffle=False,
drop_last=False,
**dataloader_kwargs
**dataloader_kwargs,
)

outputs = [inference_f(batch=batch) for batch in dataloader]

return tuple(map(lambda x: torch.concat(x).to(torch.device('cpu'), dtype=torch.float).numpy(), zip(*outputs)))
return tuple(
map(
lambda x: torch.concat(x)
.to(torch.device("cpu"), dtype=torch.float)
.numpy(),
zip(*outputs),
)
)


def dataset_inference_vi(
dataset: torch.utils.data.Dataset,
checkpoint_path: str,
variational_iters: int,
):
torch.set_float32_matmul_precision('high')
dataset: torch.utils.data.Dataset,
checkpoint_path: str,
variational_iters: int,
):
torch.set_float32_matmul_precision("high")

model = load_checkpoint(checkpoint_path)['model']
model = load_checkpoint(checkpoint_path)["model"]

inference_f = partial(variational_dropout, model=model, variational_iters=variational_iters)
inference_f = partial(
variational_dropout, model=model, variational_iters=variational_iters
)

preds, stds = gen_output_from_dataset(dataset=dataset, inference_f=inference_f)

Expand All @@ -104,14 +118,17 @@ def is_dir(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory.")
return path


def is_file(path):
if not os.path.isfile(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a valid file.")
return path


def positive_int(value):
def err(value):
raise argparse.ArgumentTypeError(f"'{value}' is not a valid positive integer.")

try:
ivalue = int(value)
except:
Expand All @@ -122,28 +139,51 @@ def err(value):

return ivalue


def get_args():
parser = argparse.ArgumentParser(description="Argument parser for dataset and variational iterations")
parser = argparse.ArgumentParser(
description="Argument parser for dataset and variational iterations"
)

parser.add_argument('--dataset_root', type=is_dir, required=True, help="Path to the dataset root directory")
parser.add_argument('--checkpoint_path', type=is_file, required=True, help="Path to the checkpoint root directory")
parser.add_argument('--save_images_path', type=is_dir, default=None, help="Path to save images (optional)")
parser.add_argument('--variational_iters', type=positive_int, default=5, help="Number of variational iterations (must be >= 1)")
parser.add_argument(
"--dataset_root",
type=is_dir,
required=True,
help="Path to the dataset root directory",
)
parser.add_argument(
"--checkpoint_path",
type=is_file,
required=True,
help="Path to the checkpoint root directory",
)
parser.add_argument(
"--save_images_path",
type=is_dir,
default=None,
help="Path to save images (optional)",
)
parser.add_argument(
"--variational_iters",
type=positive_int,
default=5,
help="Number of variational iterations (must be >= 1)",
)

return parser.parse_args()

if __name__ == '__main__':

if __name__ == "__main__":
args = get_args()

dataset = FitsDataset(dataset_root, mode=mode)
dataset_inference(dataset, args.variational_iters, args.save_images_path)

if save_images_path is not None:
out_path = save_images_path + f'/preds_{mode}'
out_path = save_images_path + f"/preds_{mode}"
save_images(dataset, out_path, preds, stds)

# root = f'{os.environ["TMPDIR"]}/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
# root = f'/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/'
# checkpoint_path = './gridsearch_efficientnet/version_6665871_5__model_efficientnet_v2_l__lr_0.0001__normalize_0__dropout_p_0.25/ckpt_step=7999.pth'
# out_path = f'/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/preds_{mode}/'

Binary file not shown.
Binary file not shown.
68 changes: 44 additions & 24 deletions neural_networks/plots/variational_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_curve
from astropy.io import fits
import torcheval.metrics.functional as tef


class RawFitsDataset(Dataset):
Expand Down Expand Up @@ -113,40 +114,49 @@ def load_model(architecture_name, model_name, device="cpu"):
@torch.no_grad()
def get_dropout_output(predictor, dataloader, mean, std, vi_iters_list):
labels = []
vi_dict = {vi_iters: {"std": [], "pred": []} for vi_iters in vi_iters_list}
for i, (img, label) in enumerate(dataloader):
data = predictor.prepare_batch(img, mean=mean, std=std)
labels += label.numpy().tolist()
for vi_iters in vi_iters_list:
vi_dict = {
vi_iters: {"std": [], "pred": [], "labels": []} for vi_iters in vi_iters_list
}
for i, vi_iters in enumerate(vi_iters_list):
for img, label in dataloader:
data = predictor.prepare_batch(img, mean=mean, std=std)
if not i:
labels += label.numpy().tolist()

predictor.variational_dropout = vi_iters
pred, stds = predictor.predict(data.clone())
vi_dict[vi_iters]["pred"] += pred.cpu().to(torch.float32).numpy().tolist()
vi_dict[vi_iters]["std"] += stds.cpu().to(torch.float32).numpy().tolist()
labels = np.asarray(labels)
for vi_iters in vi_iters_list:
vi_dict[vi_iters]["pred"] = torch.asarray(vi_dict[vi_iters]["pred"])
vi_dict[vi_iters]["std"] = torch.asarray(vi_dict[vi_iters]["std"])
return vi_dict, labels
vi_dict[vi_iters]["labels"] = torch.asarray(labels)
return vi_dict


def plot_pr_curves(savedir, vi_dict, labels):
def plot_pr_curves(savedir, vi_dict, vi_iters_list):
os.makedirs(savedir, exist_ok=True)
for vi_iter, pred_dict in vi_dict.items():
preds = pred_dict["pred"]
# for vi_iter, pred_dict in vi_dict.items():
for vi_iter in sorted(vi_iters_list):
pred_dict = vi_dict[vi_iter]
preds, labels = pred_dict["pred"], pred_dict["labels"]
# Reverse labels to compute pr curve for predicting "stop"
precision, recall, thresholds = precision_recall_curve(
-np.asarray(labels) + 1, -np.asarray(preds) + 1
np.asarray(labels), np.asarray(preds)
)
auprc = tef.binary_auprc(torch.asarray(preds), torch.asarray(labels))
print(f"auprc vi_iters {vi_iter}", auprc)

plt.plot(recall, precision, label=f"VI Iters: {vi_iter}")
plt.plot(
recall, precision, label=f"VI Iters: {vi_iter} auprc: {auprc.item():.3f}"
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("PR Curves for varying Variational Inference iteration counts")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)

plt.tight_layout()
plt.ylim(0, 1)
plt.savefig(f"{savedir}/precision_recall_curves.png", dpi=300)
plt.clf()

Expand All @@ -168,7 +178,7 @@ def get_dataloader(data_root, mode="val", batch_size=32):
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
Expand All @@ -181,33 +191,43 @@ def get_dataloader(data_root, mode="val", batch_size=32):
if __name__ == "__main__":
# Latest model
model_name = "surf/dinov2_09739_rotations"
# model_name = "surf/dinov2_097_rotations_dropout_01"
TESTING = True
architecture_name = "surf/TransferLearning"
# Set Device here
DEVICE = "cuda"
# Thresholds to consider for classification
vi_iters_list = [0, 1, 2, 4, 8, 16, 32]
vi_iters_list = [0, 1, 2, 4, 8, 16, 32, 64, 128]
# vi_iters_list = [0]
# Change to directory of files. Should have subfolders 'continue_val' and 'stop_val'
data_root = "/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data"
# Uses cached confusion matrix for testing the plotting functionalities
# Uses cached results for testing the plotting functionalities
savedir = model_name.split("/")[-1]
dict_fname = f"{savedir}/pr_curve_dict.npydict"
VI_DICT, LABELS = None, None
# Load cached results
if Path(dict_fname).exists() and TESTING:
pr_dict = np.load(dict_fname, allow_pickle=True)[()]
vi_dict, labels = pr_dict["vi_dict"], pr_dict["labels"]
else:
VI_DICT = np.load(dict_fname, allow_pickle=True)[()]
existing_vi_iters = list(VI_DICT.keys())
vi_iters_list = list(set(vi_iters_list) - set(existing_vi_iters))

if vi_iters_list != []:

dataloader = get_dataloader(data_root, mode="val")

predictor = load_model(architecture_name, model_name, device=DEVICE)

mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"]

vi_dict, labels = get_dropout_output(
predictor, dataloader, mean, std, vi_iters_list
)
vi_dict = get_dropout_output(predictor, dataloader, mean, std, vi_iters_list)
# Add new results to cached results and save
if VI_DICT is not None:
VI_DICT = vi_dict | VI_DICT
else:
VI_DICT = vi_dict

os.makedirs(savedir, exist_ok=True)
with open(dict_fname, "wb") as f:
np.save(f, {"vi_dict": vi_dict, "labels": labels})
np.save(f, VI_DICT)

plot_pr_curves(savedir, vi_dict, labels)
plot_pr_curves(savedir, VI_DICT, vi_iters_list)
Loading

0 comments on commit 480f800

Please sign in to comment.