Skip to content

Commit

Permalink
Added preliminary embedding export script
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Jul 4, 2024
1 parent b0f5b3a commit 1628693
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
wandb
output
outputs
*.tar

# aux directories
.vscode
Expand Down
159 changes: 159 additions & 0 deletions experiments/pretrain_32.2M_mae_tpu_2048.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# default.yaml

# MODEL SUMMARY
# | Name | Type | Params
# -------------------------------------------------------
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M
# -------------------------------------------------------
# 329 M Trainable params
# 4.7 M Non-trainable params
# 333 M Total params
# 1,335.838 Total estimated model params size (MB)

# general
log_level: 'DEBUG'
experiment:
name: "mae-2011" # generate random name in wandb
project: "sdofm"
task: "pretrain" # options: train, evaluate (not implemented)
model: "mae"
backbone_checkpoint: null
resuming: false
seed: 0
disable_cuda: false
wandb:
enable: true
entity: "fdlx"
group: "sdofm-phase1"
job_type: "pretrain"
tags: []
notes: ""
output_directory: "wandb_output"
log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches)
gcp_storage: # this will checkpoint all epoches, perhaps clean up this config
enabled: true
bucket: "sdofm-checkpoints"
fold: null
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
checkpoint: null # this is the wandb run_id of the checkpoint to load
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
log_n_batches: 1000 # log every n training batches
save_results: true # save full results to file and wandb
accelerator: "tpu" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
distributed:
enabled: true
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators.
strategy: "auto"
log_every_n_steps: 25

# dataset configuration
data:
min_date: '2011-01-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08'
max_date: '2011-12-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072'
month_splits: # non selected months will form training set
# train: [1,2,3,4,5,6,7,8,9,10]
val: [11]
test: [12]
holdout: []
num_workers: 8 # set appropriately for your machine
prefetch_factor: 2
num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE
# output_directory: "wandb_output"
sdoml:
base_directory: "/mnt/sdoml"
sub_directory:
hmi: "HMI.zarr"
aia: "AIA.zarr"
eve: "EVE_legacy.zarr"
cache: "cache"
components: null # null for select all magnetic components ["Bx", "By", "Bz"]
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"]
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
frequency: '12min' # smallest is 12min
mask_with_hmi_threshold: null # None/null for no mask, float for threshold
drop_frame_dim: false

# model configurations
model:
# PRETRAINERS
mae:
img_size: 512
patch_size: 16
num_frames: 1
tubelet_size: 1
in_chans: 9
embed_dim: 2048
depth: 24
num_heads: 16
decoder_embed_dim: 512
decoder_depth: 8
decoder_num_heads: 16
mlp_ratio: 4.0
norm_layer: 'LayerNorm'
norm_pix_loss: False
masking_ratio: 0.5
samae:
# uses all parameters as in mae plus these
masking_type: "random" # 'random' or 'solar_aware'
active_region_mu_degs: 15.73
active_region_std_degs: 6.14
active_region_scale: 1.0
active_region_abs_lon_max_degs: 60
active_region_abs_lat_max_degs: 60
nvae:
use_se: true
res_dist: true
num_x_bits: 8
num_latent_scales: 3 # 5
num_groups_per_scale: 1 # 16
num_latent_per_group: 1 # 10
ada_groups: true
min_groups_per_scale: 1
num_channels_enc: 30
num_channels_dec: 30
num_preprocess_blocks: 2 # 1
num_preprocess_cells: 2
num_cell_per_cond_enc: 2
num_postprocess_blocks: 2 # 1
num_postprocess_cells: 2
num_cell_per_cond_dec: 2
num_mixture_dec: 1
num_nf: 2
kl_anneal_portion: 0.3
kl_const_portion: 0.0001
kl_const_coeff: 0.0001
# learning_rate: 1e-2
# weight_decay: 3e-4
weight_decay_norm_anneal: true
weight_decay_norm_init: 1.
weight_decay_norm: 1e-2

# FINE-TUNERS
degragation:
num_neck_filters: 32
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar
loss: "mse" # options: "mse", "heteroscedastic"
freeze_encoder: true

# ML optimization arguments:
opt:
loss: "mse" # options: "mae", "mse", "mape"
scheduler: "constant" #other options: "cosine", "plateau", "exp"
scheduler_warmup: 0
batch_size: 2
learning_rate: 0.0001
weight_decay: 3e-4 # 0.0
optimiser: "adam"
epochs: 100
patience: 2

# hydra configuration
hydra:
mode: RUN
# run:
# dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S}
# sweep:
# dir: ${hydra.run.dir}
# subdir: ${hydra.job.num}
6 changes: 3 additions & 3 deletions experiments/pretrain_nvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ experiment:
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
checkpoint: null # this is the wandb run_id of the checkpoint to load
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
precision: "bf16-true" # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
precision: 32 # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu
log_n_batches: 1000 # log every n training batches
save_results: true # save full results to file and wandb
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
profiler: 'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
distributed:
enabled: true
world_size: 1 # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators.
Expand Down Expand Up @@ -146,7 +146,7 @@ model:
loss: "mse" # options: "mae", "mse", "mape"
scheduler: "constant" #other options: "cosine", "plateau", "exp"
scheduler_warmup: 0
batch_size: 5
batch_size: 2
learning_rate: 0.0001
weight_decay: 3e-4 # 0.0
optimiser: "adam"
Expand Down
175 changes: 175 additions & 0 deletions scripts/export_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import argparse
import codecs
import datetime
import os
import pprint
import shutil
import sys
import time
import warnings
from pathlib import Path

import numpy as np
import omegaconf
import torch
import webdataset as wds
from tqdm import tqdm

from sdofm.datasets import TimestampedSDOMLDataModule
from sdofm.models import ConvTransformerTokensToEmbeddingNeck
from sdofm.pretraining import MAE, NVAE
from sdofm.utils import days_hours_mins_secs_str


def main():
# fmt: off
parser=argparse.ArgumentParser(description="SDO Latent Dataset generation script", formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("--experiment", "-e", help="Training configuration file", default="../experiments/pretrain", type=str, )
parser.add_argument("--model", "-m", help="Pretrained model filename", required=True, type=str)
parser.add_argument("--src_id", "-i", help="Source Identify, e.g. W&B run ID", default=None, type=str, )
parser.add_argument("--out_dir", "-o", help="Output directory to save the dataset", default="", type=str, )
parser.add_argument("--batch_size", "-b", help="Batch size", default=32, type=int)
parser.add_argument("--num_workers", "-n", help="Number of workers", default=0, type=int)
parser.add_argument("--max", help="Maximum number of samples to process", default=None, type=int)
parser.add_argument("--device", help="Compute device(cpu, cuda:0, cuda:1, etc.)", default="cpu", type=str, )
opt = parser.parse_args()
# fmt: on

print("SDO-FM Dataset Export Script, modified from work by")
print("NASA FDL-X 2023 Thermospheric Drag Team")
command_line_arguments = " ".join(sys.argv[1:])
print("Arguments:\n{}\n".format(command_line_arguments))
print("Script Config:")
pprint.pprint(vars(opt), depth=2, width=50)

# Datamodule
print("\nLoading Data:")
cfg = omegaconf.OmegaConf.load(opt.experiment)

data_module = TimestampedSDOMLDataModule(
# hmi_path=os.path.join(
# cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi
# ),
hmi_path=None,
aia_path=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
),
eve_path=None,
components=cfg.data.sdoml.components,
wavelengths=cfg.data.sdoml.wavelengths,
ions=cfg.data.sdoml.ions,
frequency=cfg.data.sdoml.frequency,
batch_size=cfg.model.opt.batch_size,
num_workers=cfg.data.num_workers,
val_months=[], # cfg.data.month_splits.val,
# fmt: off
test_months=[1,2,3,4,5,6,7,8,9,10,11,12,],
# cfg.data.month_splits.test,
# fmt: on
holdout_months=[], # =cfg.data.month_splits.holdout,
cache_dir=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
),
min_date=cfg.data.min_date,
max_date=cfg.data.max_date,
num_frames=cfg.data.num_frames,
drop_frame_dim=cfg.data.drop_frame_dim,
)
data_module.setup()

print("\nLoading model:")
try:
match (cfg.experiment.model):
case "nvae":
model = NVAE.load_from_checkpoint(
opt.model, map_location=torch.device(opt.device)
)
z_dim, z_shapes = model.z_dim, model.z_shapes
latent_dim = z_dim.shape[0]
case "mae":
model = MAE.load_from_checkpoint(
opt.model, map_location=torch.device(opt.device)
)
latent_dim = cfg.model.mae.embed_dim

num_tokens = 512 // 16
emb_decoder = ConvTransformerTokensToEmbeddingNeck(
embed_dim=latent_dim,
# output_embed_dim=32,
output_embed_dim=latent_dim,
Hp=num_tokens,
Wp=num_tokens,
drop_cls_token=True,
num_frames=1,
)
case _:
raise ValueError(
f"Model export of {cfg.experiment.model} not yet supported."
)
print("Model ready.")
except Exception as e:
print(f"Could not load model at {opt.model}")
raise e

# Begin

output = f"SDOFM-{datetime.datetime.now().strftime('%Y_%m_%d_%H%M')}-embsize_{latent_dim}-period_{data_module.min_date.strftime('%Y_%m_%d_%H%M')}_{data_module.max_date.strftime('%Y_%m_%d_%H%M')}"
if opt.src_id:
output += f"-{opt.src_id}"
output = Path(opt.out_dir) / (output + ".tar")

print(f"\nBeginning webdataset creation at {output}")
count = 0
latent_dim_checked = False

sink = wds.TarWriter(f"{output}.tar")

dl = data_module.test_dataloader()
pbar = tqdm(total=len(dl))
for data in dl:
pbar.update(1)

x, timestamps = data["image_stack"], data["timestamps"][0]
batch_size = x.shape[0]

match (cfg.experiment.model):
case "nvae":
z = model.encode(x)
case "mae":
x_hat, _, _ = model.autoencoder.forward_encoder(x, mask_ratio=0.0)
x_hat = x_hat.squeeze(dim=2)
z = emb_decoder(x_hat)

z = z.detach().cpu().numpy()
if not latent_dim_checked:
if latent_dim != z.shape[1]:
warnings.warn(
"Latent dimension mismatch: {} vs {}".format(latent_dim, z.shape[1])
)
latent_dim_checked = True

for i in range(batch_size):
sink.write(
{
"__key__": timestamps[i],
"emb.pyd": z[i],
}
)
count += 1

if opt.max is not None and count >= opt.max:
print("Stopping after {} samples".format(opt.max))
break

sink.close()


if __name__ == "__main__":
time_start = time.time()
main()
print(
"\nTotal duration: {}".format(
days_hours_mins_secs_str(time.time() - time_start)
)
)
sys.exit(0)
4 changes: 3 additions & 1 deletion sdofm/BaseModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def __init__(
weight_decay: float = 0.0,
hyperparam_ignore=[],
# pass to pl.LightningModule
*args,
**kwargs
):
super().__init__()
super().__init__(*args, **kwargs)
self.save_hyperparameters(ignore=hyperparam_ignore)

# optimiser values
Expand Down
Loading

0 comments on commit 1628693

Please sign in to comment.