Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

homework 3. committing slice_timing_image.py after interpolation loo… #10

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions day3/slice_time_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@
remember to do ``reload(slice_time_script)`` to get the new version.
"""

# Add any extra imports here
import sys
# Compatibility settings
from __future__ import print_function
from __future__ import division
from scipy.interpolate import InterpolatedUnivariateSpline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very minor comment - but the standard habit here is to import Python standard libraries such as sys and os first, then very common libraries such as numpy, then less common libraries such as scipy, and nibabel. It's a tradition that helps the reader look for a given import quickly.


# Import modules
import os
import sys
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

# Parameter declaration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['image.interpolation'] = 'nearest'

def slice_time_image(img, slice_times, TR):
"""Run slice-timing correction on nibabel image 'img'.
Expand All @@ -46,11 +56,37 @@ def slice_time_image(img, slice_times, TR):
st_img : Nifti1Image
A new copy of the input image with slice-time interpolation applied
"""

# Get the image data as an array;
data = img.get_data() # The current function structure leads to many loadings of the image file, making the process slower and increasing memory demands
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should generally keep lines to 80 characters length, to make lines easier to read in a variety of text-editors and web-interfaces, like this one.

In fact, you don't need to do img.get_data() below to find the number of slices, you can just do img = nib.load(fname); n_slices = img.shape[2]. This doesn't load the data from disk - see http://nipy.org/nibabel/nibabel_images.html#the-image-data-array

n_x_voxels = data.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, using so-called "tuple-unpacking":

n_x_voxels, n_y_voxels, n_slices, n_tr = data.shape

n_y_voxels = data.shape[1]
n_slices = data.shape[2] # number of slices on z dimension
n_tr = data.shape[3]

# Make a new empty array "interp_data" to hold the interpolated data;
interp_data = np.zeros(data.shape)

# Do the interpolation;
tr_onsets = np.arange(n_tr) * TR # define onset of each TR in sec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - good variable naming here and below.


# for each slice interpolate time series to the onset of the first slice acquired in the TR
for iz in range(n_slices):
all_iz_times = tr_onsets + slice_times[iz]
# for each voxel in the slice do an interpolation; could possibly reshape matrix first, but more difficult to read and then have to inverse reshape
for ix in range(n_x_voxels):
for iy in range (n_y_voxels):
voxel_timeseries = data[ix,iy,iz,:]
interp = InterpolatedUnivariateSpline(all_iz_times, voxel_timeseries, k=1)
interp_data[ix,iy,iz,:] = interp(tr_onsets) #we could potentially also interpolate for all n_slices intervals and get n_tr * n_slices time points.

# debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally remove commented out stuff when you've done, it can be distracting for the reviewer.

# plt.plot(tr_onsets[:10]+slice_times[10], data[0,0,10,:10], ':+')
# plt.plot(tr_onsets[:10], interp_data[0,0,10,:10], ':kx')

# This is how to make a new image with the interpolated data
new_img = nib.Nifti1Image(interp_data, img.affine, img.header)

return new_img


Expand All @@ -71,7 +107,15 @@ def slice_time_file(fname, slice_times, TR):
TR: double
Repetition time in seconds.
"""

img = nib.load(fname)
interp_img = slice_time_image(img,slice_times,TR)

# Hint: use os.path.split and os.path.join to make the new filename
path, filename = os.path.split(fname)
new_fname = 'a_' + filename
new_fname = os.path.join(path,new_fname)

nib.save(interp_img, new_fname)


Expand All @@ -80,16 +124,38 @@ def main():
"""
# Get the filename from the command line parameters
fname = sys.argv[1]
#fname = 'ds107_sub012_t1r2.nii' # filename used

img = nib.load(fname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for comment higher up - to save memory and time, this should be:

img = nib.load(fname)
n_slices = img.shape[2]

data = img.get_data()
n_slices = data.shape[2] # number of slices on z dimension

# Assume the TR
TR = 2.0
# Assume the slices were acquired even slices first, inferior to
# superior, then odd slices, inferior to superior, where the most inferior
# slice is index 0 and 0 is an even number. What are the slice acquisition
# times in seconds, where the first value is the acquisition time of slice
# 0, the second is acquisition time of slice 1, etc?
slice_times = ?
slice_time_file(fname, slice_times, TR)

interval = TR / n_slices # acquisition time interval in seconds between each consecutive slice
# Identify sequence of slices per TR
even_slices = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, you can use the step argument to range and do:

slice_seq = list(range(0, n_slices, 2)) + list(range(1, n_slices, 2))

odd_slices = []
for islice in range(n_slices):
if islice % 2 == 0:
even_slices = even_slices + [islice]
else:
odd_slices = odd_slices + [islice]
slice_seq = even_slices + odd_slices
lag = 0
slice_lag = np.zeros(n_slices)
for it in slice_seq:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it!

You don't really need slice_lag, could just use slice_times here.

slice_lag[it] = lag
lag = lag + interval

slice_times = slice_lag
slice_time_file(fname, slice_times, TR)

if __name__ == '__main__':
# If this file is being run as a script, execute the "main" function
Expand Down