Skip to content

Commit

Permalink
spec: base encoder spec
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 6, 2023
1 parent 73cb390 commit 315a580
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 12 deletions.
310 changes: 310 additions & 0 deletions specs/nico/training/em_encoder/train/01_m3_m3_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
from __future__ import annotations

import json
import os
from functools import partial

import torch
from imgaug import augmenters as iaa
from torch.utils.data import DataLoader as TorchDataLoader

from zetta_utils.api.v0 import *

POST_WEIGHT = 1.6
FIELD_MAGN_THR = 0.8
LR = 1e-4
EQUI_WEIGHT = 0.5
EXP_NAME = "general_encoder_debug"
TRAINING_ROOT = "gs://zetta-research-nico/training_artifacts"

EXP_VERSION = f"0.0.0_M3_M3_lr{LR}_equi{EQUI_WEIGHT}_fmt{FIELD_MAGN_THR}"

START_EXP_VERSION = None
MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/base_enc_zfish/{#START_EXP_VERSION}/last.ckpt"

VALIDATION_SRC_PATH = "gs://zetta-research-nico/pairs_dsets/zfish_x1/img_pairwise/-1"
VALIDATION_TGT_PATH = "gs://zetta-research-nico/pairs_dsets/zfish_x1/img_aligned"

SOURCE_PATHS = {
"microns_pinky": {"contiguous": True, "resolution": [32, 32, 40]},
"microns_basil": {"contiguous": True, "resolution": [32, 32, 40]},
"microns_minnie": {"contiguous": False, "resolution": [32, 32, 40]},
"microns_interneuron": {"contiguous": False, "resolution": [32, 32, 40]},
"aibs_v1dd": {"contiguous": False, "resolution": [38.8, 38.8, 45]},
"kim_n2da": {"contiguous": True, "resolution": [32, 32, 50]},
"kim_pfc2022": {"contiguous": True, "resolution": [32, 32, 40]},
"kronauer_cra9": {"contiguous": True, "resolution": [32, 32, 42]},
"kubota_001": {"contiguous": True, "resolution": [40, 40, 40]},
"lee_fanc": {"contiguous": False, "resolution": [34.4, 34.4, 45]},
"lee_banc": {"contiguous": False, "resolution": [32, 32, 45]},
"lee_ppc": {"contiguous": True, "resolution": [32, 32, 40]},
"lee_mosquito": {"contiguous": False, "resolution": [32, 32, 40]},
"lichtman_zebrafish": {"contiguous": False, "resolution": [32, 32, 30]},
"prieto_godino_larva": {"contiguous": True, "resolution": [32, 32, 32]},
"fafb_v15": {"contiguous": False, "resolution": [32, 32, 40]},
"lichtman_h01": {"contiguous": False, "resolution": [32, 32, 33]},
"janelia_hemibrain": {"contiguous": True, "resolution": [32, 32, 32]},
"janelia_manc": {"contiguous": False, "resolution": [32, 32, 32]},
"nguyen_thomas_2022": {"contiguous": True, "resolution": [32, 32, 40]},
"mulcahy_2022_16h": {"contiguous": True, "resolution": [32, 32, 30]},
"wildenberg_2021_vta_dat12a": {"contiguous": True, "resolution": [32, 32, 40]},
"bumbarber_2013": {"contiguous": True, "resolution": [31.2, 31.2, 50]},
"wilson_2019_p3": {"contiguous": True, "resolution": [32, 32, 30]},
"ishibashi_2021_em1": {"contiguous": True, "resolution": [32, 32, 32]},
"ishibashi_2021_em2": {"contiguous": True, "resolution": [32, 32, 32]},
"templier_2019_wafer1": {"contiguous": True, "resolution": [32, 32, 50]},
"templier_2019_wafer3": {"contiguous": True, "resolution": [32, 32, 50]},
"lichtman_octopus2022": {"contiguous": True, "resolution": [32, 32, 30]},
}

BASE_PATH = "gs://zetta-research-nico/encoder/"


val_img_aug = [
partial(rearrange, pattern="c x y 1 -> c x y"),
partial(divide, value=255.0),
]

train_img_aug = [
partial(divide, value=255.0),
partial(
square_tile_pattern_aug,
prob=0.5,
tile_size=uniform_distr(64, 1024),
tile_stride=uniform_distr(64, 1024),
max_brightness_change=uniform_distr(0.0, 0.3),
rotation_degree=uniform_distr(0, 90),
preserve_data_val=0.0,
repeats=1,
device="cpu",
),
partial(torch.clamp, min=0.0, max=1.0),
partial(multiply, value=255.0),
partial(to_uint8),
partial(
imgaug_readproc,
augmenters=[
iaa.SomeOf(
n=2,
children=[
iaa.OneOf(
children=[
iaa.OneOf(
children=[
iaa.GammaContrast(gamma=(0.5, 2.0)),
iaa.SigmoidContrast(gain=(4, 6), cutoff=(0.3, 0.7)),
iaa.LogContrast(gain=(0.7, 1.3)),
iaa.LinearContrast(alpha=(0.4, 1.6)),
]
),
iaa.AllChannelsCLAHE(clip_limit=(0.1, 8.0), tile_grid_size_px=(3, 64)),
]
),
iaa.Add((-40, 40)),
iaa.imgcorruptlike.DefocusBlur(severity=(1, 2)),
iaa.Cutout(
squared=False, nb_iterations=1, size=(0.05, 0.8), cval=(0, 255)
),
iaa.JpegCompression(compression=(0, 35)),
],
random_order=True,
)
],
),
]

shared_train_img_aug = [
partial(
imgaug_readproc,
augmenters=[
iaa.Sequential(
children=[
iaa.Rot90(k=[0, 1, 2, 3]),
iaa.Fliplr(p=0.25),
iaa.Flipud(p=0.25),
],
random_order=True,
),
],
),
partial(rearrange, pattern="c x y 1 -> c x y"),
partial(divide, value=255.0),
]


training = JointDataset(
mode="horizontal",
datasets={
"images": JointDataset(
mode="vertical",
datasets={
name: LayerDataset(
layer=build_layer_set(
{
"src_img": build_cv_layer(
BASE_PATH + "datasets/" + name, read_procs=train_img_aug
),
"tgt_img": build_cv_layer(
BASE_PATH + "pairwise_aligned/" + name + "/warped_img",
read_procs=train_img_aug,
),
},
readonly=True,
read_procs=shared_train_img_aug,
),
sample_indexer=VolumetricNGLIndexer(
path="zetta-research-nico/encoder/pairwise_aligned/" + name,
resolution=settings["resolution"],
chunk_size=[1024, 1024, 1],
),
)
for name, settings in SOURCE_PATHS.items()
},
),
"field": LayerDataset(
layer=build_cv_layer(
"gs://zetta-research-nico/perlin_noise_fields/1px",
read_procs=[
partial(rearrange, pattern="c x y 1 -> c x y"),
],
),
sample_indexer=RandomIndexer(
VolumetricStridedIndexer(
bbox=BBox3D.from_coords(
start_coord=[0, 0, 0], end_coord=[2048, 2048, 2040], resolution=[4, 4, 45]
),
stride=[128, 128, 1],
chunk_size=[1024, 1024, 1],
resolution=[4, 4, 45],
)
),
),
},
)

validation = JointDataset(
mode="horizontal",
datasets={
"images": LayerDataset(
layer=build_layer_set(
{
"src_img": build_cv_layer(VALIDATION_SRC_PATH, read_procs=val_img_aug),
"tgt_img": build_cv_layer(VALIDATION_TGT_PATH, read_procs=val_img_aug),
},
readonly=True,
),
sample_indexer=VolumetricNGLIndexer(
resolution=[32, 32, 30], chunk_size=[1024, 1024, 1], path="nkem/zfish/val"
),
),
"field": LayerDataset(
layer=build_cv_layer(
"gs://zetta-research-nico/perlin_noise_fields/1px",
read_procs=[
partial(rearrange, pattern="c x y 1 -> c x y"),
],
),
sample_indexer=RandomIndexer(
VolumetricStridedIndexer(
bbox=BBox3D.from_coords(
start_coord=[0, 0, 0], end_coord=[2048, 2048, 2040], resolution=[4, 4, 45]
),
stride=[512, 512, 1],
chunk_size=[1024, 1024, 1],
resolution=[4, 4, 45],
)
),
),
},
)


target = partial(
lightning_train,
regime=BaseEncoderRegime(
field_magn_thr=FIELD_MAGN_THR,
post_weight=POST_WEIGHT,
val_log_row_interval=4,
train_log_row_interval=500,
lr=LR,
equivar_weight=EQUI_WEIGHT,
model=load_weights_file(
model=torch.nn.Sequential(
ConvBlock(
num_channels=[1, 32],
kernel_sizes=[1, 1],
activate_last=True,
),
UNet(
list_num_channels=[
[32, 32, 32],
[32, 32, 32],
[32, 32, 32],
[32, 32, 32],
[32, 32, 32],
[32, 32, 32],
[32, 32, 32],
],
downsample=partial(torch.nn.MaxPool2d, kernel_size=2),
upsample=partial(
UpConv,
kernel_size=3,
upsampler=partial(
torch.nn.Upsample,
scale_factor=2,
mode="nearest",
align_corners=None,
),
conv=partial(torch.nn.Conv2d, padding=1),
),
activate_last=True,
kernel_sizes=[3, 3],
padding_modes="reflect",
unet_skip_mode="sum",
skips={"0": 2},
),
torch.nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1),
torch.nn.Tanh(),
)
),
),
trainer=ZettaDefaultTrainer(
accelerator="gpu",
devices=1,
max_epochs=10,
default_root_dir=TRAINING_ROOT,
experiment_name=EXP_NAME,
experiment_version=EXP_VERSION,
log_every_n_steps=100,
val_check_interval=500,
# track_grad_norm=2,
# gradient_clip_algorithm="norm",
# gradient_clip_val=CLIP,
# detect_anomaly=True,
# overfit_batches=100,
checkpointing_kwargs={"update_every_n_secs": 600, "backup_every_n_secs": 900},
),
train_dataloader=TorchDataLoader(
batch_size=1,
shuffle=True,
num_workers=8,
dataset=training,
),
val_dataloader=TorchDataLoader(
batch_size=1,
shuffle=False,
num_workers=8,
dataset=validation,
),
)


os.environ["ZETTA_RUN_SPEC"] = json.dumps("")

execute_on_gcp_with_sqs(
target=target,
worker_image="us.gcr.io/zetta-research/zetta_utils:nico_py3.9_20230905",
worker_resources={"memory": "25560Mi", "nvidia.com/gpu": "1"},
worker_replicas=1,
local_test=False,
)
5 changes: 1 addition & 4 deletions zetta_utils/training/lightning/regimes/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from . import encoding_coarsener
from . import encoding_coarsener_highres
from . import encoding_coarsener_gen_x1
from . import base_encoder
from . import misalignment_detector
from . import base_coarsener
from . import misalignment_detector_aced
20 changes: 12 additions & 8 deletions zetta_utils/training/lightning/regimes/alignment/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import attrs
import cc3d
import einops
import numpy as np
import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -183,14 +184,17 @@ def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_nam
)

f_aff = (
tensor_ops.transform.get_affine_field(
size=src.shape[-1],
rot_deg=self.equivar_rot_deg_distr(),
scale=self.equivar_scale_distr(),
shear_x_deg=self.equivar_shear_deg_distr(),
shear_y_deg=self.equivar_shear_deg_distr(),
trans_x_px=self.equivar_trans_px_distr(),
trans_y_px=self.equivar_trans_px_distr(),
einops.rearrange(
tensor_ops.transform.get_affine_field(
size=src.shape[-1],
rot_deg=self.equivar_rot_deg_distr(),
scale=self.equivar_scale_distr(),
shear_x_deg=self.equivar_shear_deg_distr(),
shear_y_deg=self.equivar_shear_deg_distr(),
trans_x_px=self.equivar_trans_px_distr(),
trans_y_px=self.equivar_trans_px_distr(),
),
"C X Y Z -> Z C X Y",
)
.pixels()
.to(seed_field.device)
Expand Down

0 comments on commit 315a580

Please sign in to comment.