Skip to content

Commit

Permalink
changes to model checkpointing/loading
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Oct 21, 2024
1 parent 4a12696 commit 95340f7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
32 changes: 24 additions & 8 deletions neural_networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import functools

import torch
from torch.nn.functional import interpolate
import os

from cortexchange.architecture import Architecture
import __main__
from astropy.io import fits

from .train_nn import ImagenetTransferLearning, load_checkpoint # noqa
from .pre_processing_for_ml import normalize_fits
from train_nn import ImagenetTransferLearning, load_checkpoint # noqa
from pre_processing_for_ml import normalize_fits

setattr(__main__, "ImagenetTransferLearning", ImagenetTransferLearning)

Expand All @@ -30,7 +32,7 @@ def __init__(
):
super().__init__(model_name, device)

self.dtype = torch.float32
self.dtype = torch.bfloat16

self.model = self.model.to(self.dtype)
self.model.eval()
Expand All @@ -39,27 +41,41 @@ def __init__(
self.variational_dropout = variational_dropout

def load_checkpoint(self, path) -> torch.nn.Module:
model, _, _, resize = load_checkpoint(path, self.device).values()
# To avoid errors on CPU
if "gpu" not in self.device and self.device != "cuda":
os.environ["XFORMERS_DISABLED"] = "1"
(
model,
_,
args,
) = load_checkpoint(path, self.device).values()
self.resize = args["resize"]
self.lift = args["lift"]
return model

@functools.lru_cache(maxsize=1)
def prepare_data(self, input_path: str) -> torch.Tensor:
input_data: torch.Tensor = torch.from_numpy(process_fits(input_path))
input_data = input_data.to(self.dtype)
input_data = input_data.swapdims(0, 2).unsqueeze(0)
if self.resize != 0:
input_data = interpolate(
input_data, size=self.resize, mode="bilinear", align_corners=False
)
input_data = input_data.to(self.device)
return input_data

@torch.no_grad()
def predict(self, data: torch.Tensor):
with torch.autocast(dtype=self.dtype, device_type=self.device):
if self.variational_dropout > 0:
self.model.feature_extractor.eval()
self.model.classifier.train()
self.model.train()
# self.model.classifier.train()

predictions = torch.concat(
[
torch.sigmoid(self.model(data)).clone()
for _ in range(self.variational_dropout)
for _ in range(max(self.variational_dropout, 1))
],
dim=1,
)
Expand All @@ -75,6 +91,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--variational_dropout",
type=int,
default=None,
default=0,
help="Optional: Amount of times to run the model to obtain a variational estimate of the stdev",
)
47 changes: 22 additions & 25 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np
import random

from pre_processing_for_ml import FitsDataset
from dino_model import DINOV2FeatureExtractor
from .pre_processing_for_ml import FitsDataset
from .dino_model import DINOV2FeatureExtractor

PROFILE = False
SEED = None
Expand Down Expand Up @@ -469,18 +469,21 @@ def main(
logging_dir=logging_dir,
model=model,
optimizer=optimizer,
normalize=normalize,
batch_size=batch_size,
use_compile=use_compile,
label_smoothing=label_smoothing,
stochastic_smoothing=stochastic_smoothing,
lift=lift,
use_lora=use_lora,
rank=rank,
alpha=alpha,
resize=resize,
lr=lr,
dropout_p=dropout_p,
args={
"normalize": normalize,
"batch_size": batch_size,
"use_compile": use_compile,
"label_smoothing": label_smoothing,
"stochastic_smoothing": stochastic_smoothing,
"lift": lift,
"use_lora": use_lora,
"rank": rank,
"alpha": alpha,
"resize": resize,
"lr": lr,
"dropout_p": dropout_p,
"model_name": model_name,
},
)

best_val_loss = torch.inf
Expand Down Expand Up @@ -706,18 +709,12 @@ def load_checkpoint(ckpt_path, device="gpu"):
ckpt_path = f"{ckpt_path}/{possible_checkpoints[0]}"
ckpt_dict = torch.load(ckpt_path, weights_only=False, map_location=device)

# ugh, this is so ugly, something something hindsight something something 20-20
# FIXME: probably should do a pattern match, but this works for now
kwargs = str(Path(ckpt_path).parent).split("/")[-1].split("__")
print(ckpt_dict.keys())

# strip 'model_' from the name
model_name = kwargs[1][6:]
lr = float(kwargs[2].split("_")[-1])
normalize = int(kwargs[3].split("_")[-1])
dropout_p = float(kwargs[4].split("_")[-1])
model_name = ckpt_dict["args"]["model_name"]
lr = ckpt_dict["args"]["lr"]
dropout_p = ckpt_dict["args"]["dropout_p"]

model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p)
model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p).to(device)
model.load_state_dict(ckpt_dict["model_state_dict"])

try:
Expand All @@ -729,7 +726,7 @@ def load_checkpoint(ckpt_path, device="gpu"):
print(f"Could not load optim due to {e}; skipping.")
optim = None

return {"model": model, "optim": optim, "normalize": normalize}
return {"model": model, "optim": optim, "args": ckpt_dict["args"]}


def get_argparser():
Expand Down

0 comments on commit 95340f7

Please sign in to comment.