Skip to content

Commit

Permalink
NVAE TPU compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed Jun 17, 2024
1 parent 3ec606c commit b0f5b3a
Show file tree
Hide file tree
Showing 4 changed files with 727 additions and 13 deletions.
20 changes: 10 additions & 10 deletions experiments/pretrain_nvae.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# default.yaml

# MODEL SUMMARY
# | Name | Type | Params
# -------------------------------------------------------
# 0 | autoencoder | MaskedAutoencoderViT3D | 333 M
# -------------------------------------------------------
# 329 M Trainable params
# 4.7 M Non-trainable params
# 333 M Total params
# 1,335.838 Total estimated model params size (MB)
# | Name | Type | Params | Mode
# ----------------------------------------------------
# 0 | autoencoder | AutoEncoder | 21.4 M | train
# ----------------------------------------------------
# 21.4 M Trainable params
# 2.5 K Non-trainable params
# 21.4 M Total params
# 85.652 Total estimated model params size (MB)

# general
log_level: 'DEBUG'
Expand Down Expand Up @@ -41,7 +40,7 @@ experiment:
log_n_batches: 1000 # log every n training batches
save_results: true # save full results to file and wandb
accelerator: "auto" # options are "auto", "gpu", "tpu", "ipu", or "cpu"
profiler: null #'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
profiler: 'XLAProfiler' # 'XLAProfiler' # options are XLAProfiler/PyTorchProfiler Warning: XLA for TPUs only works on single world size
distributed:
enabled: true
world_size: 1 # The "auto" option recognizes the machine you are on, and selects the appropriate number of accelerators.
Expand All @@ -60,6 +59,7 @@ data:
num_workers: 16 # set appropriately for your machine
prefetch_factor: 3
num_frames: 1 # WARNING: This is only read for FINETUNING, model num_frames overrides in BACKBONE
drop_frame_dim: True
# output_directory: "wandb_output"
sdoml:
base_directory: "/mnt/sdoml"
Expand Down
2 changes: 2 additions & 0 deletions scripts/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def __init__(self, cfg, logger=None, profiler=None, is_backbone=False):
),
min_date=cfg.data.min_date,
max_date=cfg.data.max_date,
num_frames=cfg.data.num_frames,
drop_frame_dim=cfg.data.num_frames,
)

if cfg.experiment.resuming or is_backbone:
Expand Down
6 changes: 3 additions & 3 deletions sdofm/models/nvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def forward(self, x):
def sample(self, num_samples, t):
scale_ind = 0
z0_size = [num_samples] + self.z0_size
device = next(self.parameters()).device
device = next(self.parameters()) #.device
mu = torch.zeros(z0_size).to(device)
log_sigma = torch.zeros(z0_size).to(device)
dist = Normal(mu=mu, log_sigma=log_sigma, temp=t)
Expand Down Expand Up @@ -802,10 +802,10 @@ def spectral_norm_parallel(self):
if i not in self.sr_u:
num_w, row, col = weights[i].shape
self.sr_u[i] = F.normalize(
torch.ones(num_w, row).normal_(0, 1).cuda(), dim=1, eps=1e-3
torch.ones(num_w, row).normal_(0, 1).to(weights[i]), dim=1, eps=1e-3
)
self.sr_v[i] = F.normalize(
torch.ones(num_w, col).normal_(0, 1).cuda(), dim=1, eps=1e-3
torch.ones(num_w, col).normal_(0, 1).to(weights[i]), dim=1, eps=1e-3
)
# increase the number of iterations for the first time
num_iter = 10 * self.num_power_iter
Expand Down
Loading

0 comments on commit b0f5b3a

Please sign in to comment.