-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
77fd6ad
commit b485bb2
Showing
2 changed files
with
183 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# default.yaml | ||
|
||
# MODEL SUMMARY | ||
# | Name | Type | Params | ||
# ------------------------------------------------------- | ||
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M | ||
# ------------------------------------------------------- | ||
# 329 M Trainable params | ||
# 4.7 M Non-trainable params | ||
# 333 M Total params | ||
# 1,335.838 Total estimated model params size (MB) | ||
|
||
# general | ||
experiment: | ||
name: "default" | ||
project: "sdofm" | ||
model: "mae" | ||
task: "train" # options: train, evaluate (not implemented) | ||
seed: 0 | ||
disable_cuda: false | ||
disable_wandb: false | ||
wandb: | ||
entity: "fdlx" | ||
group: "sdofm-phase1" | ||
job_type: "pretrain" | ||
tags: [] | ||
notes: "" | ||
fold: null | ||
evaluate: false # skip training and only evaluate (requires checkpoint to be set) | ||
checkpoint: null # this is the wandb run_id of the checkpoint to load | ||
device: null # this is set automatically using the disable_cuda flag and torch.cuda.is_available() | ||
precision: 64 # 32, 64 | ||
log_n_batches: 1000 # log every n training batches | ||
save_results: true # save full results to file and wandb | ||
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu" | ||
distributed: | ||
enabled: true | ||
backend: "ddp" | ||
# nproc_per_node: 1 | ||
# nnodes: 1 | ||
world_size: "auto" # The "auto" option recognizes the machine you are on, and selects the appropriate Accelerator. | ||
# node_rank: 0 | ||
# local_rank: 0 | ||
# master_addr: "localhost" | ||
# master_port: 12345 | ||
|
||
# dataset configuration | ||
data: | ||
min_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # minimum is '2010-09-09 00:00:11.08' | ||
max_date: "0000-00-00 00:00:00" # NOT IMPLEMENTED # maximum is '2023-05-26 06:36:08.072' | ||
month_splits: # non selected months will form training set | ||
# train: [1,2,3,4,5,6,7,8,9,10] | ||
val: [11] | ||
test: [12] | ||
holdout: [] | ||
num_workers: 16 # set appropriately for your machine | ||
output_directory: "output" | ||
sdoml: | ||
base_directory: "/mnt/sdoml" | ||
sub_directory: | ||
hmi: "HMI.zarr" | ||
aia: "AIA.zarr" | ||
eve: EVE_legacy.zarr" | ||
cache: "cache" | ||
components: null # null for select all magnetic components ["Bx", "By", "Bz"] | ||
wavelengths: null # null for select all wavelengths channels ["131A","1600A","1700A","171A","193A","211A","304A","335A","94A"] | ||
ions: null # null to select all ion channels ["C III", "Fe IX", "Fe VIII", "Fe X", "Fe XI", "Fe XII", "Fe XIII", "Fe XIV", "Fe XIX", "Fe XV", "Fe XVI", "Fe XVIII", "Fe XVI_2", "Fe XX", "Fe XX_2", "Fe XX_3", "H I", "H I_2", "H I_3", "He I", "He II", "He II_2", "He I_2", "Mg IX", "Mg X", "Mg X_2", "Ne VII", "Ne VIII", "O II", "O III", "O III_2", "O II_2", "O IV", "O IV_2", "O V", "O VI", "S XIV", "Si XII", "Si XII_2"] | ||
frequency: '12min' # smallest is 12min | ||
mask_with_hmi_threshold: null # None/null for no mask, float for threshold | ||
|
||
# model configurations | ||
model: | ||
# PRETRAINERS | ||
mae: | ||
img_size: 512 | ||
patch_size: 16 | ||
num_frames: 5 | ||
tubelet_size: 5 | ||
in_chans: 9 | ||
embed_dim: 4096 | ||
depth: 24 | ||
num_heads: 16 | ||
decoder_embed_dim: 512 | ||
decoder_depth: 8 | ||
decoder_num_heads: 16 | ||
mlp_ratio: 4.0 | ||
# norm_layer: defaults to nn.LayerNorm | ||
norm_pix_loss: False | ||
samae: | ||
# uses all parameters as in mae plus these | ||
masking_type: "solar_aware" # 'random' or 'solar_aware' | ||
active_region_mu_degs: 15.73 | ||
active_region_std_degs: 6.14 | ||
active_region_scale: 1.0 | ||
active_region_abs_lon_max_degs: 60 | ||
active_region_abs_lat_max_degs: 60 | ||
nvae: | ||
use_se: true | ||
res_dist: true | ||
num_x_bits: 8 | ||
num_latent_scales: 3 # 5 | ||
num_groups_per_scale: 1 # 16 | ||
num_latent_per_group: 1 # 10 | ||
ada_groups: true | ||
min_groups_per_scale: 1 | ||
num_channels_enc: 30 | ||
num_channels_dec: 30 | ||
num_preprocess_blocks: 2 # 1 | ||
num_preprocess_cells: 2 | ||
num_cell_per_cond_enc: 2 | ||
num_postprocess_blocks: 2 # 1 | ||
num_postprocess_cells: 2 | ||
num_cell_per_cond_dec: 2 | ||
num_mixture_dec: 1 | ||
num_nf: 2 | ||
kl_anneal_portion: 0.3 | ||
kl_const_portion: 0.0001 | ||
kl_const_coeff: 0.0001 | ||
# learning_rate: 1e-2 | ||
# weight_decay: 3e-4 | ||
weight_decay_norm_anneal: true | ||
weight_decay_norm_init: 1. | ||
weight_decay_norm: 1e-2 | ||
|
||
# FINE-TUNERS | ||
dimming: | ||
num_neck_filters: 32 | ||
output_dim: 1 # not sure why this is implemented for autocorrelation, should be a scalar | ||
loss: "mse" # options: "mse", "heteroscedastic" | ||
freeze_encoder: true | ||
|
||
|
||
# ML optimization arguments: | ||
opt: | ||
loss: "mse" # options: "mae", "mse", "mape" | ||
scheduler: "constant" #other options: "cosine", "plateau", "exp" | ||
scheduler_warmup: 0 | ||
batch_size: 256 | ||
learning_rate: 0.0001 | ||
weight_decay: 3e-4 # 0.0 | ||
optimiser: "adam" | ||
epochs: 4 | ||
patience: 2 | ||
|
||
# hydra configuration | ||
hydra: | ||
mode: MULTIRUN | ||
run: | ||
dir: ${data.output_directory}/${now:%Y-%m-%d-%H-%M-%S} | ||
sweep: | ||
dir: ${hydra.run.dir} | ||
subdir: ${hydra.job.num} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters