diff --git a/experiments/finetune_brightspots_virtualeve.yaml b/experiments/finetune_brightspots_virtualeve.yaml new file mode 100644 index 0000000..359c5c3 --- /dev/null +++ b/experiments/finetune_brightspots_virtualeve.yaml @@ -0,0 +1,157 @@ +# finetune_32.2M_mae_virtualeve.yaml + +# general +log_level: 'DEBUG' +experiment: + name: null # generate random name in wandb when set to null + project: "sdofm" + task: "finetune" # options: train, evaluate (not implemented) + model: "virtualeve" + resuming: false + checkpoint: null # this is the wandb run_id of the checkpoint to load + backbone: + checkpoint: "model-tk45el88:best" #"mae128-epoch=17-step=139302.ckpt" #"sdofm/runs/771lx6o3:best" Only use models inside project, it will fail path otherwise + model: "brightspots" + seed: 0 + disable_cuda: false + wandb: + enable: true + entity: "fdlx" + group: "sdofm-phase1" + job_type: "finetune" + tags: [] + notes: "" + output_directory: "wandb_output" + log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches) + gcp_storage: # this will checkpoint all epoches and upload them to a GCP bucket, W&B will store references (TODO: perhaps explain this better) + enabled: true + bucket: "sdofm-checkpoints" + fold: null + evaluate: false # skip training and only evaluate (requires checkpoint to be set) + device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() + precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu + log_n_batches: 1000 # log every n training batches + save_results: true # save full results to file and wandb + accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" + profiler: null # options are 'XLAProfiler' (TPU), 'PyTorchProfiler', warning: PyTorchProfiler only works on cpu/gpu according to docs + distributed: + enabled: true # set to true to use more than one device + world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. + strategy: "ddp_find_unused_parameters_true" + log_every_n_steps: 5 + +# dataset configuration +data: + min_date: '2011-10-01 00:00:00.00' # minimum is '2010-09-09 00:00:11.08' + max_date: '2011-12-31 23:59:59.99' # maximum is '2023-05-26 06:36:08.072' + month_splits: # non selected months will form training set + # train: [1,2,3,4,5,6,7,8,9,10] + val: [11] + test: [12] + holdout: [] + num_workers: 32 # set appropriately for your machine + prefetch_factor: 3 # TODO: not implemented, 2 is default + num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE + drop_frame_dim: False # Requires num_frames=1, for backwards compatibility + sdoml: + base_directory: "/mnt/sdoml" + sub_directory: + hmi: "HMI.zarr" + aia: "AIA.zarr" + eve: "EVE_legacy.zarr" + cache: "cache" + components: null # null for select all magnetic components ["Bx", "By", "Bz"] + wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] + ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] + frequency: '12min' # smallest is 12min + mask_with_hmi_threshold: null # None/null for no mask, float for threshold + +# model configurations +model: + # PRETRAINERS + mae: + img_size: 512 + patch_size: 16 + num_frames: 1 + tubelet_size: 1 + in_chans: 9 + embed_dim: 512 + depth: 24 + num_heads: 16 + decoder_embed_dim: 512 + decoder_depth: 8 + decoder_num_heads: 16 + mlp_ratio: 4.0 + norm_layer: 'LayerNorm' + norm_pix_loss: False + masking_ratio: 0.5 + samae: + # uses all parameters as in mae plus these + masking_type: "random" # 'random' or 'solar_aware' + active_region_mu_degs: 15.73 + active_region_std_degs: 6.14 + active_region_scale: 1.0 + active_region_abs_lon_max_degs: 60 + active_region_abs_lat_max_degs: 60 + nvae: + use_se: true + res_dist: true + num_x_bits: 8 + num_latent_scales: 3 # 5 + num_groups_per_scale: 1 # 16 + num_latent_per_group: 1 # 10 + ada_groups: true + min_groups_per_scale: 1 + num_channels_enc: 30 + num_channels_dec: 30 + num_preprocess_blocks: 2 # 1 + num_preprocess_cells: 2 + num_cell_per_cond_enc: 2 + num_postprocess_blocks: 2 # 1 + num_postprocess_cells: 2 + num_cell_per_cond_dec: 2 + num_mixture_dec: 1 + num_nf: 2 + kl_anneal_portion: 0.3 + kl_const_portion: 0.0001 + kl_const_coeff: 0.0001 + # learning_rate: 1e-2 + # weight_decay: 3e-4 + weight_decay_norm_anneal: true + weight_decay_norm_init: 1. + weight_decay_norm: 1e-2 + + # FINE-TUNERS + autocalibration: + num_neck_filters: 32 + output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar + loss: "mse" # options: "mse", "heteroscedastic" + freeze_encoder: true + virtualeve: + num_neck_filters: 32 + cnn_model: "efficientnet_b3" + lr_linear: 0.01 + lr_cnn: 0.0001 + cnn_dp: 0.75 + epochs_linear: 20 + + # ML optimization arguments: + opt: + loss: "mse" # options: "mae", "mse", "mape" + scheduler: "constant" #other options: "cosine", "plateau", "exp" + scheduler_warmup: 0 + batch_size: 16 + learning_rate: 0.0001 + weight_decay: 3e-4 # 0.0 + optimiser: "adam" + epochs: 50 + patience: 2 + +# hydra configuration +hydra: + mode: RUN + # run: + # dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} + # sweep: + # dir: ${hydra.run.dir} + # subdir: ${hydra.job.num} \ No newline at end of file diff --git a/experiments/moe.yaml b/experiments/moe.yaml new file mode 100644 index 0000000..61d449f --- /dev/null +++ b/experiments/moe.yaml @@ -0,0 +1,157 @@ +# finetune_32.2M_mae_virtualeve.yaml + +# general +log_level: 'DEBUG' +experiment: + name: null # generate random name in wandb when set to null + project: "sdofm" + task: "finetune" # options: train, evaluate (not implemented) + model: "virtualeve" + resuming: false + checkpoint: null # this is the wandb run_id of the checkpoint to load + backbone: + checkpoint: "model-tk45el88:best" #"mae128-epoch=17-step=139302.ckpt" #"sdofm/runs/771lx6o3:best" Only use models inside project, it will fail path otherwise + model: "mae" + seed: 0 + disable_cuda: false + wandb: + enable: true + entity: "fdlx" + group: "sdofm-phase1" + job_type: "finetune" + tags: [] + notes: "" + output_directory: "wandb_output" + log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches) + gcp_storage: # this will checkpoint all epoches and upload them to a GCP bucket, W&B will store references (TODO: perhaps explain this better) + enabled: true + bucket: "sdofm-checkpoints" + fold: null + evaluate: false # skip training and only evaluate (requires checkpoint to be set) + device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() + precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu + log_n_batches: 1000 # log every n training batches + save_results: true # save full results to file and wandb + accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" + profiler: null # options are 'XLAProfiler' (TPU), 'PyTorchProfiler', warning: PyTorchProfiler only works on cpu/gpu according to docs + distributed: + enabled: true # set to true to use more than one device + world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. + strategy: "ddp_find_unused_parameters_true" + log_every_n_steps: 5 + +# dataset configuration +data: + min_date: '2011-10-01 00:00:00.00' # minimum is '2010-09-09 00:00:11.08' + max_date: '2011-12-31 23:59:59.99' # maximum is '2023-05-26 06:36:08.072' + month_splits: # non selected months will form training set + # train: [1,2,3,4,5,6,7,8,9,10] + val: [11] + test: [12] + holdout: [] + num_workers: 32 # set appropriately for your machine + prefetch_factor: 3 # TODO: not implemented, 2 is default + num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE + drop_frame_dim: False # Requires num_frames=1, for backwards compatibility + sdoml: + base_directory: "/mnt/sdoml" + sub_directory: + hmi: "HMI.zarr" + aia: "AIA.zarr" + eve: "EVE_legacy.zarr" + cache: "cache" + components: null # null for select all magnetic components ["Bx", "By", "Bz"] + wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] + ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] + frequency: '12min' # smallest is 12min + mask_with_hmi_threshold: null # None/null for no mask, float for threshold + +# model configurations +model: + # PRETRAINERS + mae: + img_size: 512 + patch_size: 16 + num_frames: 1 + tubelet_size: 1 + in_chans: 9 + embed_dim: 512 + depth: 24 + num_heads: 16 + decoder_embed_dim: 512 + decoder_depth: 8 + decoder_num_heads: 16 + mlp_ratio: 4.0 + norm_layer: 'LayerNorm' + norm_pix_loss: False + masking_ratio: 0.5 + samae: + # uses all parameters as in mae plus these + masking_type: "random" # 'random' or 'solar_aware' + active_region_mu_degs: 15.73 + active_region_std_degs: 6.14 + active_region_scale: 1.0 + active_region_abs_lon_max_degs: 60 + active_region_abs_lat_max_degs: 60 + nvae: + use_se: true + res_dist: true + num_x_bits: 8 + num_latent_scales: 3 # 5 + num_groups_per_scale: 1 # 16 + num_latent_per_group: 1 # 10 + ada_groups: true + min_groups_per_scale: 1 + num_channels_enc: 30 + num_channels_dec: 30 + num_preprocess_blocks: 2 # 1 + num_preprocess_cells: 2 + num_cell_per_cond_enc: 2 + num_postprocess_blocks: 2 # 1 + num_postprocess_cells: 2 + num_cell_per_cond_dec: 2 + num_mixture_dec: 1 + num_nf: 2 + kl_anneal_portion: 0.3 + kl_const_portion: 0.0001 + kl_const_coeff: 0.0001 + # learning_rate: 1e-2 + # weight_decay: 3e-4 + weight_decay_norm_anneal: true + weight_decay_norm_init: 1. + weight_decay_norm: 1e-2 + + # FINE-TUNERS + autocalibration: + num_neck_filters: 32 + output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar + loss: "mse" # options: "mse", "heteroscedastic" + freeze_encoder: true + virtualeve: + num_neck_filters: 32 + cnn_model: "efficientnet_b3" + lr_linear: 0.01 + lr_cnn: 0.0001 + cnn_dp: 0.75 + epochs_linear: 20 + + # ML optimization arguments: + opt: + loss: "mse" # options: "mae", "mse", "mape" + scheduler: "constant" #other options: "cosine", "plateau", "exp" + scheduler_warmup: 0 + batch_size: 16 + learning_rate: 0.0001 + weight_decay: 3e-4 # 0.0 + optimiser: "adam" + epochs: 50 + patience: 2 + +# hydra configuration +hydra: + mode: RUN + # run: + # dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} + # sweep: + # dir: ${hydra.run.dir} + # subdir: ${hydra.job.num} \ No newline at end of file diff --git a/experiments/pretrain_brightspots.yaml b/experiments/pretrain_brightspots.yaml new file mode 100755 index 0000000..5c566f4 --- /dev/null +++ b/experiments/pretrain_brightspots.yaml @@ -0,0 +1,163 @@ +# default.yaml + +# MODEL SUMMARY +# | Name | Type | Params +# ------------------------------------------------------- +# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M +# ------------------------------------------------------- +# 329 M Trainable params +# 4.7 M Non-trainable params +# 333 M Total params +# 1,335.838 Total estimated model params size (MB) + +# general +log_level: 'DEBUG' +experiment: + name: null # generate random name in wandb + project: "sdofm" + task: "pretrain" # options: train, evaluate (not implemented) + model: "brightspots" + backbone_checkpoint: null + resuming: false + seed: 0 + disable_cuda: false + wandb: + enable: true + entity: "fdlx" + group: "sdofm-phase1" + job_type: "pretrain" + tags: [] + notes: "" + output_directory: "wandb_output" + log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches) + gcp_storage: # this will checkpoint all epoches, perhaps clean up this config + enabled: true + bucket: "sdofm-checkpoints" + fold: null + evaluate: false # skip training and only evaluate (requires checkpoint to be set) + checkpoint: null # this is the wandb run_id of the checkpoint to load + device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() + precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu + log_n_batches: 1000 # log every n training batches + save_results: true # save full results to file and wandb + accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" + profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size + distributed: + enabled: true + world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. + strategy: "ddp" + log_every_n_steps: 5 + +# dataset configuration +data: + min_date: '2011-01-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08' + max_date: '2011-12-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072' + month_splits: # non selected months will form training set + # train: [1,2,3,4,5,6,7,8,9,10] + val: [11] + test: [12] + holdout: [] + num_workers: 24 # set appropriately for your machine + prefetch_factor: 3 + num_frames: 5 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE + # output_directory: "wandb_output" + sdoml: + base_directory: "/mnt/sdoml" + sub_directory: + hmi: "HMI.zarr" + aia: "AIA.zarr" + eve: "EVE_legacy.zarr" + cache: "cache" + components: null # null for select all magnetic components ["Bx", "By", "Bz"] + wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] + ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] + frequency: '12min' # smallest is 12min + mask_with_hmi_threshold: null # None/null for no mask, float for threshold + +# model configurations +model: + # PRETRAINERS + mae: + img_size: 512 + patch_size: 16 + num_frames: 5 + tubelet_size: 1 + in_chans: 9 + embed_dim: 128 + depth: 24 + num_heads: 16 + decoder_embed_dim: 512 + decoder_depth: 8 + decoder_num_heads: 16 + mlp_ratio: 4.0 + norm_layer: 'LayerNorm' + norm_pix_loss: False + samae: + # uses all parameters as in mae plus these + masking_type: "random" # 'random' or 'solar_aware' + active_region_mu_degs: 15.73 + active_region_std_degs: 6.14 + active_region_scale: 1.0 + active_region_abs_lon_max_degs: 60 + active_region_abs_lat_max_degs: 60 + nvae: + use_se: true + res_dist: true + num_x_bits: 8 + num_latent_scales: 3 # 5 + num_groups_per_scale: 1 # 16 + num_latent_per_group: 1 # 10 + ada_groups: true + min_groups_per_scale: 1 + num_channels_enc: 30 + num_channels_dec: 30 + num_preprocess_blocks: 2 # 1 + num_preprocess_cells: 2 + num_cell_per_cond_enc: 2 + num_postprocess_blocks: 2 # 1 + num_postprocess_cells: 2 + num_cell_per_cond_dec: 2 + num_mixture_dec: 1 + num_nf: 2 + kl_anneal_portion: 0.3 + kl_const_portion: 0.0001 + kl_const_coeff: 0.0001 + # learning_rate: 1e-2 + # weight_decay: 3e-4 + weight_decay_norm_anneal: true + weight_decay_norm_init: 1. + weight_decay_norm: 1e-2 + brightspots: + n_channels: 12 + n_classes: 1 + bilinear: true + use_embeddings_block: true + size_factor: 4 + + # FINE-TUNERS + degragation: + num_neck_filters: 32 + output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar + loss: "mse" # options: "mse", "heteroscedastic" + freeze_encoder: true + + # ML optimization arguments: + opt: + loss: "mse" # options: "mae", "mse", "mape" + scheduler: "constant" #other options: "cosine", "plateau", "exp" + scheduler_warmup: 0 + batch_size: 16 + learning_rate: 0.0001 + weight_decay: 3e-4 # 0.0 + optimiser: "adam" + epochs: 100 + patience: 2 + +# hydra configuration +hydra: + mode: RUN + # run: + # dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} + # sweep: + # dir: ${hydra.run.dir} + # subdir: ${hydra.job.num} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9729a81..7647f49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,7 @@ sunpy matplotlib overrides tensorrt -google-cloud-storage \ No newline at end of file +google-cloud-storage +rlxutils +blosc +loguru \ No newline at end of file diff --git a/scripts/pretrain.py b/scripts/pretrain.py index 55cebbd..a12964f 100755 --- a/scripts/pretrain.py +++ b/scripts/pretrain.py @@ -11,8 +11,8 @@ import wandb from sdofm import utils -from sdofm.datasets import SDOMLDataModule -from sdofm.pretraining import MAE, NVAE, SAMAE +from sdofm.datasets import SDOMLDataModule, BrightSpotsSDOMLDataModule +from sdofm.pretraining import MAE, NVAE, SAMAE, BrightSpots class Pretrainer(object): @@ -116,6 +116,50 @@ def __init__(self, cfg, logger=None, profiler=None, is_backbone=False): weight_decay=cfg.model.opt.weight_decay, ) + case "brightspots": + self.model_class = BrightSpots + self.data_module = BrightSpotsSDOMLDataModule( + hmi_path=os.path.join( + cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi + ), + aia_path=os.path.join( + cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia + ), + eve_path=None, + components=cfg.data.sdoml.components, + wavelengths=cfg.data.sdoml.wavelengths, + ions=cfg.data.sdoml.ions, + frequency=cfg.data.sdoml.frequency, + batch_size=cfg.model.opt.batch_size, + num_workers=cfg.data.num_workers, + blosc_cache = "/home/walsh/blosc_cache", + val_months=cfg.data.month_splits.val, + test_months=cfg.data.month_splits.test, + holdout_months=cfg.data.month_splits.holdout, + cache_dir=os.path.join( + cfg.data.sdoml.base_directory, + cfg.data.sdoml.sub_directory.cache, + ), + min_date=cfg.data.min_date, + max_date=cfg.data.max_date, + num_frames=cfg.model.mae.num_frames, + ) + self.data_module.setup() + + if cfg.experiment.resuming or is_backbone: + self.model = self.load_checkpoint( + cfg.experiment.checkpoint + if not is_backbone + else cfg.experiment.backbone.checkpoint + ) + else: + self.model = self.model_class( + **cfg.model.brightspots, + optimiser=cfg.model.opt.optimiser, + lr=cfg.model.opt.learning_rate, + weight_decay=cfg.model.opt.weight_decay, + ) + case "nvae": self.model_class = NVAE diff --git a/sdofm/datasets/BrightSpotsSDOML.py b/sdofm/datasets/BrightSpotsSDOML.py index a21383d..31403a9 100644 --- a/sdofm/datasets/BrightSpotsSDOML.py +++ b/sdofm/datasets/BrightSpotsSDOML.py @@ -30,20 +30,20 @@ class BrightSpotsSDOMLDataModule(SDOMLDataModule): def __init__(self, blosc_cache=None, - start_date=None, - end_date=None, + # start_date=None, + # end_date=None, *args, **kwargs): super().__init__(*args, **kwargs) self.blosc_cache = blosc_cache - self.start_date = start_date - self.end_date = end_date + # self.start_date = start_date + # self.end_date = end_date - if start_date is not None: - self.aligndata = self.aligndata[self.start_date:] + # if start_date is not None: + # self.aligndata = self.aligndata[self.start_date:] - if end_date is not None: - self.aligndata = self.aligndata[:self.end_date] + # if end_date is not None: + # self.aligndata = self.aligndata[:self.end_date] def setup(self, stage=None): diff --git a/sdofm/datasets/__init__.py b/sdofm/datasets/__init__.py index 401d4aa..885859b 100755 --- a/sdofm/datasets/__init__.py +++ b/sdofm/datasets/__init__.py @@ -1,3 +1,5 @@ from .DegradedSDOML import DegradedSDOMLDataModule from .SDOML import SDOMLDataModule from .SynopticSDOML import SynopticSDOMLDataModule +from .BrightSpotsSDOML import BrightSpotsSDOMLDataModule +from .RandomIntervalSDOML import RandomIntervalSDOMLDataModule \ No newline at end of file diff --git a/sdofm/finetuning/VirtualEVE.py b/sdofm/finetuning/VirtualEVE_bMSE.py similarity index 100% rename from sdofm/finetuning/VirtualEVE.py rename to sdofm/finetuning/VirtualEVE_bMSE.py diff --git a/sdofm/finetuning/VirtualEVE_bUNet.py b/sdofm/finetuning/VirtualEVE_bUNet.py new file mode 100644 index 0000000..4209bcd --- /dev/null +++ b/sdofm/finetuning/VirtualEVE_bUNet.py @@ -0,0 +1,89 @@ +# Adapted from https://github.com/FrontierDevelopmentLab/2023-FDL-X-ARD-EVE + +import sys +import lightning.pytorch as pl +import torch +import torch.nn as nn +import torchvision +from torch.nn import HuberLoss + +from ..BaseModule import BaseModule +from ..models import ( + Autocalibration13Head, + ConvTransformerTokensToEmbeddingNeck, + PrithviEncoder, + HybridIrradianceModel, +) + + +class VirtualEVE(BaseModule): + + def __init__( + self, + # Backbone parameters + # img_size: int = 512, + # patch_size: int = 16, + # embed_dim: int = 128, + # num_frames: int = 5, + # Neck parameters + num_neck_filters: int = 32, + # Head parameters + # d_input=None, + cnn_model: str = "efficientnet_b3", + lr_linear: float = 0.01, + lr_cnn: float = 0.0001, + cnn_dp: float = 0.75, + epochs_linear: int = 50, + d_output=None, + eve_norm=None, + # for finetuning + backbone: object = None, + freeze_encoder: bool = True, + # all else + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.encoder = backbone + + if freeze_encoder: + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + + # HEAD + self.head = HybridIrradianceModel( + # virtual eve + d_input=num_neck_filters, + d_output=d_output, + eve_norm=eve_norm, + # from config + cnn_model=cnn_model, + lr_linear=lr_linear, + lr_cnn=lr_cnn, + cnn_dp=cnn_dp, + epochs_linear=epochs_linear, + # general - might need to be ported over correctly + # optimiser=self.cfg.model.opt.optimiser, + # lr=self.cfg.model.opt.learning_rate, + # weight_decay=self.cfg.model.opt.weight_decay, + ) + + def training_step(self, batch, batch_idx): + image_stack, eve = batch + x = self.encoder.forward_encode(image_stack) # imgs[:, :9, :, :, :] + embeddings = self.encoder.forward_from_embeddings(x).reshape(-1) + y_hat = self.head(embeddings) + loss = self.head.loss_func(y_hat, eve[:, :38]) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + image_stack, eve = batch + x = self.encoder.forward_encode(image_stack) # imgs[:, :9, :, :, :] + embeddings = self.encoder.forward_from_embeddings(x).reshape(-1) + y_hat = self.head(embeddings) + loss = self.head.loss_func(y_hat, eve[:, :38]) + self.log("val_loss", loss) + return loss diff --git a/sdofm/models/__init__.py b/sdofm/models/__init__.py index 9b39ea5..79be1a9 100755 --- a/sdofm/models/__init__.py +++ b/sdofm/models/__init__.py @@ -6,3 +6,4 @@ from .prithvi_encoders import * from .samae3d import SolarAwareMaskedAutoencoderViT3D from .virtualeve import HybridIrradianceModel +from .unet import UNet \ No newline at end of file diff --git a/sdofm/pretraining/BrightSpots.py b/sdofm/pretraining/BrightSpots.py new file mode 100644 index 0000000..af029c0 --- /dev/null +++ b/sdofm/pretraining/BrightSpots.py @@ -0,0 +1,44 @@ +import lightning.pytorch as pl +import torch.nn.functional as F +import torch +from ..BaseModule import BaseModule +from ..models import MaskedAutoencoderViT3D +from ..benchmarks import reconstruction as bench_recon +from sdofm.constants import ALL_WAVELENGTHS + +from sdofm.models import UNet + + +class BrightSpots(BaseModule): + def __init__( + self, + # backbone specific + n_channels: int=12, + n_classes: int= 1, + bilinear: bool=True, + use_embeddings_block: bool=True, + size_factor:int= 4, + # pass to BaseModule + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.n_channels = n_channels + + self.model = UNet( + n_channels=n_channels, n_classes=n_classes, bilinear=bilinear, use_embeddings_block=use_embeddings_block, size_factor=size_factor + ) + + def training_step(self, batch, batch_idx): + image_stack, bright_spots = batch['image_stack'], batch['bright_spots'] + y_hat = self.model.forward(image_stack).repeat_interleave(self.n_channels, 1) + loss = torch.sqrt(F.mse_loss(y_hat, bright_spots)) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + image_stack, bright_spots = batch['image_stack'], batch['bright_spots'] + x_hat = self.model.forward(image_stack).repeat_interleave(self.n_channels, 1) + loss = torch.sqrt(F.mse_loss(x_hat, bright_spots), sync_dist=True) + self.log("val_loss", loss) + diff --git a/sdofm/pretraining/__init__.py b/sdofm/pretraining/__init__.py index 5d5db52..5c650f9 100644 --- a/sdofm/pretraining/__init__.py +++ b/sdofm/pretraining/__init__.py @@ -1,3 +1,4 @@ from .MAE import MAE from .NVAE import NVAE from .SAMAE import SAMAE +from .BrightSpots import BrightSpots \ No newline at end of file