From f47a09abf0a4866290b9b25c993787c3174599f7 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Fri, 27 Oct 2023 18:19:10 +0200 Subject: [PATCH 1/9] fix(ddp): ensure zu modules registered in DDP subprocesses --- zetta_utils/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zetta_utils/__init__.py b/zetta_utils/__init__.py index b8aecdd4e..6eeb91ab2 100644 --- a/zetta_utils/__init__.py +++ b/zetta_utils/__init__.py @@ -61,3 +61,6 @@ def load_training_modules(): ) from .layer import volumetric from .layer.volumetric import cloudvol + + +try_load_train_inference() From 4265285e2b42e11b3f4c1c94c3aad9b88886dfe0 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Tue, 17 Oct 2023 21:05:20 +0200 Subject: [PATCH 2/9] feat(training): custom SamplerWrapper to avoid PL DDP Sampler override --- tests/unit/training/test_sampler.py | 21 +++++++++++++++ zetta_utils/training/__init__.py | 2 +- zetta_utils/training/sampler.py | 42 +++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 tests/unit/training/test_sampler.py create mode 100644 zetta_utils/training/sampler.py diff --git a/tests/unit/training/test_sampler.py b/tests/unit/training/test_sampler.py new file mode 100644 index 000000000..a2579d4b6 --- /dev/null +++ b/tests/unit/training/test_sampler.py @@ -0,0 +1,21 @@ +from lightning_fabric import seed_everything +from torch.utils.data import RandomSampler + +from zetta_utils.training.sampler import SamplerWrapper + + +def test_sampler_wrapper(): + sampler = RandomSampler(list(range(100))) + wrapper = SamplerWrapper(sampler) + + assert len(wrapper) == len(sampler) + + wrapper.set_epoch(0) + seed_everything(42) + epoch_0 = list(wrapper) + seed_everything(42) + assert list(wrapper) == epoch_0 + + wrapper.set_epoch(1) + seed_everything(42) + assert list(wrapper) != epoch_0 diff --git a/zetta_utils/training/__init__.py b/zetta_utils/training/__init__.py index 3cf20deb9..8ffb51793 100644 --- a/zetta_utils/training/__init__.py +++ b/zetta_utils/training/__init__.py @@ -1 +1 @@ -from . import data_loader, datasets, lightning +from . import data_loader, datasets, lightning, sampler diff --git a/zetta_utils/training/sampler.py b/zetta_utils/training/sampler.py new file mode 100644 index 000000000..83b9625cf --- /dev/null +++ b/zetta_utils/training/sampler.py @@ -0,0 +1,42 @@ +from typing import Iterator + +import torch.utils.data + +from zetta_utils import builder + +builder.register("TorchRandomSampler")(torch.utils.data.RandomSampler) + + +# Needed for DDP + RandomSampler to work with pytorch-lightning, which +# overwrites Sequential and RandomSampler with DistributedSampler. +# With the wrapper below, it will apply its own DistributedSamplerWrapper instead. +@builder.register("SamplerWrapper") +class SamplerWrapper(torch.utils.data.Sampler[int]): + sampler: torch.utils.data.Sampler[int] + + def __init__(self, sampler: torch.utils.data.Sampler[int]) -> None: + super().__init__(None) + self.sampler = sampler + self.epoch = 0 + + def __iter__(self) -> Iterator[int]: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed + self.epoch) + self.sampler.generator = generator # type: ignore[attr-defined] + + return iter(self.sampler) + + def __len__(self) -> int: + return len(self.sampler) # type: ignore + + def set_epoch(self, epoch: int) -> None: + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch From 756a57fa7358fe01d571af1779eba032126049f4 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Thu, 17 Aug 2023 14:34:46 +0200 Subject: [PATCH 3/9] specs(training): training data generation for base encoder --- .../em_encoder/preprocess/01_gen_datasets.py | 483 ++++++++++++++++++ .../preprocess/02_pairwise_align.cue | 283 ++++++++++ .../preprocess/02_pairwise_align.py | 259 ++++++++++ .../em_encoder/preprocess/03_enc_misd.py | 278 ++++++++++ .../em_encoder/preprocess/04_mask_empty.py | 125 +++++ .../em_encoder/preprocess/05_export_annos.py | 88 ++++ 6 files changed, 1516 insertions(+) create mode 100644 specs/nico/training/em_encoder/preprocess/01_gen_datasets.py create mode 100644 specs/nico/training/em_encoder/preprocess/02_pairwise_align.cue create mode 100644 specs/nico/training/em_encoder/preprocess/02_pairwise_align.py create mode 100644 specs/nico/training/em_encoder/preprocess/03_enc_misd.py create mode 100644 specs/nico/training/em_encoder/preprocess/04_mask_empty.py create mode 100644 specs/nico/training/em_encoder/preprocess/05_export_annos.py diff --git a/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py b/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py new file mode 100644 index 000000000..d0a63aa8e --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import math + +from cloudvolume import CloudVolume +from cloudvolume.lib import Bbox + +from zetta_utils import mazepa +from zetta_utils.mazepa_addons.configurations.execute_on_gcp_with_sqs import execute_on_gcp_with_sqs +from zetta_utils.builder.built_in_registrations import efficient_parse_lambda_str +from zetta_utils.geometry.bbox import BBox3D +from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer +from zetta_utils.layer.volumetric.tools import VolumetricIndexTranslator +from zetta_utils.mazepa_layer_processing.common import build_subchunkable_apply_flow + +from zetta_utils import log + +from zetta_utils.ng.link_builder import make_ng_link + +from zetta_utils.geometry.vec import Vec3D + +logger = log.get_logger("zetta_utils") +log.set_verbosity("INFO") +log.configure_logger() + +SOURCE_PATHS = { + # "microns_pinky": { + # "src_res": [32, 32, 40], + # "path": "gs://neuroglancer/pinky100_v0/son_of_alignment_v15_rechunked", + # "bbox": [[6144, 5120, 17], [14336, 9216, 1201]], + # "n": 256, + # "stride": 1, + # # "chunk_size": [256, 256, 16], + # # Boss version downsamples are partially corrupt + # }, + # "microns_basil": { + # "src_res": [32, 32, 40], + # "path": "https://s3.amazonaws.com/bossdb-open-data/iarpa_microns/basil/em", + # "bbox": [[2048, 2048, 0], [27648, 32768, 993]], + # "n": 10, + # "stride": 1, + # # "chunk_size": [128, 128, 64], + # }, + # "microns_minnie": { + # "src_res": [32, 32, 40], + # "path": "gs://iarpa_microns/minnie/minnie65/em", + # "bbox": [[3072, 3072, 14832], [56320, 48128, 24464]], + # "n": 4, + # "stride": 2400, + # # "chunk_size": [64, 64, 64], + # # 14825-27904, S3 version has strong JPEG artifacts + # }, + # "microns_interneuron": { + # "src_res": [8, 8, 40], + # "dst_res": [32, 32, 40], + # "path": "https://s3.amazonaws.com/bossdb-open-data/iarpa_microns/interneuron/em", + # "bbox": [[20480, 24576, 4683], [110592, 114688, 17286]], + # "n": 16, + # "stride": 787, + # # "chunk_size": [2048, 2048, 1], + # # 4005-17347, stronger JPEG artifacts at 16x16 and 32x32 due to low-contrast + # }, + # "aibs_v1dd": { + # "src_res": [38.8, 38.8, 45], + # "path": "gs://v1dd_imagery/image/aligned_image", + # "bbox": [[9216, 5120, 0], [40960, 26624, 15708]], + # "n": 12, + # "stride": 1280, + # # "chunk_size": [64, 64, 64], + # # Bounding box way too large - try max: 40,960 x 32,768 + # }, + # "kim_n2da": { + # "src_res": [32, 32, 50], + # "path": "gs://zetta_jkim_001_n2da_1430/tests/corgie_tests/uint8_siftv11_newnets_onepass_m7m5m3_retry4_m55333/img/img_rendered", + # "bbox": [[0, 0, 1], [1024, 1024, 622]], + # "n": 8192, # all 621 sections + # "stride": 1, + # # "chunk_size": [1024, 1024, 1], + # }, + # "kim_pfc2022": { + # "src_res": [16, 16, 40], + # "dst_res": [32, 32, 40], + # "path": "gs://zetta_jkim_001_pfc2022_em/pfc/v1", + # "bbox": [[0, 0, 4], [14336, 12288, 1205]], + # "n": 183, + # "stride": 1, + # # "chunk_size": [256, 256, 16], + # # jpeg artifacts at 32x32 + # }, + # "kronauer_cra9": { + # "src_res": [4, 4, 42], + # "dst_res": [32, 32, 42], + # "path": "gs://dkronauer-ant-001-drop/cra9_inspection_4nm_sections2665-2680", + # "bbox": [[73728, 49152, 2665], [172032, 131072, 2679]], + # "n": 68, # all 14 sections + # "stride": 1, + # # "chunk_size": [1024, 1024, 1], + # # 2665-2680, 4nm only + # }, + # "kubota_001": { + # "src_res": [20, 20, 40], + # "path": "gs://zetta_kubota_001_alignment/v1", + # "bbox": [[1024, 1024, 0], [6144, 6144, 1191]], + # "n": 300, + # "stride": 1, + # # "chunk_size": [512, 512, 8], + # # suspicious of this resolution... 20 nm looks closer to 30-40nm + # }, + # "lee_fanc": { + # "src_res": [34.4, 34.4, 45], + # "path": "gs://zetta_lee_fly_vnc_001_precomputed/vnc1_full_v3align_2/realigned_v1", + # "bbox": [[0, 0, 0], [10240, 27648, 4400]], + # "n": 30, + # "stride": 146, + # # "chunk_size": [1024, 1024, 1], + # }, + # "lee_banc": { + # "src_res": [32, 32, 45], + # "path": "gs://zetta_lee_fly_cns_001_alignment/aligned/v0", + # "bbox": [[1024, 1024, 0], [26624, 32768, 7010]], + # "n": 10, + # "stride": 701, + # # "chunk_size": [2048, 2048, 1], + # }, + # "lee_ppc": { + # "src_res": [8, 8, 40], + # "dst_res": [32, 32, 40], + # "path": "gs://zetta_lee_mouse_ppc_001_alignment/test_bbox/m7_sm2000_m5_sm2000_m3_sm2000_300iter/img/img/img_rendered", + # "bbox": [[144384, 39936, 12], [156672, 52224, 1241]], + # "n": 910, + # "stride": 1, + # # "chunk_size": [2048, 2048, 1], + # # Cutout: 144384, 39936, 12 - 156672, 52224, 1241 @ 8x8x40 only + # }, + # "lee_mosquito": { + # "src_res": [16, 16, 40], + # "dst_res": [32, 32, 40], + # "path": "gs://zetta_lee_mosquito_001_raw_image/V1_aligned/raw", + # "bbox": [[0, 0, 3000], [44032, 28160, 4747]], + # "n": 28, + # "stride": 62, + # # "chunk_size": [1024, 1024, 1], + # # 16x16x40 only + # }, + # "lichtman_zebrafish": { + # "src_res": [32, 32, 30], + # "path": "gs://zetta_jlichtman_zebrafish_001_alignment/fine_full_v2/img", + # "bbox": [[1024, 2048, 0], [10240, 14336, 4010]], + # "n": 76, + # "stride": 52, + # # "chunk_size": [2048, 2048, 1], + # }, + # # "neitz_macaque": { + # # "src_res": [10,10,50], + # # "path": "gs://zetta_neitz_macaque_retina_001_alignment_temp/13846-17051_11069-14269_5-2354/image_stitch_decay140_z1230-2250_mip1" + # # # too small and render artifacts + # # }, + # "prieto_godino_larva": { + # "src_res": [32, 32, 32], + # "path": "gs://zetta-prieto-godino-fly-larva-001-image/image-v1-iso", + # "bbox": [[0, 0, 0], [4218, 4531, 3442]], + # "n": 450, + # "stride": 1, + # # "chunk_size": [128, 128, 128], + # }, + # "fafb_v15": { + # "src_res": [32, 32, 40], + # "path": "https://tigerdata.princeton.edu/sseung-test1/fafb-v15-alignment-temp/fine_final/z0_7063/v1/aligned/mip1", + # "bbox": [[2048, 2048, 0], [29696, 14336, 7063]], + # "n": 25, + # "stride": 280, + # # "chunk_size": [512, 512, 8], + # }, + # "lichtman_h01": { + # "src_res": [8, 8, 33], + # "dst_res": [32, 32, 33], + # "path": "gs://h01-release/data/20210601/4nm_raw", + # "bbox": [[61440, 45056, 0], [491520, 286720, 5293]], + # "n": 3, + # "stride": 1760, + # # "chunk_size": [128, 128, 32], + # # 16x16 starts showing JPEG artifacts + # }, + # # "janelia_hemibrain": { + # # "src_res": [32, 32, 32], + # # "path": "gs://neuroglancer-janelia-flyem-hemibrain/emdata/clahe_yz/jpeg", + # # "bbox": [[0, 1024, 0], [8606, 9216, 10240]], + # # "n": 102, + # # "stride": 1, + # # # "chunk_size": [64, 64, 64], + # # # Slab interfaces - need to use yz, requires manual transpose + # # }, + # "janelia_manc": { + # "src_res": [32, 32, 32], + # # "path": "gs://flyem-vnc-2-26-213dba213ef26e094c16c860ae7f4be0/v3_emdata_clahe_xy/jpeg", + # "path": "https://storage.googleapis.com/flyem-vnc-2-26-213dba213ef26e094c16c860ae7f4be0/v3_emdata_clahe_xy/jpeg", + # "bbox": [[1024, 1024, 0], [9216, 12288, 20569]], + # "n": 93, + # "stride": 192, + # # "chunk_size": [64, 64, 64], + # # Slab interfaces - xy is good + # }, + # "nguyen_thomas_2022": { + # "src_res": [4, 4, 40], + # "dst_res": [32, 32, 40], + # "path": "https://s3.amazonaws.com/bossdb-open-data/nguyen_thomas2022/cb2/em", + # "bbox": [[0, 0, 0], [249600, 230400, 1200]], + # "n": 10, + # "stride": 1, + # # "chunk_size": [1024, 1024, 25], + # # Corrupt downsamples - use 4x4x40 + # }, + # # "maher_briegel_2023": { + # # "src_res": [5,5,75], + # # "path": "https://s3.amazonaws.com/bossdb-open-data/MaherBriegel2023/Lgn200/sbem" + # # # Sections too thick + # # }, + # # "mulcahy_2022_1h": { + # # "src_res": [16, 16, 30], + # # "path": "https://s3.amazonaws.com/bossdb-open-data/mulcahy2022/1h_L1/em", + # # # Poor alignment + # # }, + # "mulcahy_2022_16h": { + # "src_res": [32, 32, 30], + # "path": "https://s3.amazonaws.com/bossdb-open-data/mulcahy2022/16h_L1/em", + # "bbox": [[0, 0, 0], [7616, 2304, 1051]], + # "n": 490, + # "stride": 1, + # # "chunk_size": [256, 256, 32], + # }, + # "wildenberg_2021_vta_dat12a": { + # "src_res": [32, 32, 40], + # "path": "https://s3.amazonaws.com/bossdb-open-data/wildenberg2021/VTA_dat12a_saline_control_Dendrites_6nm_aligned/image", + # "bbox": [[0, 0, 0], [2565, 2662, 191]], + # "n": 1258, # all 191 sections + # "stride": 1, + # # "chunk_size": [512, 512, 16], + # }, + # "bumbarber_2013": { + # "src_res": [31.2, 31.2, 50], + # "path": "https://s3.amazonaws.com/bossdb-open-data/neurodata/bumbarger/bumbarger13/image", + # "bbox": [[512, 512, 0], [2560, 2560, 2762]], + # "n": 2048, + # "stride": 1, + # # "chunk_size": [512, 512, 16], + # }, + # "wilson_2019_p3": { + # "src_res": [32, 32, 30], + # "path": "https://s3.amazonaws.com/bossdb-open-data/wilson2019/P3/em", + # "bbox": [[0, 0, 0], [5120, 7168, 1657]], + # "n": 234, + # "stride": 1, + # # "chunk_size": [512, 512, 16], + # }, + # # "ishibashi_2021_em1": { + # # "src_res": [16, 16, 4], + # # "path": "https://s3.amazonaws.com/bossdb-open-data/Ishibashi2021/EM1/em", + # # "bbox": [[0,0,0], [1536, 1024, 1136]], + # # "n": 21845, # all 142 sections + # # "stride": 1, + # # # "chunk_size": [512, 512, 16], + # # # 32x32x4 downsampling is corrupt, also should take every 8th slice + # # }, + # # "ishibashi_2021_em2": { + # # "src_res": [16, 16, 4], + # # "path": "https://s3.amazonaws.com/bossdb-open-data/Ishibashi2021/EM2/em", + # # "bbox": [[0,0,0], [1664, 1152, 1344]], + # # "n": 13443, # all 168 sections + # # "stride": 1, + # # # "chunk_size": [512, 512, 16], + # # # 32x32x4 downsampling is corrupt, also should take every 8th slice + # # }, + # "templier_2019_wafer1": { + # "src_res": [32, 32, 50], + # "path": "https://s3.amazonaws.com/bossdb-open-data/neurodata/templier/Wafer1/C1_EM", + # "bbox": [[0, 0, 0], [9216, 7168, 514]], + # "n": 130, + # "stride": 1, + # # "chunk_size": [64, 64, 64], + # }, + # "templier_2019_wafer3": { + # "src_res": [32, 32, 50], + # "path": "https://s3.amazonaws.com/bossdb-open-data/neurodata/templier/Wafer3/EM", + # "bbox": [[0, 0, 0], [7168, 6144, 204]], + # "n": 195, + # "stride": 1, + # # "chunk_size": [64, 64, 64], + # }, + "lichtman_octopus2022": { + "src_res": [32, 32, 30], + "path": "gs://octopus-connectomes/vertical_lobe/img", + "bbox": [[2048, 1024, 0], [9216, 12288, 892]], + "n": 106, + "stride": 1, + # "chunk_size": [64, 64, 64], + } +} + + +for k, v in SOURCE_PATHS.items(): + cv = CloudVolume(v["path"], v["src_res"], use_https=True) + bbox = Bbox(v["bbox"][0], v["bbox"][1]) + total_sections = (min(v["n"], int(bbox.size3()[2])) * int(cv.chunk_size[2])) / math.ceil( + (int(cv.chunk_size[2]) / v["stride"]) + ) + total_chunks = math.ceil(total_sections / int(cv.chunk_size[2])) + print( + "Download: ", + int(total_chunks), + "chunks", + k, + int(bbox.size3()[0]) + * int(bbox.size3()[1]) + * int(cv.chunk_size[2]) + * total_chunks + / 1024 + / 1024 + / 1024, + "GiB", + ) + print(cv.chunk_size) + + +flows = [] +for k, v in SOURCE_PATHS.items(): + print(v["path"]) + if k in ["kronauer_cra9", "lichtman_zebrafish"]: + continue + # Check for src chunk size + cv = CloudVolume(v["path"], v["src_res"], use_https=True) + chunk_size_z = cv.chunk_size[2] + bbox = Bbox(*v["bbox"]) + src_res = v["src_res"] + dst_res = v.get("dst_res", src_res) + scale_factor = int(round(dst_res[0] / src_res[0])) + chunk_size_xy_adjust = 2 ** math.ceil(math.log(math.sqrt(chunk_size_z), 2)) + + if v["stride"] == 1: + # Copy continuous chunk, match processing chunk size to src chunk size + processing_chunk_size = [ + max(1024, 32768 // (scale_factor * chunk_size_xy_adjust)), + max(1024, 32768 // (scale_factor * chunk_size_xy_adjust)), + chunk_size_z, + ] + size_z = int(min(v["n"], bbox.size3()[2])) + start_z = int( + max( + bbox.minpt[2], + bbox.minpt[2] + math.ceil(bbox.size3()[2] / 2) - math.ceil(size_z / 2), + ) + ) + end_z = start_z + size_z + 1 + src_bboxes = [ + BBox3D.from_coords( + start_coord=[int(bbox.minpt[0]), int(bbox.minpt[1]), start_z], + end_coord=[int(bbox.maxpt[0]), int(bbox.maxpt[1]), end_z], + resolution=v["src_res"], + ) + ] + dst_bboxes = [ + BBox3D.from_coords( + start_coord=[0, 0, 0], + end_coord=[ + int(bbox.maxpt[0]) - int(bbox.minpt[0]), + int(bbox.maxpt[1]) - int(bbox.minpt[1]), + size_z, + ], + resolution=v["src_res"], + ) + ] + else: + size_z = int(2 * v["n"]) + processing_chunk_size = [ + max(1024, 32768 // (scale_factor * chunk_size_xy_adjust)), + max(1024, 32768 // (scale_factor * chunk_size_xy_adjust)), + 2 + ] + src_bboxes = [ + BBox3D.from_coords( + start_coord=[int(bbox.minpt[0]), int(bbox.minpt[1]), start_z], + end_coord=[int(bbox.maxpt[0]), int(bbox.maxpt[1]), start_z + 2], + resolution=v["src_res"], + ) + for start_z in range(int(bbox.minpt[2]), int(bbox.maxpt[2]), v["stride"]) + ] + dst_bboxes = [ + BBox3D.from_coords( + start_coord=[0, 0, start_z], + end_coord=[ + int(bbox.maxpt[0]) - int(bbox.minpt[0]), + int(bbox.maxpt[1]) - int(bbox.minpt[1]), + start_z + 2, + ], + resolution=v["src_res"], + ) + for start_z in range(0, int(bbox.maxpt[2]) - int(bbox.minpt[2]), 2) + ] + + flow = mazepa.concurrent_flow( + [ + build_subchunkable_apply_flow( + dst=build_cv_layer( + "gs://zetta-research-nico/encoder/datasets/" + k, + info_reference_path=v["path"], + on_info_exists="overwrite", + info_field_overrides={ + "type": "image", + "num_channels": 1, + "data_type": "uint8", + "scales": [ + { + "chunk_sizes": [[1024, 1024, 1]], + "resolution": dst_res, + "encoding": "raw", + "key": f"{dst_res[0]}_{dst_res[1]}_{dst_res[2]}", + "voxel_offset": [0, 0, 0], + "size": [ + int(bbox.size3()[0] // scale_factor), + int(bbox.size3()[1] // scale_factor), + size_z, + ], + } + ], + }, + cv_kwargs={"delete_black_uploads": True}, + ), + fn=efficient_parse_lambda_str(lambda_str="lambda src: src", name=f"Transfer {k}"), + skip_intermediaries=True, + dst_resolution=dst_res, + processing_chunk_sizes=[processing_chunk_size], + op_kwargs={ + "src": build_cv_layer( + v["path"], + data_resolution=v["src_res"], + interpolation_mode="img", + cv_kwargs={"use_https": True}, + index_procs=[ + VolumetricIndexTranslator( + offset=[ + 10 * (src_bbox.start[0] - dst_bbox.start[0]), # Hack for decimal resolutions + 10 * (src_bbox.start[1] - dst_bbox.start[1]), # Hack for decimal resolutions + src_bbox.start[2] - dst_bbox.start[2], + ], + resolution=[0.1, 0.1, 1], # Hack for decimal resolutions + ) + ], + ) + }, + bbox=dst_bbox, + ) + for (src_bbox, dst_bbox) in zip(src_bboxes, dst_bboxes) + ] + ) + flows.append(flow) + + +for k in SOURCE_PATHS.keys(): + cv = CloudVolume("precomputed://gs://zetta-research-nico/encoder/datasets/" + k) + make_ng_link( + layers=[(k, "image", "precomputed://gs://zetta-research-nico/encoder/datasets/" + k)], + title=k, + position=Vec3D(*cv.bounds.center().round()), + scale_bar_nm=5000, + ) + + +import os +import json +os.environ["ZETTA_RUN_SPEC"] = json.dumps("") +execute_on_gcp_with_sqs( + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230728_7", + worker_resources={"memory": "27560Mi"}, + worker_replicas=50, + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + checkpoint_interval_sec=60, + do_dryrun_estimation=True, + # checkpoint="gs://zetta_utils_runs/nkem/exec-nice-sepia-wren-of-jest/2023-07-29_152159_7246.zstd", + local_test=False, + batch_gap_sleep_sec=0.1, + target=mazepa.seq_flow(flows) +) diff --git a/specs/nico/training/em_encoder/preprocess/02_pairwise_align.cue b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.cue new file mode 100644 index 000000000..dade53f58 --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.cue @@ -0,0 +1,283 @@ +import "math" +import "list" + +#BASE_FOLDER: "zetta-research-nico/encoder/" +#IMG_SRC_PATH: "\(#BASE_FOLDER)/datasets/" +#IMG_DST_PATH: "\(#BASE_FOLDER)/pairwise_aligned/" + +#DST_INFO_CHUNK_SIZE: [2048, 2048, 1] + +#FIELD_INFO_OVERRIDE: { + "data_type": "float32", + "num_channels": 2, + "scales": [ + for i in list.Range(0, 10, 1) { + encoding: "zfpc" + zfpc_correlated_dims: [true, true, false, false] + zfpc_tolerance: 0.001953125 + } + ], + "type": "image" +} + +#ALIGN_STAGES: [ + #STAGE_TMPL & { + dst_resolution: [128, 128, 45] + + fn: { + sm: 10 + num_iter: 500 + lr: 0.05 + } + chunk_size: [2048, 2048, 1] + }, + #STAGE_TMPL & { + dst_resolution: [64, 64, 45] + + fn: { + sm: 10 + num_iter: 300 + lr: 0.1 + } + chunk_size: [2048, 2048, 1] + }, + + #STAGE_TMPL & { + dst_resolution: [32, 32, 45] + + fn: { + sm: 10 + num_iter: 200 + lr: 0.1 + } + chunk_size: [2048, 2048, 1] + }, +] + +#ENCODE_STAGES: [ + { + type: "BaseEncoder" + model: "gs://zetta-research-nico/training_artifacts/base_encodings/gamma_low0.75_high1.5_prob1.0_tile_0.0_0.2_lr0.00002_post1.8_cns_all/last.ckpt.model.spec.json" + resolution: [32, 32, 45] + res_change_mult: [1, 1, 1] + }, +] + +#STAGE_TMPL: { + "@type": "ComputeFieldStage" + dst_resolution: _ + chunk_size: _ + fn: { + "@type": "align_with_online_finetuner" + "@mode": "partial" + sm: _ + num_iter: _ + lr?: _ + } + crop_pad: [64, 64, 0] +} + + +#CF_FLOW_TMPL: { + "@type": "build_compute_field_multistage_flow" + bbox: #BCUBE_COMBINED_32NM + stages: _ + src_offset?: _ + tgt_offset?: _ + offset_resolution: [4, 4, 45] + src: { + "@type": "build_cv_layer" + path: #ENC_PATH + } + tgt: { + "@type": "build_cv_layer" + path: #ENC_PATH + } + dst: { + "@type": "build_cv_layer" + path: _ + info_field_overrides: #FIELD_INFO_OVERRIDE + // on_info_exists: "overwrite" + } + tmp_layer_dir: _ + tmp_layer_factory: { + "@type": "build_cv_layer" + "@mode": "partial" + info_field_overrides: #FIELD_INFO_OVERRIDE + // on_info_exists: "overwrite" + } + src_field: *null | { + "@type": "build_cv_layer" + path: _ + data_resolution: _ + interpolation_mode: "field" + } +} + + +#WARP_IMG_STAGE: { + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "WarpOperation" + mode: "img" + crop_pad: [256, 256, 0] + } + dst_resolution: [32, 32, 45] + processing_chunk_sizes: [[2048, 2048, 1]] + processing_crop_pads: [[256, 256, 0]] + bbox: #BCUBE_COMBINED_32NM + src: { + "@type": "build_cv_layer" + path: _ + } + field: { + "@type": "build_cv_layer" + path: _ + } + dst: { + "@type": "build_cv_layer" + path: _ + info_reference_path: src.path + } +} + +#ENCODE_STAGE: { + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "VolumetricCallableOperation" + fn: { + "@type": "BaseEncoder" + model_path: #ENCODE_STAGES[0].model + } + crop_pad: [128, 128, 0] + } + dst_resolution: [32, 32, 45] + processing_chunk_sizes: [[2048, 2048, 1]] + processing_crop_pads: [[128, 128, 0]] + bbox: #BCUBE_COMBINED_32NM + src: { + "@type": "build_cv_layer" + path: _ + } + dst: { + "@type": "build_cv_layer" + path: _ + info_reference_path: src.path + info_field_overrides: { + data_type: "int8" + } + info_chunk_size: [2048, 2048, 1] + on_info_exists: "overwrite" + } +} + + +#MASK_IMG_STAGE: { + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "VolumetricCallableOperation" + fn: { + "@type": "apply_mask_fn" + "@mode": "partial" + } + } + dst_resolution: [32, 32, 45] + processing_chunk_sizes: [[2048, 2048, 1]] + processing_crop_pads: [[0, 0, 0]] + bbox: #BCUBE_COMBINED_32NM + src: { + "@type": "build_cv_layer" + path: _ + } + dst: { + "@type": "build_cv_layer" + path: _ + info_reference_path: src.path + } + masks: [ + { + "@type": "build_cv_layer" + path: _ + read_procs: [ + { + "@type": "lambda" + lambda_str: "lambda data: (data == 0)" + }, + ] + }, + ] +} + + +#JOINT_OFFSET_FLOW: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230405" + worker_resources: { + memory: "18560Mi" + "nvidia.com/gpu": "1" + } + worker_replicas: 100 + batch_gap_sleep_sec: 1 + do_dryrun_estimation: true + local_test: false + + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for z_offset in #Z_OFFSETS { + "@type": "mazepa.concurrent_flow" + stages: [ + { + "@type": "mazepa.seq_flow" + stages: [ + // Fine Alignment - Good coarse + #CF_FLOW_TMPL & { + dst: path: "\(#FIELDS_PATH)/fine/\(z_offset)" + tmp_layer_dir: "\(#FIELDS_PATH)/fine/\(z_offset)/tmp" + tgt_offset: [0, 0, z_offset] + stages: #ALIGN_STAGES + src_field: { + path: "\(#FIELDS_PATH)/coarse/\(z_offset)" + data_resolution: [256, 256, 45] + } + }, + + // Warp good fine alignment + #WARP_IMG_STAGE & { + src: path: #IMG_PATH + field: path: "\(#FIELDS_PATH)/fine/\(z_offset)" + dst: path: "\(#OUTPUT_IMG_PATH)/fine/\(z_offset)" + bbox: #BCUBE_COMBINED_32NM + }, + + #ENCODE_STAGE & { + src: path: "\(#OUTPUT_IMG_PATH)/fine/\(z_offset)" + dst: path: "\(#OUTPUT_ENC_PATH)/fine/\(z_offset)" + bbox: #BCUBE_COMBINED_32NM + }, + + #MASK_IMG_STAGE & { + src: path: "\(#OUTPUT_IMG_PATH)/fine/\(z_offset)" + dst: path: "\(#OUTPUT_IMG_PATH)/fine_masked/\(z_offset)" + masks: [ + { + path: "\(#OUTPUT_ENC_PATH)/fine/\(z_offset)" + } + ] + bbox: #BCUBE_COMBINED_32NM + }, + + ] + }, + ] + }, + ] + } +} + +[ + + //ALIGN + #JOINT_OFFSET_FLOW + +] \ No newline at end of file diff --git a/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py new file mode 100644 index 000000000..8f9b25dca --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import math +from functools import partial + +from cloudvolume import CloudVolume + + +from zetta_utils.api.v0 import * + +SOURCE_PATHS = { + # # # "microns_pinky": {"contiguous": True}, + # # # "microns_basil": {"contiguous": True}, + # # # "microns_minnie": {"contiguous": False}, + # # # "microns_interneuron": {"contiguous": False}, + "aibs_v1dd": {"contiguous": False}, + # # # "kim_n2da": {"contiguous": True}, + # # # "kim_pfc2022": {"contiguous": True}, + # # # "kronauer_cra9": {"contiguous": True}, + # # # "kubota_001": {"contiguous": True}, + "lee_fanc": {"contiguous": False}, + # # # "lee_banc": {"contiguous": False}, + # # # "lee_ppc": {"contiguous": True}, + # # # "lee_mosquito": {"contiguous": False}, + # # # "lichtman_zebrafish": {"contiguous": False}, + # # # "prieto_godino_larva": {"contiguous": True}, + # # # "fafb_v15": {"contiguous": False}, + # # # "lichtman_h01": {"contiguous": False}, + # # # "janelia_hemibrain": {"contiguous": True}, + # # # "janelia_manc": {"contiguous": False}, + # # # "nguyen_thomas_2022": {"contiguous": True}, + # # # "mulcahy_2022_16h": {"contiguous": True}, + # # # "wildenberg_2021_vta_dat12a": {"contiguous": True}, + # # # "bumbarber_2013": {"contiguous": True}, + # # # "wilson_2019_p3": {"contiguous": True}, + # # # "ishibashi_2021_em1": {"contiguous": True}, + # # # "ishibashi_2021_em2": {"contiguous": True}, + # # # "templier_2019_wafer1": {"contiguous": True}, + # # # "templier_2019_wafer3": {"contiguous": True}, + # # # "lichtman_octopus2022": {"contiguous": True}, +} + + +BASE_PATH = "gs://zetta-research-nico/encoder/" +IMG_SRC_PATH = BASE_PATH + "datasets/" +DST_PATH = BASE_PATH + "pairwise_aligned/" +TMP_FIELD_PATH = BASE_PATH + "tmp/pairwise_aligned_fields/" + + +concurrent_cf_flows = [] +concurrent_warp_flows = [] +tasks_count = {} +for k, v in SOURCE_PATHS.items(): + src_img_path = IMG_SRC_PATH + k + dst_field_path = DST_PATH + k + "/field" + dst_img_path = DST_PATH + k + "/warped_img" + dst_enc_path = DST_PATH + k + "/warped_enc" + + cv_src_img = CloudVolume(src_img_path, progress=False) + + + bounds = cv_src_img.bounds + resolution = cv_src_img.resolution.tolist() + minpt = bounds.minpt.tolist() + maxpt = bounds.maxpt.tolist() + size = bounds.size3().tolist() + + field_ref = cv_src_img.info + field_ref["data_type"] = "float32" + field_ref["num_channels"] = 2 + for i in range(3): + field_ref["scales"][i].update( + { + "encoding": "zfpc", + "zfpc_correlated_dims": [True, True, False, False], + "zfpc_tolerance": 1 / 512, + } + ) + + # ds_pyramid = [] + # src_res = resolution + # dst_res = [src_res[0] * 2, src_res[1] * 2, src_res[2]] + + # for i in range(2): + # ds_pyramid.append( + # build_subchunkable_apply_flow( + # dst=build_cv_layer(IMG_SRC_PATH + k, cv_kwargs={"delete_black_uploads": True}), + # fn=efficient_parse_lambda_str( + # lambda_str="lambda src: src", name=f"Downsample {k}" + # ), + # skip_intermediaries=True, + # dst_resolution=dst_res, + # processing_chunk_sizes=[[8192, 8192, 1]], + # op_kwargs={ + # "src": build_cv_layer( + # IMG_SRC_PATH + k, + # data_resolution=src_res, + # interpolation_mode="img", + # ) + # }, + # bbox=BBox3D.from_coords( + # start_coord=[0, 0, 0], + # end_coord=[size[0], size[1], size[2]], + # resolution=dst_res, + # ), + # expand_bbox_processing=True, + # ) + # ) + # src_res = dst_res + # dst_res = [src_res[0] * 2, src_res[1] * 2, src_res[2]] + # size = [size[0] // 2, size[1] // 2, size[2]] + + # concurrent_flows.append(mazepa.seq_flow(ds_pyramid)) + + if v["contiguous"]: + z_ranges = [(minpt[2], maxpt[2] + 1)] + else: + z_ranges = [(z, z + 1) for z in range(minpt[2], maxpt[2], 2)] + + tasks_count[k] = {"cf": [0, 0, 0], "warp": 0} + for z_start, z_end in z_ranges[1:]: + # tasks_count[k]["cf"][0] += (16384 * math.ceil(0.25 * size[0] / 16384.0) * 16384 * math.ceil(0.25 * size[1] / 16384.0) * (z_end - z_start)) / (16384 * 16384) + # tasks_count[k]["cf"][1] += (16384 * math.ceil(0.5 * size[0] / 16384.0) * 16384 * math.ceil(0.5 * size[1] / 16384.0) * (z_end - z_start)) / (16384 * 16384) + # tasks_count[k]["cf"][2] += (16384 * math.ceil(size[0] / 16384.0) * 16384 * math.ceil(size[1] / 16384.0) * (z_end - z_start)) / (16384 * 16384) + # tasks_count[k]["warp"] += (16384 * math.ceil(size[0] / 16384.0) * 16384 * math.ceil(size[1] / 16384.0) * (z_end - z_start)) / (16384 * 16384) + lvl0_sizes = [ + [ + min(8192, 2048 * math.ceil(size[0] / 2048)), + min(8192, 2048 * math.ceil(size[1] / 2048)), + 1 + ], + [ + min(8192, 2048 * math.ceil(size[0] / 2048)), + min(8192, 2048 * math.ceil(size[1] / 2048)), + 1 + ], + [ + min(8192, 2048 * math.ceil(size[0] / 2048)), + min(8192, 2048 * math.ceil(size[1] / 2048)), + 1 + ], + [ + min(16384, 2048 * math.ceil(size[0] / 2048)), + min(16384, 2048 * math.ceil(size[1] / 2048)), + 1 + ], + ] + + + compute_field_flow = build_compute_field_multistage_flow( + stages=[ + ComputeFieldStage( + fn=partial(align_with_online_finetuner, sm=10, num_iter=300, lr=0.1), + dst_resolution=resolution, + processing_chunk_sizes=[lvl0_sizes[0], [2048, 2048, 1]], + processing_crop_pads=[[0, 0, 0], [64, 64, 0]], + expand_bbox_processing=True, + ), + ComputeFieldStage( + fn=partial(align_with_online_finetuner, sm=10, num_iter=300, lr=0.1), + dst_resolution=resolution, + processing_chunk_sizes=[lvl0_sizes[1], [2048, 2048, 1]], + processing_crop_pads=[[0, 0, 0], [64, 64, 0]], + expand_bbox_processing=True, + ), + ComputeFieldStage( + fn=partial(align_with_online_finetuner, sm=10, num_iter=200, lr=0.1), + dst_resolution=resolution, + processing_chunk_sizes=[lvl0_sizes[2], [2048, 2048, 1]], + processing_crop_pads=[[0, 0, 0], [64, 64, 0]], + expand_bbox_processing=True, + ) + ], + bbox=BBox3D.from_coords( + start_coord=[minpt[0], minpt[1], z_start], + end_coord=[maxpt[0], maxpt[1], z_end], + resolution=resolution + ), + src_offset=[0, 0, 1], + tgt_offset=[0, 0, 0], + offset_resolution=resolution, + src=build_cv_layer(src_img_path), + tgt=build_cv_layer(src_img_path), + dst=build_cv_layer( + DST_PATH + k + "/field", + info_field_overrides=field_ref, + ), + tmp_layer_dir=TMP_FIELD_PATH + k, + tmp_layer_factory=partial(build_cv_layer, info_field_overrides=field_ref) + ) + concurrent_cf_flows.append(compute_field_flow) + + + warp_img_flow = build_subchunkable_apply_flow( + dst=build_cv_layer( + dst_img_path, + cv_kwargs={"delete_black_uploads": True}, + info_reference_path=src_img_path, + + ), + op=WarpOperation(mode="img"), + skip_intermediaries=True, + dst_resolution=resolution, + processing_chunk_sizes=[lvl0_sizes[3], [2048, 2048, 1]], + processing_crop_pads=[[0, 0, 0], [256, 256, 0]], + op_kwargs={ + "src": build_cv_layer( + src_img_path, + index_procs=[ + VolumetricIndexTranslator( + offset=[0, 0, 1], + resolution=resolution + ) + ], + ), + "field": build_cv_layer(dst_field_path) + }, + bbox=BBox3D.from_coords( + start_coord=[minpt[0], minpt[1], z_start], + end_coord=[maxpt[0], maxpt[1], z_end], + resolution=resolution, + ), + expand_bbox_processing=True, + ) + concurrent_warp_flows.append(warp_img_flow) + + +import json +import os + +os.environ["ZETTA_RUN_SPEC"] = json.dumps("") +execute_on_gcp_with_sqs( + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230809", + worker_resources={"memory": "17560Mi", "nvidia.com/gpu": 1}, + worker_replicas=300, + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + checkpoint_interval_sec=60, + do_dryrun_estimation=True, + local_test=False, + batch_gap_sleep_sec=0.1, + # checkpoint="gs://zetta_utils_runs/nkem/exec-smart-caracara-of-strange-progress/2023-08-04_235232_9768.zstd", + target=concurrent_flow(concurrent_cf_flows), +) + +execute_on_gcp_with_sqs( + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230809", #03_2 + worker_resources={"memory": "13000Mi"}, + worker_replicas=50, + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + checkpoint_interval_sec=60, + do_dryrun_estimation=True, + local_test=False, + batch_gap_sleep_sec=0.1, + target=concurrent_flow(concurrent_warp_flows), +) diff --git a/specs/nico/training/em_encoder/preprocess/03_enc_misd.py b/specs/nico/training/em_encoder/preprocess/03_enc_misd.py new file mode 100644 index 000000000..d1fe2f8b2 --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/03_enc_misd.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +import math +from functools import partial +import os + +from cloudvolume import CloudVolume +from copy import deepcopy +import json + + +from zetta_utils.api.v0 import * + +SOURCE_PATHS = { + # "microns_pinky": {"contiguous": True}, + # "microns_basil": {"contiguous": True}, + # "microns_minnie": {"contiguous": False}, + # "microns_interneuron": {"contiguous": False}, + # "aibs_v1dd": {"contiguous": False}, + "kim_n2da": {"contiguous": True}, + # "kim_pfc2022": {"contiguous": True}, + # "kronauer_cra9": {"contiguous": True}, + # "kubota_001": {"contiguous": True}, + # "lee_fanc": {"contiguous": False}, + # "lee_banc": {"contiguous": False}, + # "lee_ppc": {"contiguous": True}, + # "lee_mosquito": {"contiguous": False}, + # "lichtman_zebrafish": {"contiguous": False}, + # "prieto_godino_larva": {"contiguous": True}, + # "fafb_v15": {"contiguous": False}, + # "lichtman_h01": {"contiguous": False}, + # "janelia_hemibrain": {"contiguous": True}, + # "janelia_manc": {"contiguous": False}, + # "nguyen_thomas_2022": {"contiguous": True}, + "mulcahy_2022_16h": {"contiguous": True}, + # "wildenberg_2021_vta_dat12a": {"contiguous": True}, + "bumbarber_2013": {"contiguous": True}, + # "wilson_2019_p3": {"contiguous": True}, + # "ishibashi_2021_em1": {"contiguous": True}, + # "ishibashi_2021_em2": {"contiguous": True}, + # "templier_2019_wafer1": {"contiguous": True}, + # "templier_2019_wafer3": {"contiguous": True}, + # "lichtman_octopus2022": {"contiguous": True}, +} + +BASE_PATH = "gs://zetta-research-nico/encoder/" +ENCODER_MODEL = "gs://zetta-research-nico/training_artifacts/base_enc_zfish/1.4.0_M3_M3_unet4_lr0.0001_equi0.5_post1.6-1.6_fmt0.8_zfish/last.ckpt.model.spec.json" +MISD_MODEL = "gs://zetta-research-nico/training_artifacts/aced_misd_cns_zfish/thr5.0_lr0.00001_zfish_finetune_2/last.ckpt.static-2.0.0+cu117-model.jit" + +concurrent_enc_flows = [] +concurrent_misd_flows = [] +concurrent_img_ds_flows = [] +concurrent_mask_ds_flows = [] +for k, v in SOURCE_PATHS.items(): + img_tgt_path = BASE_PATH + "datasets/" + k + img_src_path = BASE_PATH + "pairwise_aligned/" + k + "/warped_img" + enc_tgt_path = BASE_PATH + "pairwise_aligned/" + k + "/tgt_enc" + enc_src_path = BASE_PATH + "pairwise_aligned/" + k + "/warped_enc" + misd_mask_path = BASE_PATH + "pairwise_aligned/" + k + "/misd_mask" + misd_mask_thr_path = BASE_PATH + "pairwise_aligned/" + k + "/misd_mask_thr" + + cv_src_img = CloudVolume(img_tgt_path, progress=False) + bounds = cv_src_img.bounds + resolution = cv_src_img.resolution.tolist() + minpt = bounds.minpt.tolist() + maxpt = bounds.maxpt.tolist() + size = bounds.size3().tolist() + + enc_ref = cv_src_img.info + enc_ref["data_type"] = "int8" + enc_ref["scales"] = enc_ref["scales"][:1] + + mask_ref = deepcopy(enc_ref) + mask_ref["data_type"] = "uint8" + + mask_thresh_ref = deepcopy(mask_ref) + mask_thresh_ref["scales"][0]["size"] = [math.ceil(maxpt[0] / 2.0**10), math.ceil(maxpt[1] / 2.0**10), maxpt[2]] + mask_thresh_ref["scales"][0]["chunk_sizes"] = [[math.ceil(maxpt[0] / 2.0**10), math.ceil(maxpt[1] / 2.0**10), 1]] + mask_thresh_ref["scales"][0]["resolution"] = [resolution[0] * 2**10, resolution[1] * 2**10, resolution[2]] + mask_thresh_ref["scales"][0]["key"] = f"{resolution[0] * 2**10}_{resolution[1] * 2**10}_{resolution[2]}" + + if v["contiguous"]: + z_ranges = [(minpt[2], maxpt[2] + 1)] + else: + z_ranges = [(z, z + 1) for z in range(minpt[2], maxpt[2], 2)] + + superchunk_size = [ + min(8192, 2048 * math.ceil(size[0] / 2048)), + min(8192, 2048 * math.ceil(size[1] / 2048)), + 1 + ] + + for z_start, z_end in z_ranges: + # for img_path, enc_path in [(img_src_path, enc_src_path), (img_tgt_path, enc_tgt_path)]: + # enc_flow = build_subchunkable_apply_flow( + # dst=build_cv_layer( + # enc_path, + # cv_kwargs={"delete_black_uploads": True}, + # info_field_overrides=enc_ref, + # ), + # op=VolumetricCallableOperation( + # fn=BaseEncoder(ENCODER_MODEL), + # crop_pad=[0, 0, 0], + # res_change_mult=[1, 1, 1], + # ), + # skip_intermediaries=True, + # dst_resolution=resolution, + # processing_chunk_sizes=[superchunk_size, [2048, 2048, 1]], + # processing_crop_pads=[[0, 0, 0], [32, 32, 0]], + # op_kwargs={ + # "src": build_cv_layer( + # img_path, + # ), + # }, + # bbox=BBox3D.from_coords( + # start_coord=[minpt[0], minpt[1], z_start], + # end_coord=[maxpt[0], maxpt[1], z_end], + # resolution=resolution, + # ), + # expand_bbox_processing=True, + # ) + # concurrent_enc_flows.append(enc_flow) + + # misd_flow = build_subchunkable_apply_flow( + # dst=build_cv_layer( + # misd_mask_path, + # cv_kwargs={"delete_black_uploads": True}, + # info_field_overrides=mask_ref, + # ), + # op=VolumetricCallableOperation( + # fn=MisalignmentDetector(MISD_MODEL), + # ), + # skip_intermediaries=True, + # dst_resolution=resolution, + # processing_chunk_sizes=[superchunk_size, [2048, 2048, 1]], + # processing_crop_pads=[[0, 0, 0], [64, 64, 0]], + # op_kwargs={ + # "src": build_cv_layer( + # enc_src_path, + # ), + # "tgt": build_cv_layer( + # enc_tgt_path, + # ), + # }, + # bbox=BBox3D.from_coords( + # start_coord=[minpt[0], minpt[1], z_start], + # end_coord=[maxpt[0], maxpt[1], z_end], + # resolution=resolution, + # ), + # expand_bbox_processing=True, + # ) + # concurrent_misd_flows.append(misd_flow) + + + # seq_ds_flows = [] + # for src_res in [[resolution[0] * 2**factor, resolution[1] * 2**factor, resolution[2]] for factor in range(0, 2)]: + # dst_res = [2 * src_res[0], 2 * src_res[1], src_res[2]] + # ds_flow = build_subchunkable_apply_flow( + # dst=build_cv_layer( + # img_src_path, + # cv_kwargs={"delete_black_uploads": True}, + # ), + # fn=efficient_parse_lambda_str(lambda_str="lambda src: src", name=f"Downsample Warped Img"), + # skip_intermediaries=True, + # dst_resolution=dst_res, + # processing_chunk_sizes=[superchunk_size, [2048, 2048, 1]], + # processing_crop_pads=[[0, 0, 0], [0, 0, 0]], + # op_kwargs={ + # "src": build_cv_layer( + # img_src_path, + # data_resolution=src_res, + # interpolation_mode="img", + # ), + # }, + # bbox=BBox3D.from_coords( + # start_coord=[minpt[0], minpt[1], z_start], + # end_coord=[maxpt[0], maxpt[1], z_end], + # resolution=resolution, + # ), + # expand_bbox_resolution=True, + # ) + # seq_ds_flows.append(ds_flow) + # concurrent_img_ds_flows.append(sequential_flow(seq_ds_flows)) + + seq_ds_flows = [] + dst_res = [resolution[0] * 2**10, resolution[1] * 2**10, resolution[2]] + ds_flow = build_subchunkable_apply_flow( + dst=build_cv_layer( + misd_mask_thr_path, + info_field_overrides=mask_thresh_ref, + data_resolution=dst_res, + interpolation_mode="nearest", + ), + fn=efficient_parse_lambda_str(lambda_str="lambda src: src", name=f"Downsample Warped Mask"), + skip_intermediaries=True, + dst_resolution=resolution, + processing_chunk_sizes=[[math.ceil(maxpt[0]/1024.0)*1024, math.ceil(maxpt[1]/1024.0)*1024, 1]], + processing_crop_pads=[[0, 0, 0]], + op_kwargs={ + "src": build_cv_layer( + misd_mask_path, + read_procs=[ + partial(rearrange, pattern="C X Y Z -> Z C X Y"), + partial(compare, mode=">=", value=32, binarize=True), + partial(to_float32), + partial(interpolate, scale_factor=[1.0/2**10, 1.0/2**10], mode="area", unsqueeze_input_to=4), + partial(compare, mode=">", value=0.1, binarize=True), + partial(to_uint8), + partial(interpolate, scale_factor=[2**10, 2**10], mode="nearest", unsqueeze_input_to=4), + partial(rearrange, pattern="Z C X Y -> C X Y Z"), + ], + ), + }, + bbox=BBox3D.from_coords( + start_coord=[minpt[0], minpt[1], z_start], + end_coord=[maxpt[0], maxpt[1], z_end], + resolution=resolution, + ), + expand_bbox_processing=True, + expand_bbox_resolution=True, + ) + concurrent_mask_ds_flows.append(ds_flow) + + + +os.environ["ZETTA_RUN_SPEC"] = json.dumps("") +# execute_on_gcp_with_sqs( +# worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230809", +# worker_resources={"memory": "17560Mi", "nvidia.com/gpu": 1}, +# worker_replicas=30, +# worker_cluster_name="zutils-x3", +# worker_cluster_region="us-east1", +# worker_cluster_project="zetta-research", +# checkpoint_interval_sec=60, +# do_dryrun_estimation=True, +# local_test=False, +# batch_gap_sleep_sec=0.1, +# target=sequential_flow([ +# concurrent_flow(concurrent_enc_flows), +# concurrent_flow(concurrent_misd_flows), + +# ]) +# ) + +# breakpoint() +# execute_on_gcp_with_sqs( +# worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230809", +# worker_resources={"memory": "13000Mi"}, +# worker_replicas=50, +# worker_cluster_name="zutils-x3", +# worker_cluster_region="us-east1", +# worker_cluster_project="zetta-research", +# checkpoint_interval_sec=60, +# do_dryrun_estimation=True, +# local_test=True, +# batch_gap_sleep_sec=0.1, +# target=concurrent_flow([ +# concurrent_mask_ds_flows +# ]), +# ) + +for k in SOURCE_PATHS.keys(): + cv = CloudVolume("precomputed://gs://zetta-research-nico/encoder/datasets/" + k) + link = make_ng_link( + layers=[ + ("tgt", "image", f"precomputed://gs://zetta-research-nico/encoder/datasets/{k}"), + ("src", "image", f"precomputed://gs://zetta-research-nico/encoder/pairwise_aligned/{k}/warped_img"), + ("misd", "image", f"precomputed://gs://zetta-research-nico/encoder/pairwise_aligned/{k}/misd_mask"), + ("bad_chunks", "segmentation", f"precomputed://gs://zetta-research-nico/encoder/pairwise_aligned/{k}/misd_mask_thr"), + (f"CREATE:zetta-research-nico/encoder/pairwise_aligned/{k}", "annotation", None), + ], + title=k, + position=Vec3D(*(cv.bounds.center().round()[:2]), 0), + scale_bar_nm=30000, + print_to_logger=False + ) + + print(f"{k}: {link}") \ No newline at end of file diff --git a/specs/nico/training/em_encoder/preprocess/04_mask_empty.py b/specs/nico/training/em_encoder/preprocess/04_mask_empty.py new file mode 100644 index 000000000..5e1c0e624 --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/04_mask_empty.py @@ -0,0 +1,125 @@ +from __future__ import annotations +from functools import partial + +import json +import math +import os + +from cloudvolume import CloudVolume + +from zetta_utils.api.v0 import * + +SOURCE_PATHS = { + # "microns_pinky": {"contiguous": True}, + # "microns_basil": {"contiguous": True}, + # "microns_minnie": {"contiguous": False}, + # "microns_interneuron": {"contiguous": False}, + # "aibs_v1dd": {"contiguous": False}, + "kim_n2da": {"contiguous": True}, + # "kim_pfc2022": {"contiguous": True}, + # "kronauer_cra9": {"contiguous": True}, + # "kubota_001": {"contiguous": True}, + # "lee_fanc": {"contiguous": False}, + # "lee_banc": {"contiguous": False}, + # "lee_ppc": {"contiguous": True}, + # "lee_mosquito": {"contiguous": False}, + # "lichtman_zebrafish": {"contiguous": False}, + # "prieto_godino_larva": {"contiguous": True}, + # "fafb_v15": {"contiguous": False}, + # "lichtman_h01": {"contiguous": False}, + # "janelia_hemibrain": {"contiguous": True}, + # "janelia_manc": {"contiguous": False}, + # "nguyen_thomas_2022": {"contiguous": True}, + "mulcahy_2022_16h": {"contiguous": True}, + # "wildenberg_2021_vta_dat12a": {"contiguous": True}, + "bumbarber_2013": {"contiguous": True}, + # "wilson_2019_p3": {"contiguous": True}, + # "ishibashi_2021_em1": {"contiguous": True}, + # "ishibashi_2021_em2": {"contiguous": True}, + # "templier_2019_wafer1": {"contiguous": True}, + # "templier_2019_wafer3": {"contiguous": True}, + # "lichtman_octopus2022": {"contiguous": True}, +} + +BASE_PATH = "gs://zetta-research-nico/encoder/" + +concurrent_mask_flows = [] +for k, v in SOURCE_PATHS.items(): + img_tgt_path = BASE_PATH + "datasets/" + k + img_src_path = BASE_PATH + "pairwise_aligned/" + k + "/warped_img" + misd_mask_thr_path = BASE_PATH + "pairwise_aligned/" + k + "/misd_mask_thr" + + cv_src_img = CloudVolume(img_tgt_path, progress=False) + bounds = cv_src_img.bounds + resolution = cv_src_img.resolution.tolist() + minpt = bounds.minpt.tolist() + maxpt = bounds.maxpt.tolist() + size = bounds.size3().tolist() + + if v["contiguous"]: + z_ranges = [(minpt[2], maxpt[2] + 1)] + else: + z_ranges = [(z, z + 1) for z in range(minpt[2], maxpt[2], 2)] + + for z_start, z_end in z_ranges: + mask_flow = build_subchunkable_apply_flow( + dst=build_cv_layer( + misd_mask_thr_path, + write_procs=[ + partial(to_uint8), + ] + ), + fn=efficient_parse_lambda_str( + lambda_str="lambda src: (src['src']==0) | (src['tgt']==0) | (src['misd']!=0)", + name=f"Downsample Warped Mask", + ), + skip_intermediaries=True, + dst_resolution=[resolution[0] * 1024, resolution[1] * 1024, resolution[2]], + processing_chunk_sizes=[ + [math.ceil(size[0] / 1024.0), math.ceil(size[1] / 1024.0), 1] + ], + processing_crop_pads=[[0, 0, 0]], + op_kwargs={ + "src": build_layer_set( + { + "src": build_cv_layer( + img_src_path, + data_resolution=[resolution[0] * 4, resolution[1] * 4, resolution[2]], + interpolation_mode="mask", + ), + "tgt": build_cv_layer( + img_tgt_path, + data_resolution=[resolution[0] * 4, resolution[1] * 4, resolution[2]], + interpolation_mode="mask", + ), + "misd": build_cv_layer( + misd_mask_thr_path, + ), + } + ) + }, + bbox=BBox3D.from_coords( + start_coord=[minpt[0], minpt[1], z_start], + end_coord=[maxpt[0], maxpt[1], z_end], + resolution=resolution, + ), + expand_bbox_processing=True, + expand_bbox_resolution=True, + ) + concurrent_mask_flows.append(mask_flow) + + +os.environ["ZETTA_RUN_SPEC"] = json.dumps("") +execute_on_gcp_with_sqs( + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230816", + worker_resources={"memory": "13000Mi"}, + worker_replicas=50, + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + checkpoint_interval_sec=60, + do_dryrun_estimation=True, + local_test=False, + batch_gap_sleep_sec=0.1, + target=concurrent_flow([concurrent_mask_flows]), +) diff --git a/specs/nico/training/em_encoder/preprocess/05_export_annos.py b/specs/nico/training/em_encoder/preprocess/05_export_annos.py new file mode 100644 index 000000000..df47db909 --- /dev/null +++ b/specs/nico/training/em_encoder/preprocess/05_export_annos.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from cloudvolume import CloudVolume + +from zetta_utils.api.v0 import * + +SOURCE_PATHS = { + # "microns_pinky": {"contiguous": True}, + # "microns_basil": {"contiguous": True}, + # "microns_minnie": {"contiguous": False}, + # "microns_interneuron": {"contiguous": False}, + # "aibs_v1dd": {"contiguous": False}, + "kim_n2da": {"contiguous": True}, + # "kim_pfc2022": {"contiguous": True}, + # "kronauer_cra9": {"contiguous": True}, + # "kubota_001": {"contiguous": True}, + # "lee_fanc": {"contiguous": False}, + # "lee_banc": {"contiguous": False}, + # "lee_ppc": {"contiguous": True}, + # "lee_mosquito": {"contiguous": False}, + # "lichtman_zebrafish": {"contiguous": False}, + # "prieto_godino_larva": {"contiguous": True}, + # "fafb_v15": {"contiguous": False}, + # "lichtman_h01": {"contiguous": False}, + # "janelia_hemibrain": {"contiguous": True}, + # "janelia_manc": {"contiguous": False}, + # "nguyen_thomas_2022": {"contiguous": True}, + "mulcahy_2022_16h": {"contiguous": True}, + # "wildenberg_2021_vta_dat12a": {"contiguous": True}, + "bumbarber_2013": {"contiguous": True}, + # "wilson_2019_p3": {"contiguous": True}, + # "ishibashi_2021_em1": {"contiguous": True}, + # "ishibashi_2021_em2": {"contiguous": True}, + # "templier_2019_wafer1": {"contiguous": True}, + # "templier_2019_wafer3": {"contiguous": True}, + # "lichtman_octopus2022": {"contiguous": True}, +} + +BASE_PATH = "gs://zetta-research-nico/encoder/" + +for k, v in SOURCE_PATHS.items(): + misd_mask_thr_path = BASE_PATH + "pairwise_aligned/" + k + "/misd_mask_thr" + annotation_layer_name = "zetta-research-nico/encoder/pairwise_aligned/" + k + cv = CloudVolume(misd_mask_thr_path, progress=False, fill_missing=True) + resolution = cv.resolution.tolist() + + data = cv[:, :, :].squeeze(-1) == 0 + if not v["contiguous"]: + data[:, :, 1::2] = False + + valid_chunks = data.nonzero() + annotations = [ + Vec3D(resolution[0] * (x + 0.5), resolution[1] * (y + 0.5), resolution[2] * z) + for (x, y, z) in zip(*valid_chunks) + ] + print(f"Writing {len(annotations)} annotations for layer {k}") + write_remote_annotations(annotation_layer_name, resolution, annotations) + + +# Writing 5019 annotations for layer microns_pinky +# Writing 2591 annotations for layer microns_basil +# Writing 2882 annotations for layer microns_minnie +# Writing 6923 annotations for layer microns_interneuron +# Writing 5805 annotations for layer aibs_v1dd +# Writing 446 annotations for layer kim_n2da +# Writing 3699 annotations for layer kim_pfc2022 +# Writing 740 annotations for layer kronauer_cra9 +# Writing 4744 annotations for layer kubota_001 +# Writing 1605 annotations for layer lee_fanc +# Writing 742 annotations for layer lee_banc +# Writing 7219 annotations for layer lee_ppc +# Writing 1964 annotations for layer lee_mosquito +# Writing 2799 annotations for layer lichtman_zebrafish +# Writing 4584 annotations for layer prieto_godino_larva +# Writing 1795 annotations for layer fafb_v15 +# Writing 6624 annotations for layer lichtman_h01 +# Writing 5304 annotations for layer janelia_hemibrain +# Writing 2398 annotations for layer janelia_manc +# Writing 1847 annotations for layer nguyen_thomas_2022 +# Writing 3379 annotations for layer mulcahy_2022_16h +# Writing 1704 annotations for layer wildenberg_2021_vta_dat12a +# Writing 7325 annotations for layer bumbarber_2013 +# Writing 2092 annotations for layer wilson_2019_p3 +# Writing 141 annotations for layer ishibashi_2021_em1 +# Writing 166 annotations for layer ishibashi_2021_em2 +# Writing 5401 annotations for layer templier_2019_wafer1 +# Writing 3577 annotations for layer templier_2019_wafer3 +# Writing 5673 annotations for layer lichtman_octopus2022 From 092c9bf8cf370a468d8110a57fa256edb9df88c2 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Tue, 5 Sep 2023 13:35:41 +0200 Subject: [PATCH 4/9] specs+regimes(training): update base encoder, deprecate old regimes --- .../em_encoder/preprocess/01_gen_datasets.py | 17 +- .../preprocess/02_pairwise_align.py | 3 +- .../em_encoder/preprocess/03_enc_misd.py | 9 +- .../em_encoder/preprocess/04_mask_empty.py | 4 +- .../em_encoder/preprocess/05_export_annos.py | 2 + .../em_encoder/train/m3_m3_encoder_dict.py | 481 ++++++++++++++++ .../em_encoder/train/m3_m4_encoder_dict.py | 489 ++++++++++++++++ .../em_encoder/train/m3_m5_encoder_dict.py | 499 +++++++++++++++++ .../em_encoder/train/m3_m6_encoder_dict.py | 509 +++++++++++++++++ .../em_encoder/train/m3_m7_encoder_dict.py | 494 +++++++++++++++++ zetta_utils/api/v0.py | 17 - .../lightning/regimes/alignment/__init__.py | 8 +- .../regimes/alignment/base_encoder.py | 384 +++++++------ .../alignment/deprecated/base_encoder.py | 524 ++++++++++++++++++ .../{ => deprecated}/encoding_coarsener.py | 2 +- .../encoding_coarsener_gen_x1.py | 2 +- .../encoding_coarsener_highres.py | 2 +- .../{ => deprecated}/minima_encoder.py | 2 +- .../{ => deprecated}/misalignment_detector.py | 2 +- 19 files changed, 3244 insertions(+), 206 deletions(-) create mode 100644 specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py create mode 100644 specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py create mode 100644 specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py create mode 100644 specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py create mode 100644 specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py create mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py rename zetta_utils/training/lightning/regimes/alignment/{ => deprecated}/encoding_coarsener.py (99%) rename zetta_utils/training/lightning/regimes/alignment/{ => deprecated}/encoding_coarsener_gen_x1.py (98%) rename zetta_utils/training/lightning/regimes/alignment/{ => deprecated}/encoding_coarsener_highres.py (99%) rename zetta_utils/training/lightning/regimes/alignment/{ => deprecated}/minima_encoder.py (99%) rename zetta_utils/training/lightning/regimes/alignment/{ => deprecated}/misalignment_detector.py (99%) diff --git a/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py b/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py index d0a63aa8e..3891d7dd7 100644 --- a/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py +++ b/specs/nico/training/em_encoder/preprocess/01_gen_datasets.py @@ -1,3 +1,5 @@ +# type: ignore +# pylint: skip-file from __future__ import annotations import math @@ -5,20 +7,18 @@ from cloudvolume import CloudVolume from cloudvolume.lib import Bbox -from zetta_utils import mazepa -from zetta_utils.mazepa_addons.configurations.execute_on_gcp_with_sqs import execute_on_gcp_with_sqs +from zetta_utils import log, mazepa from zetta_utils.builder.built_in_registrations import efficient_parse_lambda_str from zetta_utils.geometry.bbox import BBox3D +from zetta_utils.geometry.vec import Vec3D from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer from zetta_utils.layer.volumetric.tools import VolumetricIndexTranslator +from zetta_utils.mazepa_addons.configurations.execute_on_gcp_with_sqs import ( + execute_on_gcp_with_sqs, +) from zetta_utils.mazepa_layer_processing.common import build_subchunkable_apply_flow - -from zetta_utils import log - from zetta_utils.ng.link_builder import make_ng_link -from zetta_utils.geometry.vec import Vec3D - logger = log.get_logger("zetta_utils") log.set_verbosity("INFO") log.configure_logger() @@ -464,8 +464,9 @@ ) -import os import json +import os + os.environ["ZETTA_RUN_SPEC"] = json.dumps("") execute_on_gcp_with_sqs( worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230728_7", diff --git a/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py index 8f9b25dca..e335f989f 100644 --- a/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py +++ b/specs/nico/training/em_encoder/preprocess/02_pairwise_align.py @@ -1,3 +1,5 @@ +# type: ignore +# pylint: skip-file from __future__ import annotations import math @@ -5,7 +7,6 @@ from cloudvolume import CloudVolume - from zetta_utils.api.v0 import * SOURCE_PATHS = { diff --git a/specs/nico/training/em_encoder/preprocess/03_enc_misd.py b/specs/nico/training/em_encoder/preprocess/03_enc_misd.py index d1fe2f8b2..15f051f29 100644 --- a/specs/nico/training/em_encoder/preprocess/03_enc_misd.py +++ b/specs/nico/training/em_encoder/preprocess/03_enc_misd.py @@ -1,13 +1,14 @@ +# type: ignore +# pylint: skip-file from __future__ import annotations +import json import math -from functools import partial import os - -from cloudvolume import CloudVolume from copy import deepcopy -import json +from functools import partial +from cloudvolume import CloudVolume from zetta_utils.api.v0 import * diff --git a/specs/nico/training/em_encoder/preprocess/04_mask_empty.py b/specs/nico/training/em_encoder/preprocess/04_mask_empty.py index 5e1c0e624..5801ec3aa 100644 --- a/specs/nico/training/em_encoder/preprocess/04_mask_empty.py +++ b/specs/nico/training/em_encoder/preprocess/04_mask_empty.py @@ -1,9 +1,11 @@ +# type: ignore +# pylint: skip-file from __future__ import annotations -from functools import partial import json import math import os +from functools import partial from cloudvolume import CloudVolume diff --git a/specs/nico/training/em_encoder/preprocess/05_export_annos.py b/specs/nico/training/em_encoder/preprocess/05_export_annos.py index df47db909..b58d01aef 100644 --- a/specs/nico/training/em_encoder/preprocess/05_export_annos.py +++ b/specs/nico/training/em_encoder/preprocess/05_export_annos.py @@ -1,3 +1,5 @@ +# type: ignore +# pylint: skip-file from __future__ import annotations from cloudvolume import CloudVolume diff --git a/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py new file mode 100644 index 000000000..8fde87463 --- /dev/null +++ b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py @@ -0,0 +1,481 @@ +# pylint: skip-file +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + EXP_NAME = "general_encoder_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.33_M3_M3_unet_fm{FM}-256_conv4_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "activate_last": True, + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM*2], + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": 1, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py new file mode 100644 index 000000000..098fe2f58 --- /dev/null +++ b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py @@ -0,0 +1,489 @@ +# pylint: skip-file +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 2 + CHANNELS = 1 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.0_M3_M4_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM*2], + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py new file mode 100644 index 000000000..5a17fa794 --- /dev/null +++ b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py @@ -0,0 +1,499 @@ +# pylint: skip-file +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 4 + CHANNELS = 1 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.14_M3_M5_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM*2, FM*2, FM*2, FM*4], + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + [FM*2, FM*2, FM*2, FM*2], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*2, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py new file mode 100644 index 000000000..df88b3b4b --- /dev/null +++ b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py @@ -0,0 +1,509 @@ +# pylint: skip-file +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 8 + CHANNELS = 2 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.15_M3_M6_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*2, FM*2, FM*2, FM*4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM*4, FM*4, FM*4, FM*8], + [FM*8, FM*8, FM*8, FM*8], + [FM*4, FM*4, FM*4, FM*4], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, + }, + "conv": { + "@type": "torch.nn.Conv2d", + "@mode": "partial", + "padding": 1, + }, + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*4, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py new file mode 100644 index 000000000..5f5cce385 --- /dev/null +++ b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py @@ -0,0 +1,494 @@ +# pylint: skip-file +from __future__ import annotations + +if __name__ == "__main__": + import os + + from zetta_utils.api.v0 import * + from zetta_utils.parsing import json + from zetta_utils.training.lightning.train import _parse_spec_and_train + + LR = 2e-4 + L1_WEIGHT_START_VAL = 0.0 + L1_WEIGHT_END_VAL = 0.05 + L1_WEIGHT_START_EPOCH = 1 + L1_WEIGHT_END_EPOCH = 6 + LOCALITY_WEIGHT = 1.0 + SIMILARITY_WEIGHT = 0.0 + CHUNK_SIZE = 1024 + FM = 32 + DS_FACTOR = 16 + CHANNELS = 3 + EXP_NAME = "general_coarsener_loss" + TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" + + EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + + START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + + BASE_PATH = "gs://zetta-research-nico/encoder/" + + VAL_DSET_NAME = "microns_basil" + VALIDATION_SRC_PATH = BASE_PATH + "datasets/" + VAL_DSET_NAME + VALIDATION_TGT_PATH = BASE_PATH + "pairwise_aligned/" + VAL_DSET_NAME + "/warped_img" + + SOURCE_PATHS = { + "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, + # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, + "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, + "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, + "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"contiguous": True, "resolution": [40, 40, 40], "num_samples": 4744}, + "lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45], "num_samples": 1605}, + "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, + "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, + "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, + "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, + "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, + "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, + "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, + "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, + "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, + # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + } + + val_img_aug = [ + # {"@type": "to_uint8", "@mode": "partial"}, + # { + # "@type": "imgaug_readproc", + # "@mode": "partial", + # "augmenters": [ + # { + # "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + # "severity": 1, + # } + # ], + # }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + train_img_aug = [ + {"@type": "divide", "@mode": "partial", "value": 255.0}, + { + "@type": "square_tile_pattern_aug", + "@mode": "partial", + "prob": 0.5, + "tile_size": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "tile_stride": {"@type": "uniform_distr", "low": 64, "high": 1024}, + "max_brightness_change": {"@type": "uniform_distr", "low": 0.0, "high": 0.3}, + "rotation_degree": {"@type": "uniform_distr", "low": 0, "high": 90}, + "preserve_data_val": 0.0, + "repeats": 1, + "device": "cpu", + }, + {"@type": "torch.clamp", "@mode": "partial", "min": 0.0, "max": 1.0}, + {"@type": "multiply", "@mode": "partial", "value": 255.0}, + {"@type": "to_uint8", "@mode": "partial"}, + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.SomeOf", + "n": 3, + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.OneOf", + "children": [ + { + "@type": "imgaug.augmenters.GammaContrast", + "gamma": (0.5, 2.0), + }, + { + "@type": "imgaug.augmenters.SigmoidContrast", + "gain": (4, 6), + "cutoff": (0.3, 0.7), + }, + { + "@type": "imgaug.augmenters.LogContrast", + "gain": (0.7, 1.3), + }, + { + "@type": "imgaug.augmenters.LinearContrast", + "alpha": (0.4, 1.6), + }, + ], + }, + { + "@type": "imgaug.augmenters.AllChannelsCLAHE", + "clip_limit": (0.1, 8.0), + "tile_grid_size_px": (3, 64), + }, + ], + }, + { + "@type": "imgaug.augmenters.Add", + "value": (-40, 40), + }, + { + "@type": "imgaug.augmenters.Sometimes", + "p": 1.0, + "then_list": [{ + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + }] + }, + { + "@type": "imgaug.augmenters.Cutout", + "nb_iterations": 1, + "size": (0.02, 0.8), + "cval": (0, 255), + "squared": False, + }, + { + "@type": "imgaug.augmenters.JpegCompression", + "compression": (0, 35), + }, + ], + "random_order": True, + }, + ], + }, + ] + + shared_train_img_aug = [ + { + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + { + "@type": "imgaug.augmenters.Rot90", + "k": [0, 1, 2, 3], + }, + { + "@type": "imgaug.augmenters.Fliplr", + "p": 0.25, + }, + { + "@type": "imgaug.augmenters.Flipud", + "p": 0.25, + }, + ], + "random_order": True, + } + ], + }, + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + {"@type": "divide", "@mode": "partial", "value": 255.0}, + ] + + + training = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "JointDataset", + "mode": "vertical", + "datasets": { + name: { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "datasets/" + name, + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": BASE_PATH + "pairwise_aligned/" + name + "/warped_img", + "read_procs": train_img_aug, + "cv_kwargs": {"cache": False}, + }, + }, + "readonly": True, + "read_procs": shared_train_img_aug, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": settings["resolution"], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, + } + } + } + for name, settings in SOURCE_PATHS.items() + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + "cv_kwargs": {"cache": False}, + }, + "sample_indexer": { + "@type": "RandomIndexer", + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [128, 128, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + validation = { + "@type": "JointDataset", + "mode": "horizontal", + "datasets": { + "images": { + "@type": "LayerDataset", + "layer": { + "@type": "build_layer_set", + "layers": { + "src_img": { + "@type": "build_cv_layer", + "path": VALIDATION_SRC_PATH, + "read_procs": val_img_aug, + }, + "tgt_img": { + "@type": "build_cv_layer", + "path": VALIDATION_TGT_PATH, + "read_procs": val_img_aug, + }, + }, + "readonly": True, + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricNGLIndexer", + "resolution": [32,32,40], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, + } + }, + }, + "field": { + "@type": "LayerDataset", + "layer": { + "@type": "build_cv_layer", + "path": "gs://zetta-research-nico/perlin_noise_fields/1px", + "read_procs": [ + {"@type": "rearrange", "@mode": "partial", "pattern": "c x y 1 -> c x y"}, + ], + }, + "sample_indexer": { + "@type": "LoopIndexer", + "desired_num_samples": 100, + "inner_indexer": { + "@type": "VolumetricStridedIndexer", + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [2048, 2048, 2040], + "resolution": [4, 4, 45], + }, + "stride": [512, 512, 1], + "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], + "resolution": [4, 4, 45], + }, + }, + }, + }, + } + + + target = BuilderPartial( + { + "@type": "lightning_train", + "regime": { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM*2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*2, FM*2, FM*2, FM*4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*4, FM*4, FM*4, FM*8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + { + "@type": "torch.nn.MaxPool2d", + "kernel_size": 2 + }, + { + "@type": "ConvBlock", + "num_channels": [FM*8, FM*8, FM*8, FM*8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 3}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM*8, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], + }, + }, + }, + "trainer": { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": 4, + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + }, + "train_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + }, + "val_dataloader": { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + }, + } + ) + + + os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + + # _parse_spec_and_train() + + lightning_train_remote( + worker_cluster_name="zutils-x3", + worker_cluster_region="us-east1", + worker_cluster_project="zetta-research", + worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + worker_resources={"nvidia.com/gpu": "4"}, + worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + num_nodes=1, + spec_path=target, + follow_logs=False, + env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + ) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index ece3516a9..050049e69 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -414,23 +414,6 @@ from zetta_utils.training.lightning.regimes.alignment.base_encoder import ( BaseEncoderRegime, ) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener import ( - EncodingCoarsenerRegime, -) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener_gen_x1 import ( - EncodingCoarsenerGenX1Regime, -) -from zetta_utils.training.lightning.regimes.alignment.encoding_coarsener_highres import ( - EncodingCoarsenerHighRes, - center_crop_norm, - warp_by_px, -) -from zetta_utils.training.lightning.regimes.alignment.minima_encoder import ( - MinimaEncoderRegime, -) -from zetta_utils.training.lightning.regimes.alignment.misalignment_detector import ( - MisalignmentDetectorRegime, -) from zetta_utils.training.lightning.regimes.alignment.misalignment_detector_aced import ( MisalignmentDetectorAcedRegime, ) diff --git a/zetta_utils/training/lightning/regimes/alignment/__init__.py b/zetta_utils/training/lightning/regimes/alignment/__init__.py index 0d003796c..e645e5759 100644 --- a/zetta_utils/training/lightning/regimes/alignment/__init__.py +++ b/zetta_utils/training/lightning/regimes/alignment/__init__.py @@ -1,6 +1,2 @@ -from . import encoding_coarsener -from . import encoding_coarsener_highres -from . import encoding_coarsener_gen_x1 -from . import base_encoder -from . import misalignment_detector -from . import misalignment_detector_aced +from . import base_encoder, misalignment_detector_aced +from .deprecated import encoding_coarsener diff --git a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py index 157d55b92..65823095d 100644 --- a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py +++ b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py @@ -1,38 +1,48 @@ # pragma: no cover # pylint: disable=too-many-locals +import os +from math import log2 from typing import Optional import attrs import einops +import numpy as np import pytorch_lightning as pl import torch +import torchfields +import wandb +from PIL import Image as PILImage +from pytorch_lightning import seed_everything -from zetta_utils import builder, distributions, tensor_ops +from zetta_utils import builder, distributions, tensor_ops, viz -from ..common import log_results - -@builder.register("BaseEncoderRegime") +@builder.register("BaseEncoderRegime", versions="==0.0.2") @attrs.mutable(eq=False) class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors model: torch.nn.Module lr: float train_log_row_interval: int = 200 val_log_row_interval: int = 25 - field_magn_thr: float = 1 - post_weight: float = 0.5 + max_displacement_px: float = 16.0 + l1_weight_start_val: float = 0.0 + l1_weight_end_val: float = 0.0 + l1_weight_start_epoch: int = 0 + l1_weight_end_epoch: int = 0 + locality_weight: float = 1.0 + similarity_weight: float = 0.0 zero_value: float = 0 - zero_conserve_weight: float = 0.5 + ds_factor: int = 1 worst_val_loss: float = attrs.field(init=False, default=0) worst_val_sample: dict = attrs.field(init=False, factory=dict) worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - equivar_weight: float = 1.0 equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + empty_tissue_threshold: float = 0.4 def __attrs_pre_init__(self): super().__init__() @@ -41,199 +51,244 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer + def log_results(self, mode: str, title_suffix: str = "", **kwargs): + if not self.logger: + return + images = [] + for k, v in kwargs.items(): + for b in range(1): + if v.dtype in (np.uint8, torch.uint8): + img = v[b].squeeze() + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.int8, np.int8): + img = v[b].squeeze().byte() + 127 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.bool, bool): + img = v[b].squeeze().byte() * 255 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + else: + if v.size(1) == 2 and k != "field": + img = torch.cat([v, torch.zeros_like(v[:, :1])], dim=1) + else: + img = v + v_min = img[b].min().round(decimals=4) + v_max = img[b].max().round(decimals=4) + images.append( + wandb.Image( + viz.rendering.Renderer()(img[b].squeeze()), + caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", + ) + ) + + self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) + + def validation_epoch_start(self, _): # pylint: disable=no-self-use + seed_everything(42) + def on_validation_epoch_end(self): - log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is not None: + seed_everything(int(env_seed) + self.current_epoch) + else: + seed_everything(None) def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + return loss - def _get_warped(self, img, field): - img_warped = field.field().from_pixels()(img) - zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 - img_warped[zeros_warped] = 0 - return img_warped, zeros_warped + @staticmethod + def _get_warped(img, field=None): + if field is not None: + img_warped = field.from_pixels()(img) + else: + img_warped = img - def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] + return img_warped - if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: - return None + @staticmethod + def _down_zeros_mask(zeros_mask, count): + if count <= 0: + return zeros_mask - seed_field = batch["field"] - seed_field = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + scale_factor = 0.5**count + return ( + torch.nn.functional.interpolate( + zeros_mask.float(), scale_factor=scale_factor, mode="bilinear" + ) + > 0.99 ) + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src_img"] + tgt = batch["images"]["tgt_img"] + + # if ( + # (src == self.zero_value) + (tgt == self.zero_value) + # ).bool().sum() / src.numel() > self.empty_tissue_threshold: + # return None # Can't return None with DDP! + + # Get random field - combination of pregenerated Perlin noise and a random affine transform + seed_field = batch["field"].field_() + f_warp = seed_field * self.max_displacement_px f_aff = ( einops.rearrange( tensor_ops.transform.get_affine_field( size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), + rot_deg=self.equivar_rot_deg_distr() if mode == "train" else 90.0, + scale=self.equivar_scale_distr() if mode == "train" else 1.0, + shear_x_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, + shear_y_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, + trans_x_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, + trans_y_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, ), "C X Y Z -> Z C X Y", ) - .field() # type: ignore - .pixels() + .pixels() # type: ignore[attr-defined] .to(seed_field.device) - ) - f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) - f2_trans = torch.tensor( - seed_field.field() - .from_pixels()(f1_trans.field().from_pixels()) # type: ignore - .pixels() - ) + ).repeat_interleave(src.size(0), dim=0) + f1_transform = f_aff.from_pixels()(f_warp.from_pixels()).pixels() - src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) - src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) - tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) + # Warp Images and Tissue mask + src_f1 = self._get_warped(src, field=f1_transform) + tgt_f1 = self._get_warped(tgt, field=f1_transform) + # Generate encodings: src, src_f1_enc, src_enc_f1, tgt_f1_enc src_enc = self.model(src) - src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore - src_f1_enc = self.model(src_f1) - - equi_diff = (src_enc_f1 - src_f1_enc).abs() - equi_loss = equi_diff[src_zeros_f1 == 0].sum() - equi_diff_map = equi_diff.clone() - equi_diff_map[src_zeros_f1] = 0 - - src_f2_enc = self.model(src_f2) + src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), mode="replicate") + src_enc_f1 = ( + torch.nn.functional.pad(f1_transform, (self.ds_factor,) * 4, mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .down(int(log2(self.ds_factor))) + .sample(src_enc_f1, padding_mode="border") + ) + src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1)) tgt_f1_enc = self.model(tgt_f1) - pre_diff = (src_f1_enc - tgt_f1_enc).abs() + crop = 256 // self.ds_factor + src_f1 = src_f1[..., 256:-256, 256:-256] + tgt_f1 = tgt_f1[..., 256:-256, 256:-256] + src_enc_f1 = src_enc_f1[..., crop:-crop, crop:-crop] + tgt_f1_enc = tgt_f1_enc[..., crop:-crop, crop:-crop] - pre_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 + # Alignment loss: Ensure even close to local optima solutions produce larger errors + # than the local optimum solution + abs_error_local_opt = ( + (src_enc_f1 - tgt_f1_enc)[:, :, 1:-1, 1:-1].pow(2).mean(dim=-3, keepdim=True) ) - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 + abs_error_1px_shift = torch.stack( + [ + (src_enc_f1[:, :, 2:, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, :-2, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, 1:-1, 2:] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (src_enc_f1[:, :, 1:-1, :-2] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), + (tgt_f1_enc[:, :, 2:, 2:] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + (tgt_f1_enc[:, :, :-2, 2:] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + (tgt_f1_enc[:, :, 2:, :-2] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + (tgt_f1_enc[:, :, :-2, :-2] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), + ] + ).mean(dim=-3, keepdim=True) + + locality_error_map = ( + ((abs_error_local_opt - abs_error_1px_shift + 4.0) * 0.2) + .pow( + 8.0 # increase to put more focus on locations where bad alignment + # still produces similar encodings - try 8? -> 42 + ) + .logsumexp(dim=0) ) - post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - - post_diff_map = (src_f2_enc - tgt_f1_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - if post_mask.sum() < 256: - return None + locality_loss = ( + locality_error_map.sum() / locality_error_map.size(0) * self.ds_factor * self.ds_factor + ) - post_loss = post_diff_map[..., post_mask].sum() + l1_loss_map = (tgt_f1_enc.abs() + src_enc_f1.abs())[:, :, 1:-1, 1:-1] + l1_loss = ( + l1_loss_map.sum() + / (2 * tgt_f1_enc.size(0) * tgt_f1_enc.size(1)) + * self.ds_factor + * self.ds_factor + ) - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_mask == 0] = 0 + l1_weight_ratio = min( + 1.0, + max(0, self.current_epoch - self.l1_weight_start_epoch) + / max(1, self.l1_weight_end_epoch - self.l1_weight_start_epoch), + ) + l1_weight = ( + l1_weight_ratio * self.l1_weight_end_val + + (1.0 - l1_weight_ratio) * self.l1_weight_start_val + ) - loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + loss = locality_loss * self.locality_weight + l1_loss * l1_weight + self.log( + f"loss/{mode}", loss, on_step=True, on_epoch=True, sync_dist=True, rank_zero_only=False + ) + self.log( + f"loss/{mode}_l1_weight", + l1_weight, + on_step=False, + on_epoch=True, + prog_bar=False, + sync_dist=False, + rank_zero_only=True, + ) + self.log( + f"loss/{mode}_l1", + l1_loss, + on_step=True, + on_epoch=True, + prog_bar=False, + sync_dist=True, + rank_zero_only=False, + ) if log_row: - log_results( + self.log_results( mode, sample_name, src=src, src_enc=src_enc, src_f1=src_f1, src_enc_f1=src_enc_f1, - src_f1_enc=src_f1_enc, - src_f2_enc=src_f2_enc, tgt_f1=tgt_f1, tgt_f1_enc=tgt_f1_enc, - field=torch.tensor(seed_field), - equi_diff_map=equi_diff_map, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def compute_metroem_loss_old( - self, batch: dict, mode: str, log_row: bool, sample_name: str = "" - ): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - field = batch["field"] - - tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) - src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) - - pre_tissue_mask = (src_zeros + tgt_zeros) == 0 - if pre_tissue_mask.sum() / src.numel() < 0.4: - return None - - zero_magns = 0 - tgt_enc = self.model(tgt) - zero_magns += tgt_enc[tgt_zeros].abs().sum() - - src_warped = field.field().from_pixels()(src) - src_warped_enc = self.model(src_warped) - src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 - - zero_magns += src_warped_enc[src_zeros_warped].abs().sum() - - # src_enc = (~(field.field().from_pixels()))(src_warped_enc) - src_enc = self.model(src) - - pre_diff = (src_enc - tgt_enc).abs() - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 - ) - post_magn_mask = field.abs().sum(1) > self.field_magn_thr - - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - post_diff_map = (src_warped_enc - tgt_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_tissue_mask == 0] = 0 - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - loss = pre_loss - post_loss * self.post_weight + zero_magns * self.zero_conserve_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) - if log_row: - log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_warped_enc=src_warped_enc, - tgt=tgt, - tgt_enc=tgt_enc, - field=field, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, + field=f_warp.tensor_(), + locality_error_map=locality_error_map, + l1_loss_map=l1_loss_map, + weighted_loss_map=( + locality_error_map / locality_error_map.size(0) * self.locality_weight + + l1_loss_map / (2 * tgt_f1_enc.size(0)) * l1_weight + ), ) return loss @@ -241,7 +296,8 @@ def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ log_row = batch_idx % self.val_log_row_interval == 0 sample_name = f"{batch_idx // self.val_log_row_interval}" - loss = self.compute_metroem_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py new file mode 100644 index 000000000..c9b12baef --- /dev/null +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py @@ -0,0 +1,524 @@ +# type: ignore +# pragma: no cover +# pylint: disable=too-many-locals, function-redefined + +from typing import Optional + +import attrs +import cc3d +import einops +import numpy as np +import pytorch_lightning as pl +import torch +import torchfields +import wandb +from PIL import Image as PILImage +from pytorch_lightning import seed_everything + +from zetta_utils import builder, distributions, tensor_ops, viz +from zetta_utils.training.lightning.regimes.common import log_results + + +@builder.register("BaseEncoderRegime", versions="==0.0.1") +@attrs.mutable(eq=False) +class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors + model: torch.nn.Module + lr: float + train_log_row_interval: int = 200 + val_log_row_interval: int = 25 + field_magn_thr: float = 1 + max_displacement_px: float = 16.0 + post_weight: float = 1.5 + zero_value: float = 0 + worst_val_loss: float = attrs.field(init=False, default=0) + worst_val_sample: dict = attrs.field(init=False, factory=dict) + worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) + + equivar_weight: float = 1.0 + equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) + equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + empty_tissue_threshold: float = 0.4 + + def __attrs_pre_init__(self): + super().__init__() + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def log_results(self, mode: str, title_suffix: str = "", **kwargs): + if self.logger is None: + return + images = [] + for k, v in kwargs.items(): + for b in range(1): + if v.dtype in (np.uint8, torch.uint8): + img = v[b].squeeze() + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.int8, np.int8): + img = v[b].squeeze().byte() + 127 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + elif v.dtype in (torch.bool, bool): + img = v[b].squeeze().byte() * 255 + img[-1, -1] = 255 + img[-2, -2] = 255 + img[-1, -2] = 0 + img[-2, -1] = 0 + images.append( + wandb.Image( + PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), + caption=f"{k}_b{b}", + ) + ) + else: + v_min = v[b].min().round(decimals=4) + v_max = v[b].max().round(decimals=4) + images.append( + wandb.Image( + viz.rendering.Renderer()(v[b].squeeze()), + caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", + ) + ) + + self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) + + def validation_epoch_start(self, _): # pylint: disable=no-self-use + seed_everything(42) + + def on_validation_epoch_end(self): + self.log_results( + "val", + "worst", + **self.worst_val_sample, + ) + self.worst_val_loss = 0 + self.worst_val_sample = {} + self.worst_val_sample_idx = None + seed_everything(None) + + def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.train_log_row_interval == 0 + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + + return loss + + def _get_warped(self, img, field=None): + img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value) + if field is not None: + img_warped = field.from_pixels()(img) + else: + img_warped = img + + zeros_padded = img_padded == self.zero_value + zeros_padded_cc = np.array( + [ + cc3d.connected_components( + x.detach().squeeze().cpu().numpy(), connectivity=4 + ).reshape(zeros_padded[0].shape) + for x in zeros_padded + ] + ) + + non_tissue_zeros_padded = zeros_padded.clone() + non_tissue_zeros_padded[ + torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device) + ] = False # keep masking resin, restore somas in center + + if field is not None: + zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .sample((~zeros_padded).float(), padding_mode="border") + <= 0.1 + ) + non_tissue_zeros_warped = ( + torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") + .from_pixels() # type: ignore[attr-defined] + .sample((~non_tissue_zeros_padded).float(), padding_mode="border") + <= 0.1 + ) + else: + zeros_warped = zeros_padded + non_tissue_zeros_warped = non_tissue_zeros_padded + + zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1)) + non_tissue_zeros_warped = torch.nn.functional.pad( + non_tissue_zeros_warped, (-1, -1, -1, -1) + ) + + img_warped[zeros_warped] = self.zero_value + return img_warped, ~zeros_warped, ~non_tissue_zeros_warped + + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src_img"] + tgt = batch["images"]["tgt_img"] + + if ( + (src == self.zero_value) + (tgt == self.zero_value) + ).bool().sum() / src.numel() > self.empty_tissue_threshold: + return None + + seed_field = batch["field"].field_() + f_warp_large = seed_field * self.max_displacement_px + f_warp_small = ( + seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + ) + + f_aff = ( + einops.rearrange( + tensor_ops.transform.get_affine_field( + size=src.shape[-1], + rot_deg=self.equivar_rot_deg_distr(), + scale=self.equivar_scale_distr(), + shear_x_deg=self.equivar_shear_deg_distr(), + shear_y_deg=self.equivar_shear_deg_distr(), + trans_x_px=self.equivar_trans_px_distr(), + trans_y_px=self.equivar_trans_px_distr(), + ), + "C X Y Z -> Z C X Y", + ) + .pixels() # type: ignore[attr-defined] + .to(seed_field.device) + ).repeat_interleave(src.size(0), dim=0) + f1_trans = f_aff.from_pixels()(f_warp_large.from_pixels()).pixels() + f2_trans = f_warp_small.from_pixels()(f1_trans.from_pixels()).pixels() + + magn_field = f_warp_small + + src_f1, _, src_nonzeros_f1 = self._get_warped(src, f1_trans) + src_f2, _, src_nonzeros_f2 = self._get_warped(src, f2_trans) + tgt_f1, _, tgt_nonzeros_f1 = self._get_warped(tgt, f1_trans) + + src_zeros_f1 = ~src_nonzeros_f1 + src_zeros_f2 = ~src_nonzeros_f2 + tgt_zeros_f1 = ~tgt_nonzeros_f1 + + src_enc = self.model(src) + src_f1_enc = self.model(src_f1) + + src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), value=0.0) + src_enc_f1 = ( + torch.nn.functional.pad(f1_trans, (1, 1, 1, 1), mode="replicate") # type: ignore + .from_pixels() + .sample(src_enc_f1, padding_mode="border") + ) + src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1), value=0.0) + + equi_diff = (src_enc_f1 - src_f1_enc).abs() + equi_loss = equi_diff[src_zeros_f1 != 0].sum() + equi_loss = equi_diff.sum() / equi_diff.size(0) + equi_diff_map = equi_diff.clone() + equi_diff_map[src_zeros_f1] = 0 + + src_f2_enc = self.model(src_f2) + tgt_f1_enc = self.model(tgt_f1) + + pre_diff = (src_f1_enc - tgt_f1_enc).abs() + + pre_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f1) + pre_loss = pre_diff[..., pre_tissue_mask].sum() / pre_diff.size(0) + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f2) + post_magn_mask = (magn_field.abs().max(1, keepdim=True)[0] > self.field_magn_thr).tensor_() + + post_diff_map = (src_f2_enc - tgt_f1_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() / post_diff_map.size(0) + + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_mask == 0] = 0 + + loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + if log_row: + self.log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_f1=src_f1, + src_enc_f1=src_enc_f1, + src_f1_enc=src_f1_enc, + src_f2_enc=src_f2_enc, + tgt_f1=tgt_f1, + tgt_f1_enc=tgt_f1_enc, + field=seed_field.tensor_(), + equi_diff_map=equi_diff_map, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.val_log_row_interval == 0 + sample_name = f"{batch_idx // self.val_log_row_interval}" + + with torchfields.set_identity_mapping_cache(True, clear_cache=False): + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) + return loss + + +@builder.register("BaseEncoderRegime", versions="==0.0.0") +@attrs.mutable(eq=False) +class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors + model: torch.nn.Module + lr: float + train_log_row_interval: int = 200 + val_log_row_interval: int = 25 + field_magn_thr: float = 1 + post_weight: float = 0.5 + zero_value: float = 0 + zero_conserve_weight: float = 0.5 + worst_val_loss: float = attrs.field(init=False, default=0) + worst_val_sample: dict = attrs.field(init=False, factory=dict) + worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) + + equivar_weight: float = 1.0 + equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) + equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) + equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) + + def __attrs_pre_init__(self): + super().__init__() + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def validation_epoch_end(self, _): + log_results( + "val", + "worst", + **self.worst_val_sample, + ) + self.worst_val_loss = 0 + self.worst_val_sample = {} + self.worst_val_sample_idx = None + + def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.train_log_row_interval == 0 + loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) + return loss + + def _get_warped(self, img, field): + img_warped = field.field().from_pixels()(img) + zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 + img_warped[zeros_warped] = 0 + return img_warped, zeros_warped + + def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): + src = batch["images"]["src"] + tgt = batch["images"]["tgt"] + + if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: + return None + + seed_field = batch["field"] + seed_field = ( + seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) + ) + + f_aff = ( + einops.rearrange( + tensor_ops.transform.get_affine_field( + size=src.shape[-1], + rot_deg=self.equivar_rot_deg_distr(), + scale=self.equivar_scale_distr(), + shear_x_deg=self.equivar_shear_deg_distr(), + shear_y_deg=self.equivar_shear_deg_distr(), + trans_x_px=self.equivar_trans_px_distr(), + trans_y_px=self.equivar_trans_px_distr(), + ), + "C X Y Z -> Z C X Y", + ) + .field() # type: ignore + .pixels() + .to(seed_field.device) + ) + f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) + f2_trans = torch.tensor( + seed_field.field() + .from_pixels()(f1_trans.field().from_pixels()) # type: ignore + .pixels() + ) + + src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) + src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) + tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) + + src_enc = self.model(src) + src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore + src_f1_enc = self.model(src_f1) + + equi_diff = (src_enc_f1 - src_f1_enc).abs() + equi_loss = equi_diff[src_zeros_f1 == 0].sum() + equi_diff_map = equi_diff.clone() + equi_diff_map[src_zeros_f1] = 0 + + src_f2_enc = self.model(src_f2) + tgt_f1_enc = self.model(tgt_f1) + + pre_diff = (src_f1_enc - tgt_f1_enc).abs() + + pre_tissue_mask = ( + tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 + ) + pre_loss = pre_diff[..., pre_tissue_mask].sum() + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ( + tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 + ) + + post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr + post_magn_mask[..., 0:10, :] = 0 + post_magn_mask[..., -10:, :] = 0 + post_magn_mask[..., :, 0:10] = 0 + post_magn_mask[..., :, -10:] = 0 + + post_diff_map = (src_f2_enc - tgt_f1_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() + + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_mask == 0] = 0 + + loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) + if log_row: + log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_f1=src_f1, + src_enc_f1=src_enc_f1, + src_f1_enc=src_f1_enc, + src_f2_enc=src_f2_enc, + tgt_f1=tgt_f1, + tgt_f1_enc=tgt_f1_enc, + field=torch.tensor(seed_field), + equi_diff_map=equi_diff_map, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def compute_metroem_loss_old( + self, batch: dict, mode: str, log_row: bool, sample_name: str = "" + ): + src = batch["images"]["src"] + tgt = batch["images"]["tgt"] + + field = batch["field"] + + tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) + src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) + + pre_tissue_mask = (src_zeros + tgt_zeros) == 0 + if pre_tissue_mask.sum() / src.numel() < 0.4: + return None + + zero_magns = 0 + tgt_enc = self.model(tgt) + zero_magns += tgt_enc[tgt_zeros].abs().sum() + + src_warped = field.field().from_pixels()(src) + src_warped_enc = self.model(src_warped) + src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 + + zero_magns += src_warped_enc[src_zeros_warped].abs().sum() + + # src_enc = (~(field.field().from_pixels()))(src_warped_enc) + src_enc = self.model(src) + + pre_diff = (src_enc - tgt_enc).abs() + pre_loss = pre_diff[..., pre_tissue_mask].sum() + pre_diff_masked = pre_diff.clone() + pre_diff_masked[..., pre_tissue_mask == 0] = 0 + + post_tissue_mask = ( + tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 + ) + post_magn_mask = field.abs().sum(1) > self.field_magn_thr + + post_magn_mask[..., 0:10, :] = 0 + post_magn_mask[..., -10:, :] = 0 + post_magn_mask[..., :, 0:10] = 0 + post_magn_mask[..., :, -10:] = 0 + post_diff_map = (src_warped_enc - tgt_enc).abs() + post_mask = post_magn_mask * post_tissue_mask + post_diff_masked = post_diff_map.clone() + post_diff_masked[..., post_tissue_mask == 0] = 0 + if post_mask.sum() < 256: + return None + + post_loss = post_diff_map[..., post_mask].sum() + loss = pre_loss - post_loss * self.post_weight + zero_magns * self.zero_conserve_weight + self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) + self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) + if log_row: + log_results( + mode, + sample_name, + src=src, + src_enc=src_enc, + src_warped_enc=src_warped_enc, + tgt=tgt, + tgt_enc=tgt_enc, + field=field, + post_diff_masked=post_diff_masked, + pre_diff_masked=pre_diff_masked, + ) + return loss + + def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ + log_row = batch_idx % self.val_log_row_interval == 0 + sample_name = f"{batch_idx // self.val_log_row_interval}" + + loss = self.compute_metroem_loss( + batch=batch, mode="val", log_row=log_row, sample_name=sample_name + ) + return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py index 6d350aa2d..d2823c8a6 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py @@ -14,7 +14,7 @@ from zetta_utils.training.lightning.train import distributed_available -@builder.register("EncodingCoarsenerRegime") +@builder.register("EncodingCoarsenerRegime", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerRegime(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py similarity index 98% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py index 34bcaefd8..fdbd5f1fd 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_gen_x1.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py @@ -13,7 +13,7 @@ from zetta_utils import builder, distributions, tensor_ops, viz -@builder.register("EncodingCoarsenerGenX1Regime") +@builder.register("EncodingCoarsenerGenX1Regime", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerGenX1Regime(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py index 6335ef14c..08a8f495f 100644 --- a/zetta_utils/training/lightning/regimes/alignment/encoding_coarsener_highres.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py @@ -60,7 +60,7 @@ def center_crop_norm(image): return crop(norm(image)) -@builder.register("EncodingCoarsenerHighRes") +@builder.register("EncodingCoarsenerHighRes", versions="==0.0.0") @attrs.mutable(eq=False) class EncodingCoarsenerHighRes(pl.LightningModule): # pylint: disable=too-many-ancestors encoder: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/minima_encoder.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/minima_encoder.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py index a32799476..39d9e2adc 100644 --- a/zetta_utils/training/lightning/regimes/alignment/minima_encoder.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py @@ -12,7 +12,7 @@ from zetta_utils import builder, distributions, tensor_ops, viz -@builder.register("MinimaEncoderRegime") +@builder.register("MinimaEncoderRegime", versions="==0.0.0") @attrs.mutable(eq=False) class MinimaEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors model: torch.nn.Module diff --git a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py similarity index 99% rename from zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py rename to zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py index 60cba4806..920305a8f 100644 --- a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector.py +++ b/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py @@ -13,7 +13,7 @@ from zetta_utils import builder, convnet, tensor_ops # pylint: disable=unused-import -@builder.register("MisalignmentDetectorRegime") +@builder.register("MisalignmentDetectorRegime", versions="==0.0.0") @attrs.mutable(eq=False) class MisalignmentDetectorRegime(pl.LightningModule): # pylint: disable=too-many-ancestors detector: torch.nn.Module From 5e1af0242fe25bcbf23f8024c16cf9b61bd27a0a Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Mon, 6 Nov 2023 23:51:45 +0100 Subject: [PATCH 5/9] feat(inference): mixed precision support for base encoder/coarsener --- zetta_utils/alignment/base_coarsener.py | 3 ++- zetta_utils/alignment/base_encoder.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/zetta_utils/alignment/base_coarsener.py b/zetta_utils/alignment/base_coarsener.py index 25654069a..1d114e0f6 100644 --- a/zetta_utils/alignment/base_coarsener.py +++ b/zetta_utils/alignment/base_coarsener.py @@ -56,7 +56,8 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor: y_end = y + self.tile_size + self.tile_pad_in tile = data_in[:, :, x_start:x_end, y_start:y_end] if (tile != 0).sum() > 0.0: - tile_result = model(tile) + with torch.autocast(device_type=device): + tile_result = model(tile) if tile_pad_out > 0: tile_result = tile_result[ :, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out diff --git a/zetta_utils/alignment/base_encoder.py b/zetta_utils/alignment/base_encoder.py index 4d6d76ad3..31231204d 100644 --- a/zetta_utils/alignment/base_encoder.py +++ b/zetta_utils/alignment/base_encoder.py @@ -36,7 +36,8 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unsupported src dtype: {src.dtype}") data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y") - result = model(data_in.to(device)) + with torch.autocast(device_type=device): + result = model(data_in.to(device)) result = einops.rearrange(result, "Z C X Y -> C X Y Z") # Final layer assumed to be tanh From 5feb436a5a38f058d8574551d3c61a2f16cdcada Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Wed, 8 Nov 2023 10:06:03 +0100 Subject: [PATCH 6/9] feat(inference): base coarsener with output channel support --- zetta_utils/alignment/base_coarsener.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/zetta_utils/alignment/base_coarsener.py b/zetta_utils/alignment/base_coarsener.py index 1d114e0f6..4dd495619 100644 --- a/zetta_utils/alignment/base_coarsener.py +++ b/zetta_utils/alignment/base_coarsener.py @@ -17,6 +17,7 @@ class BaseCoarsener: model_path: str abs_val_thr: float = 0.005 ds_factor: int = 1 + output_channels: int = 1 tile_pad_in: int = 128 tile_size: int = 1024 @@ -36,14 +37,15 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unsupported src dtype: {src.dtype}") data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y").to(device) - result = torch.zeros_like( - data_in[ - ..., - : data_in.shape[-2] // self.ds_factor, - : data_in.shape[-1] // self.ds_factor, - ] - ).float() - + result = torch.zeros( + data_in.shape[0], + self.output_channels, + data_in.shape[-2] // self.ds_factor, + data_in.shape[-1] // self.ds_factor, + dtype=torch.float32, + layout=data_in.layout, + device=data_in.device + ) tile_pad_out = self.tile_pad_in // self.ds_factor for x in range(self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size): From 31a864a6ffbb5dc92332110dde598ed5d1a180bf Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Wed, 8 Nov 2023 11:11:36 +0100 Subject: [PATCH 7/9] chore: minor version updates --- .dockerignore | 3 +++ .gitignore | 3 +++ docker/Dockerfile.all.p39 | 6 +++--- pyproject.toml | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.dockerignore b/.dockerignore index 314aaa37b..b9304474d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -137,3 +137,6 @@ venv.bak/ # Pyre type checker .pyre/ + +# editable installs +src/ diff --git a/.gitignore b/.gitignore index 96231eaba..da35471d9 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ dmypy.json # Pyre type checker .pyre/ + +# editable installs +src/ diff --git a/docker/Dockerfile.all.p39 b/docker/Dockerfile.all.p39 index dffa0b187..ef07f66af 100644 --- a/docker/Dockerfile.all.p39 +++ b/docker/Dockerfile.all.p39 @@ -1,10 +1,10 @@ -FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime ENV DEBIAN_FRONTEND="noninteractive" RUN apt-get update \ && apt-get install -y git build-essential wget curl vim ffmpeg libsm6 libxext6 software-properties-common unixodbc-dev \ - && pip install posix-ipc \ + && pip install --no-cache-dir posix-ipc gevent \ && apt-get --purge autoremove -y build-essential \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* \ @@ -22,6 +22,6 @@ RUN apt-get update \ ENV PYTHONPATH /opt/zetta_utils WORKDIR /opt/zetta_utils ADD pyproject.toml /opt/zetta_utils/ -RUN pip install '.[modules]' +RUN pip install --no-cache-dir '.[modules]' COPY . /opt/zetta_utils/ RUN zetta --help diff --git a/pyproject.toml b/pyproject.toml index a497841cb..d87eece91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ keywords = ["neuroscience connectomics EM"] license = { text = "MIT" } urls = { Homepage = "https://github.com/zettaai/zetta_utils" } -requires-python = ">3.8,<3.11" +requires-python = ">3.8,<3.12" dependencies = [ "attrs >= 21.3", "typeguard == 4.1.5", From 46f534a306c0fefe4f26ef9296e23a976b4b503c Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Mon, 20 Nov 2023 17:20:44 +0100 Subject: [PATCH 8/9] specs+regimes(training): misd prep + training specs + regime --- .../preprocess/01_gen_warp_fields.cue | 260 ++++++ .../preprocess/02_encode_aligned.cue | 322 +++++++ .../preprocess/03_optimize_warp_fields.cue | 847 ++++++++++++++++++ .../aced_misd/train/z1z2_enc_misd.cue | 351 ++++++++ .../training/aced_misd_cns/z1z2_enc_misd.cue | 757 ---------------- .../alignment/misalignment_detector_aced.py | 17 +- 6 files changed, 1787 insertions(+), 767 deletions(-) create mode 100644 specs/nico/training/aced_misd/preprocess/01_gen_warp_fields.cue create mode 100644 specs/nico/training/aced_misd/preprocess/02_encode_aligned.cue create mode 100644 specs/nico/training/aced_misd/preprocess/03_optimize_warp_fields.cue create mode 100644 specs/nico/training/aced_misd/train/z1z2_enc_misd.cue delete mode 100644 specs/nico/training/aced_misd_cns/z1z2_enc_misd.cue diff --git a/specs/nico/training/aced_misd/preprocess/01_gen_warp_fields.cue b/specs/nico/training/aced_misd/preprocess/01_gen_warp_fields.cue new file mode 100644 index 000000000..52f114b8a --- /dev/null +++ b/specs/nico/training/aced_misd/preprocess/01_gen_warp_fields.cue @@ -0,0 +1,260 @@ +import "math" +import "list" + +#BASE_PATH: "gs://zetta-research-nico/encoder/" +#TGT_IMG_PATH: #BASE_PATH + "datasets/" +#WARPED_SRC_IMG_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "warped_enc/" +#PERLIN_FIELD_PATH: #BASE_PATH + "misd/misalignment_fields/" + +#DATASETS: { + "microns_pinky": { + "contiguous": true + "bounds": [[0, 262144], [0, 131072], [0, 10240]] + "resolution": [32, 32, 40] + } + "microns_basil": { + "contiguous": true + "bounds": [[0, 819200], [0, 983040], [0, 400]] + "resolution": [32, 32, 40] + }, + "microns_minnie": { + "contiguous": false + "bounds": [[0, 1703936], [0, 1441792], [0, 320]] + "resolution": [32, 32, 40] + }, + "microns_interneuron": { + "contiguous": false + "bounds": [[0, 720896], [0, 720896], [0, 1280]] + "resolution": [32, 32, 40] + }, + "aibs_v1dd": { + "contiguous": false + "bounds": [[0.0, 1231667.2], [0.0, 834355.2], [0.0, 1080.0]] + "resolution": [38.8, 38.8, 45.0] + }, + "kim_n2da": { + "contiguous": true + "bounds": [[0, 32768], [0, 32768], [0, 31050]] + "resolution": [32, 32, 50] + }, + "kim_pfc2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 7320]] + "resolution": [32, 32, 40] + }, + "kronauer_cra9": { + "contiguous": true + "bounds": [[0, 393216], [0, 327680], [0, 588]] + "resolution": [32, 32, 42] + }, + "kubota_001": { + "contiguous": true + "bounds": [[0, 204800], [0, 204800], [0, 12000]] + "resolution": [40, 40, 40] + }, + "lee_fanc": { + "contiguous": false + "bounds": [[0.0, 352256.0], [0.0, 951091.2], [0.0, 2700.0]] + "resolution": [34.4, 34.4, 45.0] + }, + "lee_banc": { + "contiguous": false + "bounds": [[0, 819200], [0, 1015808], [0, 900]] + "resolution": [32, 32, 45] + }, + "lee_ppc": { + "contiguous": true + "bounds": [[0, 98304], [0, 98304], [0, 36400]] + "resolution": [32, 32, 40] + }, + "lee_mosquito": { + "contiguous": false + "bounds": [[0, 704512], [0, 450560], [0, 2240]] + "resolution": [32, 32, 40] + }, + "lichtman_zebrafish": { + "contiguous": false + "bounds": [[0, 294912], [0, 393216], [0, 4560]] + "resolution": [32, 32, 30] + }, + "prieto_godino_larva": { + "contiguous": true + "bounds": [[0, 134976], [0, 144992], [0, 14400]] + "resolution": [32, 32, 32] + }, + "fafb_v15": { + "contiguous": false + "bounds": [[0, 884736], [0, 393216], [0, 2000]] + "resolution": [32, 32, 40] + }, + "lichtman_h01": { + "contiguous": false + "bounds": [[0, 3440640], [0, 1933312], [0, 198]] + "resolution": [32, 32, 33] + }, + "janelia_hemibrain": { + "contiguous": true + "bounds": [[0, 317824], [0, 331168], [0, 3296]] + "resolution": [32, 32, 32] + }, + "janelia_manc": { + "contiguous": false + "bounds": [[0, 262144], [0, 360448], [0, 5952]] + "resolution": [32, 32, 32] + }, + "nguyen_thomas_2022": { + "contiguous": true + "bounds": [[0, 998400], [0, 921600], [0, 400]] + "resolution": [32, 32, 40] + }, + "mulcahy_2022_16h": { + "contiguous": true + "bounds": [[0, 243712], [0, 73728], [0, 14700]] + "resolution": [32, 32, 30] + }, + "wildenberg_2021_vta_dat12a": { + "contiguous": true + "bounds": [[0, 82080], [0, 85184], [0, 7640]] + "resolution": [32, 32, 40] + }, + "bumbarber_2013": { + "contiguous": true + "bounds": [[0.0, 63897.6], [0.0, 63897.6], [0.0, 102400.0]] + "resolution": [31.2, 31.2, 50.0] + }, + "wilson_2019_p3": { + "contiguous": true + "bounds": [[0, 163840], [0, 229376], [0, 7020]] + "resolution": [32, 32, 30] + }, + "ishibashi_2021_em1": { + "contiguous": true + "bounds": [[0, 24576], [0, 16384], [0, 4544]] + "resolution": [32, 32, 32] + }, + "ishibashi_2021_em2": { + "contiguous": true + "bounds": [[0, 26624], [0, 18432], [0, 5376]] + "resolution": [32, 32, 32] + }, + "templier_2019_wafer1": { + "contiguous": true + "bounds": [[0, 294912], [0, 229376], [0, 6500]] + "resolution": [32, 32, 50] + }, + "templier_2019_wafer3": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 9750]] + "resolution": [32, 32, 50] + }, + "lichtman_octopus2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 360448], [0, 3180]] + "resolution": [32, 32, 30] + } +} + + +#DST_INFO_CHUNK_SIZE: [2048, 2048, 1] +#PERLIN_FIELD_DS_FACTOR: math.Pow(2, 3) +#FIELD_INFO_OVERRIDE: { + _dataset_bounds: _ + _dst_resolution: _ + type: "image" + data_type: "float32", + num_channels: 2, + scales: [ + { + let vx_res = _dst_resolution + let ds_offset = [ for j in [0, 1, 2] { + _dataset_bounds[j][0] / _dst_resolution[j] // technically should be floor + }] + let ds_size = [ for j in [0, 1, 2] { + math.Ceil((_dataset_bounds[j][1] - _dataset_bounds[j][0]) / _dst_resolution[j]) + }] + + chunk_sizes: [[ for j in [0, 1, 2] {list.Min([#DST_INFO_CHUNK_SIZE[j], ds_size[j]])}]] + resolution: vx_res + encoding: "zfpc" + zfpc_correlated_dims: [true, true, false, false] + zfpc_tolerance: 0.001953125 + key: "\(vx_res[0])_\(vx_res[1])_\(vx_res[2])" + voxel_offset: ds_offset + size: ds_size + } + ], + +} + + + +#MAX_DISP: 20 +#MEDIAN_DISP: 7.5 +#PERLIN_NOISE_TEMPLATE: { + _bounds: _ + let vx_res = dst_resolution + let x_mult = math.Ceil(((_bounds[0][1] - _bounds[0][0]) / vx_res[0]) / 2048) + let y_mult = math.Ceil(((_bounds[1][1] - _bounds[1][0]) / vx_res[1]) / 2048) + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "VolumetricCallableOperation" + fn: { + "@type": "gen_biased_perlin_noise_field" + "@mode": "partial" + shape: [2, x_mult * 2048, y_mult * 2048, 1] + res: [ x_mult * 2, y_mult * 2 ] + max_displacement_px: #MAX_DISP / #PERLIN_FIELD_DS_FACTOR + field_magn_thr_px: #MEDIAN_DISP / #PERLIN_FIELD_DS_FACTOR + octaves: 8 + device: "cpu" + } + crop_pad: [0, 0, 0] + } + dst_resolution: _ + skip_intermediaries: true + processing_chunk_sizes: [[x_mult * 2048, y_mult * 2048, 1]] + processing_crop_pads: [[0, 0, 0]] + expand_bbox_resolution: true + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + dst: { + "@type": "build_cv_layer" + path: _ + info_field_overrides: #FIELD_INFO_OVERRIDE & { + _dataset_bounds: _bounds + _dst_resolution: dst_resolution + } + } +} + + +"@type": "mazepa.execute_on_gcp_with_sqs" +worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" +worker_resources: { + memory: "10560Mi" +} +worker_replicas: 100 +batch_gap_sleep_sec: 0.1 +do_dryrun_estimation: true +local_test: false +worker_cluster_project: "zetta-research" +worker_cluster_region: "us-east1" +worker_cluster_name: "zutils-x3" +target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + #PERLIN_NOISE_TEMPLATE & { + _bounds: dataset.bounds, + dst: path: #PERLIN_FIELD_PATH + key + "/raw_perlin" + + let ds_factor = [#PERLIN_FIELD_DS_FACTOR, #PERLIN_FIELD_DS_FACTOR, 1] + let res = [ for j in [0, 1, 2] {dataset.resolution[j] * ds_factor[j]} ] + dst_resolution: res + } + } + ] +} \ No newline at end of file diff --git a/specs/nico/training/aced_misd/preprocess/02_encode_aligned.cue b/specs/nico/training/aced_misd/preprocess/02_encode_aligned.cue new file mode 100644 index 000000000..9b7fe08f8 --- /dev/null +++ b/specs/nico/training/aced_misd/preprocess/02_encode_aligned.cue @@ -0,0 +1,322 @@ +import "math" +import "list" + +#BASE_PATH: "gs://zetta-research-nico/encoder/" +#TGT_IMG_PATH: #BASE_PATH + "datasets/" // + k +#WARPED_SRC_IMG_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/warped_img" +#DST_TGT_ENC_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/tgt_enc_2023" +#DST_WARPED_SRC_ENC_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/warped_enc_2023" + +#DATASETS: { + "microns_pinky": { + "contiguous": true + "bounds": [[0, 262144], [0, 131072], [0, 10240]] + "resolution": [32, 32, 40] + } + "microns_basil": { + "contiguous": true + "bounds": [[0, 819200], [0, 983040], [0, 400]] + "resolution": [32, 32, 40] + }, + "microns_minnie": { + "contiguous": false + "bounds": [[0, 1703936], [0, 1441792], [0, 320]] + "resolution": [32, 32, 40] + }, + "microns_interneuron": { + "contiguous": false + "bounds": [[0, 720896], [0, 720896], [0, 1280]] + "resolution": [32, 32, 40] + }, + "aibs_v1dd": { + "contiguous": false + "bounds": [[0.0, 1231667.2], [0.0, 834355.2], [0.0, 1080.0]] + "resolution": [38.8, 38.8, 45.0] + }, + "kim_n2da": { + "contiguous": true + "bounds": [[0, 32768], [0, 32768], [0, 31050]] + "resolution": [32, 32, 50] + }, + "kim_pfc2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 7320]] + "resolution": [32, 32, 40] + }, + "kronauer_cra9": { + "contiguous": true + "bounds": [[0, 393216], [0, 327680], [0, 588]] + "resolution": [32, 32, 42] + }, + "kubota_001": { + "contiguous": true + "bounds": [[0, 204800], [0, 204800], [0, 12000]] + "resolution": [40, 40, 40] + }, + "lee_fanc": { + "contiguous": false + "bounds": [[0.0, 352256.0], [0.0, 951091.2], [0.0, 2700.0]] + "resolution": [34.4, 34.4, 45.0] + }, + "lee_banc": { + "contiguous": false + "bounds": [[0, 819200], [0, 1015808], [0, 900]] + "resolution": [32, 32, 45] + }, + "lee_ppc": { + "contiguous": true + "bounds": [[0, 98304], [0, 98304], [0, 36400]] + "resolution": [32, 32, 40] + }, + "lee_mosquito": { + "contiguous": false + "bounds": [[0, 704512], [0, 450560], [0, 2240]] + "resolution": [32, 32, 40] + }, + "lichtman_zebrafish": { + "contiguous": false + "bounds": [[0, 294912], [0, 393216], [0, 4560]] + "resolution": [32, 32, 30] + }, + "prieto_godino_larva": { + "contiguous": true + "bounds": [[0, 134976], [0, 144992], [0, 14400]] + "resolution": [32, 32, 32] + }, + "fafb_v15": { + "contiguous": false + "bounds": [[0, 884736], [0, 393216], [0, 2000]] + "resolution": [32, 32, 40] + }, + "lichtman_h01": { + "contiguous": false + "bounds": [[0, 3440640], [0, 1933312], [0, 198]] + "resolution": [32, 32, 33] + }, + "janelia_hemibrain": { + "contiguous": true + "bounds": [[0, 317824], [0, 331168], [0, 3296]] + "resolution": [32, 32, 32] + }, + "janelia_manc": { + "contiguous": false + "bounds": [[0, 262144], [0, 360448], [0, 5952]] + "resolution": [32, 32, 32] + }, + "nguyen_thomas_2022": { + "contiguous": true + "bounds": [[0, 998400], [0, 921600], [0, 400]] + "resolution": [32, 32, 40] + }, + "mulcahy_2022_16h": { + "contiguous": true + "bounds": [[0, 243712], [0, 73728], [0, 14700]] + "resolution": [32, 32, 30] + }, + "wildenberg_2021_vta_dat12a": { + "contiguous": true + "bounds": [[0, 82080], [0, 85184], [0, 7640]] + "resolution": [32, 32, 40] + }, + "bumbarber_2013": { + "contiguous": true + "bounds": [[0.0, 63897.6], [0.0, 63897.6], [0.0, 102400.0]] + "resolution": [31.2, 31.2, 50.0] + }, + "wilson_2019_p3": { + "contiguous": true + "bounds": [[0, 163840], [0, 229376], [0, 7020]] + "resolution": [32, 32, 30] + }, + "ishibashi_2021_em1": { + "contiguous": true + "bounds": [[0, 24576], [0, 16384], [0, 4544]] + "resolution": [32, 32, 32] + }, + "ishibashi_2021_em2": { + "contiguous": true + "bounds": [[0, 26624], [0, 18432], [0, 5376]] + "resolution": [32, 32, 32] + }, + "templier_2019_wafer1": { + "contiguous": true + "bounds": [[0, 294912], [0, 229376], [0, 6500]] + "resolution": [32, 32, 50] + }, + "templier_2019_wafer3": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 9750]] + "resolution": [32, 32, 50] + }, + "lichtman_octopus2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 360448], [0, 3180]] + "resolution": [32, 32, 30] + } +} + + +#DST_INFO_CHUNK_SIZE: [1024, 1024, 1] +#MAX_TASK_SIZE: [8192, 8192, 1] + +#ENC_INFO_OVERRIDE: { + _dataset_bounds: _ + _highest_resolution: _ + type: "image" + data_type: "int8" + num_channels: 1 + scales: [ + for i in list.Range(0, 4, 1) { + let res_factor = [math.Pow(2, i), math.Pow(2, i), 1] + let vx_res = [ for j in [0, 1, 2] {_highest_resolution[j] * res_factor[j]}] + let ds_offset = [ for j in [0, 1, 2] { + _dataset_bounds[j][0] / vx_res[j] // technically should be floor, but it's + }] + let ds_size = [ for j in [0, 1, 2] { + math.Ceil((_dataset_bounds[j][1] - _dataset_bounds[j][0]) / vx_res[j]) + }] + + chunk_sizes: [[ for j in [0, 1, 2] {list.Min([#DST_INFO_CHUNK_SIZE[j], ds_size[j]])}]] + resolution: vx_res + encoding: "raw" + key: "\(vx_res[0])_\(vx_res[1])_\(vx_res[2])" + voxel_offset: ds_offset + size: ds_size + }, + ] +} + +#MODELS: [ + { + path: "gs://alignment_models/general_encoders_2023/32_32_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [1, 1, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_64_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [2, 2, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_128_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [4, 4, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_256_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [8, 8, 1] + } +] + +#ENCODE_TEMPLATE: { + _bounds: _ + _high_resolution: [number, number, number] + _layer_name: _ + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_bounds[0][1] - _bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_bounds[1][1] - _bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "VolumetricCallableOperation" + operation_name: _layer_name + fn: { + "@type": "BaseEncoder" + model_path: string + } | { + "@type": "BaseCoarsener" + model_path: string + tile_pad_in: int + tile_size: int + ds_factor: int + output_channels: 1 + } + crop_pad: [16, 16, 0] + res_change_mult: [int, int, int] + } + dst_resolution: [number, number, number] + processing_chunk_sizes: [max_chunk_size, [1024, 1024, 1]] + processing_crop_pads: [[0, 0, 0], [16,16,0]] + expand_bbox_resolution: true + skip_intermediaries: true + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + op_kwargs: { + src: { + "@type": "build_cv_layer" + path: _ + } + } + dst: { + "@type": "build_cv_layer" + path: _ + info_field_overrides: #ENC_INFO_OVERRIDE & { + _dataset_bounds: _bounds + _highest_resolution: _high_resolution + } + on_info_exists: "overwrite" + } +} + + + +"@type": "mazepa.execute_on_gcp_with_sqs" +worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" +worker_resources: { + "nvidia.com/gpu": 1 +} +worker_replicas: 200 +batch_gap_sleep_sec: 0.1 +do_dryrun_estimation: true +local_test: false +worker_cluster_project: "zetta-research" +worker_cluster_region: "us-east1" +worker_cluster_name: "zutils-x3" +target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for img_source in ["tgt", "warped_src"] { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + for i in list.Range(0, 4, 1) { + #ENCODE_TEMPLATE & { + _bounds: dataset.bounds, + _high_resolution: dataset.resolution + _layer_name: "Model \(i)" + let _ds_factor = #MODELS[i].res_change_mult + let res = [ for j in [0, 1, 2] {dataset.resolution[j] * _ds_factor[j]} ] + if i == 0 { + op: fn: "@type": "BaseEncoder" + } + if i > 0 { + op: fn: { + "@type": "BaseCoarsener" + tile_pad_in: #ENCODE_TEMPLATE.op.crop_pad[0] * _ds_factor[0] + tile_size: 1024 + ds_factor: _ds_factor[0] + } + } + op: fn: model_path: #MODELS[i].path + op: res_change_mult: _ds_factor + dst_resolution: res + if img_source == "tgt" { + op_kwargs: src: path: #TGT_IMG_PATH + key + dst: path: #DST_TGT_ENC_PATH + key + "/tgt_enc_2023" + } + if img_source == "warped_src" { + op_kwargs: src: path: #WARPED_SRC_IMG_PATH + key + "/warped_img" + dst: path: #DST_WARPED_SRC_ENC_PATH + key + "/warped_enc_2023" + } + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/specs/nico/training/aced_misd/preprocess/03_optimize_warp_fields.cue b/specs/nico/training/aced_misd/preprocess/03_optimize_warp_fields.cue new file mode 100644 index 000000000..b2f2bd060 --- /dev/null +++ b/specs/nico/training/aced_misd/preprocess/03_optimize_warp_fields.cue @@ -0,0 +1,847 @@ +import "math" +import "list" + +#BASE_PATH: "gs://zetta-research-nico/encoder/" +// #TGT_IMG_PATH: #BASE_PATH + "datasets/" // + k +#ORIGINAL_WARPED_SRC_IMG_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/warped_img" +#TGT_ENC_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/tgt_enc_2023" +#WARPED_SRC_ENC_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/warped_enc_2023" +#PERLIN_FIELD_PATH: #BASE_PATH + "misd/misalignment_fields/" // + k + "/raw_perlin" +#DST_FIELD_PATH: #BASE_PATH + "misd/misalignment_fields/" // + k + "/optimized_perlin" | "/no_perlin" + "/z\(_z_offset)" + +#DST_WARPED_SRC_IMG_PATH: #BASE_PATH + "misd/img/" // + k + "/good_alignment" | "/bad_alignment" + "/z\(_z_offset)" +#DST_WARPED_SRC_ENC_PATH: #BASE_PATH + "misd/enc/" // + k + "/good_alignment" | "/bad_alignment" + "/z\(_z_offset)" + + +#DATASETS: { + "microns_pinky": { + "contiguous": true + "bounds": [[0, 262144], [0, 131072], [0, 10240]] + "resolution": [32, 32, 40] + } + "microns_basil": { + "contiguous": true + "bounds": [[0, 819200], [0, 983040], [0, 400]] + "resolution": [32, 32, 40] + }, + // // "microns_minnie": { + // // "contiguous": false + // // "bounds": [[0, 1703936], [0, 1441792], [0, 320]] + // // "resolution": [32, 32, 40] + // // }, + // // "microns_interneuron": { + // // "contiguous": false + // // "bounds": [[0, 720896], [0, 720896], [0, 1280]] + // // "resolution": [32, 32, 40] + // // }, + // // "aibs_v1dd": { + // // "contiguous": false + // // "bounds": [[0.0, 1231667.2], [0.0, 834355.2], [0.0, 1080.0]] + // // "resolution": [38.8, 38.8, 45.0] + // // }, + "kim_n2da": { + "contiguous": true + "bounds": [[0, 32768], [0, 32768], [0, 31050]] + "resolution": [32, 32, 50] + }, + "kim_pfc2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 7320]] + "resolution": [32, 32, 40] + }, + "kronauer_cra9": { + "contiguous": true + "bounds": [[0, 393216], [0, 327680], [0, 588]] + "resolution": [32, 32, 42] + }, + "kubota_001": { + "contiguous": true + "bounds": [[0, 204800], [0, 204800], [0, 12000]] + "resolution": [40, 40, 40] + }, + // // "lee_fanc": { + // // "contiguous": false + // // "bounds": [[0.0, 352256.0], [0.0, 951091.2], [0.0, 2700.0]] + // // "resolution": [34.4, 34.4, 45.0] + // // }, + // // "lee_banc": { + // // "contiguous": false + // // "bounds": [[0, 819200], [0, 1015808], [0, 900]] + // // "resolution": [32, 32, 45] + // // }, + "lee_ppc": { + "contiguous": true + "bounds": [[0, 98304], [0, 98304], [0, 36400]] + "resolution": [32, 32, 40] + }, + // // "lee_mosquito": { + // // "contiguous": false + // // "bounds": [[0, 704512], [0, 450560], [0, 2240]] + // // "resolution": [32, 32, 40] + // // }, + // // "lichtman_zebrafish": { + // // "contiguous": false + // // "bounds": [[0, 294912], [0, 393216], [0, 4560]] + // // "resolution": [32, 32, 30] + // // }, + "prieto_godino_larva": { + "contiguous": true + "bounds": [[0, 134976], [0, 144992], [0, 14400]] + "resolution": [32, 32, 32] + }, + // // "fafb_v15": { + // // "contiguous": false + // // "bounds": [[0, 884736], [0, 393216], [0, 2000]] + // // "resolution": [32, 32, 40] + // // }, + // // "lichtman_h01": { + // // "contiguous": false + // // "bounds": [[0, 3440640], [0, 1933312], [0, 198]] + // // "resolution": [32, 32, 33] + // // }, + "janelia_hemibrain": { + "contiguous": true + "bounds": [[0, 317824], [0, 331168], [0, 3296]] + "resolution": [32, 32, 32] + }, + // // "janelia_manc": { + // // "contiguous": false + // // "bounds": [[0, 262144], [0, 360448], [0, 5952]] + // // "resolution": [32, 32, 32] + // // }, + "nguyen_thomas_2022": { + "contiguous": true + "bounds": [[0, 998400], [0, 921600], [0, 400]] + "resolution": [32, 32, 40] + }, + "mulcahy_2022_16h": { + "contiguous": true + "bounds": [[0, 243712], [0, 73728], [0, 14700]] + "resolution": [32, 32, 30] + }, + "wildenberg_2021_vta_dat12a": { + "contiguous": true + "bounds": [[0, 82080], [0, 85184], [0, 7640]] + "resolution": [32, 32, 40] + }, + "bumbarber_2013": { + "contiguous": true + "bounds": [[0.0, 63897.6], [0.0, 63897.6], [0.0, 102400.0]] + "resolution": [31.2, 31.2, 50.0] + }, + "wilson_2019_p3": { + "contiguous": true + "bounds": [[0, 163840], [0, 229376], [0, 7020]] + "resolution": [32, 32, 30] + }, + "ishibashi_2021_em1": { + "contiguous": true + "bounds": [[0, 24576], [0, 16384], [0, 4544]] + "resolution": [32, 32, 32] + }, + "ishibashi_2021_em2": { + "contiguous": true + "bounds": [[0, 26624], [0, 18432], [0, 5376]] + "resolution": [32, 32, 32] + }, + "templier_2019_wafer1": { + "contiguous": true + "bounds": [[0, 294912], [0, 229376], [0, 6500]] + "resolution": [32, 32, 50] + }, + "templier_2019_wafer3": { + "contiguous": true + "bounds": [[0, 229376], [0, 196608], [0, 9750]] + "resolution": [32, 32, 50] + }, + "lichtman_octopus2022": { + "contiguous": true + "bounds": [[0, 229376], [0, 360448], [0, 3180]] + "resolution": [32, 32, 30] + } +} + +#MODELS: [ + { + path: "gs://alignment_models/general_encoders_2023/32_32_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [1, 1, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_64_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [2, 2, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_128_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [4, 4, 1] + }, + { + path: "gs://alignment_models/general_encoders_2023/32_256_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [8, 8, 1] + } +] + + +#DST_INFO_CHUNK_SIZE: [2048, 2048, 1] +#MAX_TASK_SIZE: [8192, 8192, 1] +#PERLIN_FIELD_DS_FACTOR: math.Pow(2, 3) + +#STAGE_TMPL: { + _stage_bounds: _ + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_stage_bounds[0][1] - _stage_bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_stage_bounds[1][1] - _stage_bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + "@type": "ComputeFieldStage" + dst_resolution: _ + processing_chunk_sizes: [max_chunk_size, [2048, 2048, 1]] + processing_crop_pads: [[0, 0, 0], [64, 64, 0]] + expand_bbox_processing: true + expand_bbox_resolution: true + fn: { + "@type": "align_with_online_finetuner" + "@mode": "partial" + sm: int + num_iter: int + lr: float + } +} + + +#FIELD_INFO_OVERRIDE: { + _dataset_bounds: _ + _highest_resolution: _ + type: "image" + data_type: "float32", + num_channels: 2, + scales: [ + for i in list.Range(0, 3, 1) { + let res_factor = [math.Pow(2, i), math.Pow(2, i), 1] + let vx_res = [ for j in [0, 1, 2] {_highest_resolution[j] * res_factor[j]}] + let ds_offset = [ for j in [0, 1, 2] { + _dataset_bounds[j][0] / vx_res[j] // technically should be floor, but it's 0 anyway + }] + let ds_size = [ for j in [0, 1, 2] { + math.Ceil((_dataset_bounds[j][1] - _dataset_bounds[j][0]) / vx_res[j]) + }] + + chunk_sizes: [[ for j in [0, 1, 2] {list.Min([#DST_INFO_CHUNK_SIZE[j], ds_size[j]])}]] + resolution: vx_res + encoding: "zfpc" + zfpc_correlated_dims: [true, true, false, false] + zfpc_tolerance: 0.001953125 + key: "\(vx_res[0])_\(vx_res[1])_\(vx_res[2])" + voxel_offset: ds_offset + size: ds_size + } + ] +} + +#COMPUTE_FIELD_TEMPLATE: { + _bounds: _ + _dst_resolution: [number, number, number] + _layer_name: _ + _z_offset: int + _use_perlin_field: *false | true + + "@type": "build_compute_field_multistage_flow" + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + stages: [ + #STAGE_TMPL & { + _stage_bounds: _bounds + dst_resolution: [_dst_resolution[0] * 4, _dst_resolution[1] * 4, _dst_resolution[2]] + fn: { + sm: 25 + num_iter: 500 + lr: 0.05 + } + }, + #STAGE_TMPL & { + _stage_bounds: _bounds + dst_resolution: [_dst_resolution[0] * 2, _dst_resolution[1] * 2, _dst_resolution[2]] + fn: { + sm: 25 + num_iter: 300 + lr: 0.1 + } + }, + #STAGE_TMPL & { + _stage_bounds: _bounds + dst_resolution: _dst_resolution + fn: { + sm: 25 + num_iter: 200 + lr: 0.1 + } + }, + ] + + if _z_offset == 2 { + src_offset: [0, 0, 1] // src is already offset by 1 + offset_resolution: _dst_resolution + } + src: { + "@type": "build_cv_layer" + path: #WARPED_SRC_ENC_PATH + _layer_name + "/warped_enc_2023" + } + tgt: { + "@type": "build_cv_layer" + path: #TGT_ENC_PATH + _layer_name + "/tgt_enc_2023" + } + dst: { + "@type": "build_cv_layer" + if _use_perlin_field == true { + path: #PERLIN_FIELD_PATH + _layer_name + "/optimized_perlin/z\(_z_offset)" + } + if _use_perlin_field == false { + path: #PERLIN_FIELD_PATH + _layer_name + "/no_perlin/z\(_z_offset)" + } + info_field_overrides: #FIELD_INFO_OVERRIDE & { + _dataset_bounds: _bounds + _highest_resolution: _dst_resolution + } + on_info_exists: "overwrite" + } + tmp_layer_dir: dst.path + "/tmp" + tmp_layer_factory: { + "@type": "build_cv_layer" + "@mode": "partial" + info_field_overrides: #FIELD_INFO_OVERRIDE & { + _dataset_bounds: _bounds + _highest_resolution: _dst_resolution + } + on_info_exists: "overwrite" + } + if _use_perlin_field { + src_field: { + let ds_factor = [#PERLIN_FIELD_DS_FACTOR, #PERLIN_FIELD_DS_FACTOR, 1] + "@type": "build_cv_layer" + path: #PERLIN_FIELD_PATH + _layer_name + "/raw_perlin" + data_resolution: [ for j in [0, 1, 2] {_dst_resolution[j] * ds_factor[j]} ] + interpolation_mode: "field" + } + } +} + + +#WARP_IMG_TEMPLATE: { + _bounds: _ + _layer_name: _ + _z_offset: int + _use_perlin_field: *false | true + + _src_field_path: _ + _dst_img_path: _ + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_bounds[0][1] - _bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_bounds[1][1] - _bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + if _use_perlin_field == true { + _src_field_path: #DST_FIELD_PATH + _layer_name + "/optimized_perlin/z\(_z_offset)" + _dst_img_path: #DST_WARPED_SRC_IMG_PATH + _layer_name + "/bad_alignment/z\(_z_offset)" + } + if _use_perlin_field == false { + _src_field_path: #DST_FIELD_PATH + _layer_name + "/no_perlin/z\(_z_offset)" + _dst_img_path: #DST_WARPED_SRC_IMG_PATH + _layer_name + "/good_alignment/z\(_z_offset)" + } + + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "WarpOperation" + mode: "img" + crop_pad: [256, 256, 0] + } + dst_resolution: _ + processing_chunk_sizes: [max_chunk_size, [2048, 2048, 1]] + processing_crop_pads: [[0, 0, 0], [256, 256, 0]] + skip_intermediaries: true + expand_bbox_processing: true + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + op_kwargs: { + src: { + "@type": "build_cv_layer" + path: #ORIGINAL_WARPED_SRC_IMG_PATH + _layer_name + "/warped_img" + index_procs: [{ + "@type": "VolumetricIndexTranslator" + offset: [0, 0, _z_offset - 1] // src is already offset by 1 + resolution: dst_resolution + }] + } + field: { + "@type": "build_cv_layer" + path: _src_field_path + } + } + dst: { + "@type": "build_cv_layer" + path: _dst_img_path + info_reference_path: op_kwargs.src.path + } +} + +#DOWNSAMPLE_FIELD_TEMPLATE: { + _bounds: _ + _layer_name: _ + _z_offset: int + _use_perlin_field: *false | true + + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_bounds[0][1] - _bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_bounds[1][1] - _bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + + "@type": "build_interpolate_flow" + mode: "field" + src_resolution: [number, number, number] + dst_resolution: [src_resolution[0] * 2, src_resolution[1] * 2, src_resolution[2]] + chunk_size: max_chunk_size + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + + _path: _ + if _use_perlin_field == true { + _path: #DST_FIELD_PATH + _layer_name + "/optimized_perlin/z\(_z_offset)" + } + if _use_perlin_field == false { + _path: #DST_FIELD_PATH + _layer_name + "/no_perlin/z\(_z_offset)" + } + src: { + "@type": "build_cv_layer" + path: _path + } + dst: { + "@type": "build_cv_layer" + path: _path + } + +} + +#FIELD_DIFF_TEMPLATE: { + _bounds: _ + _layer_name: _ + _z_offset: int + + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_bounds[0][1] - _bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_bounds[1][1] - _bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + + "@type": "build_subchunkable_apply_flow" + fn: { + "@type": "torch.sub", "@mode": "partial" + } + processing_chunk_sizes: [max_chunk_size] + processing_crop_pads: [[0, 0, 0]] + dst_resolution: _ + expand_bbox_resolution: true + skip_intermediaries: true + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + op_kwargs: { + input: { + "@type": "build_cv_layer" + path: #DST_FIELD_PATH + _layer_name + "/optimized_perlin/z\(_z_offset)" + } + other: { + "@type": "build_cv_layer" + path: #DST_FIELD_PATH + _layer_name + "/no_perlin/z\(_z_offset)" + } + } + dst: { + "@type": "build_cv_layer" + path: #DST_FIELD_PATH + _layer_name + "/displacements/z\(_z_offset)" + info_reference_path: #TGT_ENC_PATH + _layer_name + "/tgt_enc_2023" + info_field_overrides: { + data_type: "uint8" + } + on_info_exists: "overwrite" + write_procs: [ + { + "@type": "lambda" + lambda_str: "lambda data: (data.norm(dim=0, keepdim=True)*10.0).round().clamp(0, 255).byte()" + } + ] + } +} + + +#ENCODE_IMG_TEMPLATE: { + _bounds: _ + _high_resolution: [number, number, number] + _layer_name: _ + _z_offset: int + _use_perlin_field: *false | true + _model: { + path: _ + res_change_mult: [int, int, int] + } + + let max_chunk_size = [ + list.Min([#MAX_TASK_SIZE[0], math.Ceil((_bounds[0][1] - _bounds[0][0]) / #DST_INFO_CHUNK_SIZE[0] / dst_resolution[0]) * #DST_INFO_CHUNK_SIZE[0]]), + list.Min([#MAX_TASK_SIZE[1], math.Ceil((_bounds[1][1] - _bounds[1][0]) / #DST_INFO_CHUNK_SIZE[1] / dst_resolution[1]) * #DST_INFO_CHUNK_SIZE[1]]), + 1 + ] + + _src_img_path: _ + _dst_enc_path: _ + if _use_perlin_field == true { + _src_img_path: #DST_WARPED_SRC_IMG_PATH + _layer_name + "/bad_alignment/z\(_z_offset)" + _dst_enc_path: #DST_WARPED_SRC_ENC_PATH + _layer_name + "/bad_alignment/z\(_z_offset)" + } + if _use_perlin_field == false { + _src_img_path: #DST_WARPED_SRC_IMG_PATH + _layer_name + "/good_alignment/z\(_z_offset)" + _dst_enc_path: #DST_WARPED_SRC_ENC_PATH + _layer_name + "/good_alignment/z\(_z_offset)" + } + + "@type": "build_subchunkable_apply_flow" + op: { + "@type": "VolumetricCallableOperation" + operation_name: _layer_name + fn: { + if _model.res_change_mult[0] == 1 { + "@type": "BaseEncoder" + } + if _model.res_change_mult[1] > 1 { + "@type": "BaseCoarsener" + tile_pad_in: op.crop_pad[0] + tile_size: 1024 + ds_factor: _model.res_change_mult[0] + } + model_path: _model.path + } + crop_pad: [16, 16, 0] + res_change_mult: _model.res_change_mult + } + dst_resolution: [ for j in [0, 1, 2] {_high_resolution[j] * _model.res_change_mult[j]} ] + processing_chunk_sizes: [max_chunk_size, [1024, 1024, 1]] + processing_crop_pads: [[0, 0, 0], [16,16,0]] + expand_bbox_resolution: true + skip_intermediaries: true + bbox: { + "@type": "BBox3D.from_coords", + start_coord: [_bounds[0][0], _bounds[1][0], _bounds[2][0]] + end_coord: [_bounds[0][1], _bounds[1][1], _bounds[2][1]] + } + op_kwargs: { + src: { + "@type": "build_cv_layer" + path: _src_img_path + } + } + dst: { + "@type": "build_cv_layer" + path: _dst_enc_path + info_reference_path: #TGT_ENC_PATH + _layer_name + "/tgt_enc_2023" + } +} + + +#COMPUTE_FIELD_STAGE: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" + worker_resources: { + "nvidia.com/gpu": "1" + } + worker_replicas: 300 + batch_gap_sleep_sec: 0.1 + do_dryrun_estimation: true + local_test: false + worker_cluster_project: "zetta-research" + worker_cluster_region: "us-east1" + worker_cluster_name: "zutils-x3" + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + #COMPUTE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + _dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + }, + #COMPUTE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + _dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + _use_perlin_field: true + }, + if dataset.contiguous { + #COMPUTE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + _dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + } + }, + if dataset.contiguous { + #COMPUTE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + _dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + _use_perlin_field: true + }, + } + ] + } + ] + } +} + + +#WARP_IMAGE_STAGE: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" + worker_resources: { + "memory": "8Gi" + } + worker_replicas: 100 + batch_gap_sleep_sec: 0.1 + do_dryrun_estimation: true + local_test: false + worker_cluster_project: "zetta-research" + worker_cluster_region: "us-east1" + worker_cluster_name: "zutils-x3" + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + #WARP_IMG_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + }, + #WARP_IMG_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + _use_perlin_field: true + }, + if dataset.contiguous { + #WARP_IMG_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + } + } + if dataset.contiguous { + #WARP_IMG_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + _use_perlin_field: true + }, + } + ] + } + ] + } +} + +#DOWNSAMPLE_FIELD_STAGE: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" + worker_resources: { + "memory": "8Gi" + } + worker_replicas: 100 + batch_gap_sleep_sec: 0.1 + do_dryrun_estimation: true + local_test: false + worker_cluster_project: "zetta-research" + worker_cluster_region: "us-east1" + worker_cluster_name: "zutils-x3" + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + #DOWNSAMPLE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + src_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + }, + #DOWNSAMPLE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + src_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + _use_perlin_field: true + }, + if dataset.contiguous { + #DOWNSAMPLE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + src_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + } + } + if dataset.contiguous { + #DOWNSAMPLE_FIELD_TEMPLATE & { + _bounds: dataset.bounds, + src_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + _use_perlin_field: true + } + } + ] + } + ] + } +} + +#EXTRACT_DISPLACEMENT_STAGE: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" + worker_resources: { + "memory": "8Gi" + } + worker_replicas: 100 + batch_gap_sleep_sec: 0.1 + do_dryrun_estimation: true + local_test: false + worker_cluster_project: "zetta-research" + worker_cluster_region: "us-east1" + worker_cluster_name: "zutils-x3" + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + #FIELD_DIFF_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + }, + #FIELD_DIFF_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: [dataset.resolution[0] * 2, dataset.resolution[1] * 2, dataset.resolution[2]] + _layer_name: key, + _z_offset: 1 + }, + if dataset.contiguous { + #FIELD_DIFF_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + } + } + if dataset.contiguous { + #FIELD_DIFF_TEMPLATE & { + _bounds: dataset.bounds, + dst_resolution: [dataset.resolution[0] * 2, dataset.resolution[1] * 2, dataset.resolution[2]] + _layer_name: key, + _z_offset: 2 + } + } + ] + } + ] + } +} + + +#ENCODE_IMAGE_STAGE: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231118" + worker_resources: { + "nvidia.com/gpu": "1" + } + worker_replicas: 300 + batch_gap_sleep_sec: 0.1 + do_dryrun_estimation: true + local_test: false + worker_cluster_project: "zetta-research" + worker_cluster_region: "us-east1" + worker_cluster_name: "zutils-x3" + target: { + "@type": "mazepa.concurrent_flow" + stages: [ + for key, dataset in #DATASETS { + "@type": "mazepa.concurrent_flow" + stages: [ + for i in list.Range(0, 2, 1) { + "@type": "mazepa.concurrent_flow" + stages: [ + #ENCODE_IMG_TEMPLATE & { + _bounds: dataset.bounds, + _high_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + _model: #MODELS[i] + }, + #ENCODE_IMG_TEMPLATE & { + _bounds: dataset.bounds, + _high_resolution: dataset.resolution + _layer_name: key, + _z_offset: 1 + _use_perlin_field: true + _model: #MODELS[i] + }, + if dataset.contiguous { + #ENCODE_IMG_TEMPLATE & { + _bounds: dataset.bounds, + _high_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + _model: #MODELS[i] + } + } + if dataset.contiguous { + #ENCODE_IMG_TEMPLATE & { + _bounds: dataset.bounds, + _high_resolution: dataset.resolution + _layer_name: key, + _z_offset: 2 + _use_perlin_field: true + _model: #MODELS[i] + }, + } + ] + } + ] + } + ] + } +} + + +[ + #COMPUTE_FIELD_STAGE, + #WARP_IMAGE_STAGE, + #DOWNSAMPLE_FIELD_STAGE, + #EXTRACT_DISPLACEMENT_STAGE, + #ENCODE_IMAGE_STAGE, +] \ No newline at end of file diff --git a/specs/nico/training/aced_misd/train/z1z2_enc_misd.cue b/specs/nico/training/aced_misd/train/z1z2_enc_misd.cue new file mode 100644 index 000000000..4a9a9f17c --- /dev/null +++ b/specs/nico/training/aced_misd/train/z1z2_enc_misd.cue @@ -0,0 +1,351 @@ +import "strings" +import "strconv" +import "list" + +#EXP_NAME: "aced_misd_general" +#TRAINING_ROOT: "gs://zetta-research-nico/training_artifacts" +#LR: 2e-4 +#K: 3 +#CHUNK_XY: 1024 +#FM: 32 + +#FIELD_MAGN_THR: 5.0 +#Z_OFFSETS: [1, 2] +#DS_FACTOR: 1 + + +#EXP_VERSION: "1.0.0_dsfactor\(#DS_FACTOR)_thr\(#FIELD_MAGN_THR)_lr\(#LR)_z" + strings.Join([for z in #Z_OFFSETS {strconv.FormatInt(z, 10)}], "_") +#MODEL_CKPT: null // "gs://zetta-research-nico/training_artifacts/aced_misd_cns/thr5.0_lr0.00005_z1z2_400-500_2910-2920_more_aligned_unet5_32/last.ckpt" + +#BASE_PATH: "gs://zetta-research-nico/encoder/" +#SRC_ENC_PATH: #BASE_PATH + "misd/enc/" // + k + ["/good_alignment"|"/bad_alignment"] + "/z\(_z_offset)" +#TGT_ENC_PATH: #BASE_PATH + "pairwise_aligned/" // + k + "/tgt_enc_2023" +#DISP_PATH: #BASE_PATH + "misd/misalignment_fields/" // + k + "/displacements/z\(_z_offset)" + +#VAL_DATASETS: { + "microns_basil": {"resolution": [32, 32, 40], "num_samples": 2591}, +} + +#TRAIN_DATASETS: { + "microns_pinky": {"resolution": [32, 32, 40], "num_samples": 5019}, + // "microns_basil": {"resolution": [32, 32, 40], "num_samples": 2591}, + "kim_n2da": {"resolution": [32, 32, 50], "num_samples": 446}, + "kim_pfc2022": {"resolution": [32, 32, 40], "num_samples": 3699}, + "kronauer_cra9": {"resolution": [32, 32, 42], "num_samples": 740}, + "kubota_001": {"resolution": [40, 40, 40], "num_samples": 4744}, + "lee_ppc": {"resolution": [32, 32, 40], "num_samples": 7219}, + "prieto_godino_larva": {"resolution": [32, 32, 32], "num_samples": 4584}, + "janelia_hemibrain": {"resolution": [32, 32, 32], "num_samples": 5304}, + "nguyen_thomas_2022": {"resolution": [32, 32, 40], "num_samples": 1847}, + "mulcahy_2022_16h": {"resolution": [32, 32, 30], "num_samples": 3379}, + "wildenberg_2021_vta_dat12a": {"resolution": [32, 32, 40], "num_samples": 1704}, + "bumbarber_2013": {"resolution": [31.2, 31.2, 50.0], "num_samples": 7325}, + "wilson_2019_p3": {"resolution": [32, 32, 30], "num_samples": 2092}, + "ishibashi_2021_em1": {"resolution": [32, 32, 32], "num_samples": 141}, + "ishibashi_2021_em2": {"resolution": [32, 32, 32], "num_samples": 166}, + "templier_2019_wafer1": {"resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer3": {"resolution": [32, 32, 50], "num_samples": 3577}, + "lichtman_octopus2022": {"resolution": [32, 32, 30], "num_samples": 5673}, +} + +#UNET_DOWNSAMPLE: { + "@type": "torch.nn.MaxPool2d" + "@mode": "partial" + kernel_size: 2 +} + +#UNET_UPSAMPLE: { + { + "@type": "UpConv" + "@mode": "partial" + kernel_size: #K + upsampler: { + "@type": "torch.nn.Upsample" + "@mode": "partial" + scale_factor: 2 + mode: "nearest" + align_corners: null + }, + conv: { + "@type": "torch.nn.Conv2d" + "@mode": "partial" + padding: 1 + } + } +} + +#TARGET: { + "@type": "lightning_train" + regime: { + "@type": "MisalignmentDetectorAcedRegime" + output_mode: "binary" + encoder_path: null + max_shared_displacement_px: 0.0 + max_src_displacement_px: { + "@type": "uniform_distr" + low: 0.0 + high: 0.0 + } + equivar_rot_deg_distr: { + "@type": "uniform_distr" + low: 0.0 + high: 0.0 + } + equivar_trans_px_distr: { + "@type": "uniform_distr" + low: 0.0 + high: 0.0 + } + + field_magn_thr: #FIELD_MAGN_THR + val_log_row_interval: 4 + train_log_row_interval: 200 + lr: #LR + model: { + "@type": "load_weights_file" + model: { + "@type": "torch.nn.Sequential" + modules: [ + { + "@type": "ConvBlock", + "@version": "0.0.2" + num_channels: [2, #FM], + kernel_sizes: [5, 5], + activate_last: true, + }, + { + "@type": "UNet" + "@version": "0.0.2" + list_num_channels: [ + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + + [#FM, #FM, #FM], + + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + [#FM, #FM, #FM], + ] + downsample: #UNET_DOWNSAMPLE + upsample: #UNET_UPSAMPLE + activate_last: true + kernel_sizes: [#K, #K] + padding_modes: "zeros" + unet_skip_mode: "sum" + skips: {"0": 2} + }, + { + "@type": "torch.nn.Conv2d" + in_channels: #FM + out_channels: 1 + kernel_size: 1 + }, + // { + // "@type": "torch.nn.Sigmoid" # Regime applies binary_cross_entropy_with_logits + // } + ] + }, + ckpt_path: #MODEL_CKPT + component_names: [ + "model", + ] + } + } + trainer: { + "@type": "ZettaDefaultTrainer" + accelerator: "gpu" + precision: "16-mixed", + devices: 1 + max_epochs: 100 + default_root_dir: #TRAINING_ROOT + experiment_name: #EXP_NAME + experiment_version: #EXP_VERSION + log_every_n_steps: 10 + val_check_interval: 500 + num_sanity_val_steps: -1 + reload_dataloaders_every_n_epochs: 1, + checkpointing_kwargs: { + update_every_n_secs: 1700 + backup_every_n_secs: 3700 + } + } + + train_dataloader: { + "@type": "TorchDataLoader" + batch_size: 8 + //shuffle: true + sampler: { + "@type": "SamplerWrapper", + sampler: { + "@type": "TorchRandomSampler" + data_source: { + "@type": "torch.arange" + "end": list.Sum([for dataset in #TRAIN_DATASETS {dataset.num_samples}]) + }, + replacement: false, + num_samples: 4000, + }, + }, + num_workers: 19 + dataset: #TRAINING + pin_memory: true + } + val_dataloader: { + "@type": "TorchDataLoader" + batch_size: 4 + shuffle: false + num_workers: 19 + dataset: #VALIDATION + pin_memory: true + } +} + + +#ENC_PROCS: [ + { + "@mode": "partial" + "@type": "rearrange" + "pattern": "c x y 1 -> c x y" + }, + { + "@type": "divide" + "@mode": "partial" + value: 127.0 + }, +] + +#DISP_PROCS: [ + { + "@mode": "partial" + "@type": "rearrange" + "pattern": "c x y 1 -> c x y" + }, + { + "@type": "divide" + "@mode": "partial" + value: 10.0 + }, +] + + +#TRAINING: { + "@type": "JointDataset" + mode: "horizontal" + datasets: { + images: { + "@type": "JointDataset" + mode: "vertical" + datasets: { + for key, dataset in #TRAIN_DATASETS { + for z_offset in #Z_OFFSETS { + "\(key)_z\(z_offset)": { + "@type": "LayerDataset" + layer: { + "@type": "build_layer_set" + layers: { + src: { + "@type": "build_cv_layer" + path: #SRC_ENC_PATH + key + "/bad_alignment/z\(z_offset)" + read_procs: #ENC_PROCS + } + tgt: { + "@type": "build_cv_layer" + path: #TGT_ENC_PATH + key + "/tgt_enc_2023" + read_procs: #ENC_PROCS + } + displacement: { + "@type": "build_cv_layer" + path: #DISP_PATH + key + "/displacements/z\(z_offset)" + read_procs: #DISP_PROCS + } + } + } + sample_indexer: { + "@type": "RandomIndexer", + inner_indexer: { + "@type": "VolumetricNGLIndexer", + resolution: [dataset.resolution[0] * #DS_FACTOR, dataset.resolution[1] * #DS_FACTOR, dataset.resolution[2]], + chunk_size: [#CHUNK_XY, #CHUNK_XY, 1], + path: "zetta-research-nico/encoder/pairwise_aligned/" + key, + } + } + }, + } + } + } + } + } +} + + +#VALIDATION: { + "@type": "JointDataset" + mode: "horizontal" + datasets: { + images: { + "@type": "JointDataset" + mode: "vertical" + datasets: { + for key, dataset in #VAL_DATASETS { + for z_offset in #Z_OFFSETS { + "\(key)_z\(z_offset)": { + "@type": "LayerDataset" + layer: { + "@type": "build_layer_set" + layers: { + src: { + "@type": "build_cv_layer" + path: #SRC_ENC_PATH + key + "/bad_alignment/z\(z_offset)" + read_procs: #ENC_PROCS + } + tgt: { + "@type": "build_cv_layer" + path: #TGT_ENC_PATH + key + "/tgt_enc_2023" + read_procs: #ENC_PROCS + } + displacement: { + "@type": "build_cv_layer" + path: #DISP_PATH + key + "/displacements/z\(z_offset)" + read_procs: #DISP_PROCS + } + } + } + sample_indexer: { + "@type": "LoopIndexer", + desired_num_samples: 100 + inner_indexer: { + "@type": "VolumetricNGLIndexer", + resolution: [dataset.resolution[0] * #DS_FACTOR, dataset.resolution[1] * #DS_FACTOR, dataset.resolution[2]], + chunk_size: [#CHUNK_XY, #CHUNK_XY, 1], + path: "zetta-research-nico/encoder/pairwise_aligned/" + key, + } + } + }, + } + } + } + } + } +} + + + +// "@type": "lightning_train_remote" +// "@mode": "partial" +// worker_cluster_name: "zutils-x3" +// worker_cluster_region: "us-east1" +// worker_cluster_project: "zetta-research" +// worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231113" +// worker_resources: {"nvidia.com/gpu": "4"}, +// worker_resource_requests: {"memory": "27560Mi", "cpu": 28}, +// num_nodes: 1 +// spec_path: #TARGET +// follow_logs: true +// env_vars: {"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + +[#TARGET] \ No newline at end of file diff --git a/specs/nico/training/aced_misd_cns/z1z2_enc_misd.cue b/specs/nico/training/aced_misd_cns/z1z2_enc_misd.cue deleted file mode 100644 index 6a60887c8..000000000 --- a/specs/nico/training/aced_misd_cns/z1z2_enc_misd.cue +++ /dev/null @@ -1,757 +0,0 @@ -#EXP_NAME: "aced_misd_cns" -#TRAINING_ROOT: "gs://zetta-research-nico/training_artifacts" -#LR: 1e-5 -#CLIP: 0e-5 -#K: 3 -#CHUNK_XY: 1024 -#FIELD_MAGN_THR: 5.0 - - -#EXP_VERSION: "thr\(#FIELD_MAGN_THR)_lr\(#LR)_z1z2_400-500_2910-2920_more_aligned_unet5_32_finetune_2" -#MODEL_CKPT: "gs://zetta-research-nico/training_artifacts/aced_misd_cns/thr5.0_lr0.00005_z1z2_400-500_2910-2920_more_aligned_unet5_32/last.ckpt" - -// #TGT_CV: "gs://zetta-research-nico/pairs_dsets/cns_x0_400-500/encs_warped/0" -// #SRC_Z2_PREFIX: "gs://zetta-research-nico/misd/enc/local_optima_400-500/enc_z2/med_7.5px_max_" -// #DISP_Z2_PREFIX: "gs://zetta-research-nico/misd/cns/local_optima_400-500/vec_length10x_z2/med_7.5px_max_" -// #MAX_DISP: 20 - -"@type": "mazepa.execute_on_gcp_with_sqs" -worker_image: "us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230405" -worker_resources: { - memory: "38560Mi" - "nvidia.com/gpu": "1" -} -worker_replicas: 1 - -local_test: false - -#UNET_DOWNSAMPLE: { - "@type": "torch.nn.MaxPool2d" - "@mode": "partial" - kernel_size: 2 -} - -#UNET_UPSAMPLE: { - { - "@type": "UpConv" - "@mode": "partial" - kernel_size: #K - upsampler: { - "@type": "torch.nn.Upsample" - "@mode": "partial" - scale_factor: 2 - mode: "nearest" - align_corners: null - }, - conv: { - "@type": "torch.nn.Conv2d" - "@mode": "partial" - padding: "same" - } - } -} - -target: { - "@type": "lightning_train" - "@mode": "partial" - - regime: { - "@type": "MisalignmentDetectorAcedRegime" - output_mode: "binary" - encoder_path: null - max_shared_displacement_px: 0.0 - max_src_displacement_px: { - "@type": "uniform_distr" - low: 0.0 - high: 0.0 - } - equivar_rot_deg_distr: { - "@type": "uniform_distr" - low: 0.0 - high: 0.0 - } - equivar_trans_px_distr: { - "@type": "uniform_distr" - low: 0.0 - high: 0.0 - } - - field_magn_thr: #FIELD_MAGN_THR - val_log_row_interval: 4 - train_log_row_interval: 200 - lr: #LR - model: { - "@type": "load_weights_file" - model: { - "@type": "torch.nn.Sequential" - modules: [ - { - "@type": "UNet" - "@version": "0.0.2" - list_num_channels: [ - [2, 32, 32], - [32, 32, 32], - [32, 32, 32], - [32, 32, 32], - [32, 32, 32], - - [32, 32, 32], - - [32, 32, 32], - [32, 32, 32], - [32, 32, 32], - [32, 32, 32], - [32, 32, 32], - ] - downsample: #UNET_DOWNSAMPLE - upsample: #UNET_UPSAMPLE - activate_last: true - kernel_sizes: [#K, #K] - padding_modes: "zeros" - unet_skip_mode: "sum" - skips: {"1": 2} - }, - { - "@type": "torch.nn.Conv2d" - in_channels: 32 - out_channels: 1 - kernel_size: 1 - }, - { - "@type": "torch.nn.Sigmoid" - } - ] - }, - ckpt_path: #MODEL_CKPT - component_names: [ - "model", - ] - } - } - trainer: { - "@type": "ZettaDefaultTrainer" - accelerator: "gpu" - devices: 1 - max_epochs: 100 - default_root_dir: #TRAINING_ROOT - experiment_name: #EXP_NAME - experiment_version: #EXP_VERSION - log_every_n_steps: 10 - val_check_interval: 1000 - gradient_clip_algorithm: "norm" - gradient_clip_val: #CLIP - checkpointing_kwargs: { - update_every_n_secs: 60 - backup_every_n_secs: 900 - } - } - - train_dataloader: { - "@type": "TorchDataLoader" - batch_size: 8 - shuffle: true - num_workers: 12 - dataset: #TRAINING_DSET - } - val_dataloader: { - "@type": "TorchDataLoader" - batch_size: 4 - shuffle: false - num_workers: 8 - dataset: #VAL_DSET - } -} - - -#IMG_PROCS: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "divide" - "@mode": "partial" - value: 127.0 - }, -] - -#DISP_PROCS: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "divide" - "@mode": "partial" - value: 10.0 - }, -] - - -#TRAINING_DSET: { - "@type": "JointDataset" - mode: "horizontal" - datasets: { - images: { - "@type": "JointDataset" - mode: "vertical" - datasets: { - for z_offset in [1, 2] { - "z400_500_\(z_offset)": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_400-500/fine_misaligned/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta_lee_fly_cns_001_alignment_temp/aced/coarse_x0/encodings_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_fields_400-500/fine_diff3/-\(z_offset)" - read_procs: #DISP_PROCS - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "ChainIndexer" - inner_indexer: [ - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [1 * 2048, 1 * 2048, 400] - end_coord: [4 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [6 * 2048, 1 * 2048, 400] - end_coord: [9 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [11 * 2048, 1 * 2048, 400] - end_coord: [15 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - } - ] - } - } - }, - "z400_500_\(z_offset)_aligned": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_400-500/fine/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta_lee_fly_cns_001_alignment_temp/aced/coarse_x0/encodings_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "file:///tmp/placeholder_400-500" - cv_kwargs: { - fill_missing: true - } - info_reference_path: "gs://zetta-research-nico/misd/cns/pairwise_fields_400-500/fine_diff3/-\(z_offset)" - read_procs: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "torch.zeros_like" - "@mode": "partial" - }, - { - "@type": "torch.add" - "@mode": "partial" - other: 0.0 - } - ] - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "ChainIndexer" - inner_indexer: [ - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [1 * 2048, 1 * 2048, 400] - end_coord: [4 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [6 * 2048, 1 * 2048, 400] - end_coord: [9 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [11 * 2048, 1 * 2048, 400] - end_coord: [15 * 2048, 4 * 2048, 498] - resolution: [32, 32, 45] - } - } - ] - } - } - }, - "z2910_2920_\(z_offset)": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_2908-2921/fine_misaligned/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/pairs_dsets/cns_x0_2910-2920_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_fields_2908-2921/fine_diff3/-\(z_offset)" - read_procs: #DISP_PROCS - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "ChainIndexer" - inner_indexer: [ - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [3 * 1024, 2 * 1024, 2910] - end_coord: [27 * 1024, 8 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [13 * 1024, 8 * 1024, 2910] - end_coord: [16 * 1024, 16 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [12 * 1024, 16 * 1024, 2910] - end_coord: [21 * 1024, 20 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [12 * 1024, 21 * 1024, 2910] - end_coord: [17 * 1024, 25 * 1024, 2921] - resolution: [32, 32, 45] - } - } - ] - } - } - }, - "z2910_2920_\(z_offset)_aligned": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_2908-2921/fine/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/pairs_dsets/cns_x0_2910-2920_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "file:///tmp/placeholder_2908-2921" - cv_kwargs: { - fill_missing: true - } - info_reference_path: "gs://zetta-research-nico/misd/cns/pairwise_fields_2908-2921/fine_diff3/-\(z_offset)" - read_procs: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "torch.zeros_like" - "@mode": "partial" - }, - { - "@type": "torch.add" - "@mode": "partial" - other: 0.0 - } - ] - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "ChainIndexer" - inner_indexer: [ - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [3 * 1024, 2 * 1024, 2910] - end_coord: [27 * 1024, 8 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [13 * 1024, 8 * 1024, 2910] - end_coord: [16 * 1024, 16 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [12 * 1024, 16 * 1024, 2910] - end_coord: [21 * 1024, 20 * 1024, 2921] - resolution: [32, 32, 45] - } - }, - { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [12 * 1024, 21 * 1024, 2910] - end_coord: [17 * 1024, 25 * 1024, 2921] - resolution: [32, 32, 45] - } - } - ] - } - } - }, - "false_neg_z\(z_offset)": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_3406-3410/fine/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/pairs_dsets/cns_x0_3406-3410_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "file:///tmp/placeholder_3406-3410" - cv_kwargs: { - fill_missing: true - } - info_reference_path: "gs://zetta-research-nico/misd/cns/pairwise_fields_2908-2921/fine_diff3/-\(z_offset)" - read_procs: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "torch.full_like" - "@mode": "partial" - fill_value: 255.0 - }, - { - "@type": "torch.add" - "@mode": "partial" - other: 0.0 - } - ] - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "LoopIndexer" - if z_offset == 1 { - desired_num_samples: 12500 - } - if z_offset == 2 { - desired_num_samples: 8000 - } - inner_indexer: { - "@type": "VolumetricNGLIndexer" - resolution: [32, 32, 45] - chunk_size: [1024, 1024, 1] - path: "nkem/cns/false_neg_z\(z_offset)" - } - } - } - } - }, - } - } - } -} - - -#VAL_DSET: { - "@type": "JointDataset" - mode: "horizontal" - datasets: { - images: { - "@type": "JointDataset" - mode: "vertical" - datasets: { - for z_offset in [2] { - "z2000_2001_\(z_offset)": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_1998-2001/fine_misaligned/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/pairs_dsets/cns_x0_1998-2001_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_fields_1998-2001/fine_diff3/-\(z_offset)" - read_procs: #DISP_PROCS - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [3 * 1024, 3 * 1024, 2000] - end_coord: [14 * 1024, 7 * 1024, 2001] - resolution: [32, 32, 45] - } - }, - } - }, - "z2000_2001_\(z_offset)_aligned": { - "@type": "LayerDataset" - layer: { - "@type": "build_layer_set" - layers: { - src: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/misd/cns/pairwise_enc_1998-2001/fine/-\(z_offset)" - read_procs: #IMG_PROCS - } - tgt: { - "@type": "build_cv_layer" - path: "gs://zetta-research-nico/pairs_dsets/cns_x0_1998-2001_masked" - read_procs: #IMG_PROCS - index_procs: [ - { - "@type": "VolumetricIndexTranslator" - offset: [0, 0, -z_offset] - resolution: [32, 32, 45] - } - ] - } - displacement: { - "@type": "build_cv_layer" - path: "file:///tmp/placeholder_1998-2001" - cv_kwargs: { - fill_missing: true - } - info_reference_path: "gs://zetta-research-nico/misd/cns/pairwise_fields_1998-2001/fine_diff3/-\(z_offset)" - read_procs: [ - { - "@mode": "partial" - "@type": "rearrange" - "pattern": "c x y 1 -> c x y" - }, - { - "@type": "torch.zeros_like" - "@mode": "partial" - }, - { - "@type": "torch.add" - "@mode": "partial" - other: 0.0 - } - ] - } - } - } - sample_indexer: { - "@type": "RandomIndexer" - inner_indexer: { - "@type": "VolumetricStridedIndexer" - resolution: [32, 32, 45] - stride: [#CHUNK_XY, #CHUNK_XY, 1] - chunk_size: [#CHUNK_XY, #CHUNK_XY, 1] - bbox: { - "@type": "BBox3D.from_coords" - start_coord: [3 * 1024, 3 * 1024, 2000] - end_coord: [14 * 1024, 7 * 1024, 2001] - resolution: [32, 32, 45] - } - }, - } - }, - } - } - } - } -} diff --git a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py b/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py index 756ce0ee8..f374426d2 100644 --- a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py +++ b/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py @@ -1,4 +1,5 @@ # pylint: disable=too-many-locals +import os from typing import Literal, Optional import attrs @@ -107,15 +108,11 @@ def validation_epoch_start(self, _): # pylint: disable=no-self-use seed_everything(42) def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - seed_everything(None) + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is not None: + seed_everything(int(env_seed) + self.current_epoch) + else: + seed_everything(None) def _get_warped(self, img, field=None): img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value) @@ -236,7 +233,7 @@ def compute_misd_loss(self, batch: dict, mode: str, log_row: bool, sample_name: weight = torch.ones_like(gt_labels, dtype=torch.float32) weight[intersect_tissue == 0] = 0.0 - loss_map = torch.nn.functional.binary_cross_entropy( + loss_map = torch.nn.functional.binary_cross_entropy_with_logits( prediction, gt_labels.float(), weight=weight, reduction="none" ) From d8dda04ff11e1e546554ef1eabf944f7257266f0 Mon Sep 17 00:00:00 2001 From: Sergiy Popovych Date: Tue, 28 Nov 2023 23:02:19 +0000 Subject: [PATCH 9/9] fix: update training code to work with python specs --- .../em_encoder/train/m3_m3_encoder_dict.py | 357 +++++++++------- .../em_encoder/train/m3_m4_encoder_dict.py | 361 +++++++++------- .../em_encoder/train/m3_m5_encoder_dict.py | 378 +++++++++-------- .../em_encoder/train/m3_m6_encoder_dict.py | 400 ++++++++++-------- .../em_encoder/train/m3_m7_encoder_dict.py | 372 ++++++++-------- zetta_utils/training/lightning/train.py | 120 ++++-- 6 files changed, 1089 insertions(+), 899 deletions(-) diff --git a/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py index 8fde87463..b471fdc2f 100644 --- a/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m3_encoder_dict.py @@ -1,12 +1,10 @@ # pylint: skip-file from __future__ import annotations -if __name__ == "__main__": - import os +from typing import cast +if __name__ == "__main__": from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -22,8 +20,10 @@ EXP_VERSION = f"1.0.33_M3_M3_unet_fm{FM}-256_conv4_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -35,7 +35,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -45,22 +49,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -143,10 +175,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -195,7 +229,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -232,8 +265,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -294,10 +327,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -329,153 +362,153 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "activate_last": True, - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM, FM, FM, FM*2], - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - [FM, FM, FM, FM], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "activate_last": True, + }, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM * 2], + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM, - "out_channels": 1, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": 1, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, - }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + }, + } - # _parse_spec_and_train() + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py index 098fe2f58..5a82a46d1 100644 --- a/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m4_encoder_dict.py @@ -1,6 +1,8 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os @@ -24,8 +26,10 @@ EXP_VERSION = f"1.0.0_M3_M4_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.10_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +41,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +55,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +181,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +235,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +271,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +333,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,159 +368,155 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM, FM, FM, FM*2], - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - [FM, FM, FM, FM], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM, FM, FM, FM * 2], + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + [FM, FM, FM, FM], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x21xi3", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py index 5a17fa794..5a90c252c 100644 --- a/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m5_encoder_dict.py @@ -1,6 +1,8 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os @@ -24,8 +26,10 @@ EXP_VERSION = f"1.0.14_M3_M5_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +41,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +55,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +181,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +235,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +271,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +333,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,169 +368,162 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM*2, FM*2, FM*2, FM*4], - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - [FM*2, FM*2, FM*2, FM*2], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM * 2, FM * 2, FM * 2, FM * 4], + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + [FM * 2, FM * 2, FM * 2, FM * 2], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM*2, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 2, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231103", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py index df88b3b4b..7ede037d8 100644 --- a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict.py @@ -1,12 +1,10 @@ # pylint: skip-file from __future__ import annotations -if __name__ == "__main__": - import os +from typing import cast +if __name__ == "__main__": from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -24,8 +22,10 @@ EXP_VERSION = f"1.0.15_M3_M6_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +37,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +51,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +177,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +231,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +267,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +329,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,179 +364,170 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*2, FM*2, FM*2, FM*4], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "UNet", - "list_num_channels": [ - [FM*4, FM*4, FM*4, FM*8], - [FM*8, FM*8, FM*8, FM*8], - [FM*4, FM*4, FM*4, FM*4], - ], - "downsample": { - "@type": "torch.nn.MaxPool2d", - "@mode": "partial", - "kernel_size": 2, - }, - "upsample": { - "@type": "UpConv", - "@mode": "partial", - "kernel_size": 3, - "upsampler": { - "@type": "torch.nn.Upsample", - "@mode": "partial", - "scale_factor": 2, - "mode": "nearest", - "align_corners": None, - }, - "conv": { - "@type": "torch.nn.Conv2d", - "@mode": "partial", - "padding": 1, - }, - }, - "activate_last": True, - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "unet_skip_mode": "sum", - "skips": {"0": 2}, + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 2, FM * 2, FM * 2, FM * 4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "UNet", + "list_num_channels": [ + [FM * 4, FM * 4, FM * 4, FM * 8], + [FM * 8, FM * 8, FM * 8, FM * 8], + [FM * 4, FM * 4, FM * 4, FM * 4], + ], + "downsample": { + "@type": "torch.nn.MaxPool2d", + "@mode": "partial", + "kernel_size": 2, + }, + "upsample": { + "@type": "UpConv", + "@mode": "partial", + "kernel_size": 3, + "upsampler": { + "@type": "torch.nn.Upsample", + "@mode": "partial", + "scale_factor": 2, + "mode": "nearest", + "align_corners": None, }, - { + "conv": { "@type": "torch.nn.Conv2d", - "in_channels": FM*4, - "out_channels": CHANNELS, - "kernel_size": 1, + "@mode": "partial", + "padding": 1, }, - {"@type": "torch.nn.Tanh"}, - ], + }, + "activate_last": True, + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "unet_skip_mode": "sum", + "skips": {"0": 2}, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 4, + "out_channels": CHANNELS, + "kernel_size": 1, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, - }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + {"@type": "torch.nn.Tanh"}, + ], }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) + }, + } - # _parse_spec_and_train() + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 2, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, + }, + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py index 5f5cce385..f9684b900 100644 --- a/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m7_encoder_dict.py @@ -1,12 +1,12 @@ # pylint: skip-file from __future__ import annotations +from typing import cast + if __name__ == "__main__": import os from zetta_utils.api.v0 import * - from zetta_utils.parsing import json - from zetta_utils.training.lightning.train import _parse_spec_and_train LR = 2e-4 L1_WEIGHT_START_VAL = 0.0 @@ -22,10 +22,12 @@ EXP_NAME = "general_coarsener_loss" TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts" - EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" + EXP_VERSION = f"1.0.3_M3_M7_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4_repro_x2_4xt4_auto_2nodes" - START_EXP_VERSION = f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" - MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" + START_EXP_VERSION = ( + f"1.0.9_M3_M3_unet_fm32-256_lr0.0004_locality1.0_similarity0.0_l10.0-0.05_N4xB2" + ) + MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_encoder_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" @@ -37,7 +39,11 @@ "microns_pinky": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 5019}, # "microns_basil": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 2591}, "microns_minnie": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 2882}, - "microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 6923}, + "microns_interneuron": { + "contiguous": False, + "resolution": [32, 32, 40], + "num_samples": 6923, + }, "aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45], "num_samples": 5805}, "kim_n2da": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 446}, "kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 3699}, @@ -47,22 +53,50 @@ "lee_banc": {"contiguous": False, "resolution": [32, 32, 45], "num_samples": 742}, "lee_ppc": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 7219}, "lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1964}, - "lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30], "num_samples": 2799}, - "prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 4584}, + "lichtman_zebrafish": { + "contiguous": False, + "resolution": [32, 32, 30], + "num_samples": 2799, + }, + "prieto_godino_larva": { + "contiguous": True, + "resolution": [32, 32, 32], + "num_samples": 4584, + }, "fafb_v15": {"contiguous": False, "resolution": [32, 32, 40], "num_samples": 1795}, "lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33], "num_samples": 6624}, "janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 5304}, "janelia_manc": {"contiguous": False, "resolution": [32, 32, 32], "num_samples": 2398}, - "nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1847}, + "nguyen_thomas_2022": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1847, + }, "mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 3379}, - "wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40], "num_samples": 1704}, - "bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50], "num_samples": 7325}, + "wildenberg_2021_vta_dat12a": { + "contiguous": True, + "resolution": [32, 32, 40], + "num_samples": 1704, + }, + "bumbarber_2013": { + "contiguous": True, + "resolution": [31.2, 31.2, 50], + "num_samples": 7325, + }, "wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 2092}, "ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 141}, # "ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32], "num_samples": 166}, - "templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 5401}, + "templier_2019_wafer1": { + "contiguous": True, + "resolution": [32, 32, 50], + "num_samples": 5401, + }, # "templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50], "num_samples": 3577}, - "lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30], "num_samples": 5673}, + "lichtman_octopus2022": { + "contiguous": True, + "resolution": [32, 32, 30], + "num_samples": 5673, + }, } val_img_aug = [ @@ -145,10 +179,12 @@ { "@type": "imgaug.augmenters.Sometimes", "p": 1.0, - "then_list": [{ - "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", - "severity": 1, - }] + "then_list": [ + { + "@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", + "severity": 1, + } + ], }, { "@type": "imgaug.augmenters.Cutout", @@ -197,7 +233,6 @@ {"@type": "divide", "@mode": "partial", "value": 255.0}, ] - training = { "@type": "JointDataset", "mode": "horizontal", @@ -234,8 +269,8 @@ "resolution": settings["resolution"], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + name, - } - } + }, + }, } for name, settings in SOURCE_PATHS.items() }, @@ -296,10 +331,10 @@ "desired_num_samples": 100, "inner_indexer": { "@type": "VolumetricNGLIndexer", - "resolution": [32,32,40], + "resolution": [32, 32, 40], "chunk_size": [CHUNK_SIZE, CHUNK_SIZE, 1], "path": "zetta-research-nico/encoder/pairwise_aligned/" + VAL_DSET_NAME, - } + }, }, }, "field": { @@ -331,164 +366,151 @@ }, } - - target = BuilderPartial( - { - "@type": "lightning_train", - "regime": { - "@type": "BaseEncoderRegime", - "@version": "0.0.2", - "max_displacement_px": 32.0, - "val_log_row_interval": 20, - "train_log_row_interval": 100, - "lr": LR, - "l1_weight_start_val": L1_WEIGHT_START_VAL, - "l1_weight_end_val": L1_WEIGHT_END_VAL, - "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, - "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, - "locality_weight": LOCALITY_WEIGHT, - "similarity_weight": SIMILARITY_WEIGHT, - "ds_factor": DS_FACTOR, - "empty_tissue_threshold": 0.6, - "model": { - "@type": "load_weights_file", - "ckpt_path": MODEL_CKPT, - "component_names": [ - "model", - ], - "model": { - "@type": "torch.nn.Sequential", - "modules": [ - { - "@type": "ConvBlock", - "num_channels": [1, FM], - "kernel_sizes": [5, 5], - "padding_modes": "reflect", - "activate_last": True, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM, FM, FM, FM*2], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*2, FM*2, FM*2, FM*4], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*4, FM*4, FM*4, FM*8], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 2}, - }, - { - "@type": "torch.nn.MaxPool2d", - "kernel_size": 2 - }, - { - "@type": "ConvBlock", - "num_channels": [FM*8, FM*8, FM*8, FM*8], - "kernel_sizes": [3, 3], - "padding_modes": "reflect", - "activate_last": True, - "skips": {"0": 3}, - }, - { - "@type": "torch.nn.Conv2d", - "in_channels": FM*8, - "out_channels": CHANNELS, - "kernel_size": 1, - }, - {"@type": "torch.nn.Tanh"}, - ], + regime_spec = { + "@type": "BaseEncoderRegime", + "@version": "0.0.2", + "max_displacement_px": 32.0, + "val_log_row_interval": 20, + "train_log_row_interval": 100, + "lr": LR, + "l1_weight_start_val": L1_WEIGHT_START_VAL, + "l1_weight_end_val": L1_WEIGHT_END_VAL, + "l1_weight_start_epoch": L1_WEIGHT_START_EPOCH, + "l1_weight_end_epoch": L1_WEIGHT_END_EPOCH, + "locality_weight": LOCALITY_WEIGHT, + "similarity_weight": SIMILARITY_WEIGHT, + "ds_factor": DS_FACTOR, + "empty_tissue_threshold": 0.6, + "model": { + "@type": "load_weights_file", + "ckpt_path": MODEL_CKPT, + "component_names": [ + "model", + ], + "model": { + "@type": "torch.nn.Sequential", + "modules": [ + { + "@type": "ConvBlock", + "num_channels": [1, FM], + "kernel_sizes": [5, 5], + "padding_modes": "reflect", + "activate_last": True, }, - }, - }, - "trainer": { - "@type": "ZettaDefaultTrainer", - "accelerator": "gpu", - "precision": "16-mixed", - "strategy": "auto", - # "use_distributed_sampler": False, - "devices": 4, - "max_epochs": 100, - "default_root_dir": TRAINING_ROOT, - "experiment_name": EXP_NAME, - "experiment_version": EXP_VERSION, - "log_every_n_steps": 100, - "val_check_interval": 500, - # "limit_val_batches": 0, - # "track_grad_norm": 2, - # "gradient_clip_algorithm": "norm", - # "gradient_clip_val": CLIP, - # "detect_anomaly": True, - # "overfit_batches": 100, - "reload_dataloaders_every_n_epochs": 1, - "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, - }, - "train_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 2, - # "shuffle": True, - "sampler": { - "@type": "SamplerWrapper", - "sampler": { - "@type": "TorchRandomSampler", # Random order across all samples and all datasets - "data_source": {"@type": "torch.arange", "end": sum([settings["num_samples"] for settings in SOURCE_PATHS.values()])}, - "replacement": False, - "num_samples": 8000, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM, FM, FM, FM * 2], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, }, - }, - "num_workers": 28, - "dataset": training, - "pin_memory": True, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 2, FM * 2, FM * 2, FM * 4], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 4, FM * 4, FM * 4, FM * 8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 2}, + }, + {"@type": "torch.nn.MaxPool2d", "kernel_size": 2}, + { + "@type": "ConvBlock", + "num_channels": [FM * 8, FM * 8, FM * 8, FM * 8], + "kernel_sizes": [3, 3], + "padding_modes": "reflect", + "activate_last": True, + "skips": {"0": 3}, + }, + { + "@type": "torch.nn.Conv2d", + "in_channels": FM * 8, + "out_channels": CHANNELS, + "kernel_size": 1, + }, + {"@type": "torch.nn.Tanh"}, + ], }, - "val_dataloader": { - "@type": "TorchDataLoader", - "batch_size": 1, - "shuffle": False, - "num_workers": 28, - "dataset": validation, - "pin_memory": True, + }, + } + trainer_spec = { + "@type": "ZettaDefaultTrainer", + "accelerator": "gpu", + "precision": "16-mixed", + "strategy": "auto", + # "use_distributed_sampler": False, + "devices": "auto", + "max_epochs": 100, + "default_root_dir": TRAINING_ROOT, + "experiment_name": EXP_NAME, + "experiment_version": EXP_VERSION, + "log_every_n_steps": 100, + "val_check_interval": 500, + # "limit_val_batches": 0, + # "track_grad_norm": 2, + # "gradient_clip_algorithm": "norm", + # "gradient_clip_val": CLIP, + # "detect_anomaly": True, + # "overfit_batches": 100, + "reload_dataloaders_every_n_epochs": 1, + "checkpointing_kwargs": {"update_every_n_secs": 1700, "backup_every_n_secs": 3700}, + } + train_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 4, + # "shuffle": True, + "sampler": { + "@type": "SamplerWrapper", + "sampler": { + "@type": "TorchRandomSampler", # Random order across all samples and all datasets + "data_source": { + "@type": "torch.arange", + "end": sum( + [cast(int, settings["num_samples"]) for settings in SOURCE_PATHS.values()] + ), + }, + "replacement": False, + "num_samples": 8000, }, - } - ) - - - os.environ["ZETTA_RUN_SPEC"] = json.dumps(target.spec) - - # _parse_spec_and_train() + }, + "num_workers": 28, + "dataset": training, + "pin_memory": True, + } + val_dataloader_spec = { + "@type": "TorchDataLoader", + "batch_size": 1, + "shuffle": False, + "num_workers": 28, + "dataset": validation, + "pin_memory": True, + } - lightning_train_remote( - worker_cluster_name="zutils-x3", - worker_cluster_region="us-east1", - worker_cluster_project="zetta-research", - worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", - worker_resources={"nvidia.com/gpu": "4"}, - worker_resource_requests={"memory": "27560Mi", "cpu": 28}, + lightning_train( + regime=regime_spec, + trainer=trainer_spec, + train_dataloader=train_dataloader_spec, + val_dataloader=val_dataloader_spec, num_nodes=1, - spec_path=target, - follow_logs=False, + cluster_name="zutils-x3", + cluster_region="us-east1", + cluster_project="zetta-research", + # image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20231106", + image="us.gcr.io/zetta-research/zetta_utils:sergiy_all_p39_x213", + resource_limits={"nvidia.com/gpu": "4"}, + resource_requests={"memory": "27560Mi", "cpu": 28}, + follow_logs=True, env_vars={"LOGLEVEL": "INFO", "NCCL_SOCKET_IFNAME": "eth0"}, + local_run=False, ) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 9866eb564..ae7b18840 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -36,13 +36,12 @@ def distributed_available() -> bool: @builder.register("lightning_train") @typeguard.typechecked def lightning_train( - regime: pl.LightningModule, - trainer: pl.Trainer, - train_dataloader: torch.utils.data.DataLoader, - val_dataloader: Optional[torch.utils.data.DataLoader] = None, + regime: pl.LightningModule | dict[str, Any], + trainer: pl.Trainer | dict[str, Any], + train_dataloader: torch.utils.data.DataLoader | dict[str, Any], + val_dataloader: Optional[torch.utils.data.DataLoader | dict[str, Any]] = None, full_state_ckpt_path: str = "last", num_nodes: int = 1, - nproc_per_node: int = 1, retry_count: int = 3, local_run: bool = True, follow_logs: bool = True, @@ -70,7 +69,6 @@ def lightning_train( than a model checkpoint. If ``full_state_ckpt_path=="last"``, the latest checkpoint for the given experiment will be identified and loaded. :param num_nodes: Number of GPU nodes for distributed training. - :param nproc_per_node: Number of GPU workers per node. :param retry_count: Max retry count for the master train job; excludes failures due to pod distruptions. :param local_run: If True run the training locally. @@ -84,13 +82,31 @@ def lightning_train( :param resource_limits: K8s reource limits per pod. :param resource_requests: K8s resource requests per pod. """ + args_mapping = { + "regime": regime, + "trainer": trainer, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + } if local_run: - _lightning_train_local(regime, trainer, train_dataloader, val_dataloader=val_dataloader) + _lightning_train_local( + regime=regime if not isinstance(regime, dict) else builder.build(regime), + trainer=trainer if not isinstance(trainer, dict) else builder.build(trainer), + train_dataloader=train_dataloader + if not isinstance(train_dataloader, dict) + else builder.build(train_dataloader), + val_dataloader=val_dataloader + if not isinstance(val_dataloader, dict) + else builder.build(val_dataloader), + full_state_ckpt_path=full_state_ckpt_path, + ) return - assert image is not None, "Must provide a container image for remote training." - assert resource_limits is not None, "Must provide resource limits for remote training." + if image is None: + raise ValueError("Must provide a container image for remote training.") + if resource_limits is None: + raise ValueError("Must provide resource limits for remote training.") execution_id = mazepa.id_generation.get_unique_id( prefix="exec", slug_len=4, add_uuid=False, max_len=50 @@ -102,26 +118,36 @@ def lightning_train( cluster_project=cluster_project, ) - train_spec = { - "@type": "lightning_train", - "regime": builder.get_initial_builder_spec(regime), - "trainer": builder.get_initial_builder_spec(trainer), - "train_dataloader": builder.get_initial_builder_spec(train_dataloader), - "val_dataloader": builder.get_initial_builder_spec(val_dataloader), + args_mapping = { + "regime": regime, + "trainer": trainer, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + } + + train_args: dict = { "full_state_ckpt_path": full_state_ckpt_path, } - for _key in ["regime", "trainer", "train_dataloader"]: - assert train_spec[_key] is not None, f"{_key} requires builder compatible spec." + for k, v in args_mapping.items(): + if isinstance(v, dict): + # argument given as spec, use it directly + train_args[k] = v + else: + arg_spec = builder.get_initial_builder_spec(v) + if arg_spec is None: + raise RuntimeError( + f"No builder spec found for {k}. Remote training requires arguments to " + ) + train_args[k] = arg_spec _lightning_train_remote( execution_id, cluster_info=cluster_info, image=image, num_nodes=num_nodes, - nproc_per_node=nproc_per_node, retry_count=retry_count, - train_spec=train_spec, + train_args=train_args, env_vars=env_vars, follow_logs=follow_logs, host_network=num_nodes > 1, @@ -130,9 +156,9 @@ def lightning_train( ) -@builder.register("multinode_train_launch") +@builder.register("_multinode_train_launch") @typeguard.typechecked -def multinode_train_launch( +def _multinode_train_launch( execution_id: str, num_nodes: int, nproc_per_node: int, @@ -152,6 +178,8 @@ def multinode_train_launch( torch_launcher_api.elastic_launch(config, _parse_spec_and_train)() +@builder.register("_lightning_train_local") +@typeguard.typechecked def _lightning_train_local( regime: pl.LightningModule, trainer: pl.Trainer, @@ -184,19 +212,20 @@ def _lightning_train_local( def _parse_spec_and_train(): load_all_modules() - train_spec = None + train_args = None with open(os.environ["ZETTA_RUN_SPEC_PATH"], "r", encoding="utf-8") as f: - train_spec = json.load(f) + train_args = json.load(f) + logger.info(train_args) - regime = builder.build(spec=train_spec["regime"]) - trainer = builder.build(spec=train_spec["trainer"]) - train_dataloader = builder.build(spec=train_spec["train_dataloader"]) + regime = builder.build(spec=train_args["regime"]) + trainer = builder.build(spec=train_args["trainer"]) + train_dataloader = builder.build(spec=train_args["train_dataloader"]) try: - val_dataloader = builder.build(spec=train_spec["val_dataloader"]) + val_dataloader = builder.build(spec=train_args["val_dataloader"]) except KeyError: val_dataloader = None try: - full_state_ckpt_path = builder.build(spec=train_spec["full_state_ckpt_path"]) + full_state_ckpt_path = train_args["full_state_ckpt_path"] except KeyError: full_state_ckpt_path = "last" _lightning_train_local(regime, trainer, train_dataloader, val_dataloader, full_state_ckpt_path) @@ -251,9 +280,8 @@ def _lightning_train_remote( cluster_info: resource_allocation.k8s.ClusterInfo, image: str, num_nodes: int, - nproc_per_node: int, retry_count: int, - train_spec: dict, + train_args: dict, env_vars: Optional[Dict[str, str]] = None, follow_logs: Optional[bool] = False, host_network: Optional[bool] = False, @@ -265,13 +293,31 @@ def _lightning_train_remote( Creates a volume mount for `train.cue` in `/opt/zetta_utils/specs`. Runs the command `zetta run specs/train.cue` on one or more worker pods. """ + if train_args["trainer"]["accelerator"] == "gpu": + num_devices = int(resource_limits["nvidia.com/gpu"]) # type: ignore + trainer_devices = train_args["trainer"]["devices"] + if ( + isinstance(trainer_devices, int) + and trainer_devices != -1 + and trainer_devices != num_devices + ): + raise ValueError( + f"Trainer specification uses {trainer_devices} devices, " + f"while `nvidia.com/gpu` limit is {num_devices}." + ) + else: + raise NotImplementedError() if num_nodes > 1: - train_spec["@type"] = "multinode_train_launch" - train_spec["execution_id"] = execution_id - train_spec["num_nodes"] = num_nodes - train_spec["nproc_per_node"] = nproc_per_node - train_spec["trainer"]["num_nodes"] = num_nodes + train_args["execution_id"] = execution_id + train_args["num_nodes"] = num_nodes + train_args["nproc_per_node"] = num_devices + + train_args["trainer"]["num_nodes"] = num_nodes + train_spec = {"@type": "_multinode_train_launch", **train_args} + else: + train_spec = {"@type": "_lightning_train_local", **train_args} + specs = {"train": train_spec} vol, mount, spec_ctx = _spec_configmap_vol_and_ctx(execution_id, cluster_info, specs) secrets, env_secret_mapping = resource_allocation.k8s.get_secrets_and_mapping( @@ -319,7 +365,9 @@ def _lightning_train_remote( # that remains the same for the duration of training # these pools must have `master-pool=true` taint # not ncessary for single node ddp so it can be scheduled on preemptibles - tolerations=_get_tolerations(role="master" if num_nodes > 1 else "worker"), + tolerations=_get_tolerations( + "worker" + ), # _get_tolerations(role="master" if num_nodes > 1 else "worker"), volumes=volumes, volume_mounts=mounts, resource_requests=resource_requests,