From 1628693b1a38fc8a1479ea3c8e9fd9322d2f1672 Mon Sep 17 00:00:00 2001 From: dead-water Date: Thu, 4 Jul 2024 00:41:24 +0000 Subject: [PATCH] Added preliminary embedding export script --- .gitignore | 1 + experiments/pretrain_32.2M_mae_tpu_2048.yaml | 159 +++++++++++++++++ experiments/pretrain_nvae.yaml | 6 +- scripts/export_dataset.py | 175 +++++++++++++++++++ sdofm/BaseModule.py | 4 +- sdofm/datasets/TimestampedSDOML.py | 92 ++++++++++ sdofm/datasets/__init__.py | 3 +- 7 files changed, 435 insertions(+), 5 deletions(-) create mode 100755 experiments/pretrain_32.2M_mae_tpu_2048.yaml create mode 100644 scripts/export_dataset.py create mode 100644 sdofm/datasets/TimestampedSDOML.py diff --git a/.gitignore b/.gitignore index e9f9067..e2ae6b8 100755 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ wandb output outputs +*.tar # aux directories .vscode diff --git a/experiments/pretrain_32.2M_mae_tpu_2048.yaml b/experiments/pretrain_32.2M_mae_tpu_2048.yaml new file mode 100755 index 0000000..81e59d4 --- /dev/null +++ b/experiments/pretrain_32.2M_mae_tpu_2048.yaml @@ -0,0 +1,159 @@ +# default.yaml + +# MODEL SUMMARY +# | Name | Type | Params +# ------------------------------------------------------- +# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M +# ------------------------------------------------------- +# 329 M Trainable params +# 4.7 M Non-trainable params +# 333 M Total params +# 1,335.838 Total estimated model params size (MB) + +# general +log_level: 'DEBUG' +experiment: + name: "mae-2011" # generate random name in wandb + project: "sdofm" + task: "pretrain" # options: train, evaluate (not implemented) + model: "mae" + backbone_checkpoint: null + resuming: false + seed: 0 + disable_cuda: false + wandb: + enable: true + entity: "fdlx" + group: "sdofm-phase1" + job_type: "pretrain" + tags: [] + notes: "" + output_directory: "wandb_output" + log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches) + gcp_storage: # this will checkpoint all epoches, perhaps clean up this config + enabled: true + bucket: "sdofm-checkpoints" + fold: null + evaluate: false # skip training and only evaluate (requires checkpoint to be set) + checkpoint: null # this is the wandb run_id of the checkpoint to load + device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() + precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu + log_n_batches: 1000 # log every n training batches + save_results: true # save full results to file and wandb + accelerator: "tpu" # options are "auto", "gpu", "tpu", "ipu", or "cpu" + profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size + distributed: + enabled: true + world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. + strategy: "auto" + log_every_n_steps: 25 + +# dataset configuration +data: + min_date: '2011-01-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08' + max_date: '2011-12-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072' + month_splits: # non selected months will form training set + # train: [1,2,3,4,5,6,7,8,9,10] + val: [11] + test: [12] + holdout: [] + num_workers: 8 # set appropriately for your machine + prefetch_factor: 2 + num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE + # output_directory: "wandb_output" + sdoml: + base_directory: "/mnt/sdoml" + sub_directory: + hmi: "HMI.zarr" + aia: "AIA.zarr" + eve: "EVE_legacy.zarr" + cache: "cache" + components: null # null for select all magnetic components ["Bx", "By", "Bz"] + wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] + ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] + frequency: '12min' # smallest is 12min + mask_with_hmi_threshold: null # None/null for no mask, float for threshold + drop_frame_dim: false + +# model configurations +model: + # PRETRAINERS + mae: + img_size: 512 + patch_size: 16 + num_frames: 1 + tubelet_size: 1 + in_chans: 9 + embed_dim: 2048 + depth: 24 + num_heads: 16 + decoder_embed_dim: 512 + decoder_depth: 8 + decoder_num_heads: 16 + mlp_ratio: 4.0 + norm_layer: 'LayerNorm' + norm_pix_loss: False + masking_ratio: 0.5 + samae: + # uses all parameters as in mae plus these + masking_type: "random" # 'random' or 'solar_aware' + active_region_mu_degs: 15.73 + active_region_std_degs: 6.14 + active_region_scale: 1.0 + active_region_abs_lon_max_degs: 60 + active_region_abs_lat_max_degs: 60 + nvae: + use_se: true + res_dist: true + num_x_bits: 8 + num_latent_scales: 3 # 5 + num_groups_per_scale: 1 # 16 + num_latent_per_group: 1 # 10 + ada_groups: true + min_groups_per_scale: 1 + num_channels_enc: 30 + num_channels_dec: 30 + num_preprocess_blocks: 2 # 1 + num_preprocess_cells: 2 + num_cell_per_cond_enc: 2 + num_postprocess_blocks: 2 # 1 + num_postprocess_cells: 2 + num_cell_per_cond_dec: 2 + num_mixture_dec: 1 + num_nf: 2 + kl_anneal_portion: 0.3 + kl_const_portion: 0.0001 + kl_const_coeff: 0.0001 + # learning_rate: 1e-2 + # weight_decay: 3e-4 + weight_decay_norm_anneal: true + weight_decay_norm_init: 1. + weight_decay_norm: 1e-2 + + # FINE-TUNERS + degragation: + num_neck_filters: 32 + output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar + loss: "mse" # options: "mse", "heteroscedastic" + freeze_encoder: true + + # ML optimization arguments: + opt: + loss: "mse" # options: "mae", "mse", "mape" + scheduler: "constant" #other options: "cosine", "plateau", "exp" + scheduler_warmup: 0 + batch_size: 2 + learning_rate: 0.0001 + weight_decay: 3e-4 # 0.0 + optimiser: "adam" + epochs: 100 + patience: 2 + +# hydra configuration +hydra: + mode: RUN + # run: + # dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} + # sweep: + # dir: ${hydra.run.dir} + # subdir: ${hydra.job.num} \ No newline at end of file diff --git a/experiments/pretrain_nvae.yaml b/experiments/pretrain_nvae.yaml index 982a691..32089bf 100755 --- a/experiments/pretrain_nvae.yaml +++ b/experiments/pretrain_nvae.yaml @@ -36,11 +36,11 @@ experiment: evaluate: false # skip training and only evaluate (requires checkpoint to be set) checkpoint: null # this is the wandb run_id of the checkpoint to load device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() - precision: "bf16-true" # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu + precision: 32 # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu log_n_batches: 1000 # log every n training batches save_results: true # save full results to file and wandb accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" - profiler: 'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size + profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size distributed: enabled: true world_size: 1 # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. @@ -146,7 +146,7 @@ model: loss: "mse" # options: "mae", "mse", "mape" scheduler: "constant" #other options: "cosine", "plateau", "exp" scheduler_warmup: 0 - batch_size: 5 + batch_size: 2 learning_rate: 0.0001 weight_decay: 3e-4 # 0.0 optimiser: "adam" diff --git a/scripts/export_dataset.py b/scripts/export_dataset.py new file mode 100644 index 0000000..814d982 --- /dev/null +++ b/scripts/export_dataset.py @@ -0,0 +1,175 @@ +import argparse +import codecs +import datetime +import os +import pprint +import shutil +import sys +import time +import warnings +from pathlib import Path + +import numpy as np +import omegaconf +import torch +import webdataset as wds +from tqdm import tqdm + +from sdofm.datasets import TimestampedSDOMLDataModule +from sdofm.models import ConvTransformerTokensToEmbeddingNeck +from sdofm.pretraining import MAE, NVAE +from sdofm.utils import days_hours_mins_secs_str + + +def main(): + # fmt: off + parser=argparse.ArgumentParser(description="SDO Latent Dataset generation script", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("--experiment", "-e", help="Training configuration file", default="../experiments/pretrain", type=str, ) + parser.add_argument("--model", "-m", help="Pretrained model filename", required=True, type=str) + parser.add_argument("--src_id", "-i", help="Source Identify, e.g. W&B run ID", default=None, type=str, ) + parser.add_argument("--out_dir", "-o", help="Output directory to save the dataset", default="", type=str, ) + parser.add_argument("--batch_size", "-b", help="Batch size", default=32, type=int) + parser.add_argument("--num_workers", "-n", help="Number of workers", default=0, type=int) + parser.add_argument("--max", help="Maximum number of samples to process", default=None, type=int) + parser.add_argument("--device", help="Compute device(cpu, cuda:0, cuda:1, etc.)", default="cpu", type=str, ) + opt = parser.parse_args() + # fmt: on + + print("SDO-FM Dataset Export Script, modified from work by") + print("NASA FDL-X 2023 Thermospheric Drag Team") + command_line_arguments = " ".join(sys.argv[1:]) + print("Arguments:\n{}\n".format(command_line_arguments)) + print("Script Config:") + pprint.pprint(vars(opt), depth=2, width=50) + + # Datamodule + print("\nLoading Data:") + cfg = omegaconf.OmegaConf.load(opt.experiment) + + data_module = TimestampedSDOMLDataModule( + # hmi_path=os.path.join( + # cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi + # ), + hmi_path=None, + aia_path=os.path.join( + cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia + ), + eve_path=None, + components=cfg.data.sdoml.components, + wavelengths=cfg.data.sdoml.wavelengths, + ions=cfg.data.sdoml.ions, + frequency=cfg.data.sdoml.frequency, + batch_size=cfg.model.opt.batch_size, + num_workers=cfg.data.num_workers, + val_months=[], # cfg.data.month_splits.val, + # fmt: off + test_months=[1,2,3,4,5,6,7,8,9,10,11,12,], + # cfg.data.month_splits.test, + # fmt: on + holdout_months=[], # =cfg.data.month_splits.holdout, + cache_dir=os.path.join( + cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache + ), + min_date=cfg.data.min_date, + max_date=cfg.data.max_date, + num_frames=cfg.data.num_frames, + drop_frame_dim=cfg.data.drop_frame_dim, + ) + data_module.setup() + + print("\nLoading model:") + try: + match (cfg.experiment.model): + case "nvae": + model = NVAE.load_from_checkpoint( + opt.model, map_location=torch.device(opt.device) + ) + z_dim, z_shapes = model.z_dim, model.z_shapes + latent_dim = z_dim.shape[0] + case "mae": + model = MAE.load_from_checkpoint( + opt.model, map_location=torch.device(opt.device) + ) + latent_dim = cfg.model.mae.embed_dim + + num_tokens = 512 // 16 + emb_decoder = ConvTransformerTokensToEmbeddingNeck( + embed_dim=latent_dim, + # output_embed_dim=32, + output_embed_dim=latent_dim, + Hp=num_tokens, + Wp=num_tokens, + drop_cls_token=True, + num_frames=1, + ) + case _: + raise ValueError( + f"Model export of {cfg.experiment.model} not yet supported." + ) + print("Model ready.") + except Exception as e: + print(f"Could not load model at {opt.model}") + raise e + + # Begin + + output = f"SDOFM-{datetime.datetime.now().strftime('%Y_%m_%d_%H%M')}-embsize_{latent_dim}-period_{data_module.min_date.strftime('%Y_%m_%d_%H%M')}_{data_module.max_date.strftime('%Y_%m_%d_%H%M')}" + if opt.src_id: + output += f"-{opt.src_id}" + output = Path(opt.out_dir) / (output + ".tar") + + print(f"\nBeginning webdataset creation at {output}") + count = 0 + latent_dim_checked = False + + sink = wds.TarWriter(f"{output}.tar") + + dl = data_module.test_dataloader() + pbar = tqdm(total=len(dl)) + for data in dl: + pbar.update(1) + + x, timestamps = data["image_stack"], data["timestamps"][0] + batch_size = x.shape[0] + + match (cfg.experiment.model): + case "nvae": + z = model.encode(x) + case "mae": + x_hat, _, _ = model.autoencoder.forward_encoder(x, mask_ratio=0.0) + x_hat = x_hat.squeeze(dim=2) + z = emb_decoder(x_hat) + + z = z.detach().cpu().numpy() + if not latent_dim_checked: + if latent_dim != z.shape[1]: + warnings.warn( + "Latent dimension mismatch: {} vs {}".format(latent_dim, z.shape[1]) + ) + latent_dim_checked = True + + for i in range(batch_size): + sink.write( + { + "__key__": timestamps[i], + "emb.pyd": z[i], + } + ) + count += 1 + + if opt.max is not None and count >= opt.max: + print("Stopping after {} samples".format(opt.max)) + break + + sink.close() + + +if __name__ == "__main__": + time_start = time.time() + main() + print( + "\nTotal duration: {}".format( + days_hours_mins_secs_str(time.time() - time_start) + ) + ) + sys.exit(0) diff --git a/sdofm/BaseModule.py b/sdofm/BaseModule.py index d0a4f83..af69e73 100644 --- a/sdofm/BaseModule.py +++ b/sdofm/BaseModule.py @@ -10,8 +10,10 @@ def __init__( weight_decay: float = 0.0, hyperparam_ignore=[], # pass to pl.LightningModule + *args, + **kwargs ): - super().__init__() + super().__init__(*args, **kwargs) self.save_hyperparameters(ignore=hyperparam_ignore) # optimiser values diff --git a/sdofm/datasets/TimestampedSDOML.py b/sdofm/datasets/TimestampedSDOML.py new file mode 100644 index 0000000..a6b9701 --- /dev/null +++ b/sdofm/datasets/TimestampedSDOML.py @@ -0,0 +1,92 @@ +from .SDOML import SDOMLDataModule, SDOMLDataset +from ..io import io +import pandas as pd +import numpy as np + + +class TimestampedSDOMLDataModule(SDOMLDataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup(self, stage=None): + + self.train_ds = TimestampedSDOMLDataset( + self.aligndata, + self.hmi_data, + self.aia_data, + self.eve_data, + self.components, + self.wavelengths, + self.ions, + self.cadence, + self.test_months, + normalizations=self.normalizations, + mask=self.hmi_mask.numpy(), + num_frames=self.num_frames, + drop_frame_dim=self.drop_frame_dim, + min_date=self.min_date, + max_date=self.max_date, + ) + + self.valid_ds = TimestampedSDOMLDataset( + self.aligndata, + self.hmi_data, + self.aia_data, + self.eve_data, + self.components, + self.wavelengths, + self.ions, + self.cadence, + self.test_months, + normalizations=self.normalizations, + mask=self.hmi_mask.numpy(), + num_frames=self.num_frames, + drop_frame_dim=self.drop_frame_dim, + min_date=self.min_date, + max_date=self.max_date, + ) + + self.test_ds = TimestampedSDOMLDataset( + self.aligndata, + self.hmi_data, + self.aia_data, + self.eve_data, + self.components, + self.wavelengths, + self.ions, + self.cadence, + self.test_months, + normalizations=self.normalizations, + mask=self.hmi_mask.numpy(), + num_frames=self.num_frames, + drop_frame_dim=self.drop_frame_dim, + min_date=self.min_date, + max_date=self.max_date, + ) + + +class TimestampedSDOMLDataset(SDOMLDataset): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getitem__(self, idx): + + # sample num_frames between idx and idx - sampling_period + items = self.aligndata.iloc[idx] + # print(items.index, idx, pd.DataFrame([items])) + timestamps = [i.strftime("%Y-%m-%d %H:%M:%S") for i in pd.DataFrame([items]).index] + + r = {'timestamps': timestamps} + + if self.eve_data: + image_stack, eve_data = super().__getitem__(idx) + r['eve_data'] = eve_data + else: + image_stack = super().__getitem__(idx) + + r['image_stack'] = image_stack + + return r + diff --git a/sdofm/datasets/__init__.py b/sdofm/datasets/__init__.py index 885859b..62f7c1a 100755 --- a/sdofm/datasets/__init__.py +++ b/sdofm/datasets/__init__.py @@ -2,4 +2,5 @@ from .SDOML import SDOMLDataModule from .SynopticSDOML import SynopticSDOMLDataModule from .BrightSpotsSDOML import BrightSpotsSDOMLDataModule -from .RandomIntervalSDOML import RandomIntervalSDOMLDataModule \ No newline at end of file +from .RandomIntervalSDOML import RandomIntervalSDOMLDataModule +from .TimestampedSDOML import TimestampedSDOMLDataModule \ No newline at end of file