From 179ac78a14075bbfbe4b699d40da50fd640b843f Mon Sep 17 00:00:00 2001 From: balbasty Date: Mon, 16 Dec 2024 17:36:42 +0000 Subject: [PATCH] FEAT(cli.registration): memory-efficient chunk-wise processing --- nitorch/cli/registration/reslice/main.py | 124 ++++++++++++++++++--- nitorch/cli/registration/reslice/parser.py | 26 +++-- nitorch/cli/registration/reslice/struct.py | 1 + 3 files changed, 125 insertions(+), 26 deletions(-) diff --git a/nitorch/cli/registration/reslice/main.py b/nitorch/cli/registration/reslice/main.py index cff23ed8..03a42018 100644 --- a/nitorch/cli/registration/reslice/main.py +++ b/nitorch/cli/registration/reslice/main.py @@ -3,6 +3,8 @@ import json import copy import torch +import itertools +import math as pymath from nitorch import io, spatial from nitorch.cli.cli import commands from nitorch.core.cli import ParseError @@ -46,8 +48,8 @@ def reslice(argv=None): except ParseError as e: print(help) print(f'[ERROR] {str(e)}', file=sys.stderr) - except Exception as e: - print(f'[ERROR] {str(e)}', file=sys.stderr) + # except Exception as e: + # print(f'[ERROR] {str(e)}', file=sys.stderr) commands['reslice'] = reslice @@ -261,6 +263,9 @@ def write_volumes(options): backend = dict(dtype=torch.float32, device=options.device) + if options.chunk: + options.chunk = py.make_list(options.chunk, 3) + # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): @@ -342,7 +347,8 @@ def build_from_target(affine, shape, smart=False): # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: - grid = build_from_target(options.target.affine, options.target.shape) + if not options.chunk: + grid = build_from_target(options.target.affine, options.target.shape) oaffine = options.target.affine # 4) Loop across input files @@ -359,27 +365,28 @@ def build_from_target(affine, shape, smart=False): ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') + vol = io.volumes.map(file.fname) if is_label: backend_int = dict(dtype=torch.long, device=backend['device']) - dat = io.volumes.load(file.fname, **backend_int) + # dat = io.volumes.load(file.fname, **backend_int) opt_pull = dict(opt_pull0) opt_pull['interpolation'] = 1 else: - dat = io.volumes.loadf(file.fname, rand=False, **backend) + # dat = io.volumes.loadf(file.fname, rand=False, **backend) opt_pull = opt_pull0 if options.channels is not None: channels = py.make_list(options.channels) channels = [ list(c) if isinstance(c, range) else - list(range(dat.shape[-1]))[c] if isinstance(c, slice) else + list(range(vol.shape[-1]))[c] if isinstance(c, slice) else c for c in channels ] if not all([isinstance(c, int) for c in channels]): raise ValueError('Channel list should be a list of integers') - dat = dat[..., channels] - dat = dat.reshape([*file.shape, file.channels]) - dat = utils.movedim(dat, -1, 0) + else: + channels = slice(None) + resample = True if not options.target: oaffine = file.affine oshape = file.shape @@ -388,14 +395,97 @@ def build_from_target(affine, shape, smart=False): dtype=oaffine.dtype) factor = spatial.voxel_size(oaffine) / ovx oaffine, oshape = spatial.affine_resize(oaffine, oshape, factor=factor, anchor='f') - grid = build_from_target(oaffine, oshape, smart=not options.voxel_size) - if grid is not None: - mat = file.affine.to(**backend) - imat = spatial.affine_inv(mat) - if options.prefilter and not is_label: - dat = spatial.spline_coeff_nd(dat, **opt_coeff) - dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull) - dat = utils.movedim(dat, 0, -1) + else: + resample = not all( + isinstance(t, struct.Linear) + for t in options.transformations + ) + else: + oaffine = options.target.affine + oshape = options.target.shape + + # -------------------------------------------------------------- + # chunkwise processing + # -------------------------------------------------------------- + chunk = py.make_list(options.chunk, 3) + if resample and chunk and any(c < x for x, c in zip(file.shape, chunk)): + + odat = utils.movedim(torch.empty( + [*oshape, file.channels], + **(backend_int if is_label else backend) + ), -1, 0) + + ncells = [ + int(pymath.ceil(x/c)) + for x, c in zip(file.shape, chunk) + ] + pcells = [list(range(x)) for x in ncells] + for cell in itertools.product(*pcells): + oslicer = tuple(slice(i*c, (i+1)*c) for i, c in zip(cell, chunk)) + + # output chunk + odatc = odat[(Ellipsis, *oslicer)] + oshapec = odatc.shape[-3:] + oaffinec, _ = spatial.affine_sub(oaffine, oshape, oslicer) + + # grid + grid = build_from_target(oaffinec, oshapec) + mat = file.affine.to(**backend) + imat = spatial.affine_inv(mat) + grid = spatial.affine_matvec(imat, grid) + + mn = grid.reshape([-1, 3]).min(dim=0).values + mx = grid.reshape([-1, 3]).max(dim=0).values + for i in range(3): + mn[i].sub_(5).floor_().clamp_(0, file.shape[i]-1) + mx[i].add_(5).ceil_().clamp_(0, file.shape[i]-1) + grid -= mn + + # input chunk + mn, mx = mn.long(), mx.long() + islicer = tuple( + slice(mn1.item(), mx1.item()+1) + for mn1, mx1 in zip(mn, mx) + ) + idatc = vol[(*islicer, Ellipsis)] + if is_label: + idatc = idatc.data(**backend_int) + else: + idatc = idatc.fdata(rand=False, **backend) + idatc = idatc[..., channels] + ishapec = idatc.shape[:3] + idatc = idatc.reshape([*ishapec, file.channels]) + idatc = utils.movedim(idatc, -1, 0) + + # reslice + if options.prefilter and not is_label: + idatc = spatial.spline_coeff_nd(idatc, **opt_coeff) + idatc = helpers.pull(idatc, grid, **opt_pull) + odatc.copy_(idatc) + + dat = utils.movedim(odat, 0, -1) + + # -------------------------------------------------------------- + # one-shot processing + # -------------------------------------------------------------- + else: + if is_label: + dat = dat.data(**backend_int) + else: + dat = dat.fdata(rand=False, **backend) + dat = dat[..., channels] + dat = dat.reshape([*file.shape, file.channels]) + dat = utils.movedim(dat, -1, 0) + + if not options.target: + grid = build_from_target(oaffine, oshape, smart=not options.voxel_size) + if grid is not None: + mat = file.affine.to(**backend) + imat = spatial.affine_inv(mat) + if options.prefilter and not is_label: + dat = spatial.spline_coeff_nd(dat, **opt_coeff) + dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull) + dat = utils.movedim(dat, 0, -1) if is_label: io.volumes.save(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype) diff --git a/nitorch/cli/registration/reslice/parser.py b/nitorch/cli/registration/reslice/parser.py index 64bacbd9..d36e2a30 100644 --- a/nitorch/cli/registration/reslice/parser.py +++ b/nitorch/cli/registration/reslice/parser.py @@ -32,14 +32,15 @@ Other tags are: -t, --target Defines the target space. If not provided, minimal reslicing is performed. - -o, --output Name of the output file (default: '*.resliced*') - -i, --interpolation Interpolation order. Use `l` for labels. (1) - -p, --prefilter Apply spline prefilter (yes) - -b, --bound Boundary conditions (dct2) - -x, --extrapolate Extrapolate out-of-bounds data (no) - -v, --voxel-size Voxel size of the resliced space (default: from target) - -c, --channels Channels to load. Can be a range start:stop:step (default: all) - -dt, --dtype Output data type (default: from input) + -o, --output Name of the output file (default: '*.resliced*') + -i, --interpolation Interpolation order. Use `l` for labels. (1) + -p, --prefilter Apply spline prefilter (yes) + -b, --bound Boundary conditions (dct2) + -x, --extrapolate Extrapolate out-of-bounds data (no) + -vx, --voxel-size Voxel size of the resliced space (from target) + -c, --channels Channels to load. Can be a range start:stop:step (all) + -k, --chunk Process data one chunk--of this size--at a time (no) + -dt, --dtype Output data type (from input) -cpu, -gpu Device to use (cpu) The output image is @@ -176,11 +177,18 @@ def parse(args): bool(int(options.prefilter))) elif tag in ('-x', '-ex', '--extrapolate'): options.extrapolate = True - elif tag in ('-v', '-vx', '--voxel-size'): + elif tag in ('-vx', '--voxel-size'): options.voxel_size = [] while cli.next_isvalue(args): val, *args = args options.voxel_size.append(float(val)) + elif tag in ('-k', '--chunk'): + options.chunk = [] + while cli.next_isvalue(args): + val, *args = args + options.chunk.append(int(val)) + if not options.chunk: + options.chunk = [128] elif tag in ('-c', '--channels'): options.channels = [] while cli.next_isvalue(args): diff --git a/nitorch/cli/registration/reslice/struct.py b/nitorch/cli/registration/reslice/struct.py index a09c724d..bd2c311a 100644 --- a/nitorch/cli/registration/reslice/struct.py +++ b/nitorch/cli/registration/reslice/struct.py @@ -39,6 +39,7 @@ class Reslicer(Structure): transformations: list = [] target: str = None voxel_size: list = None + chunk: list = None output: str = '{dir}/{base}.resliced{ext}' interpolation: int = 1 bound: str = 'dct2'