From 315a580464390cae871246afb59c5109f83abdea Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Tue, 5 Sep 2023 14:04:19 +0200 Subject: [PATCH] spec: base encoder spec --- .../em_encoder/train/01_m3_m3_encoder.py | 310 ++++++++++++++++++ .../lightning/regimes/alignment/__init__.py | 5 +- .../regimes/alignment/base_encoder.py | 20 +- 3 files changed, 323 insertions(+), 12 deletions(-) create mode 100644 specs/nico/training/em_encoder/train/01_m3_m3_encoder.py diff --git a/specs/nico/training/em_encoder/train/01_m3_m3_encoder.py b/specs/nico/training/em_encoder/train/01_m3_m3_encoder.py new file mode 100644 index 000000000..c72cc3a63 --- /dev/null +++ b/specs/nico/training/em_encoder/train/01_m3_m3_encoder.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import json +import os +from functools import partial + +import torch +from imgaug import augmenters as iaa +from torch.utils.data import DataLoader as TorchDataLoader + +from zetta_utils.api.v0 import * + +POST_WEIGHT = 1.6 +FIELD_MAGN_THR = 0.8 +LR = 1e-4 +EQUI_WEIGHT = 0.5 +EXP_NAME = "general_encoder_debug" +TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + +EXP_VERSION = f"0.0.0_M3_M3_lr{LR}_equi{EQUI_WEIGHT}_fmt{FIELD_MAGN_THR}" + +START_EXP_VERSION = None +MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/base_enc_zfish/{#START_EXP_VERSION}/last.ckpt" + +VALIDATION_SRC_PATH = "gs://zetta-research-nico/pairs_dsets/zfish_x1/img_pairwise/-1" +VALIDATION_TGT_PATH = "gs://zetta-research-nico/pairs_dsets/zfish_x1/img_aligned" + +SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40]}, + "microns_basil": {"contiguous": True, "resolution": [32, 32, 40]}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40]}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40]}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45]}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50]}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40]}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42]}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40]}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45]}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45]}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40]}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40]}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30]}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32]}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40]}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33]}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32]}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32]}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40]}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30]}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40]}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50]}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30]}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32]}, + "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32]}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50]}, + "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50]}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30]}, +} + +BASE_PATH = "gs://zetta-research-nico/encoder/" + + +val_img_aug = [ + partial(rearrange, pattern="c x y 1 -> c x y"), + partial(divide, value=255.0), +] + +train_img_aug = [ + partial(divide, value=255.0), + partial( + square_tile_pattern_aug, + prob=0.5, + tile_size=uniform_distr(64, 1024), + tile_stride=uniform_distr(64, 1024), + max_brightness_change=uniform_distr(0.0, 0.3), + rotation_degree=uniform_distr(0, 90), + preserve_data_val=0.0, + repeats=1, + device="cpu", + ), + partial(torch.clamp, min=0.0, max=1.0), + partial(multiply, value=255.0), + partial(to_uint8), + partial( + imgaug_readproc, + augmenters=[ + iaa.SomeOf( + n=2, + children=[ + iaa.OneOf( + children=[ + iaa.OneOf( + children=[ + iaa.GammaContrast(gamma=(0.5, 2.0)), + iaa.SigmoidContrast(gain=(4, 6), cutoff=(0.3, 0.7)), + iaa.LogContrast(gain=(0.7, 1.3)), + iaa.LinearContrast(alpha=(0.4, 1.6)), + ] + ), + iaa.AllChannelsCLAHE(clip_limit=(0.1, 8.0), tile_grid_size_px=(3, 64)), + ] + ), + iaa.Add((-40, 40)), + iaa.imgcorruptlike.DefocusBlur(severity=(1, 2)), + iaa.Cutout( + squared=False, nb_iterations=1, size=(0.05, 0.8), cval=(0, 255) + ), + iaa.JpegCompression(compression=(0, 35)), + ], + random_order=True, + ) + ], + ), +] + +shared_train_img_aug = [ + partial( + imgaug_readproc, + augmenters=[ + iaa.Sequential( + children=[ + iaa.Rot90(k=[0, 1, 2, 3]), + iaa.Fliplr(p=0.25), + iaa.Flipud(p=0.25), + ], + random_order=True, + ), + ], + ), + partial(rearrange, pattern="c x y 1 -> c x y"), + partial(divide, value=255.0), +] + + +training = JointDataset( + mode="horizontal", + datasets={ + "images": JointDataset( + mode="vertical", + datasets={ + name: LayerDataset( + layer=build_layer_set( + { + "src_img": build_cv_layer( + BASE_PATH + "datasets/" + name, read_procs=train_img_aug + ), + "tgt_img": build_cv_layer( + BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + read_procs=train_img_aug, + ), + }, + readonly=True, + read_procs=shared_train_img_aug, + ), + sample_indexer=VolumetricNGLIndexer( + path="zetta-research-nico/encoder/pairwise_aligned/" + name, + resolution=settings["resolution"], + chunk_size=[1024, 1024, 1], + ), + ) + for name, settings in SOURCE_PATHS.items() + }, + ), + "field": LayerDataset( + layer=build_cv_layer( + "gs://zetta-research-nico/perlin_noise_fields/1px", + read_procs=[ + partial(rearrange, pattern="c x y 1 -> c x y"), + ], + ), + sample_indexer=RandomIndexer( + VolumetricStridedIndexer( + bbox=BBox3D.from_coords( + start_coord=[0, 0, 0], end_coord=[2048, 2048, 2040], resolution=[4, 4, 45] + ), + stride=[128, 128, 1], + chunk_size=[1024, 1024, 1], + resolution=[4, 4, 45], + ) + ), + ), + }, +) + +validation = JointDataset( + mode="horizontal", + datasets={ + "images": LayerDataset( + layer=build_layer_set( + { + "src_img": build_cv_layer(VALIDATION_SRC_PATH, read_procs=val_img_aug), + "tgt_img": build_cv_layer(VALIDATION_TGT_PATH, read_procs=val_img_aug), + }, + readonly=True, + ), + sample_indexer=VolumetricNGLIndexer( + resolution=[32, 32, 30], chunk_size=[1024, 1024, 1], path="nkem/zfish/val" + ), + ), + "field": LayerDataset( + layer=build_cv_layer( + "gs://zetta-research-nico/perlin_noise_fields/1px", + read_procs=[ + partial(rearrange, pattern="c x y 1 -> c x y"), + ], + ), + sample_indexer=RandomIndexer( + VolumetricStridedIndexer( + bbox=BBox3D.from_coords( + start_coord=[0, 0, 0], end_coord=[2048, 2048, 2040], resolution=[4, 4, 45] + ), + stride=[512, 512, 1], + chunk_size=[1024, 1024, 1], + resolution=[4, 4, 45], + ) + ), + ), + }, +) + + +target = partial( + lightning_train, + regime=BaseEncoderRegime( + field_magn_thr=FIELD_MAGN_THR, + post_weight=POST_WEIGHT, + val_log_row_interval=4, + train_log_row_interval=500, + lr=LR, + equivar_weight=EQUI_WEIGHT, + model=load_weights_file( + model=torch.nn.Sequential( + ConvBlock( + num_channels=[1, 32], + kernel_sizes=[1, 1], + activate_last=True, + ), + UNet( + list_num_channels=[ + [32, 32, 32], + [32, 32, 32], + [32, 32, 32], + [32, 32, 32], + [32, 32, 32], + [32, 32, 32], + [32, 32, 32], + ], + downsample=partial(torch.nn.MaxPool2d, kernel_size=2), + upsample=partial( + UpConv, + kernel_size=3, + upsampler=partial( + torch.nn.Upsample, + scale_factor=2, + mode="nearest", + align_corners=None, + ), + conv=partial(torch.nn.Conv2d, padding=1), + ), + activate_last=True, + kernel_sizes=[3, 3], + padding_modes="reflect", + unet_skip_mode="sum", + skips={"0": 2}, + ), + torch.nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1), + torch.nn.Tanh(), + ) + ), + ), + trainer=ZettaDefaultTrainer( + accelerator="gpu", + devices=1, + max_epochs=10, + default_root_dir=TRAINING_ROOT, + experiment_name=EXP_NAME, + experiment_version=EXP_VERSION, + log_every_n_steps=100, + val_check_interval=500, + # track_grad_norm=2, + # gradient_clip_algorithm="norm", + # gradient_clip_val=CLIP, + # detect_anomaly=True, + # overfit_batches=100, + checkpointing_kwargs={"update_every_n_secs": 600, "backup_every_n_secs": 900}, + ), + train_dataloader=TorchDataLoader( + batch_size=1, + shuffle=True, + num_workers=8, + dataset=training, + ), + val_dataloader=TorchDataLoader( + batch_size=1, + shuffle=False, + num_workers=8, + dataset=validation, + ), +) + + +os.environ["ZETTA_RUN_SPEC"] = json.dumps("") + +execute_on_gcp_with_sqs( + target=target, + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230905", + worker_resources={"memory": "25560Mi", "nvidia.com/gpu": "1"}, + worker_replicas=1, + local_test=False, +) diff --git a/zetta_utils/training/lightning/regimes/alignment/__init__.py b/zetta_utils/training/lightning/regimes/alignment/__init__.py index 0d003796c..e07703c46 100644 --- a/zetta_utils/training/lightning/regimes/alignment/__init__.py +++ b/zetta_utils/training/lightning/regimes/alignment/__init__.py @@ -1,6 +1,3 @@ -from . import encoding_coarsener -from . import encoding_coarsener_highres -from . import encoding_coarsener_gen_x1 from . import base_encoder -from . import misalignment_detector +from . import base_coarsener from . import misalignment_detector_aced diff --git a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py index ce7732863..9d6319b12 100644 --- a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py +++ b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py @@ -5,6 +5,7 @@ import attrs import cc3d +import einops import numpy as np import pytorch_lightning as pl import torch @@ -183,14 +184,17 @@ def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_nam ) f_aff = ( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), + einops.rearrange( + tensor_ops.transform.get_affine_field( + size=src.shape[-1], + rot_deg=self.equivar_rot_deg_distr(), + scale=self.equivar_scale_distr(), + shear_x_deg=self.equivar_shear_deg_distr(), + shear_y_deg=self.equivar_shear_deg_distr(), + trans_x_px=self.equivar_trans_px_distr(), + trans_y_px=self.equivar_trans_px_distr(), + ), + "C X Y Z -> Z C X Y", ) .pixels() .to(seed_field.device)