-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added preliminary embedding export script
- Loading branch information
1 parent
b0f5b3a
commit 1628693
Showing
7 changed files
with
435 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
wandb | ||
output | ||
outputs | ||
*.tar | ||
|
||
# aux directories | ||
.vscode | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# default.yaml | ||
|
||
# MODEL SUMMARY | ||
# | Name | Type | Params | ||
# ------------------------------------------------------- | ||
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M | ||
# ------------------------------------------------------- | ||
# 329 M Trainable params | ||
# 4.7 M Non-trainable params | ||
# 333 M Total params | ||
# 1,335.838 Total estimated model params size (MB) | ||
|
||
# general | ||
log_level: 'DEBUG' | ||
experiment: | ||
name: "mae-2011" # generate random name in wandb | ||
project: "sdofm" | ||
task: "pretrain" # options: train, evaluate (not implemented) | ||
model: "mae" | ||
backbone_checkpoint: null | ||
resuming: false | ||
seed: 0 | ||
disable_cuda: false | ||
wandb: | ||
enable: true | ||
entity: "fdlx" | ||
group: "sdofm-phase1" | ||
job_type: "pretrain" | ||
tags: [] | ||
notes: "" | ||
output_directory: "wandb_output" | ||
log_model: "all" # can be True (final checkpoint), False (no checkpointing), or "all" (for all epoches) | ||
gcp_storage: # this will checkpoint all epoches, perhaps clean up this config | ||
enabled: true | ||
bucket: "sdofm-checkpoints" | ||
fold: null | ||
evaluate: false # skip training and only evaluate (requires checkpoint to be set) | ||
checkpoint: null # this is the wandb run_id of the checkpoint to load | ||
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() | ||
precision: 'bf16-true' # (32, 64) for cuda, ('32-true', '16-true', 'bf16-true') for tpu | ||
log_n_batches: 1000 # log every n training batches | ||
save_results: true # save full results to file and wandb | ||
accelerator: "tpu" # options are "auto", "gpu", "tpu", "ipu", or "cpu" | ||
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size | ||
distributed: | ||
enabled: true | ||
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators. | ||
strategy: "auto" | ||
log_every_n_steps: 25 | ||
|
||
# dataset configuration | ||
data: | ||
min_date: '2011-01-01 00:00:00.00' # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08' | ||
max_date: '2011-12-31 23:59:59.99' # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072' | ||
month_splits: # non selected months will form training set | ||
# train: [1,2,3,4,5,6,7,8,9,10] | ||
val: [11] | ||
test: [12] | ||
holdout: [] | ||
num_workers: 8 # set appropriately for your machine | ||
prefetch_factor: 2 | ||
num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE | ||
# output_directory: "wandb_output" | ||
sdoml: | ||
base_directory: "/mnt/sdoml" | ||
sub_directory: | ||
hmi: "HMI.zarr" | ||
aia: "AIA.zarr" | ||
eve: "EVE_legacy.zarr" | ||
cache: "cache" | ||
components: null # null for select all magnetic components ["Bx", "By", "Bz"] | ||
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] | ||
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] | ||
frequency: '12min' # smallest is 12min | ||
mask_with_hmi_threshold: null # None/null for no mask, float for threshold | ||
drop_frame_dim: false | ||
|
||
# model configurations | ||
model: | ||
# PRETRAINERS | ||
mae: | ||
img_size: 512 | ||
patch_size: 16 | ||
num_frames: 1 | ||
tubelet_size: 1 | ||
in_chans: 9 | ||
embed_dim: 2048 | ||
depth: 24 | ||
num_heads: 16 | ||
decoder_embed_dim: 512 | ||
decoder_depth: 8 | ||
decoder_num_heads: 16 | ||
mlp_ratio: 4.0 | ||
norm_layer: 'LayerNorm' | ||
norm_pix_loss: False | ||
masking_ratio: 0.5 | ||
samae: | ||
# uses all parameters as in mae plus these | ||
masking_type: "random" # 'random' or 'solar_aware' | ||
active_region_mu_degs: 15.73 | ||
active_region_std_degs: 6.14 | ||
active_region_scale: 1.0 | ||
active_region_abs_lon_max_degs: 60 | ||
active_region_abs_lat_max_degs: 60 | ||
nvae: | ||
use_se: true | ||
res_dist: true | ||
num_x_bits: 8 | ||
num_latent_scales: 3 # 5 | ||
num_groups_per_scale: 1 # 16 | ||
num_latent_per_group: 1 # 10 | ||
ada_groups: true | ||
min_groups_per_scale: 1 | ||
num_channels_enc: 30 | ||
num_channels_dec: 30 | ||
num_preprocess_blocks: 2 # 1 | ||
num_preprocess_cells: 2 | ||
num_cell_per_cond_enc: 2 | ||
num_postprocess_blocks: 2 # 1 | ||
num_postprocess_cells: 2 | ||
num_cell_per_cond_dec: 2 | ||
num_mixture_dec: 1 | ||
num_nf: 2 | ||
kl_anneal_portion: 0.3 | ||
kl_const_portion: 0.0001 | ||
kl_const_coeff: 0.0001 | ||
# learning_rate: 1e-2 | ||
# weight_decay: 3e-4 | ||
weight_decay_norm_anneal: true | ||
weight_decay_norm_init: 1. | ||
weight_decay_norm: 1e-2 | ||
|
||
# FINE-TUNERS | ||
degragation: | ||
num_neck_filters: 32 | ||
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar | ||
loss: "mse" # options: "mse", "heteroscedastic" | ||
freeze_encoder: true | ||
|
||
# ML optimization arguments: | ||
opt: | ||
loss: "mse" # options: "mae", "mse", "mape" | ||
scheduler: "constant" #other options: "cosine", "plateau", "exp" | ||
scheduler_warmup: 0 | ||
batch_size: 2 | ||
learning_rate: 0.0001 | ||
weight_decay: 3e-4 # 0.0 | ||
optimiser: "adam" | ||
epochs: 100 | ||
patience: 2 | ||
|
||
# hydra configuration | ||
hydra: | ||
mode: RUN | ||
# run: | ||
# dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} | ||
# sweep: | ||
# dir: ${hydra.run.dir} | ||
# subdir: ${hydra.job.num} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import argparse | ||
import codecs | ||
import datetime | ||
import os | ||
import pprint | ||
import shutil | ||
import sys | ||
import time | ||
import warnings | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import omegaconf | ||
import torch | ||
import webdataset as wds | ||
from tqdm import tqdm | ||
|
||
from sdofm.datasets import TimestampedSDOMLDataModule | ||
from sdofm.models import ConvTransformerTokensToEmbeddingNeck | ||
from sdofm.pretraining import MAE, NVAE | ||
from sdofm.utils import days_hours_mins_secs_str | ||
|
||
|
||
def main(): | ||
# fmt: off | ||
parser=argparse.ArgumentParser(description="SDO Latent Dataset generation script", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) | ||
parser.add_argument("--experiment", "-e", help="Training configuration file", default="../experiments/pretrain", type=str, ) | ||
parser.add_argument("--model", "-m", help="Pretrained model filename", required=True, type=str) | ||
parser.add_argument("--src_id", "-i", help="Source Identify, e.g. W&B run ID", default=None, type=str, ) | ||
parser.add_argument("--out_dir", "-o", help="Output directory to save the dataset", default="", type=str, ) | ||
parser.add_argument("--batch_size", "-b", help="Batch size", default=32, type=int) | ||
parser.add_argument("--num_workers", "-n", help="Number of workers", default=0, type=int) | ||
parser.add_argument("--max", help="Maximum number of samples to process", default=None, type=int) | ||
parser.add_argument("--device", help="Compute device(cpu, cuda:0, cuda:1, etc.)", default="cpu", type=str, ) | ||
opt = parser.parse_args() | ||
# fmt: on | ||
|
||
print("SDO-FM Dataset Export Script, modified from work by") | ||
print("NASA FDL-X 2023 Thermospheric Drag Team") | ||
command_line_arguments = " ".join(sys.argv[1:]) | ||
print("Arguments:\n{}\n".format(command_line_arguments)) | ||
print("Script Config:") | ||
pprint.pprint(vars(opt), depth=2, width=50) | ||
|
||
# Datamodule | ||
print("\nLoading Data:") | ||
cfg = omegaconf.OmegaConf.load(opt.experiment) | ||
|
||
data_module = TimestampedSDOMLDataModule( | ||
# hmi_path=os.path.join( | ||
# cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi | ||
# ), | ||
hmi_path=None, | ||
aia_path=os.path.join( | ||
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia | ||
), | ||
eve_path=None, | ||
components=cfg.data.sdoml.components, | ||
wavelengths=cfg.data.sdoml.wavelengths, | ||
ions=cfg.data.sdoml.ions, | ||
frequency=cfg.data.sdoml.frequency, | ||
batch_size=cfg.model.opt.batch_size, | ||
num_workers=cfg.data.num_workers, | ||
val_months=[], # cfg.data.month_splits.val, | ||
# fmt: off | ||
test_months=[1,2,3,4,5,6,7,8,9,10,11,12,], | ||
# cfg.data.month_splits.test, | ||
# fmt: on | ||
holdout_months=[], # =cfg.data.month_splits.holdout, | ||
cache_dir=os.path.join( | ||
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache | ||
), | ||
min_date=cfg.data.min_date, | ||
max_date=cfg.data.max_date, | ||
num_frames=cfg.data.num_frames, | ||
drop_frame_dim=cfg.data.drop_frame_dim, | ||
) | ||
data_module.setup() | ||
|
||
print("\nLoading model:") | ||
try: | ||
match (cfg.experiment.model): | ||
case "nvae": | ||
model = NVAE.load_from_checkpoint( | ||
opt.model, map_location=torch.device(opt.device) | ||
) | ||
z_dim, z_shapes = model.z_dim, model.z_shapes | ||
latent_dim = z_dim.shape[0] | ||
case "mae": | ||
model = MAE.load_from_checkpoint( | ||
opt.model, map_location=torch.device(opt.device) | ||
) | ||
latent_dim = cfg.model.mae.embed_dim | ||
|
||
num_tokens = 512 // 16 | ||
emb_decoder = ConvTransformerTokensToEmbeddingNeck( | ||
embed_dim=latent_dim, | ||
# output_embed_dim=32, | ||
output_embed_dim=latent_dim, | ||
Hp=num_tokens, | ||
Wp=num_tokens, | ||
drop_cls_token=True, | ||
num_frames=1, | ||
) | ||
case _: | ||
raise ValueError( | ||
f"Model export of {cfg.experiment.model} not yet supported." | ||
) | ||
print("Model ready.") | ||
except Exception as e: | ||
print(f"Could not load model at {opt.model}") | ||
raise e | ||
|
||
# Begin | ||
|
||
output = f"SDOFM-{datetime.datetime.now().strftime('%Y_%m_%d_%H%M')}-embsize_{latent_dim}-period_{data_module.min_date.strftime('%Y_%m_%d_%H%M')}_{data_module.max_date.strftime('%Y_%m_%d_%H%M')}" | ||
if opt.src_id: | ||
output += f"-{opt.src_id}" | ||
output = Path(opt.out_dir) / (output + ".tar") | ||
|
||
print(f"\nBeginning webdataset creation at {output}") | ||
count = 0 | ||
latent_dim_checked = False | ||
|
||
sink = wds.TarWriter(f"{output}.tar") | ||
|
||
dl = data_module.test_dataloader() | ||
pbar = tqdm(total=len(dl)) | ||
for data in dl: | ||
pbar.update(1) | ||
|
||
x, timestamps = data["image_stack"], data["timestamps"][0] | ||
batch_size = x.shape[0] | ||
|
||
match (cfg.experiment.model): | ||
case "nvae": | ||
z = model.encode(x) | ||
case "mae": | ||
x_hat, _, _ = model.autoencoder.forward_encoder(x, mask_ratio=0.0) | ||
x_hat = x_hat.squeeze(dim=2) | ||
z = emb_decoder(x_hat) | ||
|
||
z = z.detach().cpu().numpy() | ||
if not latent_dim_checked: | ||
if latent_dim != z.shape[1]: | ||
warnings.warn( | ||
"Latent dimension mismatch: {} vs {}".format(latent_dim, z.shape[1]) | ||
) | ||
latent_dim_checked = True | ||
|
||
for i in range(batch_size): | ||
sink.write( | ||
{ | ||
"__key__": timestamps[i], | ||
"emb.pyd": z[i], | ||
} | ||
) | ||
count += 1 | ||
|
||
if opt.max is not None and count >= opt.max: | ||
print("Stopping after {} samples".format(opt.max)) | ||
break | ||
|
||
sink.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
time_start = time.time() | ||
main() | ||
print( | ||
"\nTotal duration: {}".format( | ||
days_hours_mins_secs_str(time.time() - time_start) | ||
) | ||
) | ||
sys.exit(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.