diff --git a/.dockerignore b/.dockerignore index 314aaa37b..24a81fcb5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -137,3 +137,5 @@ venv.bak/ # Pyre type checker .pyre/ + +src/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 96231eaba..d7506de8d 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,5 @@ dmypy.json # Pyre type checker .pyre/ + +src/ \ No newline at end of file diff --git a/docker/Dockerfile.all.p39 b/docker/Dockerfile.all.p39 index dffa0b187..ef07f66af 100644 --- a/docker/Dockerfile.all.p39 +++ b/docker/Dockerfile.all.p39 @@ -1,10 +1,10 @@ -FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime ENV DEBIAN_FRONTEND="noninteractive" RUN apt-get update \ && apt-get install -y git build-essential wget curl vim ffmpeg libsm6 libxext6 software-properties-common unixodbc-dev \ - && pip install posix-ipc \ + && pip install --no-cache-dir posix-ipc gevent \ && apt-get --purge autoremove -y build-essential \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* \ @@ -22,6 +22,6 @@ RUN apt-get update \ ENV PYTHONPATH /opt/zetta_utils WORKDIR /opt/zetta_utils ADD pyproject.toml /opt/zetta_utils/ -RUN pip install '.[modules]' +RUN pip install --no-cache-dir '.[modules]' COPY . /opt/zetta_utils/ RUN zetta --help diff --git a/pyproject.toml b/pyproject.toml index a497841cb..d87eece91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ keywords = ["neuroscience connectomics EM"] license = { text = "MIT" } urls = { Homepage = "https://github.com/zettaai/zetta_utils" } -requires-python = ">3.8,<3.11" +requires-python = ">3.8,<3.12" dependencies = [ "attrs >= 21.3", "typeguard == 4.1.5", diff --git a/specs/examples/training/ddp.cue b/specs/examples/training/ddp.cue index 8c8af902d..b176000b4 100644 --- a/specs/examples/training/ddp.cue +++ b/specs/examples/training/ddp.cue @@ -40,7 +40,7 @@ regime: { } trainer: { "@type": "ZettaDefaultTrainer" - accelerator: "cuda" + accelerator: "gpu" devices: 1 num_nodes: 32 max_epochs: 100 diff --git a/specs/nico/training/em_encoder/train/01_m3_m3_encoder_dict.py b/specs/nico/training/em_encoder/train/01_m3_m3_encoder_dict.py new file mode 100644 index 000000000..a480b0642 --- /dev/null +++ b/specs/nico/training/em_encoder/train/01_m3_m3_encoder_dict.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + EXP_NAME = "general_encoder_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.33_M3_M3_unet_fm{FM}-256_conv4_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "activate_last": True, + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM*2], + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": 1, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/02_m3_m4_encoder_dict.py b/specs/nico/training/em_encoder/train/02_m3_m4_encoder_dict.py new file mode 100644 index 000000000..a34581810 --- /dev/null +++ b/specs/nico/training/em_encoder/train/02_m3_m4_encoder_dict.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 2 + CHANNELS = 1 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.0_M3_M4_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM*2], + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/02_m3_m5_encoder_dict.py b/specs/nico/training/em_encoder/train/02_m3_m5_encoder_dict.py new file mode 100644 index 000000000..65447cb0a --- /dev/null +++ b/specs/nico/training/em_encoder/train/02_m3_m5_encoder_dict.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 4 + CHANNELS = 1 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.14_M3_M5_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*2, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/02_m3_m6_encoder_dict.py b/specs/nico/training/em_encoder/train/02_m3_m6_encoder_dict.py new file mode 100644 index 000000000..9f296aed6 --- /dev/null +++ b/specs/nico/training/em_encoder/train/02_m3_m6_encoder_dict.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 8 + CHANNELS = 2 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.15_M3_M6_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*2, FM*2, FM*2, FM*4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*4, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/02_m3_m7_encoder_dict.py b/specs/nico/training/em_encoder/train/02_m3_m7_encoder_dict.py new file mode 100644 index 000000000..a37e4176d --- /dev/null +++ b/specs/nico/training/em_encoder/train/02_m3_m7_encoder_dict.py @@ -0,0 +1,493 @@ +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 16 + CHANNELS = 3 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*2, FM*2, FM*2, FM*4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*4, FM*4, FM*4, FM*8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*8, FM*8, FM*8, FM*8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 3}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*8, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 88a96511c..81a2d37d2 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -411,26 +411,12 @@ from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( VolumetricStridedIndexer, ) +from zetta_utils.training.lightning.regimes.alignment.base_coarsener import ( + BaseCoarsenerRegime, +) from zetta_utils.training.lightning.regimes.alignment.base_encoder import ( BaseEncoderRegime, ) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener import ( - EncodingCoarsenerRegime, -) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener_gen_x1 import ( - EncodingCoarsenerGenX1Regime, -) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener_highres import ( - EncodingCoarsenerHighRes, - center_crop_norm, - warp_by_px, -) -from zetta_utils.training.lightning.regimes.alignment.minima_encoder import ( - MinimaEncoderRegime, -) -from zetta_utils.training.lightning.regimes.alignment.misalignment_detector import ( - MisalignmentDetectorRegime, -) from zetta_utils.training.lightning.regimes.alignment.misalignment_detector_aced import ( MisalignmentDetectorAcedRegime, ) diff --git a/zetta_utils/training/lightning/regimes/alignment/__init__.py b/zetta_utils/training/lightning/regimes/alignment/__init__.py index 0d003796c..e645e5759 100644 --- a/zetta_utils/training/lightning/regimes/alignment/__init__.py +++ b/zetta_utils/training/lightning/regimes/alignment/__init__.py @@ -1,6 +1,2 @@ -from . import encoding_coarsener -from . import encoding_coarsener_highres -from . import encoding_coarsener_gen_x1 -from . import base_encoder -from . import misalignment_detector -from . import misalignment_detector_aced +from . import base_encoder, misalignment_detector_aced +from .deprecated import encoding_coarsener diff --git a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py index 157d55b92..176ee8ed7 100644 --- a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py +++ b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py @@ -1,38 +1,50 @@ # pragma: no cover # pylint: disable=too-many-locals +import os +from itertools import combinations +from math import log2 from typing import Optional import attrs +import cc3d import einops +import numpy as np import pytorch_lightning as pl import torch +import torchfields +import wandb +from PIL import Image as PILImage +from pytorch_lightning import seed_everything -from zetta_utils import builder, distributions, tensor_ops +from zetta_utils import builder, distributions, tensor_ops, viz -from ..common import log_results - -@builder.register("BaseEncoderRegime") +@builder.register("BaseEncoderRegime", versions="==0.0.2") @attrs.mutable(eq=False) class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors model: torch.nn.Module lr: float train_log_row_interval: int = 200 val_log_row_interval: int = 25 - field_magn_thr: float = 1 - post_weight: float = 0.5 + max_displacement_px: float = 16.0 + l1_weight_start_val: float = 0.0 + l1_weight_end_val: float = 0.0 + l1_weight_start_epoch: int = 0 + l1_weight_end_epoch: int = 0 + locality_weight: float = 1.0 + similarity_weight: float = 0.0 zero_value: float = 0 - zero_conserve_weight: float = 0.5 + ds_factor: int = 1 worst_val_loss: float = attrs.field(init=False, default=0) worst_val_sample: dict = attrs.field(init=False, factory=dict) worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - equivar_weight: float = 1.0 equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + empty_tissue_threshold: float = 0.4 def __attrs_pre_init__(self): super().__init__() @@ -41,199 +53,351 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer + def log_results(self, mode: str, title_suffix: str = "", **kwargs): + if not self.logger: + return + images = [] + for k, v in kwargs.items(): + for b in range(1): + if v.dtype in (np.uint8, torch.uint8): + img = v[b].squeeze() + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.int8, np.int8): + img = v[b].squeeze().byte() + 127 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.bool, bool): + img = v[b].squeeze().byte() * 255 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + else: + if v.size(1) == 2 and k != "field": + img = torch.cat([v, torch.zeros_like(v[:, :1])], dim=1) + else: + img = v + v_min = img[b].min().round(decimals=4) + v_max = img[b].max().round(decimals=4) + images.append( + wandb.Image( + viz.rendering.Renderer()(img[b].squeeze()), + caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", + ) + ) + + self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) + + def validation_epoch_start(self, _): # pylint: disable=no-self-use + seed_everything(42) + def on_validation_epoch_end(self): - log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is not None: + seed_everything(int(env_seed) + self.current_epoch) + else: + seed_everything(None) def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + return loss - def _get_warped(self, img, field): - img_warped = field.field().from_pixels()(img) - zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 - img_warped[zeros_warped] = 0 - return img_warped, zeros_warped + def _get_warped(self, img, field=None, zero_value=None): + if zero_value is None: + zero_value = self.zero_value + img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=zero_value) + if field is not None: + img_warped = field.from_pixels()(img) + else: + img_warped = img + + zeros_padded = img_padded == zero_value + zeros_padded_cc = np.array( + [ + cc3d.connected_components( + x.detach().squeeze().cpu().numpy(), connectivity=4 + ).reshape(zeros_padded[0].shape) + for x in zeros_padded + ] + ) - def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] + non_tissue_zeros_padded = zeros_padded.clone() + non_tissue_zeros_padded[ + torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device) + ] = False # keep masking resin, restore somas in center + + if field is not None: + zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .sample((~zeros_padded).float(), padding_mode="border") + <= 0.95 + ) + non_tissue_zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .sample((~non_tissue_zeros_padded).float(), padding_mode="border") + <= 0.95 + ) + else: + zeros_warped = zeros_padded + non_tissue_zeros_warped = non_tissue_zeros_padded + + zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1)) + non_tissue_zeros_warped = torch.nn.functional.pad( + non_tissue_zeros_warped, (-1, -1, -1, -1) + ) - if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: - return None + img_warped[zeros_warped] = zero_value + return img_warped, ~zeros_warped, ~non_tissue_zeros_warped - seed_field = batch["field"] - seed_field = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + @staticmethod + def _down_zeros_mask(zeros_mask, count): + if count <= 0: + return zeros_mask + + scale_factor = 0.5**count + return ( + torch.nn.functional.interpolate( + zeros_mask.float(), scale_factor=scale_factor, mode="bilinear" + ) + > 0.99 ) + @staticmethod + def compute_abs_dot_product_channel_combinations(tensor): + if tensor.size(1) == 1: + return torch.zeros_like(tensor) + + channel_combinations = list(combinations(range(tensor.size(1)), 2)) + dot_products = [ + (tensor[:, i : i + 1] * tensor[:, j : j + 1]) for i, j in channel_combinations + ] + dot_product_matrix = torch.hstack(dot_products).abs() + + return dot_product_matrix + + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src_img"] + tgt = batch["images"]["tgt_img"] + + # if ( + # (src == self.zero_value) + (tgt == self.zero_value) + # ).bool().sum() / src.numel() > self.empty_tissue_threshold: + # return None # Can't return None with DDP! + + # Get random field - combination of pregenerated Perlin noise and a random affine transform + seed_field = batch["field"].field_() + f_warp = seed_field * self.max_displacement_px f_aff = ( einops.rearrange( tensor_ops.transform.get_affine_field( size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), + rot_deg=self.equivar_rot_deg_distr() if mode == "train" else 90.0, + scale=self.equivar_scale_distr() if mode == "train" else 1.0, + shear_x_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, + shear_y_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, + trans_x_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, + trans_y_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, ), "C X Y Z -> Z C X Y", ) - .field() # type: ignore - .pixels() + .pixels() # type: ignore[attr-defined] .to(seed_field.device) + ).repeat_interleave(src.size(0), dim=0) + f1_transform = f_aff.from_pixels()(f_warp.from_pixels()).pixels() + + # Warp Images and Tissue mask + src_f1, _, src_nonzeros_f1 = self._get_warped( + src, f1_transform, zero_value=self.zero_value ) - f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) - f2_trans = torch.tensor( - seed_field.field() - .from_pixels()(f1_trans.field().from_pixels()) # type: ignore - .pixels() + tgt_f1, _, tgt_nonzeros_f1 = self._get_warped( + tgt, f1_transform, zero_value=self.zero_value ) - src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) - src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) - tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) + src_zeros_f1 = ~self._down_zeros_mask(src_nonzeros_f1, count=int(log2(self.ds_factor))) + tgt_zeros_f1 = ~self._down_zeros_mask(tgt_nonzeros_f1, count=int(log2(self.ds_factor))) + # Generate encodings: src, src_f1_enc, src_enc_f1, tgt_f1_enc src_enc = self.model(src) - src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore - src_f1_enc = self.model(src_f1) - - equi_diff = (src_enc_f1 - src_f1_enc).abs() - equi_loss = equi_diff[src_zeros_f1 == 0].sum() - equi_diff_map = equi_diff.clone() - equi_diff_map[src_zeros_f1] = 0 - - src_f2_enc = self.model(src_f2) + src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), mode="replicate") + src_enc_f1 = ( + torch.nn.functional.pad(f1_transform, (self.ds_factor,) * 4, mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .down(int(log2(self.ds_factor))) + .sample(src_enc_f1, padding_mode="border") + ) + src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1)) tgt_f1_enc = self.model(tgt_f1) - pre_diff = (src_f1_enc - tgt_f1_enc).abs() + tissue_mask = ~(tgt_zeros_f1 | src_zeros_f1) - pre_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 - ) - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 + crop = 256 // self.ds_factor + src_f1 = src_f1[..., 256:-256, 256:-256] + tgt_f1 = tgt_f1[..., 256:-256, 256:-256] + src_enc_f1 = src_enc_f1[..., crop:-crop, crop:-crop] + tgt_f1_enc = tgt_f1_enc[..., crop:-crop, crop:-crop] + tissue_mask = tissue_mask[..., crop:-crop, crop:-crop] + tissue_mask = tissue_mask[:, :, 1:-1, 1:-1] - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 + # Alignment loss: Ensure even close to local optima solutions produce larger errors + # than the local optimum solution + abs_error_local_opt = ( + (src_enc_f1 - tgt_f1_enc)[:, :, 1:-1, 1:-1].pow(2).sum(dim=-3, keepdim=True) + ) + abs_error_1px_shift = torch.stack( + [ + (src_enc_f1[:, :, 2:, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, :-2, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, 1:-1, 2:] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, 1:-1, :-2] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + # (tgt_f1_enc[:, :, 2:, 1:-1] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + # (tgt_f1_enc[:, :, :-2, 1:-1] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + # (tgt_f1_enc[:, :, 1:-1, 2:] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + # (tgt_f1_enc[:, :, 1:-1, :-2] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + ] + ).sum(dim=-3, keepdim=True) + + # similarity_error_map = abs_error_local_opt + locality_error_map = ((abs_error_local_opt - abs_error_1px_shift + 4.0) * 0.2).pow( + 3.0 + ).sum(dim=0) / len(abs_error_1px_shift) + + # similarity_error_map = similarity_error_map * tissue_mask + locality_error_map = locality_error_map * tissue_mask + + # similarity_loss = similarity_error_map.sum() / similarity_error_map.size(0) + locality_loss = ( + locality_error_map.sum() + / (locality_error_map.size(0) * locality_error_map.size(1)) + * self.ds_factor + * self.ds_factor ) - post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - - post_diff_map = (src_f2_enc - tgt_f1_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - if post_mask.sum() < 256: - return None + l1_loss_map = (tgt_f1_enc.abs() + src_enc_f1.abs())[:, :, 1:-1, 1:-1] + l1_loss = ( + l1_loss_map.sum() + / (2 * tgt_f1_enc.size(0) * tgt_f1_enc.size(1)) + * self.ds_factor + * self.ds_factor + ) - post_loss = post_diff_map[..., post_mask].sum() + l1_weight_ratio = min( + 1.0, + max(0, self.current_epoch - self.l1_weight_start_epoch) + / max(1, self.l1_weight_end_epoch - self.l1_weight_start_epoch), + ) + l1_weight = ( + l1_weight_ratio * self.l1_weight_end_val + + (1.0 - l1_weight_ratio) * self.l1_weight_start_val + ) - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_mask == 0] = 0 + abs_dot_map = ( + self.compute_abs_dot_product_channel_combinations(src_enc_f1)[:, :, 1:-1, 1:-1] + * tissue_mask + ) + abs_dot_loss = abs_dot_map.sum() / (abs_dot_map.size(0) * abs_dot_map.size(1)) + abs_dot_weight = 1.0 if self.current_epoch == 0 else 0.0 - loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + loss = ( + # similarity_loss * self.similarity_weight + locality_loss * self.locality_weight + + l1_loss * l1_weight + ) + self.log( + f"loss/{mode}", loss, on_step=True, on_epoch=True, sync_dist=True, rank_zero_only=False + ) + # self.log( + # f"loss/{mode}_similar", + # similarity_loss, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # sync_dist=True, + # ) + # self.log( + # f"loss/{mode}_locality", + # locality_loss, + # on_step=True, + # on_epoch=True, + # prog_bar=True, sync_dist=True + # ) + self.log( + f"loss/{mode}_l1_weight", + l1_weight, + on_step=False, + on_epoch=True, + prog_bar=False, + sync_dist=False, + rank_zero_only=True, + ) + self.log( + f"loss/{mode}_l1", + l1_loss, + on_step=True, + on_epoch=True, + prog_bar=False, + sync_dist=True, + rank_zero_only=False, + ) + self.log( + f"loss/{mode}_abs_dot", + abs_dot_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + rank_zero_only=False, + ) if log_row: - log_results( + self.log_results( mode, sample_name, src=src, src_enc=src_enc, src_f1=src_f1, src_enc_f1=src_enc_f1, - src_f1_enc=src_f1_enc, - src_f2_enc=src_f2_enc, tgt_f1=tgt_f1, tgt_f1_enc=tgt_f1_enc, - field=torch.tensor(seed_field), - equi_diff_map=equi_diff_map, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def compute_metroem_loss_old( - self, batch: dict, mode: str, log_row: bool, sample_name: str = "" - ): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - field = batch["field"] - - tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) - src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) - - pre_tissue_mask = (src_zeros + tgt_zeros) == 0 - if pre_tissue_mask.sum() / src.numel() < 0.4: - return None - - zero_magns = 0 - tgt_enc = self.model(tgt) - zero_magns += tgt_enc[tgt_zeros].abs().sum() - - src_warped = field.field().from_pixels()(src) - src_warped_enc = self.model(src_warped) - src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 - - zero_magns += src_warped_enc[src_zeros_warped].abs().sum() - - # src_enc = (~(field.field().from_pixels()))(src_warped_enc) - src_enc = self.model(src) - - pre_diff = (src_enc - tgt_enc).abs() - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 - ) - post_magn_mask = field.abs().sum(1) > self.field_magn_thr - - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - post_diff_map = (src_warped_enc - tgt_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_tissue_mask == 0] = 0 - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - loss = pre_loss - post_loss * self.post_weight + zero_magns * self.zero_conserve_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) - if log_row: - log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_warped_enc=src_warped_enc, - tgt=tgt, - tgt_enc=tgt_enc, - field=field, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, + field=f_warp.tensor_(), + tissue_mask=tissue_mask, + # similarity_error_map=similarity_error_map, + locality_error_map=locality_error_map, + l1_loss_map=l1_loss_map, + abs_dot_map=abs_dot_map, + weighted_loss_map=( + # similarity_error_map / similarity_error_map.size(0) * self.similarity_weight + locality_error_map / locality_error_map.size(0) * self.locality_weight + + l1_loss_map / (2 * tgt_f1_enc.size(0)) * l1_weight + + abs_dot_map / (abs_dot_map.size(0) * abs_dot_map.size(1)) * abs_dot_weight + ), ) return loss @@ -241,7 +405,8 @@ def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ log_row = batch_idx % self.val_log_row_interval == 0 sample_name = f"{batch_idx // self.val_log_row_interval}" - loss = self.compute_metroem_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.0.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.0.py new file mode 100644 index 000000000..cf6502d4c --- /dev/null +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.0.py @@ -0,0 +1,247 @@ +# pragma: no cover +# pylint: disable=too-many-locals + +from typing import Optional + +import attrs +import einops +import pytorch_lightning as pl +import torch + +from zetta_utils import builder, distributions, tensor_ops + +from ...common import log_results + + +@builder.register("BaseEncoderRegime", versions="==0.0.0") +@attrs.mutable(eq=False) +class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors + model: torch.nn.Module + lr: float + train_log_row_interval: int = 200 + val_log_row_interval: int = 25 + field_magn_thr: float = 1 + post_weight: float = 0.5 + zero_value: float = 0 + zero_conserve_weight: float = 0.5 + worst_val_loss: float = attrs.field(init=False, default=0) + worst_val_sample: dict = attrs.field(init=False, factory=dict) + worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) + + equivar_weight: float = 1.0 + equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) + equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + + def __attrs_pre_init__(self): + super().__init__() + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def validation_epoch_end(self, _): + log_results( + "val", + "worst", + **self.worst_val_sample, + ) + self.worst_val_loss = 0 + self.worst_val_sample = {} + self.worst_val_sample_idx = None + + def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.train_log_row_interval == 0 + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + return loss + + def _get_warped(self, img, field): + img_warped = field.field().from_pixels()(img) + zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 + img_warped[zeros_warped] = 0 + return img_warped, zeros_warped + + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src"] + tgt = batch["images"]["tgt"] + + if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: + return None + + seed_field = batch["field"] + seed_field = ( + seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + ) + + f_aff = ( + einops.rearrange( + tensor_ops.transform.get_affine_field( + size=src.shape[-1], + rot_deg=self.equivar_rot_deg_distr(), + scale=self.equivar_scale_distr(), + shear_x_deg=self.equivar_shear_deg_distr(), + shear_y_deg=self.equivar_shear_deg_distr(), + trans_x_px=self.equivar_trans_px_distr(), + trans_y_px=self.equivar_trans_px_distr(), + ), + "C X Y Z -> Z C X Y", + ) + .field() # type: ignore + .pixels() + .to(seed_field.device) + ) + f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) + f2_trans = torch.tensor( + seed_field.field() + .from_pixels()(f1_trans.field().from_pixels()) # type: ignore + .pixels() + ) + + src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) + src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) + tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) + + src_enc = self.model(src) + src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore + src_f1_enc = self.model(src_f1) + + equi_diff = (src_enc_f1 - src_f1_enc).abs() + equi_loss = equi_diff[src_zeros_f1 == 0].sum() + equi_diff_map = equi_diff.clone() + equi_diff_map[src_zeros_f1] = 0 + + src_f2_enc = self.model(src_f2) + tgt_f1_enc = self.model(tgt_f1) + + pre_diff = (src_f1_enc - tgt_f1_enc).abs() + + pre_tissue_mask = ( + tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 + ) + pre_loss = pre_diff[..., pre_tissue_mask].sum() + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ( + tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 + ) + + post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr + post_magn_mask[..., 0:10, :] = 0 + post_magn_mask[..., -10:, :] = 0 + post_magn_mask[..., :, 0:10] = 0 + post_magn_mask[..., :, -10:] = 0 + + post_diff_map = (src_f2_enc - tgt_f1_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() + + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_mask == 0] = 0 + + loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + if log_row: + log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_f1=src_f1, + src_enc_f1=src_enc_f1, + src_f1_enc=src_f1_enc, + src_f2_enc=src_f2_enc, + tgt_f1=tgt_f1, + tgt_f1_enc=tgt_f1_enc, + field=torch.tensor(seed_field), + equi_diff_map=equi_diff_map, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def compute_metroem_loss_old( + self, batch: dict, mode: str, log_row: bool, sample_name: str = "" + ): + src = batch["images"]["src"] + tgt = batch["images"]["tgt"] + + field = batch["field"] + + tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) + src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) + + pre_tissue_mask = (src_zeros + tgt_zeros) == 0 + if pre_tissue_mask.sum() / src.numel() < 0.4: + return None + + zero_magns = 0 + tgt_enc = self.model(tgt) + zero_magns += tgt_enc[tgt_zeros].abs().sum() + + src_warped = field.field().from_pixels()(src) + src_warped_enc = self.model(src_warped) + src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 + + zero_magns += src_warped_enc[src_zeros_warped].abs().sum() + + # src_enc = (~(field.field().from_pixels()))(src_warped_enc) + src_enc = self.model(src) + + pre_diff = (src_enc - tgt_enc).abs() + pre_loss = pre_diff[..., pre_tissue_mask].sum() + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ( + tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 + ) + post_magn_mask = field.abs().sum(1) > self.field_magn_thr + + post_magn_mask[..., 0:10, :] = 0 + post_magn_mask[..., -10:, :] = 0 + post_magn_mask[..., :, 0:10] = 0 + post_magn_mask[..., :, -10:] = 0 + post_diff_map = (src_warped_enc - tgt_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_tissue_mask == 0] = 0 + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() + loss = pre_loss - post_loss * self.post_weight + zero_magns * self.zero_conserve_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) + if log_row: + log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_warped_enc=src_warped_enc, + tgt=tgt, + tgt_enc=tgt_enc, + field=field, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.val_log_row_interval == 0 + sample_name = f"{batch_idx // self.val_log_row_interval}" + + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) + return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.1.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.1.py new file mode 100644 index 000000000..6854cf540 --- /dev/null +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder_0.0.1.py @@ -0,0 +1,288 @@ +# pragma: no cover +# pylint: disable=too-many-locals + +from typing import Optional + +import attrs +import cc3d +import einops +import numpy as np +import pytorch_lightning as pl +import torch +import torchfields +import wandb +from PIL import Image as PILImage +from pytorch_lightning import seed_everything + +from zetta_utils import builder, distributions, tensor_ops, viz + + +@builder.register("BaseEncoderRegime", versions="==0.0.1") +@attrs.mutable(eq=False) +class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors + model: torch.nn.Module + lr: float + train_log_row_interval: int = 200 + val_log_row_interval: int = 25 + field_magn_thr: float = 1 + max_displacement_px: float = 16.0 + post_weight: float = 1.5 + zero_value: float = 0 + worst_val_loss: float = attrs.field(init=False, default=0) + worst_val_sample: dict = attrs.field(init=False, factory=dict) + worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) + + equivar_weight: float = 1.0 + equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) + equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + empty_tissue_threshold: float = 0.4 + + def __attrs_pre_init__(self): + super().__init__() + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def log_results(self, mode: str, title_suffix: str = "", **kwargs): + if self.logger is None: + return + images = [] + for k, v in kwargs.items(): + for b in range(1): + if v.dtype in (np.uint8, torch.uint8): + img = v[b].squeeze() + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.int8, np.int8): + img = v[b].squeeze().byte() + 127 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.bool, bool): + img = v[b].squeeze().byte() * 255 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + else: + v_min = v[b].min().round(decimals=4) + v_max = v[b].max().round(decimals=4) + images.append( + wandb.Image( + viz.rendering.Renderer()(v[b].squeeze()), + caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", + ) + ) + + self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) + + def on_validation_epoch_start(self, _): # pylint: disable=no-self-use + seed_everything(42) + + def on_validation_epoch_end(self): + self.log_results( + "val", + "worst", + **self.worst_val_sample, + ) + self.worst_val_loss = 0 + self.worst_val_sample = {} + self.worst_val_sample_idx = None + seed_everything(None) + + def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.train_log_row_interval == 0 + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + + return loss + + def _get_warped(self, img, field=None): + img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value) + if field is not None: + img_warped = field.from_pixels()(img) + else: + img_warped = img + + zeros_padded = img_padded == self.zero_value + zeros_padded_cc = np.array( + [ + cc3d.connected_components( + x.detach().squeeze().cpu().numpy(), connectivity=4 + ).reshape(zeros_padded[0].shape) + for x in zeros_padded + ] + ) + + non_tissue_zeros_padded = zeros_padded.clone() + non_tissue_zeros_padded[ + torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device) + ] = False # keep masking resin, restore somas in center + + if field is not None: + zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() + .sample((~zeros_padded).float(), padding_mode="border") + <= 0.1 + ) + non_tissue_zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() + .sample((~non_tissue_zeros_padded).float(), padding_mode="border") + <= 0.1 + ) + else: + zeros_warped = zeros_padded + non_tissue_zeros_warped = non_tissue_zeros_padded + + zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1)) + non_tissue_zeros_warped = torch.nn.functional.pad( + non_tissue_zeros_warped, (-1, -1, -1, -1) + ) + + img_warped[zeros_warped] = self.zero_value + return img_warped, ~zeros_warped, ~non_tissue_zeros_warped + + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src_img"] + tgt = batch["images"]["tgt_img"] + + if ( + (src == self.zero_value) + (tgt == self.zero_value) + ).bool().sum() / src.numel() > self.empty_tissue_threshold: + return None + + seed_field = batch["field"].field_() + f_warp_large = seed_field * self.max_displacement_px + f_warp_small = ( + seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + ) + + f_aff = ( + einops.rearrange( + tensor_ops.transform.get_affine_field( + size=src.shape[-1], + rot_deg=self.equivar_rot_deg_distr(), + scale=self.equivar_scale_distr(), + shear_x_deg=self.equivar_shear_deg_distr(), + shear_y_deg=self.equivar_shear_deg_distr(), + trans_x_px=self.equivar_trans_px_distr(), + trans_y_px=self.equivar_trans_px_distr(), + ), + "C X Y Z -> Z C X Y", + ) + .pixels() + .to(seed_field.device) + ).repeat_interleave(src.size(0), dim=0) + f1_trans = f_aff.from_pixels()(f_warp_large.from_pixels()).pixels() + f2_trans = f_warp_small.from_pixels()(f1_trans.from_pixels()).pixels() + + magn_field = f_warp_small + + src_f1, _, src_nonzeros_f1 = self._get_warped(src, f1_trans) + src_f2, _, src_nonzeros_f2 = self._get_warped(src, f2_trans) + tgt_f1, _, tgt_nonzeros_f1 = self._get_warped(tgt, f1_trans) + + src_zeros_f1 = ~src_nonzeros_f1 + src_zeros_f2 = ~src_nonzeros_f2 + tgt_zeros_f1 = ~tgt_nonzeros_f1 + + src_enc = self.model(src) + src_f1_enc = self.model(src_f1) + + src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), value=0.0) + src_enc_f1 = ( + torch.nn.functional.pad(f1_trans, (1, 1, 1, 1), mode="replicate") # type: ignore + .from_pixels() + .sample(src_enc_f1, padding_mode="border") + ) + src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1), value=0.0) + + equi_diff = (src_enc_f1 - src_f1_enc).abs() + equi_loss = equi_diff[src_zeros_f1 != 0].sum() + equi_loss = equi_diff.sum() / equi_diff.size(0) + equi_diff_map = equi_diff.clone() + equi_diff_map[src_zeros_f1] = 0 + + src_f2_enc = self.model(src_f2) + tgt_f1_enc = self.model(tgt_f1) + + pre_diff = (src_f1_enc - tgt_f1_enc).abs() + + pre_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f1) + pre_loss = pre_diff[..., pre_tissue_mask].sum() / pre_diff.size(0) + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f2) + post_magn_mask = (magn_field.abs().max(1, keepdim=True)[0] > self.field_magn_thr).tensor_() + + post_diff_map = (src_f2_enc - tgt_f1_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() / post_diff_map.size(0) + + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_mask == 0] = 0 + + loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + if log_row: + self.log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_f1=src_f1, + src_enc_f1=src_enc_f1, + src_f1_enc=src_f1_enc, + src_f2_enc=src_f2_enc, + tgt_f1=tgt_f1, + tgt_f1_enc=tgt_f1_enc, + field=seed_field.tensor_(), + equi_diff_map=equi_diff_map, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.val_log_row_interval == 0 + sample_name = f"{batch_idx // self.val_log_row_interval}" + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) + return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py index 6d350aa2d..d2823c8a6 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py @@ -14,7 +14,7 @@ from zetta_utils.training.lightning.train import distributed_available -@builder.register("EncodingCoarsenerRegime") +@builder.register("EncodingCoarsenerRegime", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerRegime(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py similarity index 98% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py index 34bcaefd8..fdbd5f1fd 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py @@ -13,7 +13,7 @@ from zetta_utils import builder, distributions, tensor_ops, viz -@builder.register("EncodingCoarsenerGenX1Regime") +@builder.register("EncodingCoarsenerGenX1Regime", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerGenX1Regime(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py index 6335ef14c..08a8f495f 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py @@ -60,7 +60,7 @@ def center_crop_norm(image): return crop(norm(image)) -@builder.register("EncodingCoarsenerHighRes") +@builder.register("EncodingCoarsenerHighRes", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerHighRes(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/minima_encoder.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/minima_encoder.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py index a32799476..39d9e2adc 100644 --- a/zetta_utils/training/lightning/regimes/alignment/minima_encoder.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py @@ -12,7 +12,7 @@ from zetta_utils import builder, distributions, tensor_ops, viz -@builder.register("MinimaEncoderRegime") +@builder.register("MinimaEncoderRegime", versions="==0.0.0") @attrs.mutable(eq=False) class MinimaEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors model: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py index 60cba4806..920305a8f 100644 --- a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py @@ -13,7 +13,7 @@ from zetta_utils import builder, convnet, tensor_ops # pylint: disable=unused-import -@builder.register("MisalignmentDetectorRegime") +@builder.register("MisalignmentDetectorRegime", versions="==0.0.0") @attrs.mutable(eq=False) class MisalignmentDetectorRegime(pl.LightningModule): # pylint: disable=too-many-ancestors detector: torch.nn.Module