Skip to content

Commit

Permalink
Add NVAE configuration and rename MAE notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Apr 15, 2024
1 parent dafde4a commit 421bb00
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 3 deletions.
45 changes: 42 additions & 3 deletions experiments/default.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# default.yaml

# MODEL SUMMARY
# | Name | Type | Params
# -------------------------------------------------------
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M
# -------------------------------------------------------
# 329 M Trainable params
# 4.7 M Non-trainable params
# 333 M Total params
# 1,335.838 Total estimated model params size (MB)

# general
experiment:
name: "default"
project: "sdofm"
model: "mse"
model: "mae"
task: "train" # options: train, evaluate (not implemented)
seed: 0
disable_cuda: false
Expand Down Expand Up @@ -56,12 +66,13 @@ data:
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"]
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
frequency: '12min' # smallest is 12min
mask_with_hmi_threshold: null # None/null for no mask, float for threshold

# model configurations
model:
# PRETRAINERS
mae:
img_size: 512
img_size: 224
patch_size: 16
num_frames: 3
tubelet_size: 1
Expand All @@ -75,6 +86,34 @@ model:
mlp_ratio: 4.0
# norm_layer: defaults to nn.LayerNorm
norm_pix_loss: False
nvae:
use_se: true
res_dist: true
num_x_bits: 8
num_latent_scales: 3 # 5
num_groups_per_scale: 1 # 16
num_latent_per_group: 1 # 10
ada_groups: true
min_groups_per_scale: 1
num_channels_enc: 30
num_channels_dec: 30
num_preprocess_blocks: 2 # 1
num_preprocess_cells: 2
num_cell_per_cond_enc: 2
num_postprocess_blocks: 2 # 1
num_postprocess_cells: 2
num_cell_per_cond_dec: 2
num_mixture_dec: 1
num_nf: 2
kl_anneal_portion: 0.3
kl_const_portion: 0.0001
kl_const_coeff: 0.0001
# learning_rate: 1e-2
# weight_decay: 3e-4
weight_decay_norm_anneal: true
weight_decay_norm_init: 1.
weight_decay_norm: 1e-2

# FINE-TUNERS
dimming:
num_neck_filters: 32
Expand All @@ -90,7 +129,7 @@ model:
scheduler_warmup: 0
batch_size: 256
learning_rate: 0.0001
weight_decay: 0.0
weight_decay: 3e-4 # 0.0
optimiser: "adam"
epochs: 4
patience: 2
Expand Down
135 changes: 135 additions & 0 deletions experiments/pretrain_tiny_nvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# pretrain_tiny_nvae.yaml


# general
experiment:
name: "pretrain_tiny_nvae"
project: "sdofm"
model: "nvae"
task: "train" # options: train, evaluate (not implemented)
seed: 0
disable_cuda: false
disable_wandb: false
wandb:
entity: "fdlx"
group: "sdofm-phase1"
job_type: "pretrain"
tags: []
notes: ""
fold: null
evaluate: false # skip training and only evaluate (requires checkpoint to be set)
checkpoint: null # this is the wandb run_id of the checkpoint to load
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available()
precision: 64 # 32, 64
log_n_batches: 1000 # log every n training batches
save_results: true # save full results to file and wandb
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
distributed:
enabled: true
backend: "ddp"
# nproc_per_node: 1
# nnodes: 1
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate Accelerator.
# node_rank: 0
# local_rank: 0
# master_addr: "localhost"
# master_port: 12345

# dataset configuration
data:
min_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08'
max_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072'
month_splits: # non selected months will form training set
# train: [1,2,3,4,5,6,7,8,9,10]
val: [11]
test: [12]
holdout: []
num_workers: 16 # set appropriately for your machine
output_directory: "output"
sdoml:
base_directory: "/mnt/sdoml"
sub_directory:
hmi: "HMI.zarr"
aia: "AIA.zarr"
eve: EVE_legacy.zarr"
cache: "cache"
components: null # null for select all magnetic components ["Bx", "By", "Bz"]
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"]
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"]
frequency: '12min' # smallest is 12min
mask_with_hmi_threshold: null # None/null for no mask, float for threshold

# model configurations
model:
# PRETRAINERS
mae:
img_size: 224
patch_size: 16
num_frames: 3
tubelet_size: 1
in_chans: 3
embed_dim: 1024
depth: 24
num_heads: 16
decoder_embed_dim: 512
decoder_depth: 8
decoder_num_heads: 16
mlp_ratio: 4.0
# norm_layer: defaults to nn.LayerNorm
norm_pix_loss: False
nvae:
use_se: true
res_dist: true
num_x_bits: 8
num_latent_scales: 3 # 5
num_groups_per_scale: 1 # 16
num_latent_per_group: 1 # 10
ada_groups: true
min_groups_per_scale: 1
num_channels_enc: 30
num_channels_dec: 30
num_preprocess_blocks: 2 # 1
num_preprocess_cells: 2
num_cell_per_cond_enc: 2
num_postprocess_blocks: 2 # 1
num_postprocess_cells: 2
num_cell_per_cond_dec: 2
num_mixture_dec: 1
num_nf: 2
kl_anneal_portion: 0.3
kl_const_portion: 0.0001
kl_const_coeff: 0.0001
# learning_rate: 1e-2
# weight_decay: 3e-4
weight_decay_norm_anneal: true
weight_decay_norm_init: 1.
weight_decay_norm: 1e-2

# FINE-TUNERS
dimming:
num_neck_filters: 32
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar
loss: "mse" # options: "mse", "heteroscedastic"
freeze_encoder: true


# ML optimization arguments:
opt:
loss: "mse" # options: "mae", "mse", "mape"
scheduler: "constant" #other options: "cosine", "plateau", "exp"
scheduler_warmup: 0
batch_size: 1
learning_rate: 0.0001
weight_decay: 3e-4 # 0.0
optimiser: "adam"
epochs: 4
patience: 2

# hydra configuration
hydra:
mode: MULTIRUN
run:
dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S}
sweep:
dir: ${hydra.run.dir}
subdir: ${hydra.job.num}
File renamed without changes.

0 comments on commit 421bb00

Please sign in to comment.