Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jurjen93/lofar_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Oct 24, 2024
2 parents cd2045f + 2890fe0 commit 4afb2bb
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 64 deletions.
96 changes: 96 additions & 0 deletions neural_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import functools

import torch
from torch.nn.functional import interpolate
import os

from cortexchange.architecture import Architecture
import __main__
from astropy.io import fits

from train_nn import ImagenetTransferLearning, load_checkpoint # noqa
from pre_processing_for_ml import normalize_fits

setattr(__main__, "ImagenetTransferLearning", ImagenetTransferLearning)


def process_fits(fits_path):
with fits.open(fits_path) as hdul:
image_data = hdul[0].data

return normalize_fits(image_data)


class TransferLearning(Architecture):
def __init__(
self,
model_name: str = None,
device: str = None,
variational_dropout: int = 0,
**kwargs
):
super().__init__(model_name, device)

self.dtype = torch.bfloat16

self.model = self.model.to(self.dtype)
self.model.eval()

assert variational_dropout >= 0
self.variational_dropout = variational_dropout

def load_checkpoint(self, path) -> torch.nn.Module:
# To avoid errors on CPU
if "gpu" not in self.device and self.device != "cuda":
os.environ["XFORMERS_DISABLED"] = "1"
(
model,
_,
args,
) = load_checkpoint(path, self.device).values()
self.resize = args["resize"]
self.lift = args["lift"]
return model

@functools.lru_cache(maxsize=1)
def prepare_data(self, input_path: str) -> torch.Tensor:
input_data: torch.Tensor = torch.from_numpy(process_fits(input_path))
input_data = input_data.to(self.dtype)
input_data = input_data.swapdims(0, 2).unsqueeze(0)
if self.resize != 0:
input_data = interpolate(
input_data, size=self.resize, mode="bilinear", align_corners=False
)
input_data = input_data.to(self.device)
return input_data

@torch.no_grad()
def predict(self, data: torch.Tensor):
with torch.autocast(dtype=self.dtype, device_type=self.device):
if self.variational_dropout > 0:
self.model.train()
# self.model.classifier.train()

predictions = torch.concat(
[
torch.sigmoid(self.model(data)).clone()
for _ in range(max(self.variational_dropout, 1))
],
dim=1,
)

mean = predictions.mean()
std = predictions.std()

print(mean, std)
return mean, std

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--variational_dropout",
type=int,
default=0,
help="Optional: Amount of times to run the model to obtain a variational estimate of the stdev",
)
100 changes: 59 additions & 41 deletions neural_networks/parameters.txt
Original file line number Diff line number Diff line change
@@ -1,44 +1,62 @@
efficientnet_v2_l 1e-04 0 0.25 32 0 0 0
efficientnet_v2_l 1e-04 1 0.25 32 0 0 0
efficientnet_v2_l 1e-04 2 0.25 32 0 0 0
dinov2_vitl14_reg 1e-04 0 0.25 32 0 0 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0 0 0
dinov2_vitl14_reg 1e-04 0 0.25 32 1 0 0
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0 0
efficientnet_v2_l 1e-04 0 0.25 32 0 0.1 1
efficientnet_v2_l 1e-04 1 0.25 32 0 0.1 1
efficientnet_v2_l 1e-04 2 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 0 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 1 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 0 0.25 32 1 0.1 1
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.1 1
efficientnet_v2_l 1e-04 0 0.25 32 0 0.1 1
efficientnet_v2_l 1e-04 1 0.25 32 0 0.1 1
efficientnet_v2_l 1e-04 2 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 0 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 1 0.25 32 0 0.1 1
dinov2_vitl14_reg 1e-04 0 0.25 32 1 0.1 1
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.1 1
efficientnet_v2_l 1e-04 0 0.25 32 0 0.2 0
efficientnet_v2_l 1e-04 1 0.25 32 0 0.2 0
efficientnet_v2_l 1e-04 2 0.25 32 0 0.2 0
dinov2_vitl14_reg 1e-04 0 0.25 32 0 0.2 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0 0.2 0
dinov2_vitl14_reg 1e-04 0 0.25 32 1 0.2 0
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.2 0
efficientnet_v2_l 1e-04 0 0.25 32 0 0.1 0
efficientnet_v2_l 1e-04 1 0.25 32 0 0.1 0
efficientnet_v2_l 1e-04 2 0.25 32 0 0.1 0
dinov2_vitl14_reg 1e-04 0 0.25 32 0 0.1 0
dinov2_vitl14_reg 1e-04 1 0.25 32 0 0.1 0
dinov2_vitl14_reg 1e-04 0 0.25 32 1 0.1 0
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.1 0

dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.2 1 3 1 560
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.1 1 16 16 560
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.1 1 16 16 784
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.2 1 16 16 560
dinov2_vitl14_reg 1e-04 1 0.25 32 1 0.2 1 16 16 784
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 1e-04 1 0.25 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 1e-04 1 0.25 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 1e-04 1 0.25 32 0.2 0 1 16 16 560

dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 1e-04 1 0.1 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 1e-04 1 0.1 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 1e-04 1 0.1 32 0.2 0 1 16 16 560

dinov2_vitl14_reg 1e-05 1 0.25 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 1e-05 1 0.25 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 1e-05 1 0.25 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 1e-05 1 0.25 32 0.2 0 1 16 16 560

dinov2_vitl14_reg 1e-05 1 0.1 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 1e-05 1 0.1 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 1e-05 1 0.1 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 1e-05 1 0.1 32 0.2 0 1 16 16 560

dinov2_vitl14_reg 5e-05 1 0.25 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 5e-05 1 0.25 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 5e-05 1 0.25 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 5e-05 1 0.25 32 0.2 0 1 16 16 560

dinov2_vitl14_reg 5e-05 1 0.1 32 0.1 0 0 16 16 560
dinov2_vitl14_reg 5e-05 1 0.1 32 0.1 1 1 16 16 560
dinov2_vitl14_reg 5e-05 1 0.1 32 0.2 1 0 16 16 560
dinov2_vitl14_reg 5e-05 1 0.1 32 0.2 0 1 16 16 560

efficientnet_v2_l 1e-04 1 0.25 32 0.2 0 0 16 16 0
efficientnet_v2_l 1e-04 1 0.25 32 0.2 1 0 16 16 0
efficientnet_v2_l 1e-04 1 0.25 32 0.1 1 0 16 16 0
efficientnet_v2_l 1e-04 1 0.25 32 0.1 0 0 16 16 0

efficientnet_v2_l 1e-04 1 0.1 32 0.2 0 0 16 16 0
efficientnet_v2_l 1e-04 1 0.1 32 0.2 1 0 16 16 0
efficientnet_v2_l 1e-04 1 0.1 32 0.1 1 0 16 16 0
efficientnet_v2_l 1e-04 1 0.1 32 0.1 0 0 16 16 0

efficientnet_v2_l 5e-05 1 0.25 32 0.2 0 0 16 16 0
efficientnet_v2_l 5e-05 1 0.25 32 0.2 1 0 16 16 0
efficientnet_v2_l 5e-05 1 0.25 32 0.1 1 0 16 16 0
efficientnet_v2_l 5e-05 1 0.25 32 0.1 0 0 16 16 0

efficientnet_v2_l 5e-05 1 0.1 32 0.2 0 0 16 16 0
efficientnet_v2_l 5e-05 1 0.1 32 0.2 1 0 16 16 0
efficientnet_v2_l 5e-05 1 0.1 32 0.1 1 0 16 16 0
efficientnet_v2_l 5e-05 1 0.1 32 0.1 0 0 16 16 0

efficientnet_v2_l 1e-05 1 0.25 32 0.2 0 0 16 16 0
efficientnet_v2_l 1e-05 1 0.25 32 0.2 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.25 32 0.1 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.25 32 0.1 0 0 16 16 0

efficientnet_v2_l 1e-05 1 0.1 32 0.2 0 0 16 16 0
efficientnet_v2_l 1e-05 1 0.1 32 0.2 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.1 32 0.1 1 0 16 16 0
efficientnet_v2_l 1e-05 1 0.1 32 0.1 0 0 16 16 0



Expand Down
8 changes: 4 additions & 4 deletions neural_networks/train_nn.job
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cd ~/projects/lofar_helpers/neural_networks


module load 2023
source venv/bin/activate
source ../../lofar_venv/bin/activate

# Read the parameter file
PARAM_FILE=parameters.txt
Expand All @@ -22,7 +22,7 @@ SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:=1}
PARAMS=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $PARAM_FILE)

# Parse the parameters
read model lr normalize dropout_p batch_size use_lora label_smoothing stochastic_smoothing rank alpha resize <<< $PARAMS
read model lr normalize dropout_p batch_size label_smoothing stochastic_smoothing use_lora rank alpha resize <<< $PARAMS

if [ "$use_lora" -eq 1 ]; then
LORA_ARG="--use_lora"
Expand All @@ -36,7 +36,7 @@ else
STOCHASTIC_SMOOTHING=""
fi

DATA_TRAINDATA_PATH="public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/"
DATA_TRAINDATA_PATH="/scratch-shared/CORTEX/public.spider.surfsara.nl/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data/"

# Execute your Python script with the given parameters
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_2 --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING
python train_nn.py $DATA_TRAINDATA_PATH --model $model --lr $lr --normalize $normalize --dropout_p $dropout_p --batch_size $batch_size --log_path grid_search_2 --label_smoothing $label_smoothing --rank $rank --resize $resize --alpha $alpha $LORA_ARG $STOCHASTIC_SMOOTHING -d
62 changes: 43 additions & 19 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,11 @@ def main(
train_dataloader=train_dataloader,
optimizer=optimizer,
logging_interval=logging_interval,
smoothing_fn=partial(label_smoother, stochastic=stochastic_smoothing, smoothing_factor=label_smoothing),
smoothing_fn=partial(
label_smoother,
stochastic=stochastic_smoothing,
smoothing_factor=label_smoothing,
),
)
val_step_f = partial(val_step_f, val_dataloader=val_dataloader)

Expand All @@ -465,8 +469,21 @@ def main(
logging_dir=logging_dir,
model=model,
optimizer=optimizer,
normalize=normalize,
batch_size=batch_size,
args={
"normalize": normalize,
"batch_size": batch_size,
"use_compile": use_compile,
"label_smoothing": label_smoothing,
"stochastic_smoothing": stochastic_smoothing,
"lift": lift,
"use_lora": use_lora,
"rank": rank,
"alpha": alpha,
"resize": resize,
"lr": lr,
"dropout_p": dropout_p,
"model_name": model_name,
},
)

best_val_loss = torch.inf
Expand Down Expand Up @@ -555,13 +572,16 @@ def val_step(model, val_dataloader, global_step, metrics_logger, prepare_data_f)
return mean_loss, logits, targets


def label_smoother(labels: torch.tensor, smoothing_factor: float = 0.1, stochastic: bool = True):
def label_smoother(
labels: torch.tensor, smoothing_factor: float = 0.1, stochastic: bool = True
):
smoothing_factor = smoothing_factor - (
torch.rand_like(labels) * smoothing_factor * stochastic
torch.rand_like(labels) * smoothing_factor * stochastic
)
smoothed_label = (1 - smoothing_factor) * labels + 0.5 * smoothing_factor
return smoothed_label


def train_step(
model,
optimizer,
Expand Down Expand Up @@ -676,33 +696,37 @@ def save_checkpoint(logging_dir, model, optimizer, global_step, **kwargs):
)


def load_checkpoint(ckpt_path):

ckpt_dict = torch.load(ckpt_path, weights_only=False)

# ugh, this is so ugly, something something hindsight something something 20-20
# FIXME: probably should do a pattern match, but this works for now
kwargs = str(Path(ckpt_path).parent).split("/")[-1].split("__")
def load_checkpoint(ckpt_path, device="gpu"):
if os.path.isfile(ckpt_path):
ckpt_dict = torch.load(ckpt_path, weights_only=False, map_location=device)
else:
files = os.listdir(ckpt_path)
possible_checkpoints = list(filter(lambda x: x.endswith(".pth"), files))
if len(possible_checkpoints) != 1:
raise ValueError(
f"Too many checkpoint files in the given checkpoint directory. Please specify the model you want to load directly."
)
ckpt_path = f"{ckpt_path}/{possible_checkpoints[0]}"
ckpt_dict = torch.load(ckpt_path, weights_only=False, map_location=device)

# strip 'model_' from the name
model_name = kwargs[1][6:]
lr = float(kwargs[2].split("_")[-1])
normalize = int(kwargs[3].split("_")[-1])
dropout_p = float(kwargs[4].split("_")[-1])
model_name = ckpt_dict["args"]["model_name"]
lr = ckpt_dict["args"]["lr"]
dropout_p = ckpt_dict["args"]["dropout_p"]

model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p)
model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p).to(device)
model.load_state_dict(ckpt_dict["model_state_dict"])

try:
# FIXME: add optim class and args to state dict
optim = ckpt_dict.get("optimizer", torch.optim.AdamW)(
lr=lr, params=model.classifier.parameters()
).load_state_dict(ckpt_dict["optimizer_state_dict"])
except e:
except Exception as e:
print(f"Could not load optim due to {e}; skipping.")
optim = None

return {"model": model, "optim": optim, "normalize": normalize}
return {"model": model, "optim": optim, "args": ckpt_dict["args"]}


def get_argparser():
Expand Down

0 comments on commit 4afb2bb

Please sign in to comment.