Skip to content

Commit

Permalink
Organised entry point to allow for HMI masking
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Apr 15, 2024
1 parent 435994a commit dafde4a
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 150 deletions.
82 changes: 82 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Main pretraining and evaluation script for SDO-FM

import os
from pathlib import Path

import pytorch_lightning as pl
import torch
import wandb

from sdofm import utils
from sdofm.datasets import DimmedSDOMLDataModule
from sdofm.finetuning import Autocalibration


class Finetuner(object):
def __init__(self, cfg):
self.cfg = cfg
self.trainer = None
self.data_module = None
self.model = None

match cfg.experiment.model:
case "autocalibration":
self.data_module = DimmedSDOMLDataModule(
hmi_path=None,
aia_path=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.aia,
),
eve_path=None,
components=self.cfg.data.sdoml.components,
wavelengths=self.cfg.data.sdoml.wavelengths,
ions=self.cfg.data.sdoml.ions,
frequency=self.cfg.data.sdoml.frequency,
batch_size=self.cfg.model.opt.batch_size,
num_workers=self.cfg.data.num_workers,
val_months=self.cfg.data.month_splits.val,
test_months=self.cfg.data.month_splits.test,
holdout_months=self.cfg.data.month_splits.holdout,
cache_dir=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.cache,
),
)
self.data_module.setup()

self.model = Autocalibration(
**self.cfg.model.mae,
optimiser=self.cfg.model.opt.optimiser,
lr=self.cfg.model.opt.learning_rate,
weight_decay=self.cfg.model.opt.weight_decay,
)
case _:
raise NotImplementedError(
f"Model {cfg.experiment.model} not implemented"
)

def run(self):
print("\nFINE TUNING\n")

if self.cfg.experiment.distributed:
trainer = pl.Trainer(
devices=self.cfg.experiment.distributed.world_size,
accelerator=self.cfg.experiment.accelerator,
max_epochs=self.cfg.model.opt.epochs,
precision=self.cfg.experiment.precision,
logger=self.logger,
)
else:
trainer = pl.Trainer(
accelerator=self.cfg.experiment.accelerator,
max_epochs=self.cfg.model.opt.epochs,
logger=self.logger,
)
trainer.fit(model=self.model, datamodule=self.data_module)
return trainer

def evaluate(self):
self.trainer.evaluate()

def test_sdofm(self):
self.trainer.test(ckpt_path="best")
74 changes: 0 additions & 74 deletions scripts/finetune_autocalibration.py

This file was deleted.

22 changes: 8 additions & 14 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,17 @@ def main(cfg: DictConfig) -> None:
config=flatten_dict(cfg),
)

match cfg.experiment.model:
case "mae":
match cfg.experient.task:
case "pretrain":
from scripts.pretrain import Pretrainer

match cfg.experiment.task:
case "train":
pretrainer = Pretrainer(cfg, logger=wandb_logger)
pretrainer.run()

case "autocalibration":
from scripts.pretrain import AutocalibrationFinetuner

match cfg.experiment.task:
case "train":
finetuner = AutocalibrationFinetuner(cfg, logger=wandb_logger)
finetuner.run()
pretrainer = Pretrainer(cfg, logger=wandb_logger)
pretrainer.run()
case "finetune":
from scripts.finetune import Finetuner

finetuner = Finetuner(cfg, logger=wandb_logger)
finetuner.run()
case _:
raise NotImplementedError(
f"Experiment {cfg.experiment.task} not implemented"
Expand Down
166 changes: 104 additions & 62 deletions scripts/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,118 @@
# Main pre-training and evaluation script for SDO-FM
# Main pretraining and evaluation script for SDO-FM

import os
from pathlib import Path

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LambdaCallback, ModelCheckpoint
from pytorch_lightning.loggers.wandb import WandbLogger

import wandb
from sdofm.callback import ImagePredictionLoggerHMI
from sdofm.datasets import ZarrIrradianceDataModuleHMI
from sdofm.models.mae import HybridIrradianceModel
from sdofm.utils import flatten_dict

from sdofm import utils
from sdofm.datasets import SDOMLDataModule
from sdofm.pretraining import MAE, NVAE


class Pretrainer(object):
def __init__(self, cfg, logger):
self.cfg = cfg
self.logger = logger
self.data_module = None
self.model = None

match cfg.experiment.model:
case "mae":
self.data_module = SDOMLDataModule(
# hmi_path=os.path.join(
# self.cfg.data.sdoml.base_directory, self.cfg.data.sdoml.sub_directory.hmi
# ),
hmi_path=None,
aia_path=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.aia,
),
eve_path=None,
components=self.cfg.data.sdoml.components,
wavelengths=self.cfg.data.sdoml.wavelengths,
ions=self.cfg.data.sdoml.ions,
frequency=self.cfg.data.sdoml.frequency,
batch_size=self.cfg.model.opt.batch_size,
num_workers=self.cfg.data.num_workers,
val_months=self.cfg.data.month_splits.val,
test_months=self.cfg.data.month_splits.test,
holdout_months=self.cfg.data.month_splits.holdout,
cache_dir=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.cache,
),
)
self.data_module.setup()

def pretrain_sdofm(cfg):
print("SDO-FM Model Pre-training")
self.model = MAE(
**cfg.model.mae,
optimiser=cfg.model.opt.optimiser,
lr=cfg.model.opt.learning_rate,
weight_decay=cfg.model.opt.weight_decay,
)
case "nvae":
self.data_module = SDOMLDataModule(
hmi_path=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.hmi,
),
aia_path=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.aia,
),
eve_path=None,
components=self.cfg.data.sdoml.components,
wavelengths=self.cfg.data.sdoml.wavelengths,
ions=self.cfg.data.sdoml.ions,
frequency=self.cfg.data.sdoml.frequency,
batch_size=self.cfg.model.opt.batch_size,
num_workers=self.cfg.data.num_workers,
val_months=self.cfg.data.month_splits.val,
test_months=self.cfg.data.month_splits.test,
holdout_months=self.cfg.data.month_splits.holdout,
cache_dir=os.path.join(
self.cfg.data.sdoml.base_directory,
self.cfg.data.sdoml.sub_directory.cache,
),
)

# set precision of torch tensors
if cfg.experiment.precision == 64:
torch.set_default_tensor_type(torch.DoubleTensor)
elif cfg.experiment.precision == 32:
torch.set_default_tensor_type(torch.FloatTensor)
else:
raise NotImplementedError(
f"Precision {cfg.experiment.precision} not implemented"
)
self.model = NVAE(
**cfg.model.nvae,
optimiser=cfg.model.opt.optimiser,
lr=cfg.model.opt.learning_rate,
weight_decay=cfg.model.opt.weight_decay,
hmi_mask=self.data_module.hmi_mask,
)
case _:
raise NotImplementedError(
f"Model {cfg.experiment.model} not implemented"
)

output_dir = Path(cfg.data.output_directory)
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Created directory for storing results: {cfg.data.output_directory}")
cache_dir = Path(f"{cfg.data.output_directory}/.cache")
cache_dir.mkdir(exist_ok=True, parents=True)
os.environ["WANDB_CACHE_DIR"] = f"{cfg.data.output_directory}/.cache"
os.environ["WANDB_MODE"] = "offline" if cfg.experiment.disable_wandb else "online"
def run(self):
print("\nPRE-TRAINING\n")

wandb_logger = WandbLogger(
# WandbLogger params
name=cfg.experiment.name,
project=cfg.experiment.project,
dir=cfg.data.output_directory,
# kwargs for wandb.init
tags=cfg.experiment.wandb.tags,
notes=cfg.experimentw.wandb.notes,
group=cfg.experiment.wandb_group,
save_code=True,
job_type=cfg.experiment.job_type,
config=flatten_dict(cfg),
)
if self.cfg.experiment.distributed:
trainer = pl.Trainer(
devices=self.cfg.experiment.distributed.world_size,
accelerator=self.cfg.experiment.accelerator,
max_epochs=self.cfg.model.opt.epochs,
precision=self.cfg.experiment.precision,
logger=self.logger,
)
else:
trainer = pl.Trainer(
accelerator=self.cfg.experiment.accelerator,
max_epochs=self.cfg.model.opt.epochs,
logger=self.logger,
)
trainer.fit(model=self.model, datamodule=self.data_module)
return trainer

data_loader = ZarrIrradianceDataModuleHMI(
hmi_path=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.hmi
),
aia_path=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.hmi
),
eve_path=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.eve
),
components=cfg.data.sdoml.components,
wavelengths=cfg.data.sdoml.wavelengths,
# ions=run_config["sci_parameters"]["eve_ions"],
# frequency=run_config["sci_parameters"]["frequency"],
batch_size=cfg.model.opt.batch_size,
num_workers=cfg.data.num_workers,
# val_months=run_config["training_parameters"]["val_months"],
# test_months=run_config["training_parameters"]["test_months"],
# holdout_months=run_config["training_parameters"]["holdout_months"],
cache_dir=cfg.data.sdoml.metadata,
)
def evaluate(self):
self.trainer.evaluate()

data_loader.setup()
def test_sdofm(self):
self.trainer.test(ckpt_path="best")

0 comments on commit dafde4a

Please sign in to comment.