From d9c329f24271cc8b0aadd87152bc28bb6044e2a0 Mon Sep 17 00:00:00 2001 From: balbasty Date: Mon, 16 Dec 2024 14:48:42 +0000 Subject: [PATCH] FIX(register): bug with channels option + FEAT(io/reslice): support streamlines (only TRK for now) --- nitorch/cli/registration/reslice/main.py | 146 ++++++++++- nitorch/cli/registration/reslice/parser.py | 2 + nitorch/cli/registration/reslice/struct.py | 1 + nitorch/io/__init__.py | 2 + nitorch/io/streamlines/__init__.py | 10 + nitorch/io/streamlines/loadsave.py | 35 +++ nitorch/io/streamlines/mapping.py | 235 +++++++++++++++++ nitorch/io/streamlines/readers.py | 15 ++ nitorch/io/streamlines/trk.py | 239 ++++++++++++++++++ nitorch/io/streamlines/writers.py | 15 ++ nitorch/io/transforms/freesurfer/lta.py | 2 +- .../tools/registration/pairwise_preproc.py | 8 +- 12 files changed, 700 insertions(+), 10 deletions(-) create mode 100644 nitorch/io/streamlines/__init__.py create mode 100644 nitorch/io/streamlines/loadsave.py create mode 100644 nitorch/io/streamlines/mapping.py create mode 100644 nitorch/io/streamlines/readers.py create mode 100644 nitorch/io/streamlines/trk.py create mode 100644 nitorch/io/streamlines/writers.py diff --git a/nitorch/cli/registration/reslice/main.py b/nitorch/cli/registration/reslice/main.py index 7c44c234..cff23ed8 100644 --- a/nitorch/cli/registration/reslice/main.py +++ b/nitorch/cli/registration/reslice/main.py @@ -1,6 +1,7 @@ import sys import os import json +import copy import torch from nitorch import io, spatial from nitorch.cli.cli import commands @@ -26,8 +27,21 @@ def reslice(argv=None): return read_info(options) - collapse(options) - write_data(options) + + options_volumes = options + options_streamlines = copy.deepcopy(options) + options_streamlines.transformations \ + = invert_transforms(options_streamlines.transformations) + + if options_volumes.volumes: + collapse(options_volumes) + write_volumes(options_volumes) + del options_volumes + + if options_streamlines.streamlines: + collapse(options_streamlines) + write_streamlines(options_streamlines) + del options_streamlines except ParseError as e: print(help) @@ -81,7 +95,17 @@ def read_info(options): """Load affine transforms and space info of other volumes""" def read_file(fname): + f = io.map(fname) + if isinstance(f, io.MappedArray): + return read_volume(fname, f) + elif isinstance(f, io.MappedStreamlines): + return read_streamlines(fname, f) + else: + raise TypeError(type(f)) + + def read_streamlines(fname, f=None): o = struct.FileWithInfo() + o.type = "streamlines" o.fname = fname o.dir = os.path.dirname(fname) or '.' o.base = os.path.basename(fname) @@ -90,7 +114,22 @@ def read_file(fname): zext = o.ext o.base, o.ext = os.path.splitext(o.base) o.ext += zext - f = io.volumes.map(fname) + f = f or io.streamlines.map(f) + o.affine = f.affine.float() + return o + + def read_volume(fname, f=None): + o = struct.FileWithInfo() + o.type = "volume" + o.fname = fname + o.dir = os.path.dirname(fname) or '.' + o.base = os.path.basename(fname) + o.base, o.ext = os.path.splitext(o.base) + if o.ext in ('.gz', '.bz2'): + zext = o.ext + o.base, o.ext = os.path.splitext(o.base) + o.ext += zext + f = f or io.volumes.map(fname) o.float = nitype(f.dtype).is_floating_point o.shape = squeeze_to_nd(f.shape, dim=3, channels=1) o.channels = o.shape[-1] @@ -119,6 +158,12 @@ def read_field(fname): return f.affine.float(), f.shape[:3] options.files = [read_file(file) for file in options.files] + options.streamlines = [ + file for file in options.files if file.type == "streamlines" + ] + options.volumes = [ + file for file in options.files if file.type == "volume" + ] for trf in options.transformations: if isinstance(trf, struct.Linear): trf.affine = read_affine(trf.file) @@ -205,8 +250,14 @@ def exponentiate_transforms(transformations, **backend): return transformations +def invert_transforms(transformations): + transformations = [copy.deepcopy(trf) for trf in reversed(transformations)] + for trf in transformations: + trf.inv = not trf.inv + return transformations + -def write_data(options): +def write_volumes(options): backend = dict(dtype=torch.float32, device=options.device) @@ -254,7 +305,7 @@ def write_data(options): if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] - for file in options.files: + for file in options.volumes: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) @@ -302,8 +353,8 @@ def build_from_target(affine, shape, smart=False): bound=options.bound, dim=3, inplace=True) - output = py.make_list(options.output, len(options.files)) - for file, ofname in zip(options.files, output): + output = py.make_list(options.output, len(options.volumes)) + for file, ofname in zip(options.volumes, output): is_label = isinstance(options.interpolation, str) and options.interpolation == 'l' ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' @@ -350,3 +401,84 @@ def build_from_target(affine, shape, smart=False): io.volumes.save(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype) else: io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine, dtype=options.dtype) + + +def write_streamlines(options): + + backend = dict(dtype=torch.float32, device=options.device) + + # 1) Pre-exponentiate velocities + for trf in options.transformations: + if isinstance(trf, struct.Velocity): + f = io.volumes.map(trf.file) + trf.affine = f.affine + trf.shape = squeeze_to_nd(f.shape, 3, 1) + trf.dat = f.fdata(**backend).reshape(trf.shape) + trf.shape = trf.shape[:3] + if trf.json: + if trf.square: + trf.dat.mul_(0.5) + with open(trf.json) as f: + prm = json.load(f) + prm['voxel_size'] = spatial.voxel_size(trf.affine) + trf.dat = spatial.shoot(trf.dat[None], displacement=True, + return_inverse=trf.inv) + if trf.inv: + trf.dat = trf.dat[-1] + else: + if trf.square: + trf.dat.mul_(0.5) + trf.dat = spatial.exp(trf.dat[None], displacement=True, + inverse=trf.inv) + trf.dat = trf.dat[0] # drop batch dimension + trf.inv = False + trf.square = False + trf.order = 1 + elif isinstance(trf, struct.Displacement): + f = io.volumes.map(trf.file) + trf.affine = f.affine + trf.shape = squeeze_to_nd(f.shape, 3, 1) + trf.dat = f.fdata(**backend).reshape(trf.shape) + trf.shape = trf.shape[:3] + if trf.unit == 'mm': + # convert mm displacement to vox displacement + trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) + trf.dat = trf.dat[..., 0] + trf.unit = 'vox' + + def deform_streamlines(streamlines): + """Compose all transformations, starting from the final orientation""" + for trf in reversed(options.transformations): + if isinstance(trf, struct.Linear): + streamlines = spatial.affine_matvec(trf.affine.to(streamlines), streamlines) + else: + mat = trf.affine.to(**backend) + if trf.inv: + disp = spatial.grid_inv(trf.dat, type='disp') + order = 1 + else: + disp = trf.dat + order = trf.order + imat = spatial.affine_inv(mat) + streamlines = spatial.affine_matvec(imat.to(streamlines), streamlines) + streamlines = streamlines + helpers.pull_grid(disp, streamlines[None, None], interpolation=order)[0, 0] + streamlines = spatial.affine_matvec(mat.to(streamlines), streamlines) + return streamlines + + # 4) Loop across input files + output = py.make_list(options.output, len(options.streamlines)) + for file, ofname in zip(options.streamlines, output): + ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) + print(f'Reslicing: {file.fname}\n' + f' -> {ofname}') + dat = list(io.streamlines.loadf(file.fname, **backend)) + offsets = py.cumsum(map(len, dat), exclusive=True) + dat = torch.cat(list(dat)) + + dat = deform_streamlines(dat) + dat = [ + dat[offsets[i]:(offsets[i+1] if i+1 < len(offsets) else None)] + for i in range(len(offsets)) + ] + + io.streamlines.savef(dat, ofname, like=file.fname) diff --git a/nitorch/cli/registration/reslice/parser.py b/nitorch/cli/registration/reslice/parser.py index 56b4b60c..64bacbd9 100644 --- a/nitorch/cli/registration/reslice/parser.py +++ b/nitorch/cli/registration/reslice/parser.py @@ -7,6 +7,8 @@ nitorch reslice *FILES <*TRF> FILE [-t FILE] [-o *FILE] [-i ORDER] [-b BND] [-p] [-x] [-cpu|gpu] + can be paths to volumes or to streamlines. + can take values (with additional options): -l, --linear Linear transform (i.e., affine matrix) -d, --displacement Dense or free-form displacement field diff --git a/nitorch/cli/registration/reslice/struct.py b/nitorch/cli/registration/reslice/struct.py index 884a0985..a09c724d 100644 --- a/nitorch/cli/registration/reslice/struct.py +++ b/nitorch/cli/registration/reslice/struct.py @@ -11,6 +11,7 @@ class FileWithInfo(Structure): ext: str = None # Extension channels: int = None # Number of channels float: bool = True # Is raw dtype floating point + type: str = None class Transform(Structure): diff --git a/nitorch/io/__init__.py b/nitorch/io/__init__.py index 65324dbb..5a6fb7da 100644 --- a/nitorch/io/__init__.py +++ b/nitorch/io/__init__.py @@ -206,10 +206,12 @@ from . import metadata from . import optionals from . import readers +from . import streamlines from . import transforms from . import utils from . import writers from .volumes import MappedArray, CatArray, cat, stack from .transforms import MappedAffine +from .streamlines import MappedStreamlines from .loadsave import map, load, loadf, save, savef diff --git a/nitorch/io/streamlines/__init__.py b/nitorch/io/streamlines/__init__.py new file mode 100644 index 00000000..1d351274 --- /dev/null +++ b/nitorch/io/streamlines/__init__.py @@ -0,0 +1,10 @@ +from . import loadsave +from . import mapping +from . import readers +from . import writers + +from .mapping import MappedStreamlines +from .loadsave import map, load, loadf, save, savef + +# Import implementations +from .trk import TrkStreamlines diff --git a/nitorch/io/streamlines/loadsave.py b/nitorch/io/streamlines/loadsave.py new file mode 100644 index 00000000..cf4334c5 --- /dev/null +++ b/nitorch/io/streamlines/loadsave.py @@ -0,0 +1,35 @@ +"""Specialization for volumes.""" +from functools import wraps +from .readers import reader_classes as streamline_reader_classes +from .writers import writer_classes as streamline_writer_classes +from .. import loadsave + + +@wraps(loadsave.map) +def map(*args, reader_classes=None, **kwargs): + reader_classes = reader_classes or streamline_reader_classes + return loadsave.map(*args, reader_classes=reader_classes, **kwargs) + + +@wraps(loadsave.load) +def load(*args, reader_classes=None, **kwargs): + reader_classes = reader_classes or streamline_reader_classes + return loadsave.load(*args, reader_classes=reader_classes, **kwargs) + + +@wraps(loadsave.loadf) +def loadf(*args, reader_classes=None, **kwargs): + reader_classes = reader_classes or streamline_reader_classes + return loadsave.loadf(*args, reader_classes=reader_classes, **kwargs) + + +@wraps(loadsave.save) +def save(*args, writer_classes=None, **kwargs): + writer_classes = writer_classes or streamline_writer_classes + return loadsave.save(*args, writer_classes=writer_classes, **kwargs) + + +@wraps(loadsave.savef) +def savef(*args, writer_classes=None, **kwargs): + writer_classes = writer_classes or streamline_writer_classes + return loadsave.savef(*args, writer_classes=writer_classes, **kwargs) diff --git a/nitorch/io/streamlines/mapping.py b/nitorch/io/streamlines/mapping.py new file mode 100644 index 00000000..bd6be6e9 --- /dev/null +++ b/nitorch/io/streamlines/mapping.py @@ -0,0 +1,235 @@ +import numpy as np +from ..mapping import MappedFile + + +class MappedStreamlines(MappedFile): + """Streamlines stored on disk""" + + @classmethod + def possible_extensions(cls): + """List all possible extensions""" + return tuple() + + def __str__(self): + return '{}()'.format(type(self)) + + __repr__ = __str__ + + @property + def affine(self): + """ + Vertex to world transformation matrix. + """ + raise NotImplementedError + + @property + def dtype(self): + return np.dtype("float64") + + def fdata(self, dtype=None, device=None, numpy=False): + """Load the streamlines from file. + + This function tries to return vertices in RAS space. + + Parameters + ---------- + dtype : torch.dtype, optional + device : torch.device, optional + numpy : bool, default=False + + Returns + ------- + streamlines : iterator[(N, 3) tensor or array] + Streamlines + + """ + raise NotImplementedError + + def data(self, dtype=None, device=None, numpy=False): + """Load the streamlines from file. + + The "raw" streamlines are loaded even if they are not in RAS space. + + Parameters + ---------- + dtype : torch.dtype, optional + device : torch.device, optional + numpy : bool, default=False + + Returns + ------- + streamlines : iterator[(N, 3) tensor or array] + Streamlines + """ + raise NotImplementedError + + def scalars(self, dtype=None, device=None, numpy=False, keys=None): + """ + Load the scalars associated with each vertex. + + Parameters + ---------- + dtype : torch.dtype, optional + device : torch.device, optional + numpy : bool, default=False + keys : [list of] str, optional + Keys to load + + Returns + ------- + scalars : [dict of] iterator[(N, K) tensor or array] + Scalars. + If `keys` is a string, return the list of scalars directly. + Else, return a dictionary mapping keys to scalars. + """ + raise NotImplementedError + + def properties(self, dtype=None, device=None, numpy=False, keys=None): + """ + Load the properties associated with each streamline. + + Parameters + ---------- + dtype : torch.dtype, optional + device : torch.device, optional + numpy : bool, default=False + keys : [list of] str, optional + Keys to load + + Returns + ------- + scalars : [dict of] iterator[(N, K) tensor or array] + Scalars. + If `keys` is a string, return the list of properties directly. + Else, return a dictionary mapping keys to properties. + """ + raise NotImplementedError + + def metadata(self, keys=None): + """Read additional metadata from the transform + + Parameters + ---------- + keys : sequence of str, optional + List of metadata keys to read. + If not provided, all (format-specific) known keys are read/ + + Returns + ------- + dict + + """ + raise NotImplementedError + + def set_fdata(self, streamlines): + """Set the streamlines data + + This function only modifies the in-memory representation of the + streamlines. The file is not modified. To overwrite the file, + call `save` afterward. + + Parameters + ---------- + streamlines : list[(N, 3) tensor or array] + + Returns + ------- + self + + """ + raise NotImplementedError + + def set_data(self, affine): + """Set the (raw) streamlines data + + This function only modifies the in-memory representation of the + streamlines. The file is not modified. To overwrite the file, + call `save` afterward. + + Parameters + ---------- + streamlines : list[(N, 3) tensor or array] + + Returns + ------- + self + + """ + raise NotImplementedError + + def set_metadata(self, **meta): + """Set additional metadata in the transform. + + This function only modifies the in-memory representation of the + transform. The file is not modified. To overwrite the file, + use `save`. + + Parameters + ---------- + **meta : dict + Only keys that make sense to the format will effectively + be set. + + Returns + ------- + self + + """ + raise NotImplementedError + + def save(self, file_like=None, **meta): + """Save the current streamlines to disk. + + Parameters + ---------- + file_like : str or file object, default=self.file_like + Target file to write the streamlines. + **meta : dict + Additional metadata to set before saving. + + Returns + ------- + self + + """ + raise NotImplementedError + + savef = save + + @classmethod + def save_new(cls, streamlines, file_like, like=None, **meta): + """Save a new affine to disk in the `cls` format + + Parameters + ---------- + streamlines : MappedStreamlines or list[(N, 3) tensor or array] + Streamlines to write + file_like : str or file object + Target file + like : MappedStreamlines or str or file object, optional + Template streamlines. Its metadata fields will be copied unless + they are overwritten by `meta`. + **meta : dict + Additional metadata to set before writing. + + """ + raise NotImplementedError + + @classmethod + def savef_new(cls, streamlines, file_like, like=None, **meta): + """Save a new streamlines to disk in the `cls` format + + Parameters + ---------- + streamlines : MappedStreamlines or list[(N, 3) tensor or array] + Streamlines to write + file_like : str or file object + Target file + like : MappedStreamlines or str or file object, optional + Template streamlines. Its metadata fields will be copied unless + they are overwritten by `meta`. + **meta : dict + Additional metadata to set before writing. + + """ + raise NotImplementedError diff --git a/nitorch/io/streamlines/readers.py b/nitorch/io/streamlines/readers.py new file mode 100644 index 00000000..b270d841 --- /dev/null +++ b/nitorch/io/streamlines/readers.py @@ -0,0 +1,15 @@ +"""Registered transform readers + +Classes should be registered in their implementation file by importing +reader_classes and appending them: +>>> from nitorch.io.transforms.readers import reader_classes +>>> reader_classes.append(MyClassThatCanRead) + +This file is kept empty to avoid all registered readers to be erased by +autoreloading the module. +""" + +from ..readers import reader_classes as all_reader_classes + +reader_classes = [] +all_reader_classes.append(reader_classes) diff --git a/nitorch/io/streamlines/trk.py b/nitorch/io/streamlines/trk.py new file mode 100644 index 00000000..4f2b6d81 --- /dev/null +++ b/nitorch/io/streamlines/trk.py @@ -0,0 +1,239 @@ +import torch +import numpy as np +from types import GeneratorType as generator +from nibabel.streamlines.trk import TrkFile, get_affine_trackvis_to_rasmm +from nibabel.streamlines.tractogram import Tractogram +from nitorch.core import dtypes +from nitorch.spatial import affine_matvec +from nitorch.io.mapping import AccessType +from .mapping import MappedStreamlines +from .readers import reader_classes +from .writers import writer_classes + + +class TrkStreamlines(MappedStreamlines): + """Streamlines stored in a TRK file""" + + readable: AccessType = AccessType.Full + writable: AccessType = AccessType.Full + + @classmethod + def possible_extensions(cls): + return ('.trk',) + + def __init__(self, file_like=None, mode='r', keep_open=False): + """ + + Parameters + ---------- + file_like : str of file object + File to map + mode : {'r', 'r+'}, default='r' + Read in read-only ('r') or read-and-write ('r+') mode. + Modifying the file in-place is only possible in 'r+' mode. + keep_open : bool, default=False + Does nothing. + """ + self.filename = None + self.mode = mode + if file_like is None: + self._struct = TrkFile(Tractogram()) + elif isinstance(file_like, TrkStreamlines): + self._struct = file_like._struct + elif isinstance(file_like, TrkFile): + self._struct = file_like + elif isinstance(file_like, Tractogram): + self._struct = TrkFile(file_like) + else: + self.filename = file_like + if 'r' in mode: + self._struct = TrkFile.load(file_like, lazy_load=True) + else: + self._struct = TrkFile(Tractogram()) + + @property + def _loaded(self): + if not self._struct: + # not loaded at all + return False + if isinstance(self._struct.streamlines, generator): + # lazilty loaded + return False + return True + + def __len__(self): + return self._struct.header["nb_streamlines"] + + def shape(self): + return torch.Size([len(self)]) + + @property + def dtype(self): + return np.dtype('float64') + + @classmethod + def _cast_generator(cls, generator, dtype=None, device=None, numpy=False): + if not numpy: + if dtype is not None: + dtype = dtypes.dtype(dtype) + if dtype.torch is None: + raise TypeError( + f'Data type {dtype} does not exist in PyTorch.' + ) + dtype = dtypes.dtype(dtype or np.float64).torch_upcast + else: + dtype = dtypes.dtype(dtype or np.float64).numpy_upcast + + for elem in generator: + yield ( + np.asarray(elem, dtype=dtype) if numpy else + torch.as_tensor(elem, dtype=dtype, device=device) + ) + + @classmethod + def _apply_affine_generator(cls, generator, affine): + for elem in generator: + elem = affine_matvec(affine.to(elem), elem) + yield elem + + def fdata(self, dtype=None, device=None, numpy=False): + yield from self._cast_generator( + self._struct.tractogram.streamlines, + dtype=dtype, device=device, numpy=numpy + ) + + def data(self, dtype=None, device=None, numpy=False): + affine = self.affine.inverse() + if numpy: + np_dtype = dtypes.dtype(dtype).numpy_upcast + for streamline in self.fdata(dtype=dtype, device=device): + streamline = affine_matvec(affine.to(streamline), streamline) + if numpy: + streamline = np.asarray(streamline.cpu(), dtype=np_dtype) + yield streamline + + def scalars(self, dtype=None, device=None, numpy=False, keys=None): + all_scalars = self._struct.tractogram.data_per_point + all_keys = list(all_scalars.keys()) + keys = keys or all_keys + return_dict = True + if isinstance(keys, str): + return_dict = False + keys = [keys] + scalars = { + key: self._cast_generator( + all_scalars[key], dtype=dtype, device=device, numpy=numpy + ) + for key in keys + } + if not return_dict: + scalars = next(iter(scalars.values())) + return scalars + + def properties(self, dtype=None, device=None, numpy=False, keys=None): + all_props = self._struct.tractogram.data_per_streamline + all_keys = list(all_props.keys()) + keys = keys or all_keys + return_dict = True + if isinstance(keys, str): + return_dict = False + keys = [keys] + props = { + key: self._cast_generator( + all_props[key], dtype=dtype, device=device, numpy=numpy + ) + for key in keys + } + if not return_dict: + props = next(iter(props.values())) + return props + + @property + def affine(self): + """ + Vertex to world transformation matrix. + """ + return torch.as_tensor( + get_affine_trackvis_to_rasmm(self._struct.header) + ) + + def metadata(self, keys=None): + raise NotImplementedError + + def set_fdata(self, streamlines): + raise NotImplementedError + + def set_data(self, affine): + raise NotImplementedError + + def set_metadata(self, **meta): + raise NotImplementedError + + def save(self, file_like=None, **meta): + # FIXME: meta + self._struct.save(file_like) + + @classmethod + def save_new(cls, streamlines, file_like, like=None, **meta): + # FIXME: meta + header = None + if like: + like = cls(like) + header = like._struct.header + + obj = streamlines + + affine = meta.get("affine", None) + if affine is None and like: + affine = like.affine + + if not isinstance(obj, TrkStreamlines): + if not isinstance(obj, (Tractogram, TrkFile)): + obj = Tractogram(obj, affine_to_rasmm=affine) + if not isinstance(TrkStreamlines, TrkFile): + obj = TrkFile(obj, header) + obj = TrkStreamlines(obj) + + if like: + obj._struct = TrkFile(obj._struct.tractogram, header) + + if "scalars" in meta: + obj._struct.tractogram.data_per_point = meta.pop("scalars") + elif like: + try: + obj._struct.tractogram.data_per_point \ + = like._struct.tractogram.data_per_point + except Exception: + pass + + if "properties" in meta: + obj._struct.tractogram.data_per_streamline = meta.pop("properties") + elif like: + try: + obj._struct.tractogram.data_per_streamline \ + = like._struct.tractogram.data_per_streamline + except Exception: + pass + + for key, value in meta.items(): + if key in header.keys(): + obj._struct.header[key] = value + + obj._struct.save(file_like) + + @classmethod + def savef_new(cls, streamlines, file_like, like=None, **meta): + if like: + like = cls(like) + affine = meta.get("affine", None) + if affine is None and like: + affine = like.affine + if not isinstance(streamlines, (TrkStreamlines, TrkFile, Tractogram)): + if affine is not None: + affine = affine.inverse() + streamlines = cls._apply_affine_generator(streamlines, affine) + cls.save_new(streamlines, file_like, like, **meta) + + +reader_classes.append(TrkStreamlines) +writer_classes.append(TrkStreamlines) diff --git a/nitorch/io/streamlines/writers.py b/nitorch/io/streamlines/writers.py new file mode 100644 index 00000000..1660982f --- /dev/null +++ b/nitorch/io/streamlines/writers.py @@ -0,0 +1,15 @@ +"""Registered transform writers + +Classes should be registered in their implementation file by importing +writer_classes and appending them: +>>> from nitorch.io.transforms.writers import writer_classes +>>> writer_classes.append(MyClassThatCanWrite) + +This file is kept empty to avoid all registered writers to be erased by +autoreloading the module. +""" + +from ..writers import writer_classes as all_writer_classes + +writer_classes = [] +all_writer_classes.append(writer_classes) diff --git a/nitorch/io/transforms/freesurfer/lta.py b/nitorch/io/transforms/freesurfer/lta.py index 5b139a49..2b3580c1 100644 --- a/nitorch/io/transforms/freesurfer/lta.py +++ b/nitorch/io/transforms/freesurfer/lta.py @@ -104,7 +104,7 @@ class LTAStruct(Structure): exposed to the user. """ - type: int = Constants.LINEAR_VOX_TO_VOX # Affine type + type: int = Constants.LINEAR_RAS_TO_RAS # Affine type nxforms: int = 1 # Number of affines stored mean: tuple = None # ? sigma: float = None # ? diff --git a/nitorch/tools/registration/pairwise_preproc.py b/nitorch/tools/registration/pairwise_preproc.py index 6689dcbd..34adaaf4 100644 --- a/nitorch/tools/registration/pairwise_preproc.py +++ b/nitorch/tools/registration/pairwise_preproc.py @@ -2,7 +2,7 @@ 'rescale_image', 'discretize_image', 'soft_quantize_image'] from nitorch import io -from nitorch.core.py import make_list +from nitorch.core.py import make_list, flatten from nitorch.core import dtypes, utils from nitorch import spatial from . import pairwise_pyramid as pyrutils @@ -267,8 +267,12 @@ def map_image(fnames, dim=None, channels=None): list(range(len(imgs)))[c] if isinstance(c, slice) else c for c in channels ] + channels = flatten(channels) if not all([isinstance(c, int) for c in channels]): - raise ValueError('Channel list should be a list of integers') + raise ValueError( + 'Channel list should be a list of integers but received:', + channels + ) imgs = io.stack([imgs[c] for c in channels]) return imgs, affine