Skip to content

Commit

Permalink
Added current experiment, updated distributed logging, and fixed typos
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Apr 17, 2024
1 parent b485bb2 commit 0d9240f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
experiment:
name: "default"
project: "sdofm"
model: "mae"
task: "train" # options: train, evaluate (not implemented)
model: "samae"
task: "pretrain" # options: train, evaluate (not implemented)
seed: 0
disable_cuda: false
disable_wandb: false
Expand Down Expand Up @@ -74,10 +74,10 @@ model:
mae:
img_size: 512
patch_size: 16
num_frames: 5
tubelet_size: 5
num_frames: 1
tubelet_size: 1
in_chans: 9
embed_dim: 4096
embed_dim: 128
depth: 24
num_heads: 16
decoder_embed_dim: 512
Expand Down Expand Up @@ -135,7 +135,7 @@ model:
loss: "mse" # options: "mae", "mse", "mape"
scheduler: "constant" #other options: "cosine", "plateau", "exp"
scheduler_warmup: 0
batch_size: 256
batch_size: 3
learning_rate: 0.0001
weight_decay: 3e-4 # 0.0
optimiser: "adam"
Expand Down
2 changes: 1 addition & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(cfg: DictConfig) -> None:
config=flatten_dict(cfg),
)

match cfg.experient.task:
match cfg.experiment.task:
case "pretrain":
from scripts.pretrain import Pretrainer

Expand Down
6 changes: 3 additions & 3 deletions scripts/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, cfg, logger):
weight_decay=cfg.model.opt.weight_decay,
)
case "samae":
data_module = SDOMLDataModule(
self.data_module = SDOMLDataModule(
hmi_path=None,
aia_path=os.path.join(
cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
Expand All @@ -74,8 +74,8 @@ def __init__(self, cfg, logger):
cfg.data.sdoml.sub_directory.cache,
),
)
data_module.setup()
model = SAMAE(
self.data_module.setup()
self.model = SAMAE(
**cfg.model.mae,
**cfg.model.samae,
optimiser=cfg.model.opt.optimiser,
Expand Down
4 changes: 2 additions & 2 deletions sdofm/pretraining/SAMAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def training_step(self, batch, batch_idx):
loss, x_hat, mask = self.autoencoder(x)
x_hat = self.autoencoder.unpatchify(x_hat)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
self.log("train_loss", loss, sync_dist=True)
return loss

def validation_step(self, batch, batch_idx):
x = batch
loss, x_hat, mask = self.autoencoder(x)
x_hat = self.autoencoder.unpatchify(x_hat)
loss = F.mse_loss(x_hat, x)
self.log("val_loss", loss)
self.log("val_loss", loss, sync_dist=True)

0 comments on commit 0d9240f

Please sign in to comment.