diff --git a/experiments/pretrain_300M_samae.yaml b/experiments/pretrain_300M_samae.yaml new file mode 100755 index 0000000..aad7071 --- /dev/null +++ b/experiments/pretrain_300M_samae.yaml @@ -0,0 +1,152 @@ +# default.yaml + +# MODEL SUMMARY +# | Name | Type | Params +# ------------------------------------------------------- +# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M +# ------------------------------------------------------- +# 329 M Trainable params +# 4.7 M Non-trainable params +# 333 M Total params +# 1,335.838 Total estimated model params size (MB) + +# general +experiment: + name: "default" + project: "sdofm" + model: "mae" + task: "train" # options: train, evaluate (not implemented) + seed: 0 + disable_cuda: false + disable_wandb: false + wandb: + entity: "fdlx" + group: "sdofm-phase1" + job_type: "pretrain" + tags: [] + notes: "" + fold: null + evaluate: false # skip training and only evaluate (requires checkpoint to be set) + checkpoint: null # this is the wandb run_id of the checkpoint to load + device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() + precision: 64 # 32, 64 + log_n_batches: 1000 # log every n training batches + save_results: true # save full results to file and wandb + accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" + distributed: + enabled: true + backend: "ddp" + # nproc_per_node: 1 + # nnodes: 1 + world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate Accelerator. + # node_rank: 0 + # local_rank: 0 + # master_addr: "localhost" + # master_port: 12345 + +# dataset configuration +data: + min_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08' + max_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072' + month_splits: # non selected months will form training set + # train: [1,2,3,4,5,6,7,8,9,10] + val: [11] + test: [12] + holdout: [] + num_workers: 16 # set appropriately for your machine + output_directory: "output" + sdoml: + base_directory: "/mnt/sdoml" + sub_directory: + hmi: "HMI.zarr" + aia: "AIA.zarr" + eve: EVE_legacy.zarr" + cache: "cache" + components: null # null for select all magnetic components ["Bx", "By", "Bz"] + wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] + ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] + frequency: '12min' # smallest is 12min + mask_with_hmi_threshold: null # None/null for no mask, float for threshold + +# model configurations +model: + # PRETRAINERS + mae: + img_size: 512 + patch_size: 16 + num_frames: 5 + tubelet_size: 5 + in_chans: 9 + embed_dim: 4096 + depth: 24 + num_heads: 16 + decoder_embed_dim: 512 + decoder_depth: 8 + decoder_num_heads: 16 + mlp_ratio: 4.0 + # norm_layer: defaults to nn.LayerNorm + norm_pix_loss: False + samae: + # uses all parameters as in mae plus these + masking_type: "solar_aware" # 'random' or 'solar_aware' + active_region_mu_degs: 15.73 + active_region_std_degs: 6.14 + active_region_scale: 1.0 + active_region_abs_lon_max_degs: 60 + active_region_abs_lat_max_degs: 60 + nvae: + use_se: true + res_dist: true + num_x_bits: 8 + num_latent_scales: 3 # 5 + num_groups_per_scale: 1 # 16 + num_latent_per_group: 1 # 10 + ada_groups: true + min_groups_per_scale: 1 + num_channels_enc: 30 + num_channels_dec: 30 + num_preprocess_blocks: 2 # 1 + num_preprocess_cells: 2 + num_cell_per_cond_enc: 2 + num_postprocess_blocks: 2 # 1 + num_postprocess_cells: 2 + num_cell_per_cond_dec: 2 + num_mixture_dec: 1 + num_nf: 2 + kl_anneal_portion: 0.3 + kl_const_portion: 0.0001 + kl_const_coeff: 0.0001 + # learning_rate: 1e-2 + # weight_decay: 3e-4 + weight_decay_norm_anneal: true + weight_decay_norm_init: 1. + weight_decay_norm: 1e-2 + + # FINE-TUNERS + dimming: + num_neck_filters: 32 + output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar + loss: "mse" # options: "mse", "heteroscedastic" + freeze_encoder: true + + + # ML optimization arguments: + opt: + loss: "mse" # options: "mae", "mse", "mape" + scheduler: "constant" #other options: "cosine", "plateau", "exp" + scheduler_warmup: 0 + batch_size: 256 + learning_rate: 0.0001 + weight_decay: 3e-4 # 0.0 + optimiser: "adam" + epochs: 4 + patience: 2 + +# hydra configuration +hydra: + mode: MULTIRUN + run: + dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} + sweep: + dir: ${hydra.run.dir} + subdir: ${hydra.job.num} \ No newline at end of file diff --git a/scripts/pretrain.py b/scripts/pretrain.py index dbf14f8..4dc8e98 100755 --- a/scripts/pretrain.py +++ b/scripts/pretrain.py @@ -9,7 +9,7 @@ from sdofm import utils from sdofm.datasets import SDOMLDataModule -from sdofm.pretraining import MAE, NVAE +from sdofm.pretraining import MAE, NVAE, SAMAE class Pretrainer(object): @@ -53,6 +53,36 @@ def __init__(self, cfg, logger): lr=cfg.model.opt.learning_rate, weight_decay=cfg.model.opt.weight_decay, ) + case "samae": + data_module = SDOMLDataModule( + hmi_path=None, + aia_path=os.path.join( + cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia + ), + eve_path=None, + components=cfg.data.sdoml.components, + wavelengths=cfg.data.sdoml.wavelengths, + ions=cfg.data.sdoml.ions, + frequency=cfg.data.sdoml.frequency, + batch_size=cfg.model.opt.batch_size, + num_workers=cfg.data.num_workers, + val_months=cfg.data.month_splits.val, + test_months=cfg.data.month_splits.test, + holdout_months=cfg.data.month_splits.holdout, + cache_dir=os.path.join( + cfg.data.sdoml.base_directory, + cfg.data.sdoml.sub_directory.cache, + ), + ) + data_module.setup() + model = SAMAE( + **cfg.model.mae, + **cfg.model.samae, + optimiser=cfg.model.opt.optimiser, + lr=cfg.model.opt.learning_rate, + weight_decay=cfg.model.opt.weight_decay, + ) + case "nvae": self.data_module = SDOMLDataModule( hmi_path=os.path.join(