Skip to content

Commit

Permalink
Added confusion matrix plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Oct 25, 2024
1 parent 1355a77 commit 2d4a5bd
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 29 deletions.
37 changes: 24 additions & 13 deletions neural_networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import __main__
from astropy.io import fits

from train_nn import ImagenetTransferLearning, load_checkpoint # noqa
from pre_processing_for_ml import normalize_fits
from .train_nn import (
ImagenetTransferLearning,
load_checkpoint,
normalize_inputs,
) # noqa
from .pre_processing_for_ml import normalize_fits

setattr(__main__, "ImagenetTransferLearning", ImagenetTransferLearning)

Expand All @@ -28,7 +32,7 @@ def __init__(
model_name: str = None,
device: str = None,
variational_dropout: int = 0,
**kwargs
**kwargs,
):
super().__init__(model_name, device)

Expand All @@ -47,30 +51,37 @@ def load_checkpoint(self, path) -> torch.nn.Module:
(
model,
_,
args,
self.args,
) = load_checkpoint(path, self.device).values()
self.resize = args["resize"]
self.lift = args["lift"]
self.resize = self.args["resize"]
self.lift = self.args["lift"]
return model

@functools.lru_cache(maxsize=1)
def prepare_data(self, input_path: str) -> torch.Tensor:
input_data: torch.Tensor = torch.from_numpy(process_fits(input_path))
input_data = input_data.to(self.dtype)
input_data = input_data.swapdims(0, 2).unsqueeze(0)
return self.prepare_batch(input_data)

def prepare_batch(self, batch: torch.Tensor, mean=None, std=None) -> torch.Tensor:
batch = batch.to(self.dtype).to(self.device)
if self.resize != 0:
input_data = interpolate(
input_data, size=self.resize, mode="bilinear", align_corners=False
batch = interpolate(
batch, size=self.resize, mode="bilinear", align_corners=False
)
input_data = input_data.to(self.device)
return input_data
if mean is None:
mean = self.mean
if std is None:
std = self.std
batch = normalize_inputs(batch, mean, std, normalize=1)
return batch

@torch.no_grad()
def predict(self, data: torch.Tensor):
with torch.autocast(dtype=self.dtype, device_type=self.device):
if self.variational_dropout > 0:
self.model.train()
# self.model.classifier.train()

predictions = torch.concat(
[
Expand All @@ -80,8 +91,8 @@ def predict(self, data: torch.Tensor):
dim=1,
)

mean = predictions.mean()
std = predictions.std()
mean = predictions.mean(dim=1)
std = predictions.std(dim=1)

print(mean, std)
return mean, std
Expand Down
15 changes: 15 additions & 0 deletions neural_networks/parameters.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ efficientnet_v2_l 1e-05 1 0.1 32 0.2 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.1 32 0.1 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.1 32 0.1 0 0 16 16 0

dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 0 16 16 560 conv 0
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 0 16 16 560 conv 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 0 16 16 560 conv 1
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 0 16 16 560 conv 1
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 0 16 16 560 stack 0
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 0 16 16 560 stack 0
efficientnet_v2_l 1e-04 1 0.1 32 0.2 0 0 16 16 0 stack 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 0 16 16 560 stack 1
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 0 16 16 560 stack 1
efficientnet_v2_l 1e-04 1 0.1 32 0.2 0 0 16 16 0 stack 1
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 1 16 16 560 conv 0
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 1 16 16 560 conv 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 1 16 16 560 conv 1
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 1 16 16 560 conv 1




Expand Down
95 changes: 95 additions & 0 deletions neural_networks/plots/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from cortexchange.architecture import get_architecture, Architecture
from pathlib import Path
import sys
import os

SCRIPT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from train_nn import MultiEpochsDataLoader
from pre_processing_for_ml import FitsDataset
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


def load_model(architecture_name, model_name):
StopPredictor: type(Architecture) = get_architecture(architecture_name)
predictor = StopPredictor(device="cuda", model_name=model_name)
return predictor


def get_dataloader(data_root, mode, batch_size):
num_workers = min(12, len(os.sched_getaffinity(0)))

prefetch_factor, persistent_workers = (
(2, True) if num_workers > 0 else (None, False)
)

return MultiEpochsDataLoader(
dataset=FitsDataset(
data_root,
mode=mode,
),
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=True,
shuffle=False,
drop_last=False,
)


def get_statistics(data_root, mode):
return FitsDataset(
data_root,
mode=mode,
).compute_statistics(1)


@torch.no_grad()
def get_confusion_matrix(
predictor, dataloader, mean, std, thresholds=[0.2, 0.3, 0.4, 0.5]
):
confusion_matrices = np.zeros((len(thresholds), 2, 2))
thresholds = torch.tensor(thresholds)
for img, label in dataloader:
data = predictor.prepare_batch(img, mean=mean, std=std)
pred = torch.sigmoid(predictor.model(data)).to("cpu")
preds_thres = pred >= thresholds
for i, _ in enumerate(thresholds):
confusion_matrices[i] += confusion_matrix(
label, preds_thres[:, i], labels=[0, 1]
)

for i, conf_matrix in enumerate(confusion_matrices):

disp = ConfusionMatrixDisplay(
# Normalization
conf_matrix / np.sum(conf_matrix, axis=1, keepdims=True),
display_labels=["continue", "stop"],
)
disp.plot()

plt.savefig(f"confusion_thres_{thresholds[i]:.3f}.png")


if __name__ == "__main__":
model_name = "surf/dinov2_09814"
architecture_name = "surf/TransferLearning"
predictor = load_model(architecture_name, model_name)
if hasattr(predictor, "args") and "dataset_mean" in predictor.args:
mean, std = predictor.args["dataset_mean"], predictor.args["dataset_std"]
else:
mean, std = get_statistics(
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/",
mode="train",
)

dataloader = get_dataloader(
"/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/",
mode="val",
batch_size=32,
)
get_confusion_matrix(predictor, dataloader, mean, std)
9 changes: 7 additions & 2 deletions neural_networks/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
matplotlib
torch
torch>=2.1.2
torchvision
torcheval
tqdm
matplotlib
joblib
astropy
astropy>6.0.0
xformers
tensorboard
dino-finetune @ git+https://github.com/sara-nl/dinov2-finetune.git
scikit-learn


21 changes: 19 additions & 2 deletions neural_networks/train_nn.job
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:=1}
PARAMS=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $PARAM_FILE)

# Parse the parameters
read model lr normalize dropout_p batch_size label_smoothing stochastic_smoothing use_lora rank alpha resize <<< $PARAMS
read model lr normalize dropout_p batch_size label_smoothing stochastic_smoothing use_lora rank alpha resize lift flip_augmentations <<< $PARAMS

if [ "$use_lora" -eq 1 ]; then
LORA_ARG="--use_lora"
Expand All @@ -36,7 +36,24 @@ else
STOCHASTIC_SMOOTHING=""
fi

if [ "$flip_augmentations" -eq 1 ]; then
FLIP_AUGMENTATIONS="--flip_augmentations"
else
FLIP_AUGMENTATIONS=""
fi

# Scale up by 1e6 to convert to integers for comparison
scaled_lr=$(echo "$lr * 1000000" | awk '{printf("%d", $1)}')
scaled_threshold=$(echo "4e-05 * 1000000" | awk '{printf("%d", $1)}')

if [ "$scaled_lr" -le "$scaled_threshold" ]; then
EPOCHS="250"
else
EPOCHS="120"
fi

DATA_TRAINDATA_PATH="/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/"

# Execute your Python script with the given parameters
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_2 --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d
echo $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_2 --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_2 --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d --epochs $EPOCHS --lift $lift $FLIP_AUGMENTATIONS
Loading

0 comments on commit 2d4a5bd

Please sign in to comment.