Skip to content

Commit

Permalink
FEAT(cli.registration): memory-efficient chunk-wise processing
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Dec 16, 2024
1 parent d9c329f commit 179ac78
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 26 deletions.
124 changes: 107 additions & 17 deletions nitorch/cli/registration/reslice/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
import copy
import torch
import itertools
import math as pymath
from nitorch import io, spatial
from nitorch.cli.cli import commands
from nitorch.core.cli import ParseError
Expand Down Expand Up @@ -46,8 +48,8 @@ def reslice(argv=None):
except ParseError as e:
print(help)
print(f'[ERROR] {str(e)}', file=sys.stderr)
except Exception as e:
print(f'[ERROR] {str(e)}', file=sys.stderr)
# except Exception as e:
# print(f'[ERROR] {str(e)}', file=sys.stderr)


commands['reslice'] = reslice
Expand Down Expand Up @@ -261,6 +263,9 @@ def write_volumes(options):

backend = dict(dtype=torch.float32, device=options.device)

if options.chunk:
options.chunk = py.make_list(options.chunk, 3)

# 1) Pre-exponentiate velocities
for trf in options.transformations:
if isinstance(trf, struct.Velocity):
Expand Down Expand Up @@ -342,7 +347,8 @@ def build_from_target(affine, shape, smart=False):
# 3) If target is provided, we can build most of the transform once
# and just multiply it with a input-wise affine matrix.
if options.target:
grid = build_from_target(options.target.affine, options.target.shape)
if not options.chunk:
grid = build_from_target(options.target.affine, options.target.shape)
oaffine = options.target.affine

# 4) Loop across input files
Expand All @@ -359,27 +365,28 @@ def build_from_target(affine, shape, smart=False):
ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
print(f'Reslicing: {file.fname}\n'
f' -> {ofname}')
vol = io.volumes.map(file.fname)
if is_label:
backend_int = dict(dtype=torch.long, device=backend['device'])
dat = io.volumes.load(file.fname, **backend_int)
# dat = io.volumes.load(file.fname, **backend_int)
opt_pull = dict(opt_pull0)
opt_pull['interpolation'] = 1
else:
dat = io.volumes.loadf(file.fname, rand=False, **backend)
# dat = io.volumes.loadf(file.fname, rand=False, **backend)
opt_pull = opt_pull0
if options.channels is not None:
channels = py.make_list(options.channels)
channels = [
list(c) if isinstance(c, range) else
list(range(dat.shape[-1]))[c] if isinstance(c, slice) else
list(range(vol.shape[-1]))[c] if isinstance(c, slice) else
c for c in channels
]
if not all([isinstance(c, int) for c in channels]):
raise ValueError('Channel list should be a list of integers')
dat = dat[..., channels]
dat = dat.reshape([*file.shape, file.channels])
dat = utils.movedim(dat, -1, 0)
else:
channels = slice(None)

resample = True
if not options.target:
oaffine = file.affine
oshape = file.shape
Expand All @@ -388,14 +395,97 @@ def build_from_target(affine, shape, smart=False):
dtype=oaffine.dtype)
factor = spatial.voxel_size(oaffine) / ovx
oaffine, oshape = spatial.affine_resize(oaffine, oshape, factor=factor, anchor='f')
grid = build_from_target(oaffine, oshape, smart=not options.voxel_size)
if grid is not None:
mat = file.affine.to(**backend)
imat = spatial.affine_inv(mat)
if options.prefilter and not is_label:
dat = spatial.spline_coeff_nd(dat, **opt_coeff)
dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull)
dat = utils.movedim(dat, 0, -1)
else:
resample = not all(
isinstance(t, struct.Linear)
for t in options.transformations
)
else:
oaffine = options.target.affine
oshape = options.target.shape

# --------------------------------------------------------------
# chunkwise processing
# --------------------------------------------------------------
chunk = py.make_list(options.chunk, 3)
if resample and chunk and any(c < x for x, c in zip(file.shape, chunk)):

odat = utils.movedim(torch.empty(
[*oshape, file.channels],
**(backend_int if is_label else backend)
), -1, 0)

ncells = [
int(pymath.ceil(x/c))
for x, c in zip(file.shape, chunk)
]
pcells = [list(range(x)) for x in ncells]
for cell in itertools.product(*pcells):
oslicer = tuple(slice(i*c, (i+1)*c) for i, c in zip(cell, chunk))

# output chunk
odatc = odat[(Ellipsis, *oslicer)]
oshapec = odatc.shape[-3:]
oaffinec, _ = spatial.affine_sub(oaffine, oshape, oslicer)

# grid
grid = build_from_target(oaffinec, oshapec)
mat = file.affine.to(**backend)
imat = spatial.affine_inv(mat)
grid = spatial.affine_matvec(imat, grid)

mn = grid.reshape([-1, 3]).min(dim=0).values
mx = grid.reshape([-1, 3]).max(dim=0).values
for i in range(3):
mn[i].sub_(5).floor_().clamp_(0, file.shape[i]-1)
mx[i].add_(5).ceil_().clamp_(0, file.shape[i]-1)
grid -= mn

# input chunk
mn, mx = mn.long(), mx.long()
islicer = tuple(
slice(mn1.item(), mx1.item()+1)
for mn1, mx1 in zip(mn, mx)
)
idatc = vol[(*islicer, Ellipsis)]
if is_label:
idatc = idatc.data(**backend_int)
else:
idatc = idatc.fdata(rand=False, **backend)
idatc = idatc[..., channels]
ishapec = idatc.shape[:3]
idatc = idatc.reshape([*ishapec, file.channels])
idatc = utils.movedim(idatc, -1, 0)

# reslice
if options.prefilter and not is_label:
idatc = spatial.spline_coeff_nd(idatc, **opt_coeff)
idatc = helpers.pull(idatc, grid, **opt_pull)
odatc.copy_(idatc)

dat = utils.movedim(odat, 0, -1)

# --------------------------------------------------------------
# one-shot processing
# --------------------------------------------------------------
else:
if is_label:
dat = dat.data(**backend_int)
else:
dat = dat.fdata(rand=False, **backend)
dat = dat[..., channels]
dat = dat.reshape([*file.shape, file.channels])
dat = utils.movedim(dat, -1, 0)

if not options.target:
grid = build_from_target(oaffine, oshape, smart=not options.voxel_size)
if grid is not None:
mat = file.affine.to(**backend)
imat = spatial.affine_inv(mat)
if options.prefilter and not is_label:
dat = spatial.spline_coeff_nd(dat, **opt_coeff)
dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull)
dat = utils.movedim(dat, 0, -1)

if is_label:
io.volumes.save(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype)
Expand Down
26 changes: 17 additions & 9 deletions nitorch/cli/registration/reslice/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@
Other tags are:
-t, --target Defines the target space.
If not provided, minimal reslicing is performed.
-o, --output Name of the output file (default: '*.resliced*')
-i, --interpolation Interpolation order. Use `l` for labels. (1)
-p, --prefilter Apply spline prefilter (yes)
-b, --bound Boundary conditions (dct2)
-x, --extrapolate Extrapolate out-of-bounds data (no)
-v, --voxel-size Voxel size of the resliced space (default: from target)
-c, --channels Channels to load. Can be a range start:stop:step (default: all)
-dt, --dtype Output data type (default: from input)
-o, --output Name of the output file (default: '*.resliced*')
-i, --interpolation Interpolation order. Use `l` for labels. (1)
-p, --prefilter Apply spline prefilter (yes)
-b, --bound Boundary conditions (dct2)
-x, --extrapolate Extrapolate out-of-bounds data (no)
-vx, --voxel-size Voxel size of the resliced space (from target)
-c, --channels Channels to load. Can be a range start:stop:step (all)
-k, --chunk Process data one chunk--of this size--at a time (no)
-dt, --dtype Output data type (from input)
-cpu, -gpu Device to use (cpu)
The output image is
Expand Down Expand Up @@ -176,11 +177,18 @@ def parse(args):
bool(int(options.prefilter)))
elif tag in ('-x', '-ex', '--extrapolate'):
options.extrapolate = True
elif tag in ('-v', '-vx', '--voxel-size'):
elif tag in ('-vx', '--voxel-size'):
options.voxel_size = []
while cli.next_isvalue(args):
val, *args = args
options.voxel_size.append(float(val))
elif tag in ('-k', '--chunk'):
options.chunk = []
while cli.next_isvalue(args):
val, *args = args
options.chunk.append(int(val))
if not options.chunk:
options.chunk = [128]
elif tag in ('-c', '--channels'):
options.channels = []
while cli.next_isvalue(args):
Expand Down
1 change: 1 addition & 0 deletions nitorch/cli/registration/reslice/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Reslicer(Structure):
transformations: list = []
target: str = None
voxel_size: list = None
chunk: list = None
output: str = '{dir}/{base}.resliced{ext}'
interpolation: int = 1
bound: str = 'dct2'
Expand Down

0 comments on commit 179ac78

Please sign in to comment.