diff --git a/python/pipeline/utils/galvo_corrections.py b/python/pipeline/utils/galvo_corrections.py index 09c2ac62..248b7b63 100644 --- a/python/pipeline/utils/galvo_corrections.py +++ b/python/pipeline/utils/galvo_corrections.py @@ -2,9 +2,12 @@ import numpy as np import datajoint as dj from scipy import interpolate as interp -from scipy import signal from scipy import ndimage from tqdm import tqdm +try: + from scipy.signal import tukey +except Exception: + from scipy.signal.windows import tukey from ..exceptions import PipelineException from ..utils.signal import mirrconv @@ -58,7 +61,7 @@ def compute_raster_phase(image, temporal_fill_fraction): return angle_shift -def compute_motion_shifts(scan, template, in_place=True, num_threads=8): +def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu=True): """ Compute shifts in y and x for rigid subpixel motion correction. Returns the number of pixels that each image in the scan was to the right (x_shift) @@ -69,13 +72,31 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8): :param np.array template: 2-d template image. Each frame in scan is aligned to this. :param bool in_place: Whether the scan can be overwritten. :param int num_threads: Number of threads used for the ffts. + :param bool try_gpu: Whether to try to compute motion shifts on a GPU device. Will + default to CPU computation if no GPU device is found. :returns: (y_shifts, x_shifts) Two arrays (num_frames) with the y, x motion shifts. ..note:: Based in imreg_dft.translation(). """ - import pyfftw from imreg_dft import utils + + gpu_flag = False + if try_gpu: + try: + # functions needed for GPU computations + import cupy as cp + from cupy.fft import fft2, ifft2, fftshift + from cupy import abs + gpu_flag = True + except Exception: + gpu_flag = False + + if not gpu_flag: + # functions needed for CPU computations + import pyfftw + from numpy.fft import fftshift + from numpy import abs # Add third dimension if scan is a single image if scan.ndim == 2: @@ -83,28 +104,53 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8): # Get some params image_height, image_width, num_frames = scan.shape - taper = np.outer(signal.tukey(image_height, 0.2), signal.tukey(image_width, 0.2)) + taper = np.outer(tukey(image_height, 0.2), tukey(image_width, 0.2)) + + if gpu_flag: + + # transfer arrays to the GPU + template = cp.array(template) + taper = cp.array(taper) + + # get fourier transform of template + template_freq = fft2(template * taper).conj() + abs_template_freq = abs(template_freq) + eps = abs_template_freq.max() * 1e-15 + + else: - # Prepare fftw - frame = pyfftw.empty_aligned((image_height, image_width), dtype='complex64') - fft = pyfftw.builders.fft2(frame, threads=num_threads, overwrite_input=in_place, - avoid_copy=True) - ifft = pyfftw.builders.ifft2(frame, threads=num_threads, overwrite_input=in_place, - avoid_copy=True) + # Prepare fftw + frame = pyfftw.empty_aligned((image_height, image_width), dtype='complex64') + fft2 = pyfftw.builders.fft2(frame, threads=num_threads, overwrite_input=in_place, + avoid_copy=True) + ifft2 = pyfftw.builders.ifft2(frame, threads=num_threads, overwrite_input=in_place, + avoid_copy=True) - # Get fourier transform of template - template_freq = fft(template * taper).conj() # we only need the conjugate - abs_template_freq = abs(template_freq) - eps = abs_template_freq.max() * 1e-15 + # Get fourier transform of template + template_freq = fft2(template * taper).conj() # we only need the conjugate + abs_template_freq = abs(template_freq) + eps = abs_template_freq.max() * 1e-15 + # Compute subpixel shifts per image y_shifts = np.empty(num_frames) x_shifts = np.empty(num_frames) for i in range(num_frames): + + # transfer to GPU if necessary + if gpu_flag: + scan_frame = cp.array(scan[:, :, i]) + else: + scan_frame = scan[:, :, i] + # Compute correlation via cross power spectrum - image_freq = fft(scan[:, :, i] * taper) + image_freq = fft2(scan_frame * taper) cross_power = (image_freq * template_freq) / (abs(image_freq) * abs_template_freq + eps) - shifted_cross_power = np.fft.fftshift(abs(ifft(cross_power))) + shifted_cross_power = fftshift(abs(ifft2(cross_power))) + + # transfer back from GPU if necessary + if gpu_flag: + shifted_cross_power = shifted_cross_power.get() # Get best shift shifts = np.unravel_index(np.argmax(shifted_cross_power), shifted_cross_power.shape)