Skip to content

Commit

Permalink
FIX(register): bug with channels option + FEAT(io/reslice): support s…
Browse files Browse the repository at this point in the history
…treamlines (only TRK for now)
  • Loading branch information
balbasty committed Dec 16, 2024
1 parent 62b17f6 commit d9c329f
Show file tree
Hide file tree
Showing 12 changed files with 700 additions and 10 deletions.
146 changes: 139 additions & 7 deletions nitorch/cli/registration/reslice/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import os
import json
import copy
import torch
from nitorch import io, spatial
from nitorch.cli.cli import commands
Expand All @@ -26,8 +27,21 @@ def reslice(argv=None):
return

read_info(options)
collapse(options)
write_data(options)

options_volumes = options
options_streamlines = copy.deepcopy(options)
options_streamlines.transformations \
= invert_transforms(options_streamlines.transformations)

if options_volumes.volumes:
collapse(options_volumes)
write_volumes(options_volumes)
del options_volumes

if options_streamlines.streamlines:
collapse(options_streamlines)
write_streamlines(options_streamlines)
del options_streamlines

except ParseError as e:
print(help)
Expand Down Expand Up @@ -81,7 +95,17 @@ def read_info(options):
"""Load affine transforms and space info of other volumes"""

def read_file(fname):
f = io.map(fname)
if isinstance(f, io.MappedArray):
return read_volume(fname, f)
elif isinstance(f, io.MappedStreamlines):
return read_streamlines(fname, f)
else:
raise TypeError(type(f))

def read_streamlines(fname, f=None):
o = struct.FileWithInfo()
o.type = "streamlines"
o.fname = fname
o.dir = os.path.dirname(fname) or '.'
o.base = os.path.basename(fname)
Expand All @@ -90,7 +114,22 @@ def read_file(fname):
zext = o.ext
o.base, o.ext = os.path.splitext(o.base)
o.ext += zext
f = io.volumes.map(fname)
f = f or io.streamlines.map(f)
o.affine = f.affine.float()
return o

def read_volume(fname, f=None):
o = struct.FileWithInfo()
o.type = "volume"
o.fname = fname
o.dir = os.path.dirname(fname) or '.'
o.base = os.path.basename(fname)
o.base, o.ext = os.path.splitext(o.base)
if o.ext in ('.gz', '.bz2'):
zext = o.ext
o.base, o.ext = os.path.splitext(o.base)
o.ext += zext
f = f or io.volumes.map(fname)
o.float = nitype(f.dtype).is_floating_point
o.shape = squeeze_to_nd(f.shape, dim=3, channels=1)
o.channels = o.shape[-1]
Expand Down Expand Up @@ -119,6 +158,12 @@ def read_field(fname):
return f.affine.float(), f.shape[:3]

options.files = [read_file(file) for file in options.files]
options.streamlines = [
file for file in options.files if file.type == "streamlines"
]
options.volumes = [
file for file in options.files if file.type == "volume"
]
for trf in options.transformations:
if isinstance(trf, struct.Linear):
trf.affine = read_affine(trf.file)
Expand Down Expand Up @@ -205,8 +250,14 @@ def exponentiate_transforms(transformations, **backend):
return transformations


def invert_transforms(transformations):
transformations = [copy.deepcopy(trf) for trf in reversed(transformations)]
for trf in transformations:
trf.inv = not trf.inv
return transformations


def write_data(options):
def write_volumes(options):

backend = dict(dtype=torch.float32, device=options.device)

Expand Down Expand Up @@ -254,7 +305,7 @@ def write_data(options):
if (options.transformations and
isinstance(options.transformations[0], struct.Linear)):
trf = options.transformations[0]
for file in options.files:
for file in options.volumes:
mat = file.affine.to(**backend)
aff = trf.affine.to(**backend)
file.affine = spatial.affine_lmdiv(aff, mat)
Expand Down Expand Up @@ -302,8 +353,8 @@ def build_from_target(affine, shape, smart=False):
bound=options.bound,
dim=3,
inplace=True)
output = py.make_list(options.output, len(options.files))
for file, ofname in zip(options.files, output):
output = py.make_list(options.output, len(options.volumes))
for file, ofname in zip(options.volumes, output):
is_label = isinstance(options.interpolation, str) and options.interpolation == 'l'
ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
print(f'Reslicing: {file.fname}\n'
Expand Down Expand Up @@ -350,3 +401,84 @@ def build_from_target(affine, shape, smart=False):
io.volumes.save(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype)
else:
io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype)


def write_streamlines(options):

backend = dict(dtype=torch.float32, device=options.device)

# 1) Pre-exponentiate velocities
for trf in options.transformations:
if isinstance(trf, struct.Velocity):
f = io.volumes.map(trf.file)
trf.affine = f.affine
trf.shape = squeeze_to_nd(f.shape, 3, 1)
trf.dat = f.fdata(**backend).reshape(trf.shape)
trf.shape = trf.shape[:3]
if trf.json:
if trf.square:
trf.dat.mul_(0.5)
with open(trf.json) as f:
prm = json.load(f)
prm['voxel_size'] = spatial.voxel_size(trf.affine)
trf.dat = spatial.shoot(trf.dat[None], displacement=True,
return_inverse=trf.inv)
if trf.inv:
trf.dat = trf.dat[-1]
else:
if trf.square:
trf.dat.mul_(0.5)
trf.dat = spatial.exp(trf.dat[None], displacement=True,
inverse=trf.inv)
trf.dat = trf.dat[0] # drop batch dimension
trf.inv = False
trf.square = False
trf.order = 1
elif isinstance(trf, struct.Displacement):
f = io.volumes.map(trf.file)
trf.affine = f.affine
trf.shape = squeeze_to_nd(f.shape, 3, 1)
trf.dat = f.fdata(**backend).reshape(trf.shape)
trf.shape = trf.shape[:3]
if trf.unit == 'mm':
# convert mm displacement to vox displacement
trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
trf.dat = trf.dat[..., 0]
trf.unit = 'vox'

def deform_streamlines(streamlines):
"""Compose all transformations, starting from the final orientation"""
for trf in reversed(options.transformations):
if isinstance(trf, struct.Linear):
streamlines = spatial.affine_matvec(trf.affine.to(streamlines), streamlines)
else:
mat = trf.affine.to(**backend)
if trf.inv:
disp = spatial.grid_inv(trf.dat, type='disp')
order = 1
else:
disp = trf.dat
order = trf.order
imat = spatial.affine_inv(mat)
streamlines = spatial.affine_matvec(imat.to(streamlines), streamlines)
streamlines = streamlines + helpers.pull_grid(disp, streamlines[None, None], interpolation=order)[0, 0]
streamlines = spatial.affine_matvec(mat.to(streamlines), streamlines)
return streamlines

# 4) Loop across input files
output = py.make_list(options.output, len(options.streamlines))
for file, ofname in zip(options.streamlines, output):
ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
print(f'Reslicing: {file.fname}\n'
f' -> {ofname}')
dat = list(io.streamlines.loadf(file.fname, **backend))
offsets = py.cumsum(map(len, dat), exclusive=True)
dat = torch.cat(list(dat))

dat = deform_streamlines(dat)
dat = [
dat[offsets[i]:(offsets[i+1] if i+1 < len(offsets) else None)]
for i in range(len(offsets))
]

io.streamlines.savef(dat, ofname, like=file.fname)
2 changes: 2 additions & 0 deletions nitorch/cli/registration/reslice/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
nitorch reslice *FILES <*TRF> FILE [-t FILE] [-o *FILE]
[-i ORDER] [-b BND] [-p] [-x] [-cpu|gpu]
<FILES> can be paths to volumes or to streamlines.
<TRF> can take values (with additional options):
-l, --linear Linear transform (i.e., affine matrix)
-d, --displacement Dense or free-form displacement field
Expand Down
1 change: 1 addition & 0 deletions nitorch/cli/registration/reslice/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class FileWithInfo(Structure):
ext: str = None # Extension
channels: int = None # Number of channels
float: bool = True # Is raw dtype floating point
type: str = None


class Transform(Structure):
Expand Down
2 changes: 2 additions & 0 deletions nitorch/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@
from . import metadata
from . import optionals
from . import readers
from . import streamlines
from . import transforms
from . import utils
from . import writers

from .volumes import MappedArray, CatArray, cat, stack
from .transforms import MappedAffine
from .streamlines import MappedStreamlines
from .loadsave import map, load, loadf, save, savef
10 changes: 10 additions & 0 deletions nitorch/io/streamlines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from . import loadsave
from . import mapping
from . import readers
from . import writers

from .mapping import MappedStreamlines
from .loadsave import map, load, loadf, save, savef

# Import implementations
from .trk import TrkStreamlines
35 changes: 35 additions & 0 deletions nitorch/io/streamlines/loadsave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Specialization for volumes."""
from functools import wraps
from .readers import reader_classes as streamline_reader_classes
from .writers import writer_classes as streamline_writer_classes
from .. import loadsave


@wraps(loadsave.map)
def map(*args, reader_classes=None, **kwargs):
reader_classes = reader_classes or streamline_reader_classes
return loadsave.map(*args, reader_classes=reader_classes, **kwargs)


@wraps(loadsave.load)
def load(*args, reader_classes=None, **kwargs):
reader_classes = reader_classes or streamline_reader_classes
return loadsave.load(*args, reader_classes=reader_classes, **kwargs)


@wraps(loadsave.loadf)
def loadf(*args, reader_classes=None, **kwargs):
reader_classes = reader_classes or streamline_reader_classes
return loadsave.loadf(*args, reader_classes=reader_classes, **kwargs)


@wraps(loadsave.save)
def save(*args, writer_classes=None, **kwargs):
writer_classes = writer_classes or streamline_writer_classes
return loadsave.save(*args, writer_classes=writer_classes, **kwargs)


@wraps(loadsave.savef)
def savef(*args, writer_classes=None, **kwargs):
writer_classes = writer_classes or streamline_writer_classes
return loadsave.savef(*args, writer_classes=writer_classes, **kwargs)
Loading

0 comments on commit d9c329f

Please sign in to comment.