From e149d82422c3271764b66e9b86a6cf53615e4aac Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Oct 2024 14:30:36 +0100 Subject: [PATCH] Fix(align_tpm): use quadratic interpolation in logit space instead of linear interpolation in prob space (mimcs spm_maff8) --- nitorch/tools/registration/affine_tpm.py | 46 +++++++++--- .../tools/registration/pairwise_preproc.py | 74 +++++++++++++++---- 2 files changed, 94 insertions(+), 26 deletions(-) diff --git a/nitorch/tools/registration/affine_tpm.py b/nitorch/tools/registration/affine_tpm.py index 97cc25c3..ab5bee5c 100644 --- a/nitorch/tools/registration/affine_tpm.py +++ b/nitorch/tools/registration/affine_tpm.py @@ -33,7 +33,7 @@ ๐“› = ๐”ผ_q[ln p(๐’™)] = โˆ‘โ‚™แตข q(๐‘ฅโ‚™ = ๐‘–) ln โˆ‘โฑผ ๐ปแตขโฑผ (๐œ‡ โˆ˜ ๐œ™)โ‚™โฑผ """ import torch -from nitorch.core import linalg, utils, py +from nitorch.core import linalg, utils, py, math from nitorch import spatial, io from .utils import jg, jhj, affine_grid_backward import nitorch.plot as niplt @@ -154,6 +154,18 @@ def align_tpm(dat, tpm=None, weights=None, spacing=(8, 4), device=None, # ------------------------------------------------------------------ dat = discretize(dat, nbins=bins, mask=weights) + # ------------------------------------------------------------------ + # PREFILTER TPM + # ------------------------------------------------------------------ + logtpm = tpm.clone() + # ensure normalized + logtpm = logtpm.clamp(tiny, 1-tiny).div_(logtpm.sum(0, keepdim=True)) + # transform to logits + logtpm = logtpm.add_(tiny).log_() + # spline prefilter + splineopt = dict(interpolation=2, bound='replicate') + logtpm = spatial.spline_coeff_nd(logtpm, dim=3, inplace=True, **splineopt) + # ------------------------------------------------------------------ # OPTIONS # ------------------------------------------------------------------ @@ -175,8 +187,8 @@ def do_spacing(sp): if not sp: return dat0, affine_dat0, weights0 sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx] - sp = [slice(None, None, sp1) for sp1 in sp] - affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], tuple(sp)) + sp = tuple([slice(None, None, sp1) for sp1 in sp]) + affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], sp) dat = dat0[(Ellipsis, *sp)] if weights0 is not None: weights = weights0[(Ellipsis, *sp)] @@ -234,7 +246,7 @@ def do_spacing(sp): if reorient is not None: affine_dat = reorient.matmul(affine_dat) - mi, aff, prm = fit_affine_tpm(dat, tpm, affine_dat, affine_tpm, + mi, aff, prm = fit_affine_tpm(dat, logtpm, affine_dat, affine_tpm, weights, **opt, prm=prm) if reorient is not None: @@ -263,7 +275,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, affine_tpm : (4, 4) tensor weights : (*spatial) tensor basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'} - fwhm : float, default=J/32 + fwhm : float, default=J/64 max_iter_gn : int, default=100 max_iter_em : int, default=32 max_line_search : int, default=12 @@ -276,6 +288,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, prm : (F) tensor """ + # !!! NOTE: `tpm` must contain spline-prefiltered log-probabilities + dim = tpm.dim() - 1 # ------------------------------------------------------------------ @@ -326,7 +340,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, affine_tpm = affine_tpm.to(**utils.backend(tpm)) shape = dat.shape[-dim:] - tpm = tpm.to(dat.device).clamp(tiny, 1-tiny) + tpm = tpm.to(dat.device) basis = make_basis(basis, dim, **utils.backend(tpm)) F = len(basis) @@ -337,7 +351,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights, verbose=verbose-2) drv_opt = dict(weights=weights) - pull_opt = dict(bound='replicate', extrapolate=True) + pull_opt = dict(bound='replicate', extrapolate=True, interpolation=2) # ------------------------------------------------------------------ # OPTIMIZE @@ -365,6 +379,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, # --- warp TPM --------------------------------------------- mov = spatial.grid_pull(tpm, phi, **pull_opt) + mov = math.softmax(mov, dim=1) # --- mutual info ------------------------------------------ mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt) @@ -399,8 +414,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, end = '\n' if verbose >= 2 else '\r' print(f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end) - if mi.mean() - mi0.mean() < 1e-4: - # print('converged', mi.mean() - mi0.mean()) + if mi.mean() - mi0.mean() < 0: #1e-4: + print('converged', mi.mean() - mi0.mean()) break # -------------------------------------------------------------- @@ -412,8 +427,13 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, g = g.sum(0) h = h.sum(0) - # --- chain rule ----------------------------------------------- + # --- spatial derivatives -------------------------------------- + mov = mov.unsqueeze(-1) gmov = spatial.grid_grad(tpm, phi, **pull_opt) + gmov = mov * (gmov - (mov * gmov).sum(1, keepdim=True)) + mov = mov.squeeze(-1) + + # --- chain rule ----------------------------------------------- gaff = lmdiv(affine_tpm, mm(gaff, affine)) g, h = chain_rule(g, h, gmov, gaff, maj=False) del gmov @@ -421,7 +441,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, # --- Gauss-Newton --------------------------------------------- h.diagonal(0, -1, -2).add_(h.diagonal(0, -1, -2).abs().max() * 1e-5) delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1) - foo = 0 + + plot_registration(dat, mov, f'{basis_name} | {n_iter}') if verbose == 1: print('') @@ -898,7 +919,8 @@ def discretize(dat, nbins=256, mask=None): def get_spm_prior(**backend): """Download the SPM prior""" - url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii' + # url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii' + url = 'https://github.com/spm/spm12/raw/refs/heads/main/tpm/TPM.nii' fname = os.path.join(cache_dir, 'SPM12_TPM.nii') if not os.path.exists(fname): os.makedirs(cache_dir, exist_ok=True) diff --git a/nitorch/tools/registration/pairwise_preproc.py b/nitorch/tools/registration/pairwise_preproc.py index 6e90a2d1..6689dcbd 100644 --- a/nitorch/tools/registration/pairwise_preproc.py +++ b/nitorch/tools/registration/pairwise_preproc.py @@ -12,7 +12,7 @@ def preproc_image(input, mask=None, label=False, missing=0, world=None, affine=None, rescale=.95, - pad=None, bound='zero', fwhm=None, + pad=None, bound='zero', fwhm=None, channels=None, dim=None, device=None, **kwargs): """Load an image and preprocess it as required @@ -43,6 +43,8 @@ def preproc_image(input, mask=None, label=False, missing=0, fwhm : [sequence of] float Smooth the volume with a Gaussian kernel of that FWHM. If last element is "mm", values are in mm and converted to voxels. + channels : [sequence of] int or range or slice + Channels to load dim : int, optional Number of spatial dimensions device : torch.device @@ -58,14 +60,30 @@ def preproc_image(input, mask=None, label=False, missing=0, Orientation matrix """ - if not torch.is_tensor(input): - dat, mask0, affine0 = load_image(input, dim=dim, device=device, - label=label, missing=missing) - else: - dat = input - mask0 = torch.isfinite(dat) - dat = dat.masked_fill(~mask0, 0) - affine0 = spatial.affine_default(dat.shape[1:]) + dat, mask0, affine0 = load_image(input, dim=dim, device=device, + label=label, missing=missing, + channels=channels) + + # if not torch.is_tensor(input): + # dat, mask0, affine0 = load_image(input, dim=dim, device=device, + # label=label, missing=missing, + # channels=channels) + # else: + # dat = input + # if channels is not None: + # channels = make_list(channels) + # channels = [ + # list(c) if isinstance(c, range) else + # list(range(len(dat)))[c] if isinstance(c, slice) else + # c for c in channels + # ] + # if not all([isinstance(c, int) for c in channels]): + # raise ValueError('Channel list should be a list of integers') + # dat = dat[channels] + # mask0 = torch.isfinite(dat) + # dat = dat.masked_fill(~mask0, 0) + # affine0 = spatial.affine_default(dat.shape[1:]) + dim = dat.dim() - 1 # load user-defined mask @@ -199,7 +217,7 @@ def prepare_pyramid_levels(images, levels, dim=None, **opt): return pyrutils.pyramid_levels(vxs, shapes, levels, **opt) -def map_image(fnames, dim=None): +def map_image(fnames, dim=None, channels=None): """Map an ND image from disk Parameters @@ -229,7 +247,6 @@ def map_image(fnames, dim=None): affine = img.affine if dim is None: dim = img.affine.shape[-1] - 1 - # img = img.fdata(rand=True, device=device) if img.dim > dim: img = img.movedim(-1, 0) else: @@ -241,10 +258,24 @@ def map_image(fnames, dim=None): imgs.append(img) del img imgs = io.cat(imgs, dim=0) + + # select a subset of channels + if channels is not None: + channels = make_list(channels) + channels = [ + list(c) if isinstance(c, range) else + list(range(len(imgs)))[c] if isinstance(c, slice) else + c for c in channels + ] + if not all([isinstance(c, int) for c in channels]): + raise ValueError('Channel list should be a list of integers') + imgs = io.stack([imgs[c] for c in channels]) + return imgs, affine -def load_image(input, dim=None, device=None, label=False, missing=0): +def load_image(input, dim=None, device=None, label=False, missing=0, + channels=None): """ Load a N-D image from disk @@ -272,15 +303,30 @@ def load_image(input, dim=None, device=None, label=False, missing=0): Orientation matrix """ if not torch.is_tensor(input): - dat, affine = map_image(input, dim) + dat, affine = map_image(input, dim, channels=channels) else: dat, affine = input, spatial.affine_default(input.shape[1:]) + + if channels is not None: + channels = make_list(channels) + channels = [ + list(c) if isinstance(c, range) else + list(range(len(dat)))[c] if isinstance(c, slice) else + c for c in channels + ] + if not all([isinstance(c, int) for c in channels]): + raise ValueError('Channel list should be a list of integers') + dat = dat[channels] + if label: dtype = dat.dtype if isinstance(dtype, (list, tuple)): dtype = dtype[0] dtype = dtypes.as_torch(dtype, upcast=True) - dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel + if torch.is_tensor(dat): + dat0 = dat[0] + else: + dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel if label is True: label = dat0.unique(sorted=True) label = label[label != 0].tolist()