-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Organised entry point to allow for HMI masking
- Loading branch information
1 parent
435994a
commit dafde4a
Showing
4 changed files
with
194 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Main pretraining and evaluation script for SDO-FM | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
import wandb | ||
|
||
from sdofm import utils | ||
from sdofm.datasets import DimmedSDOMLDataModule | ||
from sdofm.finetuning import Autocalibration | ||
|
||
|
||
class Finetuner(object): | ||
def __init__(self, cfg): | ||
self.cfg = cfg | ||
self.trainer = None | ||
self.data_module = None | ||
self.model = None | ||
|
||
match cfg.experiment.model: | ||
case "autocalibration": | ||
self.data_module = DimmedSDOMLDataModule( | ||
hmi_path=None, | ||
aia_path=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.aia, | ||
), | ||
eve_path=None, | ||
components=self.cfg.data.sdoml.components, | ||
wavelengths=self.cfg.data.sdoml.wavelengths, | ||
ions=self.cfg.data.sdoml.ions, | ||
frequency=self.cfg.data.sdoml.frequency, | ||
batch_size=self.cfg.model.opt.batch_size, | ||
num_workers=self.cfg.data.num_workers, | ||
val_months=self.cfg.data.month_splits.val, | ||
test_months=self.cfg.data.month_splits.test, | ||
holdout_months=self.cfg.data.month_splits.holdout, | ||
cache_dir=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.cache, | ||
), | ||
) | ||
self.data_module.setup() | ||
|
||
self.model = Autocalibration( | ||
**self.cfg.model.mae, | ||
optimiser=self.cfg.model.opt.optimiser, | ||
lr=self.cfg.model.opt.learning_rate, | ||
weight_decay=self.cfg.model.opt.weight_decay, | ||
) | ||
case _: | ||
raise NotImplementedError( | ||
f"Model {cfg.experiment.model} not implemented" | ||
) | ||
|
||
def run(self): | ||
print("\nFINE TUNING\n") | ||
|
||
if self.cfg.experiment.distributed: | ||
trainer = pl.Trainer( | ||
devices=self.cfg.experiment.distributed.world_size, | ||
accelerator=self.cfg.experiment.accelerator, | ||
max_epochs=self.cfg.model.opt.epochs, | ||
precision=self.cfg.experiment.precision, | ||
logger=self.logger, | ||
) | ||
else: | ||
trainer = pl.Trainer( | ||
accelerator=self.cfg.experiment.accelerator, | ||
max_epochs=self.cfg.model.opt.epochs, | ||
logger=self.logger, | ||
) | ||
trainer.fit(model=self.model, datamodule=self.data_module) | ||
return trainer | ||
|
||
def evaluate(self): | ||
self.trainer.evaluate() | ||
|
||
def test_sdofm(self): | ||
self.trainer.test(ckpt_path="best") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,118 @@ | ||
# Main pre-training and evaluation script for SDO-FM | ||
# Main pretraining and evaluation script for SDO-FM | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import LambdaCallback, ModelCheckpoint | ||
from pytorch_lightning.loggers.wandb import WandbLogger | ||
|
||
import wandb | ||
from sdofm.callback import ImagePredictionLoggerHMI | ||
from sdofm.datasets import ZarrIrradianceDataModuleHMI | ||
from sdofm.models.mae import HybridIrradianceModel | ||
from sdofm.utils import flatten_dict | ||
|
||
from sdofm import utils | ||
from sdofm.datasets import SDOMLDataModule | ||
from sdofm.pretraining import MAE, NVAE | ||
|
||
|
||
class Pretrainer(object): | ||
def __init__(self, cfg, logger): | ||
self.cfg = cfg | ||
self.logger = logger | ||
self.data_module = None | ||
self.model = None | ||
|
||
match cfg.experiment.model: | ||
case "mae": | ||
self.data_module = SDOMLDataModule( | ||
# hmi_path=os.path.join( | ||
# self.cfg.data.sdoml.base_directory, self.cfg.data.sdoml.sub_directory.hmi | ||
# ), | ||
hmi_path=None, | ||
aia_path=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.aia, | ||
), | ||
eve_path=None, | ||
components=self.cfg.data.sdoml.components, | ||
wavelengths=self.cfg.data.sdoml.wavelengths, | ||
ions=self.cfg.data.sdoml.ions, | ||
frequency=self.cfg.data.sdoml.frequency, | ||
batch_size=self.cfg.model.opt.batch_size, | ||
num_workers=self.cfg.data.num_workers, | ||
val_months=self.cfg.data.month_splits.val, | ||
test_months=self.cfg.data.month_splits.test, | ||
holdout_months=self.cfg.data.month_splits.holdout, | ||
cache_dir=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.cache, | ||
), | ||
) | ||
self.data_module.setup() | ||
|
||
def pretrain_sdofm(cfg): | ||
print("SDO-FM Model Pre-training") | ||
self.model = MAE( | ||
**cfg.model.mae, | ||
optimiser=cfg.model.opt.optimiser, | ||
lr=cfg.model.opt.learning_rate, | ||
weight_decay=cfg.model.opt.weight_decay, | ||
) | ||
case "nvae": | ||
self.data_module = SDOMLDataModule( | ||
hmi_path=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.hmi, | ||
), | ||
aia_path=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.aia, | ||
), | ||
eve_path=None, | ||
components=self.cfg.data.sdoml.components, | ||
wavelengths=self.cfg.data.sdoml.wavelengths, | ||
ions=self.cfg.data.sdoml.ions, | ||
frequency=self.cfg.data.sdoml.frequency, | ||
batch_size=self.cfg.model.opt.batch_size, | ||
num_workers=self.cfg.data.num_workers, | ||
val_months=self.cfg.data.month_splits.val, | ||
test_months=self.cfg.data.month_splits.test, | ||
holdout_months=self.cfg.data.month_splits.holdout, | ||
cache_dir=os.path.join( | ||
self.cfg.data.sdoml.base_directory, | ||
self.cfg.data.sdoml.sub_directory.cache, | ||
), | ||
) | ||
|
||
# set precision of torch tensors | ||
if cfg.experiment.precision == 64: | ||
torch.set_default_tensor_type(torch.DoubleTensor) | ||
elif cfg.experiment.precision == 32: | ||
torch.set_default_tensor_type(torch.FloatTensor) | ||
else: | ||
raise NotImplementedError( | ||
f"Precision {cfg.experiment.precision} not implemented" | ||
) | ||
self.model = NVAE( | ||
**cfg.model.nvae, | ||
optimiser=cfg.model.opt.optimiser, | ||
lr=cfg.model.opt.learning_rate, | ||
weight_decay=cfg.model.opt.weight_decay, | ||
hmi_mask=self.data_module.hmi_mask, | ||
) | ||
case _: | ||
raise NotImplementedError( | ||
f"Model {cfg.experiment.model} not implemented" | ||
) | ||
|
||
output_dir = Path(cfg.data.output_directory) | ||
output_dir.mkdir(exist_ok=True, parents=True) | ||
print(f"Created directory for storing results: {cfg.data.output_directory}") | ||
cache_dir = Path(f"{cfg.data.output_directory}/.cache") | ||
cache_dir.mkdir(exist_ok=True, parents=True) | ||
os.environ["WANDB_CACHE_DIR"] = f"{cfg.data.output_directory}/.cache" | ||
os.environ["WANDB_MODE"] = "offline" if cfg.experiment.disable_wandb else "online" | ||
def run(self): | ||
print("\nPRE-TRAINING\n") | ||
|
||
wandb_logger = WandbLogger( | ||
# WandbLogger params | ||
name=cfg.experiment.name, | ||
project=cfg.experiment.project, | ||
dir=cfg.data.output_directory, | ||
# kwargs for wandb.init | ||
tags=cfg.experiment.wandb.tags, | ||
notes=cfg.experimentw.wandb.notes, | ||
group=cfg.experiment.wandb_group, | ||
save_code=True, | ||
job_type=cfg.experiment.job_type, | ||
config=flatten_dict(cfg), | ||
) | ||
if self.cfg.experiment.distributed: | ||
trainer = pl.Trainer( | ||
devices=self.cfg.experiment.distributed.world_size, | ||
accelerator=self.cfg.experiment.accelerator, | ||
max_epochs=self.cfg.model.opt.epochs, | ||
precision=self.cfg.experiment.precision, | ||
logger=self.logger, | ||
) | ||
else: | ||
trainer = pl.Trainer( | ||
accelerator=self.cfg.experiment.accelerator, | ||
max_epochs=self.cfg.model.opt.epochs, | ||
logger=self.logger, | ||
) | ||
trainer.fit(model=self.model, datamodule=self.data_module) | ||
return trainer | ||
|
||
data_loader = ZarrIrradianceDataModuleHMI( | ||
hmi_path=os.path.join( | ||
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.hmi | ||
), | ||
aia_path=os.path.join( | ||
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.hmi | ||
), | ||
eve_path=os.path.join( | ||
cfg.data.sdoml.base_directory, cfg.data.sdoml.instrument_sub_directory.eve | ||
), | ||
components=cfg.data.sdoml.components, | ||
wavelengths=cfg.data.sdoml.wavelengths, | ||
# ions=run_config["sci_parameters"]["eve_ions"], | ||
# frequency=run_config["sci_parameters"]["frequency"], | ||
batch_size=cfg.model.opt.batch_size, | ||
num_workers=cfg.data.num_workers, | ||
# val_months=run_config["training_parameters"]["val_months"], | ||
# test_months=run_config["training_parameters"]["test_months"], | ||
# holdout_months=run_config["training_parameters"]["holdout_months"], | ||
cache_dir=cfg.data.sdoml.metadata, | ||
) | ||
def evaluate(self): | ||
self.trainer.evaluate() | ||
|
||
data_loader.setup() | ||
def test_sdofm(self): | ||
self.trainer.test(ckpt_path="best") |