Skip to content

Commit

Permalink
Merge pull request #453 from nipreps/fix/revise-bspline-fitting
Browse files Browse the repository at this point in the history
FIX: Revision of the B-Spline fitting code
  • Loading branch information
oesteban authored Jul 4, 2024
2 parents bb89bdc + 1883c15 commit 581e460
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 34 deletions.
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ jobs:

- restore_cache:
keys:
- workdir-v2-{{ .Branch }}-
- workdir-v2-master-
- workdir-v2-
- workdir-v3-{{ .Branch }}-
- workdir-v3-master-
- workdir-v3-
- run:
name: Refreshing cached intermediate results
working_directory: /tmp/src/sdcflows
Expand Down Expand Up @@ -343,7 +343,7 @@ jobs:
--cov sdcflows --cov-report xml:/out/unittests.xml \
sdcflows/
- save_cache:
key: workdir-v2-{{ .Branch }}-{{ .BuildNum }}
key: workdir-v3-{{ .Branch }}-{{ .BuildNum }}
paths:
- /tmp/work
- store_artifacts:
Expand Down
39 changes: 23 additions & 16 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
in_data = File(exists=True, mandatory=True, desc="path to a fieldmap")
in_mask = File(exists=True, desc="path to a brain mask")
bs_spacing = InputMultiObject(
[DEFAULT_ZOOMS_MM],
[DEFAULT_HF_ZOOMS_MM],
traits.Tuple(traits.Float, traits.Float, traits.Float),
usedefault=True,
desc="spacing between B-Spline control points",
)
ridge_alpha = traits.Float(
0.01, usedefault=True, desc="controls the regularization"
1e-4, usedefault=True, desc="controls the regularization"
)
recenter = traits.Enum(
"mode",
"median",
"mode",
"mean",
False,
usedefault=True,
Expand All @@ -80,7 +80,7 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
zooms_min = traits.Union(
traits.Float,
traits.Tuple(traits.Float, traits.Float, traits.Float),
default_value=4.0,
default_value=1.0,
usedefault=True,
desc="limit minimum image zooms, set 0.0 to use the original image",
)
Expand All @@ -90,6 +90,7 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):


class _BSplineApproxOutputSpec(TraitedSpec):
out_intercept = traits.Float
out_field = File(exists=True)
out_coeff = OutputMultiObject(File(exists=True))
out_error = File(exists=True)
Expand Down Expand Up @@ -139,15 +140,15 @@ def _run_interface(self, runtime):

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
fmapnii = nb.as_closest_canonical(fmapnii)
# fmapnii = nb.as_closest_canonical(fmapnii)
zooms = fmapnii.header.get_zooms()

# Get a mask (or define on the spot to cover the full extent)
masknii = (
nb.load(self.inputs.in_mask) if isdefined(self.inputs.in_mask) else None
)
if masknii is not None:
masknii = nb.as_closest_canonical(masknii)
# if masknii is not None:
# masknii = nb.as_closest_canonical(masknii)

# Determine the shape of bspline coefficients
# This should not change with resizing, so do it first
Expand Down Expand Up @@ -211,9 +212,7 @@ def _run_interface(self, runtime):
)

# Fit the model
model = lm.Ridge(
alpha=self.inputs.ridge_alpha, fit_intercept=False, solver="lsqr"
)
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
for attempt in range(3):
model.fit(colmat, data.reshape(-1))
extreme = np.abs(model.coef_).max()
Expand All @@ -228,6 +227,8 @@ def _run_interface(self, runtime):
f"Extreme value {extreme:.2e} detected in spline coefficients."
)

self._results["out_intercept"] = model.intercept_

# Store coefficients
index = 0
self._results["out_coeff"] = []
Expand All @@ -247,11 +248,11 @@ def _run_interface(self, runtime):
# Interpolating in the original grid will require a new collocation matrix
if need_resize:
fmapnii = nb.load(self.inputs.in_data)
fmapnii = nb.as_closest_canonical(fmapnii)
# fmapnii = nb.as_closest_canonical(fmapnii)
data = fmapnii.get_fdata(dtype="float32") - center
if masknii is not None:
masknii = nb.load(self.inputs.in_mask)
masknii = nb.as_closest_canonical(masknii)
# masknii = nb.as_closest_canonical(masknii)
mask = np.asanyarray(masknii.dataobj) > 1e-4
else:
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
Expand All @@ -267,14 +268,20 @@ def _run_interface(self, runtime):
# Store interpolated field
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
outnii = fmapnii.__class__(interp_data, fmapnii.affine, hdr)
outnii.header["cal_max"] = np.abs(outnii.dataobj).max()
outnii.header["cal_min"] = -outnii.header["cal_max"]
outnii.to_filename(out_name)
self._results["out_field"] = out_name

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
fmapnii.__class__(
data * mask - interp_data, fmapnii.affine, fmapnii.header
).to_filename(self._results["out_error"])
errornii = fmapnii.__class__(
(data - interp_data) * mask, fmapnii.affine, fmapnii.header
)
errornii.header["cal_min"] = 0
errornii.header["cal_max"] = np.max(errornii.dataobj)
errornii.to_filename(self._results["out_error"])

if not self.inputs.extrapolate:
return runtime
Expand Down
41 changes: 32 additions & 9 deletions sdcflows/interfaces/tests/test_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,64 @@
@pytest.mark.parametrize("testnum", range(100))
def test_bsplines(tmp_path, testnum):
"""Test idempotency of B-Splines interpolation + approximation."""
targetshape = (10, 12, 9)
targetshape = (50, 50, 30)

# Generate an oblique affine matrix for the target - it will be a common case.
targetaff = nb.affines.from_matvec(
nb.eulerangles.euler2mat(x=0.9, y=0.001, z=0.001) @ np.diag((2, 3, 4)),
nb.eulerangles.euler2mat(x=-0.9, y=0.001, z=0.001) @ np.diag((2, 2, 2.4)),
)

# Intendedly mis-centered (exercise we may not have volume-centered NIfTIs)
targetaff[:3, 3] = nb.affines.apply_affine(
targetaff, 0.5 * (np.array(targetshape) - 3)
)

mask = np.zeros(targetshape)
mask[10:-10, 10:-10, 6:-6] = 1
# Generate some target grid
targetnii = nb.Nifti1Image(np.ones(targetshape), targetaff, None)
targetnii.to_filename(tmp_path / "target.nii.gz")
targetnii = nb.Nifti1Image(mask, targetaff, None)
targetnii.header.set_qform(targetaff, code=1)
targetnii.header.set_sform(targetaff, code=1)
targetnii.to_filename(tmp_path / "mask.nii.gz")

# Generate random coefficients
gridnii = bspline_grid(targetnii, control_zooms_mm=(4, 6, 8))
coeff = (rng.random(size=gridnii.shape) - 0.5) * 500
gridnii = bspline_grid(targetnii, control_zooms_mm=(40, 40, 16))
coeff = (rng.standard_normal(size=gridnii.shape)) * 100
coeffnii = nb.Nifti1Image(coeff.astype("float32"), gridnii.affine, gridnii.header)
coeffnii.header["cal_max"] = np.abs(coeff).max()
coeffnii.header["cal_min"] = -coeffnii.header["cal_max"]
coeffnii.header.set_qform(gridnii.affine, code=1)
coeffnii.header.set_sform(gridnii.affine, code=1)
coeffnii.to_filename(tmp_path / "coeffs.nii.gz")

os.chdir(tmp_path)
# Check that we can interpolate the coefficients on a target
test1 = ApplyCoeffsField(
in_data=str(tmp_path / "target.nii.gz"),
in_data=str(tmp_path / "mask.nii.gz"),
in_coeff=str(tmp_path / "coeffs.nii.gz"),
pe_dir="j-",
ro_time=1.0,
).run()

fieldnii = nb.load(test1.outputs.out_field)
fielddata = fieldnii.get_fdata()
fielddata -= np.median(fielddata)
fielddata = 200 * fielddata / np.abs(fielddata).max()

fieldnii.header["cal_max"] = np.abs(fielddata).max()
fieldnii.header["cal_min"] = -fieldnii.header["cal_max"]
fieldnii.header.set_qform(targetaff, code=1)
fieldnii.header.set_sform(targetaff, code=1)

nb.Nifti1Image(fielddata, targetaff, fieldnii.header).to_filename(
tmp_path / "testfield.nii.gz",
)

# Approximate the interpolated target
test2 = BSplineApprox(
in_data=test1.outputs.out_field,
bs_spacing=[(4, 6, 8)],
in_data=str(tmp_path / "testfield.nii.gz"),
# in_mask=str(tmp_path / "mask.nii.gz"),
bs_spacing=[(40, 40, 16)],
zooms_min=0,
recenter=False,
ridge_alpha=1e-4,
Expand Down
1 change: 1 addition & 0 deletions sdcflows/utils/wrangler.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def find_estimators(
"Dataset includes `B0FieldIdentifier` metadata."
"Any data missing this metadata will be ignored."
)

for b0_id in b0_ids:
# Found B0FieldIdentifier metadata entries
b0_entities = base_entities.copy()
Expand Down
10 changes: 5 additions & 5 deletions sdcflows/workflows/fit/fieldmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def init_fmap_wf(omp_nthreads=1, sloppy=False, debug=False, mode="phasediff", na
"""
from ...interfaces.bspline import (
BSplineApprox,
DEFAULT_LF_ZOOMS_MM,
DEFAULT_HF_ZOOMS_MM,
DEFAULT_ZOOMS_MM,
)
from ...interfaces.fmap import CheckRegister

Expand All @@ -114,9 +112,11 @@ def _unzip(fmap_spec):
magnitude_wf = init_magnitude_wf(omp_nthreads=omp_nthreads)
bs_filter = pe.Node(BSplineApprox(), name="bs_filter")
bs_filter.interface._always_run = debug
bs_filter.inputs.bs_spacing = (
[DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM] if not sloppy else [DEFAULT_ZOOMS_MM]
)
bs_filter.inputs.bs_spacing = [DEFAULT_HF_ZOOMS_MM]

if sloppy:
bs_filter.inputs.zooms_min = 4.0

bs_filter.inputs.extrapolate = not debug

# fmt: off
Expand Down

0 comments on commit 581e460

Please sign in to comment.