From be8e839ff9c36560ec4f05abe1e4e565ba55bbc8 Mon Sep 17 00:00:00 2001 From: Frank Charles DeGuire Date: Tue, 11 Jun 2024 08:41:23 -0700 Subject: [PATCH 1/4] Draft of GPU motion correction --- python/pipeline/utils/galvo_corrections.py | 62 +++++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/python/pipeline/utils/galvo_corrections.py b/python/pipeline/utils/galvo_corrections.py index 09c2ac62..2c89f4df 100644 --- a/python/pipeline/utils/galvo_corrections.py +++ b/python/pipeline/utils/galvo_corrections.py @@ -74,8 +74,17 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8): ..note:: Based in imreg_dft.translation(). """ - import pyfftw - from imreg_dft import utils + try: + import cupy as cp + from cupy.fft import fft2, ifft2, fftshift + from cupy import abs + gpu_flag = True + except Exception: + gpu_flag = False + import pyfftw + from imreg_dft import utils + from numpy.fft import fftshift + from numpy import abs # Add third dimension if scan is a single image if scan.ndim == 2: @@ -84,27 +93,52 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8): # Get some params image_height, image_width, num_frames = scan.shape taper = np.outer(signal.tukey(image_height, 0.2), signal.tukey(image_width, 0.2)) + + if gpu_flag: + + # transfer arrays to the GPU + template = cp.array(template) + taper = cp.array(taper) + + # get fourier transform of template + template_freq = cp.fft.fft(template * taper).conj() + abs_template_freq = abs(template_freq) + eps = abs_template_freq.max() * 1e-15 + + else: - # Prepare fftw - frame = pyfftw.empty_aligned((image_height, image_width), dtype='complex64') - fft = pyfftw.builders.fft2(frame, threads=num_threads, overwrite_input=in_place, - avoid_copy=True) - ifft = pyfftw.builders.ifft2(frame, threads=num_threads, overwrite_input=in_place, - avoid_copy=True) + # Prepare fftw + frame = pyfftw.empty_aligned((image_height, image_width), dtype='complex64') + fft2 = pyfftw.builders.fft2(frame, threads=num_threads, overwrite_input=in_place, + avoid_copy=True) + ifft2 = pyfftw.builders.ifft2(frame, threads=num_threads, overwrite_input=in_place, + avoid_copy=True) - # Get fourier transform of template - template_freq = fft(template * taper).conj() # we only need the conjugate - abs_template_freq = abs(template_freq) - eps = abs_template_freq.max() * 1e-15 + # Get fourier transform of template + template_freq = fft(template * taper).conj() # we only need the conjugate + abs_template_freq = abs(template_freq) + eps = abs_template_freq.max() * 1e-15 + # Compute subpixel shifts per image y_shifts = np.empty(num_frames) x_shifts = np.empty(num_frames) for i in range(num_frames): + + # transfer to GPU if necessary + if gpu_flag: + scan_frame = cp.array(scan[:, :, i]) + else: + scan_frame = scan[:, :, i] + # Compute correlation via cross power spectrum - image_freq = fft(scan[:, :, i] * taper) + image_freq = fft2(scan_frame * taper) cross_power = (image_freq * template_freq) / (abs(image_freq) * abs_template_freq + eps) - shifted_cross_power = np.fft.fftshift(abs(ifft(cross_power))) + shifted_cross_power = fftshift(abs(ifft2(cross_power))) + + # transfer back from GPU if necessary + if gpu_flag: + shifted_cross_power = shifted_cross_power.get() # Get best shift shifts = np.unravel_index(np.argmax(shifted_cross_power), shifted_cross_power.shape) From 7b55965ba1fac21cf189298171dbe26ee294b59f Mon Sep 17 00:00:00 2001 From: Frank Charles DeGuire Date: Tue, 11 Jun 2024 08:45:44 -0700 Subject: [PATCH 2/4] Add GPU flag --- python/pipeline/utils/galvo_corrections.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/pipeline/utils/galvo_corrections.py b/python/pipeline/utils/galvo_corrections.py index 2c89f4df..9ca8e580 100644 --- a/python/pipeline/utils/galvo_corrections.py +++ b/python/pipeline/utils/galvo_corrections.py @@ -58,7 +58,7 @@ def compute_raster_phase(image, temporal_fill_fraction): return angle_shift -def compute_motion_shifts(scan, template, in_place=True, num_threads=8): +def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu=True): """ Compute shifts in y and x for rigid subpixel motion correction. Returns the number of pixels that each image in the scan was to the right (x_shift) @@ -74,13 +74,16 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8): ..note:: Based in imreg_dft.translation(). """ - try: - import cupy as cp - from cupy.fft import fft2, ifft2, fftshift - from cupy import abs - gpu_flag = True - except Exception: - gpu_flag = False + gpu_flag = False + if try_gpu: + try: + import cupy as cp + from cupy.fft import fft2, ifft2, fftshift + from cupy import abs + gpu_flag = True + except Exception: + gpu_flag = False + if not gpu_flag: import pyfftw from imreg_dft import utils from numpy.fft import fftshift From 6ea3ea0a8021ad6004028a410a470d8d0651f36b Mon Sep 17 00:00:00 2001 From: Frank Charles DeGuire Date: Tue, 11 Jun 2024 08:48:02 -0700 Subject: [PATCH 3/4] Add GPU docstring --- python/pipeline/utils/galvo_corrections.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pipeline/utils/galvo_corrections.py b/python/pipeline/utils/galvo_corrections.py index 9ca8e580..8e33c4b2 100644 --- a/python/pipeline/utils/galvo_corrections.py +++ b/python/pipeline/utils/galvo_corrections.py @@ -69,6 +69,8 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu= :param np.array template: 2-d template image. Each frame in scan is aligned to this. :param bool in_place: Whether the scan can be overwritten. :param int num_threads: Number of threads used for the ffts. + :param bool try_gpu: Whether to try to compute motion shifts on a GPU device. Will + default to CPU computation if no GPU device is found. :returns: (y_shifts, x_shifts) Two arrays (num_frames) with the y, x motion shifts. From 6578ff8e911ea11feded028be5d94b2bd3df6b3b Mon Sep 17 00:00:00 2001 From: Frank Charles DeGuire Date: Tue, 11 Jun 2024 09:09:13 -0700 Subject: [PATCH 4/4] Confirmed functionality of GPU motion correction --- python/pipeline/utils/galvo_corrections.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/pipeline/utils/galvo_corrections.py b/python/pipeline/utils/galvo_corrections.py index 8e33c4b2..248b7b63 100644 --- a/python/pipeline/utils/galvo_corrections.py +++ b/python/pipeline/utils/galvo_corrections.py @@ -2,9 +2,12 @@ import numpy as np import datajoint as dj from scipy import interpolate as interp -from scipy import signal from scipy import ndimage from tqdm import tqdm +try: + from scipy.signal import tukey +except Exception: + from scipy.signal.windows import tukey from ..exceptions import PipelineException from ..utils.signal import mirrconv @@ -76,18 +79,22 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu= ..note:: Based in imreg_dft.translation(). """ + from imreg_dft import utils + gpu_flag = False if try_gpu: try: + # functions needed for GPU computations import cupy as cp from cupy.fft import fft2, ifft2, fftshift from cupy import abs - gpu_flag = True + gpu_flag = True except Exception: gpu_flag = False + if not gpu_flag: + # functions needed for CPU computations import pyfftw - from imreg_dft import utils from numpy.fft import fftshift from numpy import abs @@ -97,7 +104,7 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu= # Get some params image_height, image_width, num_frames = scan.shape - taper = np.outer(signal.tukey(image_height, 0.2), signal.tukey(image_width, 0.2)) + taper = np.outer(tukey(image_height, 0.2), tukey(image_width, 0.2)) if gpu_flag: @@ -106,7 +113,7 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu= taper = cp.array(taper) # get fourier transform of template - template_freq = cp.fft.fft(template * taper).conj() + template_freq = fft2(template * taper).conj() abs_template_freq = abs(template_freq) eps = abs_template_freq.max() * 1e-15 @@ -120,7 +127,7 @@ def compute_motion_shifts(scan, template, in_place=True, num_threads=8, try_gpu= avoid_copy=True) # Get fourier transform of template - template_freq = fft(template * taper).conj() # we only need the conjugate + template_freq = fft2(template * taper).conj() # we only need the conjugate abs_template_freq = abs(template_freq) eps = abs_template_freq.max() * 1e-15