Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, ENH] begin integrating fastsurfer #1112

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- name: Checkout repo
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "AFQ/nn/FastSurfer"]
path = AFQ/nn/FastSurfer
url = https://github.com/santiestrada32/FastSurfer.git
35 changes: 34 additions & 1 deletion AFQ/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
import logging
import time
import requests
from tqdm import tqdm

import warnings
Expand Down Expand Up @@ -53,7 +54,8 @@
"fetch_stanford_hardi_tractography",
"read_stanford_hardi_tractography",
"organize_stanford_data",
"fetch_stanford_hardi_lv1"]
"fetch_stanford_hardi_lv1",
"download_hypvinn"]


# Set a user-writeable file-system location to put files:
Expand Down Expand Up @@ -1831,3 +1833,34 @@ def fetch_hbn_afq(subjects, path=None):
"GeneratedBy": [{'Name': 'afq'}]})

return data_files, op.join(my_path, "HBN")


def download_hypvinn():
HYPVINN_URL = "https://b2share.fz-juelich.de/api/files/7133b542-733b-4cc6-a284-5c333ff25f78" # noqa
HYPVINN_CFG_URL = "https://raw.githubusercontent.com/santiestrada32/FastSurfer/dev/HypVINN/config/" # noqa
planes = ["axial", "coronal", "sagittal"]

os.makedirs(op.join(afq_home, "fs_checkpoints"), exist_ok=True)
ckpt_paths = {}
for plane in planes:
ckpt_paths[plane] = op.join(
afq_home,
f"fs_checkpoints/HypVINN_{plane}_v1.0.0.pkl")
if not op.exists(ckpt_paths[plane]):
response = requests.get(
f"{HYPVINN_URL}/HypVINN_{plane}_v1.0.0.pkl",
verify=True)
with open(ckpt_paths[plane], "wb") as f:
f.write(response.content)
cfg_paths = {}
for plane in planes:
cfg_paths[plane] = op.join(
afq_home,
f"fs_checkpoints/HypVINN_{plane}_v1.0.0.yaml")
if not op.exists(cfg_paths[plane]):
response = requests.get(
f"{HYPVINN_CFG_URL}/HypVINN_{plane}_v1.0.0.yaml",
verify=True)
with open(cfg_paths[plane], "wb") as f:
f.write(response.content)
return ckpt_paths, cfg_paths
22 changes: 15 additions & 7 deletions AFQ/definitions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class ImageFile(ImageDefinition):
Additional filters to pass to bids_layout.get() to identify
the file.
Default: {}
resample : bool, optional
Whether to resample the image to the DWI data.
Default: True

Examples
--------
Expand All @@ -111,7 +114,7 @@ class ImageFile(ImageDefinition):
"seed_threshold": 0.1})
"""

def __init__(self, path=None, suffix=None, filters={}):
def __init__(self, path=None, suffix=None, filters={}, resample=True):
if path is None and suffix is None:
raise ValueError((
"One of `path` or `suffix` must set to "
Expand All @@ -126,6 +129,8 @@ def __init__(self, path=None, suffix=None, filters={}):
self.filters = filters
self.fnames = {}

self.resample = resample

def find_path(self, bids_layout, from_path,
subject, session, required=True):
if self._from_path:
Expand Down Expand Up @@ -169,14 +174,17 @@ def _image_getter_helper(dwi, bids_info):
image_data_orig, image_file)

# Resample to DWI data:
image_data = _resample_image(
image_data,
dwi.get_fdata(),
image_affine,
dwi.affine)
if self.resample:
image_data = _resample_image(
image_data,
dwi.get_fdata(),
image_affine,
dwi.affine)
image_affine = dwi.affine

return nib.Nifti1Image(
image_data.astype(np.float32),
dwi.affine), meta
image_affine), meta
if task_name == "data":
def image_getter(dwi, bids_info):
return _image_getter_helper(dwi, bids_info)
Expand Down
1 change: 1 addition & 0 deletions AFQ/nn/FastSurfer
Submodule FastSurfer added at 6c373a
Empty file added AFQ/nn/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions AFQ/nn/fastsurfer_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import sys
import tempfile
import os.path as op
import argparse
import logging
import importlib.util
from AFQ.data.fetch import download_hypvinn

package_dir = op.dirname(importlib.util.find_spec("AFQ").origin)
fastsurfer_path = op.join(package_dir, "nn", "FastSurfer")
if fastsurfer_path not in sys.path:
sys.path.append(fastsurfer_path)

try:
from HypVINN.run_prediction import (
get_prediction, load_volumes, set_up_cfgs)
from HypVINN.inference import Inference
from HypVINN.config.hypvinn_global_var import HYPVINN_CLASS_NAMES
except TypeError:
raise ValueError("FastSurfer requires python 3.10 or higher")


logger = logging.getLogger('AFQ')


def run_hypvinn(t1, device="cpu"):
ckpt_paths, cfg_paths = download_hypvinn()

working_dir = op.dirname(t1)
temp_dir = tempfile.gettempdir()

args = argparse.Namespace(
in_dir=working_dir,
out_dir=temp_dir,
sid='subject', log_name=temp_dir,
orig_name=t1,
t2=None, registration=True, qc_snapshots=False,
reg_type='coreg', device=device, viewagg_device='auto',
threads=8, batch_size=1, async_io=False, allow_root=False,
ckpt_cor=ckpt_paths["coronal"],
ckpt_ax=ckpt_paths["axial"],
ckpt_sag=ckpt_paths["sagittal"],
cfg_cor=cfg_paths["coronal"],
cfg_ax=cfg_paths["axial"],
cfg_sag=cfg_paths["sagittal"],
t1=t1, mode='t1')

view_ops = {}
cfg_ax = set_up_cfgs(args.cfg_ax, args)
view_ops["axial"] = {"cfg": cfg_ax, "ckpt": args.ckpt_ax}
cfg_sag = set_up_cfgs(args.cfg_sag, args)
view_ops["sagittal"] = {"cfg": cfg_sag, "ckpt": args.ckpt_sag}
cfg_cor = set_up_cfgs(args.cfg_cor, args)
view_ops["coronal"] = {"cfg": cfg_cor, "ckpt": args.ckpt_cor}

modalities, affine, _, orig_zoom, orig_size = load_volumes(
mode=args.mode, t1_path=args.t1, t2_path=args.t2)

model = Inference(cfg=cfg_cor, args=args)

pred_classes = get_prediction(
args.sid, modalities, orig_zoom, model, gt_shape=orig_size,
view_opts=view_ops, out_scale=None, mode=args.mode,
logger=logger)

return pred_classes, HYPVINN_CLASS_NAMES, affine
161 changes: 161 additions & 0 deletions AFQ/nn/profile_roi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from skimage.morphology import skeletonize_3d, binary_dilation
import numpy as np
import nibabel as nib
from dipy.align import resample


def _find_longest_true_series_indices(input_list):
max_length = 0
current_length = 0
start_index = 0
max_start_index = -1
max_end_index = -1
for i, value in enumerate(input_list):
if value:
if current_length == 0:
start_index = i
current_length += 1
else:
if current_length > max_length:
max_length = current_length
max_start_index = start_index
max_end_index = i
current_length = 0
if current_length > max_length:
max_length = current_length
max_start_index = start_index
max_end_index = len(input_list)
return max_start_index, max_end_index


def roi_from_segmentation(seg, label, dwi):
"""
Resample a region of interest (ROI) from a
segmentation to the space of a diffusion-weighted
image (DWI).
"""
roi = seg.get_fdata() == label
resampled_roi = resample(
binary_dilation(roi),
dwi.get_fdata()[..., 0],
seg.affine,
dwi.affine).get_fdata() > 0
return resampled_roi


def skeleton_from_roi(roi, affine, orientation_axis, jump_threshold=3):
"""
Skeletonize a region of interest (ROI)

Parameters
----------
roi : ndarray
A 3D binary array with the ROI.
affine : ndarray
The affine transformation of the ROI.
orientation_axis : ndarray
Axis used to orient profile. One of
"L", "P", "I".
jump_threshold : int, optional
The maximum distance between two points in the
skeleton of the ROI. Used to prune the skeleton.
Default: 3.
"""
# First, skeletonize the ROI
skeleton = skeletonize_3d(roi)
if np.sum(skeleton) == 0:
skeleton = roi
skel_pts = np.asarray(np.where(skeleton)).T

if len(skel_pts) == 0:
return skel_pts

# Then, remove any jumps in the skeleton
# that are less than the jump_threshold
# (i.e., remove any small branches in the skeleton)
skel_jumps = np.linalg.norm(
skel_pts[1:] - skel_pts[:-1], ord=2, axis=1) < jump_threshold
sidx, eidx = _find_longest_true_series_indices(skel_jumps)
skel_pts = skel_pts[sidx:eidx]

# Reorient the skeleton if necessary
orientation = nib.aff2axcodes(affine)
if orientation_axis not in ["L", "P", "I"]:
raise ValueError("Invalid orientation_axis. "
"Valid options are 'L', 'P', 'I'.")
elif orientation_axis == "L":
if np.logical_xor(
skel_pts[0, 0] - skel_pts[-1, 0] < 0,
orientation[0] == "L"):
skel_pts = skel_pts[::-1]
elif orientation_axis == "P":
if np.logical_xor(
skel_pts[0, 1] - skel_pts[-1, 1] < 0,
orientation[1] == "P"):
skel_pts = skel_pts[::-1]
elif orientation_axis == "I":
if np.logical_xor(
skel_pts[0, 2] - skel_pts[-1, 2] < 0,
orientation[2] == "I"):
skel_pts = skel_pts[::-1]

return skel_pts


def profile_roi(
roi, skel_pts, scalar_data, d_plane_thresh=1):
"""
Calculate the tract profile of a set of scalars
within a region of interest (ROI). Finds the maximum
value of each scalar within a disk
around each point in the skeleton of the ROI.

Parameters
----------
roi : ndarray
A 3D binary array with the ROI.
skel_pts : ndarray
A 2D array with the coordinates of the
skeleton of the ROI.
scalar_data : ndarray
A 3D array with the scalar data.
d_plane_thresh : int, optional
The maximum distance between a point in the
ROI and the plane defined by the skeleton of
the ROI.
Default: 1.
"""
pts_len = skel_pts.shape[0]
roi_rad = int(np.sqrt((np.sum(roi) / pts_len) / np.pi))
tract_profile = np.full(pts_len, -np.inf)
for nodeid in range(pts_len):
min_idx = max(0, nodeid - 1)
max_idx = min(pts_len - 1, nodeid + 1)
n_vec = skel_pts[min_idx] - skel_pts[max_idx]
n_vec = n_vec / np.linalg.norm(n_vec)
c_pt = skel_pts[nodeid]

dr = np.zeros((3, 2), dtype=int)
for dim in range(3):
dr[dim, 0] = int(c_pt[dim] - roi_rad)
dr[dim, 1] = int(c_pt[dim] + roi_rad)
for ii in range(dr[0, 0], dr[0, 1]):
for jj in range(dr[1, 0], dr[1, 1]):
for kk in range(dr[2, 0], dr[2, 1]):
if roi[ii, jj, kk]:
euc_d = c_pt - np.asarray([ii, jj, kk])
d_plane = np.abs(np.sum(n_vec * euc_d))
d_point = np.sum(euc_d**2)**0.5
if d_plane <= d_plane_thresh and d_point <= roi_rad:
if scalar_data[ii, jj, kk] >\
tract_profile[nodeid]:
tract_profile[nodeid] =\
scalar_data[ii, jj, kk]

# Interpolate the tract profile to have 100 points
tract_profile = np.interp(
np.linspace(0, pts_len - 1, num=100),
np.linspace(0, pts_len - 1, num=pts_len),
tract_profile)

return tract_profile
Loading
Loading