Skip to content

Commit

Permalink
regimes: update encoder+coarsener, deprecate old regimes
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 6, 2023
1 parent e9e52a3 commit 73cb390
Show file tree
Hide file tree
Showing 8 changed files with 755 additions and 137 deletions.
335 changes: 335 additions & 0 deletions zetta_utils/training/lightning/regimes/alignment/base_coarsener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
# pragma: no cover
# pylint: disable=too-many-locals

from math import log2
from typing import Optional

import attrs
import cc3d
import numpy as np
import pytorch_lightning as pl
import torch
import torchfields
import wandb
from PIL import Image as PILImage
from pytorch_lightning import seed_everything

from zetta_utils import builder, distributions, tensor_ops, viz


@builder.register("BaseCoarsenerRegime")
@attrs.mutable(eq=False)
class BaseCoarsenerRegime(pl.LightningModule): # pylint: disable=too-many-ancestors
model: torch.nn.Module
lr: float
train_log_row_interval: int = 200
val_log_row_interval: int = 25
field_magn_thr: float = 1
max_displacement_px: float = 16.0
post_weight_start_step: int = 0
post_weight_end_step: int = 0
post_weight_start_val: float = 1.5
post_weight_end_val: float = 1.5
zero_value: float = 0
worst_val_loss: float = attrs.field(init=False, default=0)
worst_val_sample: dict = attrs.field(init=False, factory=dict)
worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None)

equivar_weight: float = 1.0
equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360)
equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10)
equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10)
equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1)
ds_factor: int = 2
empty_tissue_threshold: float = 0.4
_training_step: int = attrs.field(init=False, default=0)

def __attrs_pre_init__(self):
super().__init__()

def __attrs_post_init__(self):
# Maybe figure out ds_factor by running the model
pass

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer

@staticmethod
def log_results(mode: str, title_suffix: str = "", **kwargs):
images = []
for k, v in kwargs.items():
for b in range(1):
if v.dtype in (np.uint8, torch.uint8):
img = v[b].squeeze()
img[-1, -1] = 255
img[-2, -2] = 255
img[-1, -2] = 0
img[-2, -1] = 0
images.append(
wandb.Image(
PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"),
caption=f"{k}_b{b}",
)
)
elif v.dtype in (torch.int8, np.int8):
img = v[b].squeeze().byte() + 127
img[-1, -1] = 255
img[-2, -2] = 255
img[-1, -2] = 0
img[-2, -1] = 0
images.append(
wandb.Image(
PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"),
caption=f"{k}_b{b}",
)
)
elif v.dtype in (torch.bool, bool):
img = v[b].squeeze().byte() * 255
img[-1, -1] = 255
img[-2, -2] = 255
img[-1, -2] = 0
img[-2, -1] = 0
images.append(
wandb.Image(
PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"),
caption=f"{k}_b{b}",
)
)
else:
v_min = v[b].min().round(decimals=4)
v_max = v[b].max().round(decimals=4)
images.append(
wandb.Image(
viz.rendering.Renderer()(v[b].squeeze()),
caption=f"{k}_b{b} | min: {v_min} | max: {v_max}",
)
)

wandb.log({f"results/{mode}_{title_suffix}_slider": images})

def validation_epoch_start(self, _): # pylint: disable=no-self-use
seed_everything(42)

def validation_epoch_end(self, _):
self.log_results(
"val",
"worst",
**self.worst_val_sample,
)
self.worst_val_loss = 0
self.worst_val_sample = {}
self.worst_val_sample_idx = None
seed_everything(None)

def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ
log_row = batch_idx % self.train_log_row_interval == 0

with torchfields.set_identity_mapping_cache(True, clear_cache=False):
loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row)

return loss

def _get_warped(self, img, field=None):
img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value)
if field is not None:
img_warped = field.from_pixels()(img)
else:
img_warped = img

zeros_padded = img_padded == self.zero_value
zeros_padded_cc = np.array(
[
cc3d.connected_components(
x.detach().squeeze().cpu().numpy(), connectivity=4
).reshape(zeros_padded[0].shape)
for x in zeros_padded
]
)

non_tissue_zeros_padded = zeros_padded.clone()
non_tissue_zeros_padded[
torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device)
] = False # keep masking resin, restore somas in center

if field is not None:
zeros_warped = (
torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate")
.from_pixels()
.sample((~zeros_padded).float(), padding_mode="border")
<= 0.1
)
non_tissue_zeros_warped = (
torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate")
.from_pixels()
.sample((~non_tissue_zeros_padded).float(), padding_mode="border")
<= 0.1
)
else:
zeros_warped = zeros_padded
non_tissue_zeros_warped = non_tissue_zeros_padded

zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1))
non_tissue_zeros_warped = torch.nn.functional.pad(
non_tissue_zeros_warped, (-1, -1, -1, -1)
)

img_warped[zeros_warped] = self.zero_value
return img_warped, ~zeros_warped, ~non_tissue_zeros_warped

def _down_zeros_mask(self, zeros_mask, count=1):
scale_factor = 0.5 ** count
return (
torch.nn.functional.interpolate(
zeros_mask.float(), scale_factor=scale_factor, mode="bilinear"
)
> 0.99
) # 0.01

def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""):
src = batch["images"]["src"]
tgt = batch["images"]["tgt"]

if (
(src == self.zero_value) + (tgt == self.zero_value)
).bool().sum() / src.numel() > self.empty_tissue_threshold:
return None

seed_field = batch["field"].field_()
f_warp_large = seed_field * self.max_displacement_px
f_warp_small = (
seed_field
* self.field_magn_thr
* self.ds_factor
/ torch.quantile(seed_field.abs().max(1)[0], 0.5)
)

f_aff = (
tensor_ops.transform.get_affine_field(
size=src.shape[-1],
rot_deg=self.equivar_rot_deg_distr(),
scale=self.equivar_scale_distr(),
shear_x_deg=self.equivar_shear_deg_distr(),
shear_y_deg=self.equivar_shear_deg_distr(),
trans_x_px=self.equivar_trans_px_distr(),
trans_y_px=self.equivar_trans_px_distr(),
)
.pixels()
.to(seed_field.device)
).repeat_interleave(src.size(0), dim=0)
f1_trans = f_aff.from_pixels()(f_warp_large.from_pixels()).pixels()
f2_trans = f_warp_small.from_pixels()(f1_trans.from_pixels()).pixels()

magn_field = f_warp_small

src_f1, _, src_nonzeros_f1 = self._get_warped(src, f1_trans)
src_f2, _, src_nonzeros_f2 = self._get_warped(src, f2_trans)
tgt_f1, _, tgt_nonzeros_f1 = self._get_warped(tgt, f1_trans)

src_zeros_f1 = ~self._down_zeros_mask(src_nonzeros_f1, count=int(log2(self.ds_factor)))
src_zeros_f2 = ~self._down_zeros_mask(src_nonzeros_f2, count=int(log2(self.ds_factor)))
tgt_zeros_f1 = ~self._down_zeros_mask(tgt_nonzeros_f1, count=int(log2(self.ds_factor)))

src_enc = self.model(src)
src_f1_enc = self.model(src_f1)

f_pad = self.ds_factor
src_enc_f1 = torch.nn.functional.pad(
src_enc, (1, 1, 1, 1), value=0.0
) # TanH! - fill with output zero value
src_enc_f1 = (
torch.nn.functional.pad(
f1_trans, (f_pad, f_pad, f_pad, f_pad), mode="replicate" # type: ignore
)
.from_pixels()
.down(int(log2(self.ds_factor)))
.sample(src_enc_f1, padding_mode="border")
)
src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1), value=0.0)

equi_diff = (src_enc_f1 - src_f1_enc).abs()
equi_loss = equi_diff[src_zeros_f1 != 0].sum()
equi_loss = equi_diff.sum() / equi_diff.size(0)
equi_diff_map = equi_diff.clone()
equi_diff_map[src_zeros_f1] = 0

src_f2_enc = self.model(src_f2)
tgt_f1_enc = self.model(tgt_f1)

pre_diff = (src_f1_enc - tgt_f1_enc).abs()

pre_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f1)
pre_loss = pre_diff[..., pre_tissue_mask].sum() / pre_diff.size(0)
pre_diff_masked = pre_diff.clone()
pre_diff_masked[..., pre_tissue_mask == 0] = 0

post_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f2)
post_magn_mask = (
(
magn_field.from_pixels()
.down(int(log2(self.ds_factor)))
.pixels()
.abs()
.max(1, keepdim=True)[0]
)
> self.field_magn_thr
).tensor_()

post_diff_map = (src_f2_enc - tgt_f1_enc).abs()
post_mask = post_magn_mask * post_tissue_mask
if post_mask.sum() < (256 // (2 * self.ds_factor)):
return None

post_loss = post_diff_map[..., post_mask].sum() / post_diff_map.size(0)

post_diff_masked = post_diff_map.clone()
post_diff_masked[..., post_mask == 0] = 0

if mode == "train":
self._training_step += 1

post_weight_ratio = min(
1,
max(0, self._training_step - self.post_weight_start_step)
/ max(1, self.post_weight_end_step - self.post_weight_start_step),
)
post_weight = (
post_weight_ratio * self.post_weight_end_val
+ (1.0 - post_weight_ratio) * self.post_weight_start_val
)

loss = pre_loss - post_loss * post_weight + equi_loss * self.equivar_weight
self.log(f"param/post_weight", post_weight, on_step=True, on_epoch=True)
self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True)
self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True)
self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True)
self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True)
if log_row:
self.log_results(
mode,
sample_name,
src=src,
src_enc=src_enc,
src_f1=src_f1,
src_enc_f1=src_enc_f1,
src_f1_enc=src_f1_enc,
src_f2_enc=src_f2_enc,
tgt_f1=tgt_f1,
tgt_f1_enc=tgt_f1_enc,
field=seed_field.tensor_(),
equi_diff_map=equi_diff_map,
post_diff_masked=post_diff_masked,
pre_diff_masked=pre_diff_masked,
)

return loss

def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ
log_row = batch_idx % self.val_log_row_interval == 0
sample_name = f"{batch_idx // self.val_log_row_interval}"

with torchfields.set_identity_mapping_cache(True, clear_cache=False):
loss = self.compute_metroem_loss(
batch=batch, mode="val", log_row=log_row, sample_name=sample_name
)
return loss
Loading

0 comments on commit 73cb390

Please sign in to comment.