diff --git a/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py index 8fde87463..b471fdc2f 100644 --- a/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py @@ -1,12 +1,10 @@ # pylint: skip-file from __future__ import annotations -if __name__ == "__main__": - import os +from typing import cast +if __name__ == "__main__": from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -22,8 +20,10 @@ EXP_VERSION = f"1.0.33_M3_M3_unet_fm{FM}-256_conv4_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -35,7 +35,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -45,22 +49,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -143,10 +175,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -195,7 +229,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -232,8 +265,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -294,10 +327,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -329,153 +362,153 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "activate_last": True, - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM, FM, FM, FM*2], - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - [FM, FM, FM, FM], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "activate_last": True, + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM * 2], + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM, - "out_channels": 1, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": 1, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, - }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + }, + } - # _parse_spec_and_train() + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py index 098fe2f58..5a82a46d1 100644 --- a/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py @@ -1,6 +1,8 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os @@ -24,8 +26,10 @@ EXP_VERSION = f"1.0.0_M3_M4_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +41,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +55,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +181,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +235,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +271,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +333,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,159 +368,155 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM, FM, FM, FM*2], - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - [FM, FM, FM, FM], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM * 2], + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x21xi3", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py index 5a17fa794..5a90c252c 100644 --- a/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py @@ -1,6 +1,8 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os @@ -24,8 +26,10 @@ EXP_VERSION = f"1.0.14_M3_M5_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +41,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +55,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +181,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +235,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +271,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +333,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,169 +368,162 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM*2, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 2, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py index df88b3b4b..7ede037d8 100644 --- a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py @@ -1,12 +1,10 @@ # pylint: skip-file from __future__ import annotations -if __name__ == "__main__": - import os +from typing import cast +if __name__ == "__main__": from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -24,8 +22,10 @@ EXP_VERSION = f"1.0.15_M3_M6_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +37,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +51,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +177,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +231,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +267,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +329,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,179 +364,170 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*2, FM*2, FM*2, FM*4], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 2, FM * 2, FM * 2, FM * 4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM*4, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 4, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, - }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + }, + } - # _parse_spec_and_train() + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py index 5f5cce385..f9684b900 100644 --- a/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py @@ -1,12 +1,12 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -22,10 +22,12 @@ EXP_NAME = "general_coarsener_loss" TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" - EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4_repro_x2_4xt4_auto_2nodes" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +39,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +53,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +179,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +233,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +269,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +331,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,164 +366,151 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*2, FM*2, FM*2, FM*4], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*4, FM*4, FM*4, FM*8], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*8, FM*8, FM*8, FM*8], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 3}, - }, - { - "@type": "torch.nn.Conv2d", - "in_channels": FM*8, - "out_channels": CHANNELS, - "kernel_size": 1, - }, - {"@type": "torch.nn.Tanh"}, - ], + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 2, FM * 2, FM * 2, FM * 4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 4, FM * 4, FM * 4, FM * 8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 8, FM * 8, FM * 8, FM * 8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 3}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 8, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 9866eb564..ae7b18840 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -36,13 +36,12 @@ def distributed_available() -> bool: @builder.register("lightning_train") @typeguard.typechecked def lightning_train( - regime: pl.LightningModule, - trainer: pl.Trainer, - train_dataloader: torch.utils.data.DataLoader, - val_dataloader: Optional[torch.utils.data.DataLoader] = None, + regime: pl.LightningModule | dict[str, Any], + trainer: pl.Trainer | dict[str, Any], + train_dataloader: torch.utils.data.DataLoader | dict[str, Any], + val_dataloader: Optional[torch.utils.data.DataLoader | dict[str, Any]] = None, full_state_ckpt_path: str = "last", num_nodes: int = 1, - nproc_per_node: int = 1, retry_count: int = 3, local_run: bool = True, follow_logs: bool = True, @@ -70,7 +69,6 @@ def lightning_train( than a model checkpoint. If ``full_state_ckpt_path=="last"``, the latest checkpoint for the given experiment will be identified and loaded. :param num_nodes: Number of GPU nodes for distributed training. - :param nproc_per_node: Number of GPU workers per node. :param retry_count: Max retry count for the master train job; excludes failures due to pod distruptions. :param local_run: If True run the training locally. @@ -84,13 +82,31 @@ def lightning_train( :param resource_limits: K8s reource limits per pod. :param resource_requests: K8s resource requests per pod. """ + args_mapping = { + "regime": regime, + "trainer": trainer, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + } if local_run: - _lightning_train_local(regime, trainer, train_dataloader, val_dataloader=val_dataloader) + _lightning_train_local( + regime=regime if not isinstance(regime, dict) else builder.build(regime), + trainer=trainer if not isinstance(trainer, dict) else builder.build(trainer), + train_dataloader=train_dataloader + if not isinstance(train_dataloader, dict) + else builder.build(train_dataloader), + val_dataloader=val_dataloader + if not isinstance(val_dataloader, dict) + else builder.build(val_dataloader), + full_state_ckpt_path=full_state_ckpt_path, + ) return - assert image is not None, "Must provide a container image for remote training." - assert resource_limits is not None, "Must provide resource limits for remote training." + if image is None: + raise ValueError("Must provide a container image for remote training.") + if resource_limits is None: + raise ValueError("Must provide resource limits for remote training.") execution_id = mazepa.id_generation.get_unique_id( prefix="exec", slug_len=4, add_uuid=False, max_len=50 @@ -102,26 +118,36 @@ def lightning_train( cluster_project=cluster_project, ) - train_spec = { - "@type": "lightning_train", - "regime": builder.get_initial_builder_spec(regime), - "trainer": builder.get_initial_builder_spec(trainer), - "train_dataloader": builder.get_initial_builder_spec(train_dataloader), - "val_dataloader": builder.get_initial_builder_spec(val_dataloader), + args_mapping = { + "regime": regime, + "trainer": trainer, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + } + + train_args: dict = { "full_state_ckpt_path": full_state_ckpt_path, } - for _key in ["regime", "trainer", "train_dataloader"]: - assert train_spec[_key] is not None, f"{_key} requires builder compatible spec." + for k, v in args_mapping.items(): + if isinstance(v, dict): + # argument given as spec, use it directly + train_args[k] = v + else: + arg_spec = builder.get_initial_builder_spec(v) + if arg_spec is None: + raise RuntimeError( + f"No builder spec found for {k}. Remote training requires arguments to " + ) + train_args[k] = arg_spec _lightning_train_remote( execution_id, cluster_info=cluster_info, image=image, num_nodes=num_nodes, - nproc_per_node=nproc_per_node, retry_count=retry_count, - train_spec=train_spec, + train_args=train_args, env_vars=env_vars, follow_logs=follow_logs, host_network=num_nodes > 1, @@ -130,9 +156,9 @@ def lightning_train( ) -@builder.register("multinode_train_launch") +@builder.register("_multinode_train_launch") @typeguard.typechecked -def multinode_train_launch( +def _multinode_train_launch( execution_id: str, num_nodes: int, nproc_per_node: int, @@ -152,6 +178,8 @@ def multinode_train_launch( torch_launcher_api.elastic_launch(config, _parse_spec_and_train)() +@builder.register("_lightning_train_local") +@typeguard.typechecked def _lightning_train_local( regime: pl.LightningModule, trainer: pl.Trainer, @@ -184,19 +212,20 @@ def _lightning_train_local( def _parse_spec_and_train(): load_all_modules() - train_spec = None + train_args = None with open(os.environ["ZETTA_RUN_SPEC_PATH"], "r", encoding="utf-8") as f: - train_spec = json.load(f) + train_args = json.load(f) + logger.info(train_args) - regime = builder.build(spec=train_spec["regime"]) - trainer = builder.build(spec=train_spec["trainer"]) - train_dataloader = builder.build(spec=train_spec["train_dataloader"]) + regime = builder.build(spec=train_args["regime"]) + trainer = builder.build(spec=train_args["trainer"]) + train_dataloader = builder.build(spec=train_args["train_dataloader"]) try: - val_dataloader = builder.build(spec=train_spec["val_dataloader"]) + val_dataloader = builder.build(spec=train_args["val_dataloader"]) except KeyError: val_dataloader = None try: - full_state_ckpt_path = builder.build(spec=train_spec["full_state_ckpt_path"]) + full_state_ckpt_path = train_args["full_state_ckpt_path"] except KeyError: full_state_ckpt_path = "last" _lightning_train_local(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) @@ -251,9 +280,8 @@ def _lightning_train_remote( cluster_info: resource_allocation.k8s.ClusterInfo, image: str, num_nodes: int, - nproc_per_node: int, retry_count: int, - train_spec: dict, + train_args: dict, env_vars: Optional[Dict[str, str]] = None, follow_logs: Optional[bool] = False, host_network: Optional[bool] = False, @@ -265,13 +293,31 @@ def _lightning_train_remote( Creates a volume mount for `train.cue` in `/opt/zetta_utils/specs`. Runs the command `zetta run specs/train.cue` on one or more worker pods. """ + if train_args["trainer"]["accelerator"] == "gpu": + num_devices = int(resource_limits["nvidia.com/gpu"]) # type: ignore + trainer_devices = train_args["trainer"]["devices"] + if ( + isinstance(trainer_devices, int) + and trainer_devices != -1 + and trainer_devices != num_devices + ): + raise ValueError( + f"Trainer specification uses {trainer_devices} devices, " + f"while `nvidia.com/gpu` limit is {num_devices}." + ) + else: + raise NotImplementedError() if num_nodes > 1: - train_spec["@type"] = "multinode_train_launch" - train_spec["execution_id"] = execution_id - train_spec["num_nodes"] = num_nodes - train_spec["nproc_per_node"] = nproc_per_node - train_spec["trainer"]["num_nodes"] = num_nodes + train_args["execution_id"] = execution_id + train_args["num_nodes"] = num_nodes + train_args["nproc_per_node"] = num_devices + + train_args["trainer"]["num_nodes"] = num_nodes + train_spec = {"@type": "_multinode_train_launch", **train_args} + else: + train_spec = {"@type": "_lightning_train_local", **train_args} + specs = {"train": train_spec} vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(execution_id, cluster_info, specs) secrets, env_secret_mapping = resource_allocation.k8s.get_secrets_and_mapping( @@ -319,7 +365,9 @@ def _lightning_train_remote( # that remains the same for the duration of training # these pools must have `master-pool=true` taint # not ncessary for single node ddp so it can be scheduled on preemptibles - tolerations=_get_tolerations(role="master" if num_nodes > 1 else "worker"), + tolerations=_get_tolerations( + "worker" + ), # _get_tolerations(role="master" if num_nodes > 1 else "worker"), volumes=volumes, volume_mounts=mounts, resource_requests=resource_requests,