From c9b4c23c0affedc77c8baef19215cf6a0512fd7e Mon Sep 17 00:00:00 2001 From: as31 Date: Mon, 5 Aug 2019 19:30:46 +0100 Subject: [PATCH] 3d motion changes demo_3d_motion_correction_pipeline.py has demo of 3d motion correction (rigid or pwrigid). Outstanding bugs: see TODO comments in motion_correction.py --- caiman/base/movies.py | 114 +++- caiman/motion_correction.py | 565 ++++++++++++++++-- caiman/source_extraction/cnmf/params.py | 13 +- .../NWB/demo_3d_motion_correction_pipeline.py | 163 +++++ 4 files changed, 785 insertions(+), 70 deletions(-) create mode 100644 use_cases/NWB/demo_3d_motion_correction_pipeline.py diff --git a/caiman/base/movies.py b/caiman/base/movies.py index e078cfa45..56be4b447 100644 --- a/caiman/base/movies.py +++ b/caiman/base/movies.py @@ -206,6 +206,86 @@ def motion_correct(self, return self, shifts, xcorrs, template + def motion_correct_3d(self, + max_shift_z=5, + max_shift_w=5, + max_shift_h=5, + num_frames_template=None, + template=None, + method='opencv', + remove_blanks=False, interpolation='cubic'): + """ + Extract shifts and motion corrected movie automatically, + + for more control consider the functions extract_shifts and apply_shifts + Disclaimer, it might change the object itself. + + Args: + max_shift,z,max_shift_w,max_shift_h: maximum pixel shifts allowed when + correcting in the axial, width, and height directions + + template: if a good template for frame by frame correlation exists + it can be passed. If None it is automatically computed + + method: depends on what is installed 'opencv' or 'skimage'. 'skimage' + is an order of magnitude slower + + num_frames_template: if only a subset of the movies needs to be loaded + for efficiency/speed reasons + + + Returns: + self: motion corected movie, it might change the object itself + + shifts : tuple, contains x, y, and z shifts and correlation with template + + xcorrs: cross correlation of the movies with the template + + template: the computed template + """ + + if template is None: # if template is not provided it is created + if num_frames_template is None: + num_frames_template = old_div( + 10e7, (self.shape[1] * self.shape[2])) + + frames_to_skip = int(np.maximum( + 1, old_div(self.shape[0], num_frames_template))) + + # sometimes it is convenient to only consider a subset of the + # movie when computing the median + submov = self[::frames_to_skip, :].copy() + templ = submov.bin_median_3d() # create template with portion of movie + shifts, xcorrs = submov.extract_shifts_3d(max_shift_z=max_shift_z, # NOTE: extract_shifts_3d has not been implemented yet - use skimage + max_shift_w=max_shift_w, max_shift_h=max_shift_h, template=templ, method=method) + submov.apply_shifts_3d( # NOTE: apply_shifts_3d has not been implemented yet + shifts, interpolation=interpolation, method=method) + template = submov.bin_median_3d() + del submov + m = self.copy() + shifts, xcorrs = m.extract_shifts_3d(max_shift_z=max_shift_z, + max_shift_w=max_shift_w, max_shift_h=max_shift_h, template=template, method=method) + m = m.apply_shifts_3d( + shifts, interpolation=interpolation, method=method) + template = (m.bin_median_3d()) + del m + else: + template = template - np.percentile(template, 8) + + # now use the good template to correct + shifts, xcorrs = self.extract_shifts_3d(max_shift_z=max_shift_z, + max_shift_w=max_shift_w, max_shift_h=max_shift_h, template=template, method=method) + self = self.apply_shifts_3d( + shifts, interpolation=interpolation, method=method) + + if remove_blanks: + max_z, max_h, max_w = np.max(shifts, axis=0) + min_z, min_h, min_w = np.min(shifts, axis=0) + self = self.crop(crop_top=max_z, crop_bottom=min_z, # NOTE: edge boundaries for z dimension need to be tested + crop_left=max_h, crop_right=-min_h + 1, crop_begin=max_w, crop_end=-min_w) + + return self, shifts, xcorrs, template + def bin_median(self, window=10): """ compute median of 3D array in along axis o by binning values @@ -226,6 +306,26 @@ def bin_median(self, window=10): num_frames = num_windows * window return np.nanmedian(np.nanmean(np.reshape(self[:num_frames], (window, num_windows, d1, d2)), axis=0), axis=0) + def bin_median_3d(self, window=10): + """ compute median of 4D array in along axis o by binning values + + Args: + mat: ndarray + input 4D matrix, (T, h, w, z) + + window: int + number of frames in a bin + + Returns: + img: + median image + + """ + T, d1, d2, d3 = np.shape(self) + num_windows = np.int(old_div(T, window)) + num_frames = num_windows * window + return np.nanmedian(np.nanmean(np.reshape(self[:num_frames], (window, num_windows, d1, d2, d3)), axis=0), axis=0) + def extract_shifts(self, max_shift_w:int=5, max_shift_h:int=5, template=None, method:str='opencv') -> Tuple[List, List]: """ Performs motion correction using the opencv matchtemplate function. At every iteration a template is built by taking the median of all frames and then used to align the other frames. @@ -1133,7 +1233,7 @@ def animate(i): def load(file_name, fr:float=30, start_time:float=0, meta_data:Dict=None, subindices=None, shape:Tuple[int,int]=None, var_name_hdf5:str='mov', in_memory:bool=False, is_behavior:bool=False, - bottom=0, top=0, left=0, right=0, channel=None, outtype=np.float32) -> Any: + bottom=0, top=0, left=0, right=0, channel=None, outtype=np.float32, is3D:bool=False) -> Any: """ load movie from file. Supports a variety of formats. tif, hdf5, npy and memory mapped. Matlab is experimental. @@ -1195,7 +1295,7 @@ def load(file_name, fr:float=30, start_time:float=0, meta_data:Dict=None, subind meta_data=meta_data, subindices=subindices, bottom=bottom, top=top, left=left, right=right, channel = channel, outtype=outtype, - var_name_hdf5=var_name_hdf5) + var_name_hdf5=var_name_hdf5,is3D=is3D) if max(top, bottom, left, right) > 0: logging.error('top bottom etc... not supported for single movie input') @@ -1369,10 +1469,12 @@ def rgb2gray(rgb): else: if type(subindices).__module__ is 'numpy': subindices = subindices.tolist() - images = np.array( - fgroup[subindices]).squeeze() - #if images.ndim > 3: - # images = images[:, 0] + if len(fgroup.shape) > 3: + images = np.array( + fgroup[subindices]).squeeze() + else: + images = np.array( + fgroup[subindices]).squeeze() #input_arr = images return movie(images.astype(outtype)) diff --git a/caiman/motion_correction.py b/caiman/motion_correction.py index 31ef94b78..75e533ab1 100644 --- a/caiman/motion_correction.py +++ b/caiman/motion_correction.py @@ -57,6 +57,8 @@ import pylab as pl import tifffile from typing import List, Optional +from skimage.transform import resize as resize_sk +from skimage.transform import warp as warp_sk import caiman as cm import caiman.base.movies @@ -96,7 +98,7 @@ class implementing motion correction operations def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig=1, splits_rig=14, num_splits_to_process_rig=None, strides=(96, 96), overlaps=(32, 32), splits_els=14, num_splits_to_process_els=[7, None], upsample_factor_grid=4, max_deviation_rigid=3, shifts_opencv=True, nonneg_movie=True, gSig_filt=None, - use_cuda=False, border_nan=True, pw_rigid=False, num_frames_split=80, var_name_hdf5='mov'): + use_cuda=False, border_nan=True, pw_rigid=False, num_frames_split=80, var_name_hdf5='mov',is3D=False): """ Constructor class for motion correction operations @@ -165,6 +167,9 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig var_name_hdf5: str, default: 'mov' If loading from hdf5, name of the variable to load + is3D: bool, default: False + Flag for 3D motion correction + Returns: self @@ -197,6 +202,7 @@ def __init__(self, fname, min_mov=None, dview=None, max_shifts=(6, 6), niter_rig self.border_nan = border_nan self.pw_rigid = pw_rigid self.var_name_hdf5 = var_name_hdf5 + self.is3D = is3D if self.use_cuda and not HAS_CUDA: logging.debug("pycuda is unavailable. Falling back to default FFT.") @@ -232,7 +238,13 @@ def motion_correct(self, template=None, save_movie=False): if self.pw_rigid: self.motion_correct_pwrigid(template=template, save_movie=save_movie) - b0 = np.ceil(np.maximum(np.max(np.abs(self.x_shifts_els)), + if self.is3D: + # TODO - error at this point after saving + b0 = np.ceil(np.maximum(np.max(np.abs(self.x_shifts_els)), + np.max(np.abs(self.y_shifts_els)), + np.max(np.abs(self.z_shifts_els)))) + else: + b0 = np.ceil(np.maximum(np.max(np.abs(self.x_shifts_els)), np.max(np.abs(self.y_shifts_els)))) else: self.motion_correct_rigid(template=template, save_movie=save_movie) @@ -246,7 +258,7 @@ def motion_correct_rigid(self, template=None, save_movie=False): Perform rigid motion correction Args: - template: ndarray 2D + template: ndarray 2D (or 3D) if known, one can pass a template to register the frames to save_movie_rigid:Bool @@ -262,7 +274,7 @@ def motion_correct_rigid(self, template=None, save_movie=False): self.templates_rig: list of templates. one for each chunk - self.shifts_rig: shifts in x and y per frame + self.shifts_rig: shifts in x and y (and z if 3D) per frame """ logging.debug('Entering Rigid Motion Correction') logging.debug(-self.min_mov) # XXX why the minus? @@ -287,7 +299,8 @@ def motion_correct_rigid(self, template=None, save_movie=False): gSig_filt=self.gSig_filt, use_cuda=self.use_cuda, border_nan=self.border_nan, - var_name_hdf5=self.var_name_hdf5) + var_name_hdf5=self.var_name_hdf5, + is3D=self.is3D) if template is None: self.total_template_rig = _total_template_rig @@ -305,7 +318,7 @@ def motion_correct_pwrigid( """Perform pw-rigid motion correction Args: - template: ndarray 2D + template: ndarray 2D (or 3D) if known, one can pass a template to register the frames to save_movie:Bool @@ -322,8 +335,9 @@ def motion_correct_pwrigid( self.templates_els: template updated by iterating over the chunks self.x_shifts_els: shifts in x per frame per patch self.y_shifts_els: shifts in y per frame per patch + self.z_shifts_els: shifts in z per frame per patch (if 3D) self.coord_shifts_els: coordinates associated to the patch for - values in x_shifts_els and y_shifts_els + values in x_shifts_els and y_shifts_els (and z_shifts_els if 3D) self.total_template_els: list of templates. one for each chunk Raises: @@ -342,20 +356,24 @@ def motion_correct_pwrigid( self.templates_els:List = [] self.x_shifts_els:List = [] self.y_shifts_els:List = [] + if self.is3D: + self.z_shifts_els:List = [] + self.coord_shifts_els:List = [] for name_cur in self.fname: for num_splits_to_process in self.num_splits_to_process_els: _fname_tot_els, new_template_els, _templates_els,\ - _x_shifts_els, _y_shifts_els, _coord_shifts_els = motion_correct_batch_pwrigid( + _x_shifts_els, _y_shifts_els, _z_shifts_els, _coord_shifts_els = motion_correct_batch_pwrigid( name_cur, self.max_shifts, self.strides, self.overlaps, -self.min_mov, dview=self.dview, upsample_factor_grid=self.upsample_factor_grid, max_deviation_rigid=self.max_deviation_rigid, splits=self.splits_els, num_splits_to_process=num_splits_to_process, num_iter=num_iter, template=self.total_template_els, shifts_opencv=self.shifts_opencv, save_movie=save_movie, nonneg_movie=self.nonneg_movie, gSig_filt=self.gSig_filt, - use_cuda=self.use_cuda, border_nan=self.border_nan, var_name_hdf5=self.var_name_hdf5) - if show_template: - pl.imshow(new_template_els) - pl.pause(.5) + use_cuda=self.use_cuda, border_nan=self.border_nan, var_name_hdf5=self.var_name_hdf5, is3D=self.is3D) + if not self.is3D: + if show_template: + pl.imshow(new_template_els) + pl.pause(.5) if np.isnan(np.sum(new_template_els)): raise Exception( 'Template contains NaNs, something went wrong. Reconsider the parameters') @@ -367,6 +385,8 @@ def motion_correct_pwrigid( self.templates_els += _templates_els self.x_shifts_els += _x_shifts_els self.y_shifts_els += _y_shifts_els + if self.is3D: + self.z_shifts_els += _z_shifts_els self.coord_shifts_els += _coord_shifts_els # return self @@ -590,7 +610,8 @@ def motion_correct_oneP_rigid( shifts_opencv=True, nonneg_movie=True, gSig_filt=gSig_filt, - border_nan=border_nan) + border_nan=border_nan, + is3D=False) mc.motion_correct_rigid(save_movie=save_movie, template=new_templ) @@ -655,7 +676,8 @@ def motion_correct_oneP_nonrigid( splits_els=splits_els, upsample_factor_grid=upsample_factor_grid, max_deviation_rigid=max_deviation_rigid, - border_nan=border_nan) + border_nan=border_nan, + is3D=False) mc.motion_correct_pwrigid(save_movie=True, template=new_templ) return mc @@ -974,6 +996,38 @@ def bin_median(mat, window=10, exclude_nans=True): return img +def bin_median_3d(mat, window=10, exclude_nans=True): + """ compute median of 4D array in along axis o by binning values + + Args: + mat: ndarray + input 4D matrix, (T, h, w, z) + + window: int + number of frames in a bin + + Returns: + img: + median image + + Raises: + Exception 'Path to template does not exist:'+template + """ + + T, d1, d2, d3 = np.shape(mat) + if T < window: + window = T + num_windows = np.int(old_div(T, window)) + num_frames = num_windows * window + if exclude_nans: + img = np.nanmedian(np.nanmean(np.reshape( + mat[:num_frames], (window, num_windows, d1, d2, d3)), axis=0), axis=0) + else: + img = np.median(np.mean(np.reshape( + mat[:num_frames], (window, num_windows, d1, d2, d3)), axis=0), axis=0) + + return img + def process_movie_parallel(arg_in): #todo: todocument fname, fr, margins_out, template, max_shift_w, max_shift_h, remove_blanks, apply_smooth, save_hdf5 = arg_in @@ -1267,18 +1321,65 @@ def close_cuda_process(n): #%% -def register_translation_3d(src_image, target_image, space = "real", - shifts_lb = None, shifts_ub = None, - max_shifts = [10,10,10], upsample_factor = 1): +def register_translation_3d(src_image, target_image, upsample_factor = 1, + space = "real", shifts_lb = None, shifts_ub = None, + max_shifts = [10,10,10]): """ Simple script for registering translation in 3D using an FFT approach. + + Args: + src_image : ndarray + Reference image. + + target_image : ndarray + Image to register. Must be same dimensionality as ``src_image``. + + upsample_factor : int, optional + Upsampling factor. Images will be registered to within + ``1 / upsample_factor`` of a pixel. For example + ``upsample_factor == 20`` means the images will be registered + within 1/20th of a pixel. Default is 1 (no upsampling) + + space : string, one of "real" or "fourier" + Defines how the algorithm interprets input data. "real" means data + will be FFT'd to compute the correlation, while "fourier" data will + bypass FFT of input data. Case insensitive. + + Returns: + shifts : ndarray + Shift vector (in pixels) required to register ``target_image`` with + ``src_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X) + + error : float + Translation invariant normalized RMS error between ``src_image`` and + ``target_image``. + + phasediff : float + Global phase difference between the two images (should be + zero if images are non-negative). + + Raises: + NotImplementedError "Error: register_translation_3d only supports " + "subpixel registration for 3D images" + + ValueError "Error: images must really be same size for " + "register_translation_3d" + + ValueError "Error: register_translation_3d only knows the \"real\" " + "and \"fourier\" values for the ``space`` argument." + """ # images must be the same shape if src_image.shape != target_image.shape: raise ValueError("Error: images must really be same size for " - "register_translation") + "register_translation_3d") + + # only 3D data makes sense right now + if src_image.ndim != 3 and upsample_factor > 1: + raise NotImplementedError("Error: register_translation_3d only supports " + "subpixel registration for 3D images") # assume complex data is already in Fourier space if space.lower() == 'fourier': @@ -1292,12 +1393,18 @@ def register_translation_3d(src_image, target_image, space = "real", target_image, dtype=np.complex64, copy=False) src_freq = np.fft.fftn(src_image_cpx) target_freq = np.fft.fftn(target_image_cpx) + else: + raise ValueError("Error: register_translation_3d only knows the \"real\" " + "and \"fourier\" values for the ``space`` argument.") shape = src_freq.shape image_product = src_freq * target_freq.conj() cross_correlation = np.fft.ifftn(image_product) - CCmax = cross_correlation.max() +# cross_correlation = ifftn(image_product) # TODO CHECK why this line is different new_cross_corr = np.abs(cross_correlation) + + CCmax = cross_correlation.max() + del cross_correlation if (shifts_lb is not None) or (shifts_ub is not None): @@ -1327,9 +1434,13 @@ def register_translation_3d(src_image, target_image, space = "real", maxima = np.unravel_index(np.argmax(new_cross_corr), new_cross_corr.shape) midpoints = np.array([np.fix(axis_size//2) for axis_size in shape]) +# maxima = np.unravel_index(np.argmax(new_cross_corr),cross_correlation.shape) +# midpoints = np.array([np.fix(old_div(axis_size, 2)) for axis_size in shape]) + shifts = np.array(maxima, dtype=np.float32) shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] + if upsample_factor > 1: shifts = old_div(np.round(shifts * upsample_factor), upsample_factor) @@ -1757,9 +1868,10 @@ def apply_shifts_dft(src_freq, shifts, diffphase, is_freq=True, border_nan=True) if min_w < 0: new_img[:, min_w:] = new_img[:, min_w-1, np.newaxis] if is3D: - new_img[:, :, :max_d] = new_img[:, :, max_d] + if max_d > 0: + new_img[:, :, :max_d] = new_img[:, :, max_d, np.newaxis] if min_d < 0: - new_img[:, :, min_d:] = new_img[:, :, min_d-1] + new_img[:, :, min_d:] = new_img[:, :, min_d-1, np.newaxis] return new_img @@ -1776,7 +1888,7 @@ def sliding_window(image, overlaps, strides): dimension of the patch strides: tuple - stride in wach dimension + stride in each dimension Returns: iterator containing five items @@ -1795,6 +1907,39 @@ def sliding_window(image, overlaps, strides): # yield the current window yield (dim_1, dim_2, x, y, image[x:x + windowSize[0], y:y + windowSize[1]]) +def sliding_window_3d(image, overlaps, strides): + """ efficiently and lazily slides a window across the image + + Args: + img:ndarray 3D + image that needs to be slices + + windowSize: tuple + dimension of the patch + + strides: tuple + stride in each dimension + + Returns: + iterator containing seven items + dim_1, dim_2, dim_3 coordinates in the patch grid + x, y, z: bottom border of the patch in the original matrix + + patch: the patch + """ + windowSize = np.add(overlaps, strides) + range_1 = list(range( + 0, image.shape[0] - windowSize[0], strides[0])) + [image.shape[0] - windowSize[0]] + range_2 = list(range( + 0, image.shape[1] - windowSize[1], strides[1])) + [image.shape[1] - windowSize[1]] + range_3 = list(range( + 0, image.shape[2] - windowSize[2], strides[2])) + [image.shape[2] - windowSize[2]] + for dim_1, x in enumerate(range_1): + for dim_2, y in enumerate(range_2): + for dim_3, z in enumerate(range_3): + # yield the current window + yield (dim_1, dim_2, dim_3, x, y, z, image[x:x + windowSize[0], y:y + windowSize[1], z:z + windowSize[2]]) + def iqr(a): return np.percentile(a, 75) - np.percentile(a, 25) @@ -2098,6 +2243,272 @@ def tile_and_correct(img, template, strides, overlaps, max_shifts, newoverlaps=N except: pass return new_img - add_to_movie, total_shifts, start_step, xy_grid + +#%% +def tile_and_correct_3d(img, template, strides, overlaps, max_shifts, newoverlaps=None, newstrides=None, upsample_factor_grid=4, + upsample_factor_fft=10, show_movie=False, max_deviation_rigid=2, add_to_movie=0, shifts_opencv=False, gSig_filt=None, + use_cuda=False, border_nan=True): + """ perform piecewise rigid motion correction iteration, by + 1) dividing the FOV in patches + 2) motion correcting each patch separately + 3) upsampling the motion correction vector field + 4) stiching back together the corrected subpatches + + Args: + img: ndaarray 3D + image to correct + + template: ndarray + reference image + + strides: tuple + strides of the patches in which the FOV is subdivided + + overlaps: tuple + amount of pixel overlaping between patches along each dimension + + max_shifts: tuple + max shifts in x, y, and z + + newstrides:tuple + strides between patches along each dimension when upsampling the vector fields + + newoverlaps:tuple + amount of pixel overlaping between patches along each dimension when upsampling the vector fields + + upsample_factor_grid: int + if newshapes or newstrides are not specified this is inferred upsampling by a constant factor the cvector field + + upsample_factor_fft: int + resolution of fractional shifts + + show_movie: boolean whether to visualize the original and corrected frame during motion correction + + max_deviation_rigid: int + maximum deviation in shifts of each patch from the rigid shift (should not be large) + + add_to_movie: if movie is too negative the correction might have some issues. In this case it is good to add values so that it is non negative most of the times + + filt_sig_size: tuple + standard deviation and size of gaussian filter to center filter data in case of one photon imaging data + + use_cuda : bool, optional + Use skcuda.fft (if available). Default: False + + border_nan : bool or string, optional + specifies how to deal with borders. (True, False, 'copy', 'min') + + Returns: + (new_img, total_shifts, start_step, xyz_grid) + new_img: ndarray, corrected image + + + """ + + img = img.astype(np.float64).copy() + template = template.astype(np.float64).copy() + + if gSig_filt is not None: + + img_orig = img.copy() + img = high_pass_filter_space(img_orig, gSig_filt) + + img = img + add_to_movie + template = template + add_to_movie + + # compute rigid shifts + rigid_shts, sfr_freq, diffphase = register_translation_3d( + img, template, upsample_factor=upsample_factor_fft, max_shifts=max_shifts) + + if max_deviation_rigid == 0: # if rigid shifts only + +# if shifts_opencv: + # NOTE: opencv does not support 3D operations - skimage is used instead + # else: + + if gSig_filt is not None: + raise Exception( + 'The use of FFT and filtering options have not been tested. Set opencv=True') + + new_img = apply_shifts_dft( # TODO: check + sfr_freq, (rigid_shts[0], rigid_shts[1], rigid_shts[2]), diffphase, border_nan=border_nan) + + return new_img - add_to_movie, (-rigid_shts[0], -rigid_shts[1]), None, None + else: + # extract patches + templates = [ + it[-1] for it in sliding_window_3d(template, overlaps=overlaps, strides=strides)] + xyz_grid = [(it[0], it[1], it[2]) for it in sliding_window_3d( + template, overlaps=overlaps, strides=strides)] + num_tiles = np.prod(np.add(xyz_grid[-1], 1)) + imgs = [it[-1] + for it in sliding_window_3d(img, overlaps=overlaps, strides=strides)] + dim_grid = tuple(np.add(xyz_grid[-1], 1)) + + if max_deviation_rigid is not None: + + lb_shifts = np.ceil(np.subtract( + rigid_shts, max_deviation_rigid)).astype(int) + ub_shifts = np.floor( + np.add(rigid_shts, max_deviation_rigid)).astype(int) + + else: + + lb_shifts = None + ub_shifts = None + + # extract shifts for each patch + shfts_et_all = [register_translation_3d( + a, b, c, shifts_lb=lb_shifts, shifts_ub=ub_shifts, max_shifts=max_shifts) for a, b, c in zip( + imgs, templates, [upsample_factor_fft] * num_tiles)] + shfts = [sshh[0] for sshh in shfts_et_all] + diffs_phase = [sshh[2] for sshh in shfts_et_all] + # create a vector field + shift_img_x = np.reshape(np.array(shfts)[:, 0], dim_grid) + shift_img_y = np.reshape(np.array(shfts)[:, 1], dim_grid) + shift_img_z = np.reshape(np.array(shfts)[:, 2], dim_grid) + diffs_phase_grid = np.reshape(np.array(diffs_phase), dim_grid) + + # shifts_opencv doesn't make sense here- replace with shifts_skimage + if shifts_opencv: + if gSig_filt is not None: + img = img_orig + + dims = img.shape + x_grid, y_grid, z_grid = np.meshgrid(np.arange(0., dims[1]).astype( + np.float32), np.arange(0., dims[0]).astype(np.float32), + np.arange(0., dims[2]).astype(np.float32)) + m_reg = warp_sk(img, np.stack((resize_sk(shift_img_y.astype(np.float32), dims) + x_grid, + resize_sk(shift_img_x.astype(np.float32), dims) + y_grid, + resize_sk(shift_img_z.astype(np.float32), dims) + z_grid),axis=0), + order=3, mode='constant') + # borderValue=add_to_movie) + total_shifts = [ + (-x, -y, -z) for x, y, z in zip(shift_img_x.reshape(num_tiles), shift_img_y.reshape(num_tiles), shift_img_z.reshape(num_tiles))] + return m_reg - add_to_movie, total_shifts, None, None + + # create automatically upsample parameters if not passed + if newoverlaps is None: + newoverlaps = overlaps + if newstrides is None: + newstrides = tuple( + np.round(np.divide(strides, upsample_factor_grid)).astype(np.int)) + + newshapes = np.add(newstrides, newoverlaps) + + imgs = [it[-1] + for it in sliding_window_3d(img, overlaps=newoverlaps, strides=newstrides)] + + xyz_grid = [(it[0], it[1], it[2]) for it in sliding_window_3d( + img, overlaps=newoverlaps, strides=newstrides)] + + start_step = [(it[3], it[4], it[5]) for it in sliding_window_3d( + img, overlaps=newoverlaps, strides=newstrides)] + + dim_new_grid = tuple(np.add(xyz_grid[-1], 1)) + + shift_img_x = resize_sk( + shift_img_x, dim_new_grid[::-1], order=3) + shift_img_y = resize_sk( + shift_img_y, dim_new_grid[::-1], order=3) + shift_img_z = resize_sk( + shift_img_z, dim_new_grid[::-1], order=3) + diffs_phase_grid_us = resize_sk( + diffs_phase_grid, dim_new_grid[::-1], order=3) + + num_tiles = np.prod(dim_new_grid) + + # what dimension shear should be looked at? shearing for 3d point scanning happens in y and z but no for plane-scanning + max_shear = np.percentile( + [np.max(np.abs(np.diff(ssshh, axis=xxsss))) for ssshh, xxsss in itertools.product( + [shift_img_x, shift_img_y], [0, 1])], 75) + + total_shifts = [ + (-x, -y, -z) for x, y, z in zip(shift_img_x.reshape(num_tiles), shift_img_y.reshape(num_tiles), shift_img_z.reshape(num_tiles))] + total_diffs_phase = [ + dfs for dfs in diffs_phase_grid_us.reshape(num_tiles)] + + if shifts_opencv: + if gSig_filt is not None: + img = img_orig + imgs = [ + it[-1] for it in sliding_window_3d(img, overlaps=newoverlaps, strides=newstrides)] + + imgs = [apply_shift_iteration(im, sh, border_nan=border_nan) + for im, sh in zip(imgs, total_shifts)] + + else: + if gSig_filt is not None: + raise Exception( + 'The use of FFT and filtering options have not been tested. Set opencv=True') + + imgs = [apply_shifts_dft(im, ( + sh[0], sh[1], sh[2]), dffphs, is_freq=False, border_nan=border_nan) for im, sh, dffphs in zip( + imgs, total_shifts, total_diffs_phase)] + + normalizer = np.zeros_like(img) * np.nan + new_img = np.zeros_like(img) * np.nan + + weight_matrix = create_weight_matrix_for_blending( + img, newoverlaps, newstrides) + + if max_shear < 0.5: + for (x, y, z), (_, _, _), im, (_, _, _), weight_mat in zip(start_step, xyz_grid, imgs, total_shifts, weight_matrix): + + prev_val_1 = normalizer[x:x + newshapes[0], y:y + newshapes[1], z:z + newshapes[2]] + + normalizer[x:x + newshapes[0], y:y + newshapes[1], z:z + newshapes[2]] = np.nansum( + np.dstack([~np.isnan(im) * 1 * weight_mat, prev_val_1]), -1) + prev_val = new_img[x:x + newshapes[0], y:y + newshapes[1], z:z + newshapes[2]] + new_img[x:x + newshapes[0], y:y + newshapes[1], z:z + newshapes[2] + ] = np.nansum(np.dstack([im * weight_mat, prev_val]), -1) + + new_img = old_div(new_img, normalizer) + + else: # in case the difference in shift between neighboring patches is larger than 0.5 pixels we do not interpolate in the overlaping area + half_overlap_x = np.int(newoverlaps[0] / 2) + half_overlap_y = np.int(newoverlaps[1] / 2) + half_overlap_z = np.int(newoverlaps[2] / 2) + + for (x, y, z), (idx_0, idx_1, idx_2), im, (_, _, _), weight_mat in zip(start_step, xyz_grid, imgs, total_shifts, weight_matrix): + + if idx_0 == 0: + x_start = x + else: + x_start = x + half_overlap_x + + if idx_1 == 0: + y_start = y + else: + y_start = y + half_overlap_y + + if idx_2 == 0: + z_start = z + else: + z_start = z + half_overlap_z + + x_end = x + newshapes[0] + y_end = y + newshapes[1] + z_end = z + newshapes[2] + new_img[x_start:x_end,y_start:y_end, + z_start:z_end] = im[x_start - x:, y_start - y:, z_start -z:] + + if show_movie: + img = apply_shifts_dft( + sfr_freq, (-rigid_shts[0], -rigid_shts[1], -rigid_shts[2]), diffphase, border_nan=border_nan) + img_show = np.vstack([new_img, img]) + + img_show = resize_sk(img_show, None, fx=1, fy=1, fz=1) + + cv2.imshow('frame', old_div(img_show, np.percentile(template, 99))) + cv2.waitKey(int(1. / 500 * 1000)) + + else: + try: + cv2.destroyAllWindows() + except: + pass + return new_img - add_to_movie, total_shifts, start_step, xyz_grid #%% def compute_flow_single_frame(frame, templ, pyr_scale=.5, levels=3, winsize=100, iterations=15, poly_n=5, @@ -2201,7 +2612,7 @@ def compute_metrics_motion_correction(fname, final_size_x, final_size_y, swap_di def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_splits_to_process=None, num_iter=1, template=None, shifts_opencv=False, save_movie_rigid=False, add_to_movie=None, nonneg_movie=False, gSig_filt=None, subidx=slice(None, None, 1), use_cuda=False, - border_nan=True, var_name_hdf5='mov'): + border_nan=True, var_name_hdf5='mov', is3D=False): """ Function that perform memory efficient hyper parallelized rigid motion corrections while also saving a memory mappable file @@ -2210,7 +2621,7 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl name of the movie to motion correct. It should not contain nans. All the loadable formats from CaImAn are acceptable max_shifts: tuple - x and y maximum allowd shifts + x and y (and z if 3D) maximum allowed shifts dview: ipyparallel view used to perform parallel computing @@ -2254,16 +2665,12 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl Exception 'The movie contains nans. Nans are not allowed!' """ + corrected_slicer = slice(subidx.start, subidx.stop, subidx.step * 10) + m = cm.load(fname, var_name_hdf5=var_name_hdf5, subindices=corrected_slicer) - - dims, T = cm.source_extraction.cnmf.utilities.get_file_size(fname, var_name_hdf5=var_name_hdf5) - - - - if T < 3000: - corrected_slicer = slice(subidx.start, subidx.stop, subidx.step * 10) + if m.shape[0] < 300: m = cm.load(fname, var_name_hdf5=var_name_hdf5, subindices=corrected_slicer) - elif T < 5000: + elif m.shape[0] < 500: corrected_slicer = slice(subidx.start, subidx.stop, subidx.step * 5) m = cm.load(fname, var_name_hdf5=var_name_hdf5, subindices=corrected_slicer) else: @@ -2280,9 +2687,14 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl if gSig_filt is not None: m = cm.movie( np.array([high_pass_filter_space(m_, gSig_filt) for m_ in m])) - - template = caiman.motion_correction.bin_median( - m.motion_correct(max_shifts[1], max_shifts[0], template=None)[0]) + if is3D: + # TODO - motion_correct_3d needs to be implemented in movies.py + template = caiman.motion_correction.bin_median_3d(m) # motion_correct_3d has not been implemented yet - instead initialize to just median image +# template = caiman.motion_correction.bin_median_3d( +# m.motion_correct_3d(max_shifts[2], max_shifts[1], max_shifts[0], template=None)[0]) + else: + template = caiman.motion_correction.bin_median( + m.motion_correct(max_shifts[1], max_shifts[0], template=None)[0]) new_templ = template if add_to_movie is None: @@ -2309,9 +2721,11 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl dview=dview, save_movie=save_movie, base_name=os.path.split( fname)[-1][:-4] + '_rig_', subidx = subidx, num_splits=num_splits_to_process, shifts_opencv=shifts_opencv, nonneg_movie=nonneg_movie, gSig_filt=gSig_filt, - use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5) - - new_templ = np.nanmedian(np.dstack([r[-1] for r in res_rig]), -1) + use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5, is3D=is3D) + if is3D: + new_templ = np.nanmedian(np.stack([r[-1] for r in res_rig]), 0) + else: + new_templ = np.nanmedian(np.dstack([r[-1] for r in res_rig]), -1) if gSig_filt is not None: new_templ = high_pass_filter_space(new_templ, gSig_filt) @@ -2324,6 +2738,10 @@ def motion_correct_batch_rigid(fname, max_shifts, dview=None, splits=56, num_spl shift_info, idxs, tmpl = rr templates.append(tmpl) shifts += [[sh[0][0], sh[0][1]] for sh in shift_info[:len(idxs)]] + # if is3D: + # shifts += [[sh[0][0], sh[0][1], sh[0][2]] for sh in shift_info[:len(idxs)]] + # else: + # shifts += [[sh[0][0], sh[0][1]] for sh in shift_info[:len(idxs)]] return fname_tot_rig, total_template, templates, shifts @@ -2331,7 +2749,7 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo dview=None, upsample_factor_grid=4, max_deviation_rigid=3, splits=56, num_splits_to_process=None, num_iter=1, template=None, shifts_opencv=False, save_movie=False, nonneg_movie=False, gSig_filt=None, - use_cuda=False, border_nan=True, var_name_hdf5='mov'): + use_cuda=False, border_nan=True, var_name_hdf5='mov', is3D=False): """ Function that perform memory efficient hyper parallelized rigid motion corrections while also saving a memory mappable file @@ -2340,10 +2758,10 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo name of the movie to motion correct. It should not contain nans. All the loadable formats from CaImAn are acceptable strides: tuple - strides of patches along x and y + strides of patches along x and y (and z if 3D) overlaps: - overlaps of patches along x and y. exmaple. If strides = (64,64) and overlaps (32,32) patches will be (96,96) + overlaps of patches along x and y (and z if 3D). example: If strides = (64,64) and overlaps (32,32) patches will be (96,96) newstrides: tuple overlaps after upsampling @@ -2352,7 +2770,7 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo strides after upsampling max_shifts: tuple - x and y maximum allowd shifts + x and y maximum allowed shifts (and z if 3D) dview: ipyparallel view used to perform parallel computing @@ -2421,7 +2839,7 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo upsample_factor_grid=upsample_factor_grid, order='F', dview=dview, save_movie=save_movie, base_name=os.path.split(fname)[-1][:-4] + '_els_', num_splits=num_splits_to_process, shifts_opencv=shifts_opencv, nonneg_movie=nonneg_movie, gSig_filt=gSig_filt, - use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5) + use_cuda=use_cuda, border_nan=border_nan, var_name_hdf5=var_name_hdf5, is3D=is3D) new_templ = np.nanmedian(np.dstack([r[-1] for r in res_el]), -1) if gSig_filt is not None: @@ -2431,17 +2849,25 @@ def motion_correct_batch_pwrigid(fname, max_shifts, strides, overlaps, add_to_mo templates = [] x_shifts = [] y_shifts = [] + z_shifts = [] coord_shifts = [] for rr in res_el: shift_info_chunk, idxs_chunk, tmpl_chunk = rr templates.append(tmpl_chunk) for shift_info, _ in zip(shift_info_chunk, idxs_chunk): - total_shift, _, xy_grid = shift_info - x_shifts.append(np.array([sh[0] for sh in total_shift])) - y_shifts.append(np.array([sh[1] for sh in total_shift])) - coord_shifts.append(xy_grid) + if is3D: + total_shift, _, xyz_grid = shift_info + x_shifts.append(np.array([sh[0] for sh in total_shift])) + y_shifts.append(np.array([sh[1] for sh in total_shift])) + z_shifts.append(np.array([sh[2] for sh in total_shift])) + coord_shifts.append(xyz_grid) + else: + total_shift, _, xy_grid = shift_info + x_shifts.append(np.array([sh[0] for sh in total_shift])) + y_shifts.append(np.array([sh[1] for sh in total_shift])) + coord_shifts.append(xy_grid) - return fname_tot_els, total_template, templates, x_shifts, y_shifts, coord_shifts + return fname_tot_els, total_template, templates, x_shifts, y_shifts, z_shifts, coord_shifts #%% in parallel @@ -2467,7 +2893,7 @@ def tile_and_correct_wrapper(params): img_name, out_fname, idxs, shape_mov, template, strides, overlaps, max_shifts,\ add_to_movie, max_deviation_rigid, upsample_factor_grid, newoverlaps, newstrides, \ - shifts_opencv, nonneg_movie, gSig_filt, is_fiji, use_cuda, border_nan, var_name_hdf5 = params + shifts_opencv, nonneg_movie, gSig_filt, is_fiji, use_cuda, border_nan, var_name_hdf5, is3D = params name, extension = os.path.splitext(img_name)[:2] extension = extension.lower() @@ -2486,13 +2912,26 @@ def tile_and_correct_wrapper(params): # elif extension == '.avi': # imgs = cm.load(img_name, subindices=np.array(idxs)) - imgs = cm.load(img_name, subindices=idxs, var_name_hdf5=var_name_hdf5) - + imgs = cm.load(img_name, subindices=idxs, var_name_hdf5=var_name_hdf5,is3D=is3D) +# if is3D: +# imgs = np.transpose(imgs,(1,0,2,3)) mc = np.zeros(imgs.shape, dtype=np.float32) for count, img in enumerate(imgs): if count % 10 == 0: logging.debug(count) - mc[count], total_shift, start_step, xy_grid = tile_and_correct(img, template, strides, overlaps, max_shifts, + if is3D: + mc[count], total_shift, start_step, xyz_grid = tile_and_correct_3d(img, template, strides, overlaps, max_shifts, + add_to_movie=add_to_movie, newoverlaps=newoverlaps, + newstrides=newstrides, + upsample_factor_grid=upsample_factor_grid, + upsample_factor_fft=10, show_movie=False, + max_deviation_rigid=max_deviation_rigid, + shifts_opencv=shifts_opencv, gSig_filt=gSig_filt, + use_cuda=use_cuda, border_nan=border_nan) + shift_info.append([total_shift, start_step, xyz_grid]) + + else: + mc[count], total_shift, start_step, xy_grid = tile_and_correct(img, template, strides, overlaps, max_shifts, add_to_movie=add_to_movie, newoverlaps=newoverlaps, newstrides=newstrides, upsample_factor_grid=upsample_factor_grid, @@ -2500,7 +2939,7 @@ def tile_and_correct_wrapper(params): max_deviation_rigid=max_deviation_rigid, shifts_opencv=shifts_opencv, gSig_filt=gSig_filt, use_cuda=use_cuda, border_nan=border_nan) - shift_info.append([total_shift, start_step, xy_grid]) + shift_info.append([total_shift, start_step, xy_grid]) if out_fname is not None: outv = np.memmap(out_fname, mode='r+', dtype=np.float32, @@ -2519,7 +2958,7 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 max_shifts=(12, 12), max_deviation_rigid=3, newoverlaps=None, newstrides=None, upsample_factor_grid=4, order='F', dview=None, save_movie=True, base_name=None, subidx = None, num_splits=None, shifts_opencv=False, nonneg_movie=False, gSig_filt=None, - use_cuda=False, border_nan=True, var_name_hdf5='mov'): + use_cuda=False, border_nan=True, var_name_hdf5='mov', is3D=False): """ """ @@ -2529,9 +2968,12 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 is_fiji = False dims, T = cm.source_extraction.cnmf.utilities.get_file_size(fname, var_name_hdf5=var_name_hdf5) - d1, d2 = dims + if is3D: + d1, d2, d3 = dims + else: + d1, d2 = dims - if isinstance(splits, int): + if type(splits) is int: if subidx is None: rng = range(T) else: @@ -2542,13 +2984,15 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 else: idxs = splits save_movie = False - if template is None: raise Exception('Not implemented') - shape_mov = (d1 * d2, T) + if is3D: + shape_mov = (d1 * d2 * d3, T) + else: + shape_mov = (d1 * d2, T) + dims = d1, d2 - dims = d1, d2 if num_splits is not None: idxs = np.array(idxs)[np.random.randint(0, len(idxs), num_splits)] save_movie = False @@ -2566,12 +3010,11 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 fname_tot = None pars = [] - for idx in idxs: pars.append([fname, fname_tot, idx, shape_mov, template, strides, overlaps, max_shifts, np.array( add_to_movie, dtype=np.float32), max_deviation_rigid, upsample_factor_grid, newoverlaps, newstrides, shifts_opencv, nonneg_movie, gSig_filt, is_fiji, - use_cuda, border_nan, var_name_hdf5]) + use_cuda, border_nan, var_name_hdf5, is3D]) if dview is not None: logging.info('** Starting parallel motion correction **') diff --git a/caiman/source_extraction/cnmf/params.py b/caiman/source_extraction/cnmf/params.py index f065dbce6..c81bbafcf 100644 --- a/caiman/source_extraction/cnmf/params.py +++ b/caiman/source_extraction/cnmf/params.py @@ -476,7 +476,10 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1), gSig_filt: int or None, default: None size of kernel for high pass spatial filtering in 1p data. If None no spatial filtering is performed - + + is3D: bool, default: False + flag for 3D recordings for motion correction + max_deviation_rigid: int, default: 3 maximum deviation in pixels between rigid shifts and shifts of individual patches @@ -704,8 +707,9 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1), } self.motion = { - 'border_nan': 'copy', # flag for allowing NaN in the boundaries + 'border_nan': 'copy', # flag for allowing NaN in the boundaries 'gSig_filt': None, # size of kernel for high pass spatial filtering in 1p data + 'is3D': False, # flag for 3D recordings for motion correction 'max_deviation_rigid': 3, # maximum deviation between rigid and non-rigid 'max_shifts': (6, 6), # maximum shifts per dimension (in pixels) 'min_mov': None, # minimum value of movie @@ -730,7 +734,10 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1), if self.data['fnames'] is not None: if isinstance(self.data['fnames'], str): self.data['fnames'] = [self.data['fnames']] - T = get_file_size(self.data['fnames'], var_name_hdf5=self.data['var_name_hdf5'])[1] + if self.motion['is3D']: + T = get_file_size(self.data['fnames'], var_name_hdf5=self.data['var_name_hdf5'])[0][0] + else: + T = get_file_size(self.data['fnames'], var_name_hdf5=self.data['var_name_hdf5'])[1] if len(self.data['fnames']) > 1: T = T[0] num_splits = T//max(self.motion['num_frames_split'],10) diff --git a/use_cases/NWB/demo_3d_motion_correction_pipeline.py b/use_cases/NWB/demo_3d_motion_correction_pipeline.py new file mode 100644 index 000000000..0806951a9 --- /dev/null +++ b/use_cases/NWB/demo_3d_motion_correction_pipeline.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +""" +This script follows closely the demo_pipeline_nwb.py script but for a 3d +dataset and only through the motion correction process. +""" + +import cv2 +import glob +import logging +import matplotlib.pyplot as plt +import numpy as np +import os + +try: + cv2.setNumThreads(0) +except: + pass + +try: + if __IPYTHON__: + # this is used for debugging purposes only. allows to reload classes + # when changed + get_ipython().magic('load_ext autoreload') + get_ipython().magic('autoreload 2') +except NameError: + pass + + +import caiman as cm +from caiman.motion_correction import MotionCorrect +from caiman.source_extraction.cnmf import cnmf as cnmf +from caiman.source_extraction.cnmf import params as params +from caiman.utils.utils import download_demo +from caiman.paths import caiman_datadir + +# %% +# Set up the logger (optional); change this if you like. +# You can log to a file using the filename parameter, or make the output more +# or less verbose by setting level to logging.DEBUG, logging.INFO, +# logging.WARNING, or logging.ERROR + +logging.basicConfig(format= + "%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s]"\ + "[%(process)d] %(message)s", + level=logging.WARNING) + +#%% +def main(): + pass # For compatibility between running under Spyder and the CLI + +#%% Select file(s) to be processed (download if not present) +# fnames = [os.path.join(caiman_datadir(), 'example_movies/sampled3dMovieRigid.nwb')] + fnames = [os.path.join(caiman_datadir(), 'example_movies/sampled3dMovie2.nwb')] + # filename to be created or processed + # dataset dependent parameters + fr = 5 # imaging rate in frames per second + decay_time = 0.4 # length of a typical transient in seconds + + starting_time = 0. +#%% load the file and save it in the NWB format (if it doesn't exist already) + if not os.path.exists(fnames[0]): +# fnames_orig = [os.path.join(caiman_datadir(), 'example_movies/sampled3dMovieRigid.h5')] # filename to be processed + fnames_orig = [os.path.join(caiman_datadir(), 'example_movies/sampled3dMovie2.h5')] # filename to be processed + orig_movie = cm.load(fnames_orig, fr=fr, is3D=True) + # orig_movie = cm.load_movie_chain(fnames_orig,fr=fr,is3D=True) + + # save file in NWB format with various additional info + orig_movie.save(fnames[0], sess_desc='test', identifier='demo 3d', + exp_desc='demo movie', imaging_plane_description='multi plane', + emission_lambda=520.0, indicator='none', + location='visual cortex', starting_time=starting_time, + experimenter='NAOMi', lab_name='Tank Lab', + institution='Princeton U', + experiment_description='Experiment Description', + session_id='Session 1', + var_name_hdf5='TwoPhotonSeries') +#%% First setup some parameters for data and motion correction + + + # motion correction parameters + dxy = (1., 1., 5.) # spatial resolution in x, y, and z in (um per pixel) + # note the lower than usual spatial resolution here + max_shift_um = (10., 10., 10.) # maximum shift in um + patch_motion_um = (50., 50., 30.) # patch size for non-rigid correction in um +# pw_rigid = False # flag to select rigid vs pw_rigid motion correction + niter_rig = 1 + pw_rigid = True # flag to select rigid vs pw_rigid motion correction + # maximum allowed rigid shift in pixels + max_shifts = [int(a/b) for a, b in zip(max_shift_um, dxy)] + # start a new patch for pw-rigid motion correction every x pixels + strides = tuple([int(a/b) for a, b in zip(patch_motion_um, dxy)]) + # overlap between pathes (size of patch in pixels: strides+overlaps) + overlaps = (24, 24, 4) + # maximum deviation allowed for patch with respect to rigid shifts + max_deviation_rigid = 3 + is3D = True + + mc_dict = { + 'fnames': fnames, + 'fr': fr, + 'decay_time': decay_time, + 'dxy': dxy, + 'pw_rigid': pw_rigid, + 'niter_rig': niter_rig, + 'max_shifts': max_shifts, + 'strides': strides, + 'overlaps': overlaps, + 'max_deviation_rigid': max_deviation_rigid, + 'border_nan': 'copy', + 'var_name_hdf5': 'acquisition/TwoPhotonSeries', + 'is3D': is3D, + 'splits_els': 12, + 'splits_rig': 12 + } + + opts = params.CNMFParams(params_dict=mc_dict) #NOTE: default adjustments of parameters are not set yet, manually setting them now + +# %% play the movie (optional) + # playing the movie using opencv. It requires loading the movie in memory. + # To close the video press q + display_images = True + if display_images: + m_orig = cm.load_movie_chain(fnames, var_name_hdf5=opts.data['var_name_hdf5'],is3D=True) + T, h, w, z = m_orig.shape # Time, plane, height, weight + m_orig = np.reshape(np.transpose(m_orig, (3,0,1,2)), (T*z, h, w)) + ds_ratio = 0.2 + moviehandle = m_orig.resize(1, 1, ds_ratio) + moviehandle.play(q_max=99.5, fr=60, magnification=2) + +# %% start a cluster for parallel processing +# NOTE: ignore dview right now for debugging purposes +# c, dview, n_processes = cm.cluster.setup_cluster( +# backend='local', n_processes=None, single_thread=False) + +# %%% MOTION CORRECTION + # first we create a motion correction object with the specified parameters + mc = MotionCorrect(fnames, dview=None, var_name_hdf5=opts.data['var_name_hdf5'], **opts.get_group('motion')) +# mc = MotionCorrect(fnames, dview=dview, var_name_hdf5=opts.data['var_name_hdf5'], **opts.get_group('motion')) + # note that the file is not loaded in memory + +# %% Run (piecewise-rigid motion) correction using NoRMCorre + mc.motion_correct(save_movie=True) + +# %% compare with original movie + if display_images: + m_orig = cm.load_movie_chain(fnames, var_name_hdf5=opts.data['var_name_hdf5'],is3D=True) + T, h, w, z = m_orig.shape # Time, plane, height, weight + m_orig = np.reshape(np.transpose(m_orig, (3,0,1,2)), (T*z, h, w)) + + m_els = cm.load(mc.mmap_file,is3D=True) + m_els = np.reshape(np.transpose(m_els, (3,0,1,2)), (T*z, h, w)) + + ds_ratio = 0.2 + moviehandle = cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov*mc.nonneg_movie, + m_els.resize(1, 1, ds_ratio)], axis=2) + moviehandle.play(fr=60, q_max=99.5, magnification=2) # press q to exit + +# %% +# This is to mask the differences between running this demo in Spyder +# versus from the CLI +if __name__ == "__main__": + main() \ No newline at end of file