diff --git a/neural_networks/__init__.py b/neural_networks/__init__.py index b857ec6..c8e1eef 100644 --- a/neural_networks/__init__.py +++ b/neural_networks/__init__.py @@ -2,13 +2,15 @@ import functools import torch +from torch.nn.functional import interpolate +import os from cortexchange.architecture import Architecture import __main__ from astropy.io import fits -from .train_nn import ImagenetTransferLearning, load_checkpoint # noqa -from .pre_processing_for_ml import normalize_fits +from train_nn import ImagenetTransferLearning, load_checkpoint # noqa +from pre_processing_for_ml import normalize_fits setattr(__main__, "ImagenetTransferLearning", ImagenetTransferLearning) @@ -30,7 +32,7 @@ def __init__( ): super().__init__(model_name, device) - self.dtype = torch.float32 + self.dtype = torch.bfloat16 self.model = self.model.to(self.dtype) self.model.eval() @@ -39,7 +41,16 @@ def __init__( self.variational_dropout = variational_dropout def load_checkpoint(self, path) -> torch.nn.Module: - model, _, _, resize = load_checkpoint(path, self.device).values() + # To avoid errors on CPU + if "gpu" not in self.device and self.device != "cuda": + os.environ["XFORMERS_DISABLED"] = "1" + ( + model, + _, + args, + ) = load_checkpoint(path, self.device).values() + self.resize = args["resize"] + self.lift = args["lift"] return model @functools.lru_cache(maxsize=1) @@ -47,19 +58,24 @@ def prepare_data(self, input_path: str) -> torch.Tensor: input_data: torch.Tensor = torch.from_numpy(process_fits(input_path)) input_data = input_data.to(self.dtype) input_data = input_data.swapdims(0, 2).unsqueeze(0) + if self.resize != 0: + input_data = interpolate( + input_data, size=self.resize, mode="bilinear", align_corners=False + ) + input_data = input_data.to(self.device) return input_data @torch.no_grad() def predict(self, data: torch.Tensor): with torch.autocast(dtype=self.dtype, device_type=self.device): if self.variational_dropout > 0: - self.model.feature_extractor.eval() - self.model.classifier.train() + self.model.train() + # self.model.classifier.train() predictions = torch.concat( [ torch.sigmoid(self.model(data)).clone() - for _ in range(self.variational_dropout) + for _ in range(max(self.variational_dropout, 1)) ], dim=1, ) @@ -75,6 +91,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--variational_dropout", type=int, - default=None, + default=0, help="Optional: Amount of times to run the model to obtain a variational estimate of the stdev", ) diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index 0a7321f..cb90f45 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -18,8 +18,8 @@ import numpy as np import random -from pre_processing_for_ml import FitsDataset -from dino_model import DINOV2FeatureExtractor +from .pre_processing_for_ml import FitsDataset +from .dino_model import DINOV2FeatureExtractor PROFILE = False SEED = None @@ -469,18 +469,21 @@ def main( logging_dir=logging_dir, model=model, optimizer=optimizer, - normalize=normalize, - batch_size=batch_size, - use_compile=use_compile, - label_smoothing=label_smoothing, - stochastic_smoothing=stochastic_smoothing, - lift=lift, - use_lora=use_lora, - rank=rank, - alpha=alpha, - resize=resize, - lr=lr, - dropout_p=dropout_p, + args={ + "normalize": normalize, + "batch_size": batch_size, + "use_compile": use_compile, + "label_smoothing": label_smoothing, + "stochastic_smoothing": stochastic_smoothing, + "lift": lift, + "use_lora": use_lora, + "rank": rank, + "alpha": alpha, + "resize": resize, + "lr": lr, + "dropout_p": dropout_p, + "model_name": model_name, + }, ) best_val_loss = torch.inf @@ -706,18 +709,12 @@ def load_checkpoint(ckpt_path, device="gpu"): ckpt_path = f"{ckpt_path}/{possible_checkpoints[0]}" ckpt_dict = torch.load(ckpt_path, weights_only=False, map_location=device) - # ugh, this is so ugly, something something hindsight something something 20-20 - # FIXME: probably should do a pattern match, but this works for now - kwargs = str(Path(ckpt_path).parent).split("/")[-1].split("__") - print(ckpt_dict.keys()) - # strip 'model_' from the name - model_name = kwargs[1][6:] - lr = float(kwargs[2].split("_")[-1]) - normalize = int(kwargs[3].split("_")[-1]) - dropout_p = float(kwargs[4].split("_")[-1]) + model_name = ckpt_dict["args"]["model_name"] + lr = ckpt_dict["args"]["lr"] + dropout_p = ckpt_dict["args"]["dropout_p"] - model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p) + model = ckpt_dict["model"](model_name=model_name, dropout_p=dropout_p).to(device) model.load_state_dict(ckpt_dict["model_state_dict"]) try: @@ -729,7 +726,7 @@ def load_checkpoint(ckpt_path, device="gpu"): print(f"Could not load optim due to {e}; skipping.") optim = None - return {"model": model, "optim": optim, "normalize": normalize} + return {"model": model, "optim": optim, "args": ckpt_dict["args"]} def get_argparser():