diff --git a/cobrawap/pipeline/stage04_wave_detection/Snakefile b/cobrawap/pipeline/stage04_wave_detection/Snakefile index f761eb0a..d507aaec 100644 --- a/cobrawap/pipeline/stage04_wave_detection/Snakefile +++ b/cobrawap/pipeline/stage04_wave_detection/Snakefile @@ -59,7 +59,7 @@ use rule template as optical_flow with: script = SCRIPTS / 'optical_flow.py' params: params('alpha', 'max_Niter', 'convergence_limit', 'gaussian_sigma', - 'derivative_filter', 'use_phases', config=config) + 'derivative_filter', 'use_phases', 'n_jobs', 'max_padding_iterations', config=config) output: Path('{dir}') / 'optical_flow' / f'optical_flow.{config.NEO_FORMAT}', output_img = Path('{dir}') / 'optical_flow' / f'optical_flow.{config.PLOT_FORMAT}' diff --git a/cobrawap/pipeline/stage04_wave_detection/scripts/optical_flow.py b/cobrawap/pipeline/stage04_wave_detection/scripts/optical_flow.py index f983cdd3..9737bc94 100644 --- a/cobrawap/pipeline/stage04_wave_detection/scripts/optical_flow.py +++ b/cobrawap/pipeline/stage04_wave_detection/scripts/optical_flow.py @@ -10,6 +10,8 @@ import numpy as np from copy import copy import matplotlib.pyplot as plt +from joblib import Parallel, delayed +from tqdm import tqdm from utils.io_utils import load_neo, write_neo, save_plot from utils.parse import none_or_str, str_to_bool from utils.neo_utils import imagesequence_to_analogsignal, analogsignal_to_imagesequence @@ -34,6 +36,10 @@ help='Filter kernel to use for calculating spatial derivatives') CLI.add_argument("--use_phases", nargs='?', type=str_to_bool, default=False, help='whether to use signal phase instead of amplitude') +CLI.add_argument("--max_padding_iterations", nargs='?', type=int, default=1000, + help='maximum number of padding interpolation iterations') +CLI.add_argument("--n_jobs", nargs='?', type=int, default=1, + help='number of parallel horn_schunck_step executions') def horn_schunck_step(frame, next_frame, alpha, max_Niter, convergence_limit, kernelHS, kernelT, kernelX, kernelY, @@ -98,60 +104,81 @@ def compute_derivatives(frame, next_frame, kernelX, kernelY, def horn_schunck(frames, alpha, max_Niter, convergence_limit, kernelHS, kernelT, kernelX, kernelY, - are_phases=False): + are_phases=False, n_jobs=1, max_padding_iterations=1000): nan_channels = np.where(np.bitwise_not(np.isfinite(frames[0]))) - frames = interpolate_empty_sites(frames, are_phases) - - vector_frames = np.zeros(frames.shape, dtype=complex) - - for i, frame in enumerate(frames[:-1]): - next_frame = frames[i+1] - - vector_frames[i] = horn_schunck_step(frame, - next_frame, - alpha=alpha, - max_Niter=max_Niter, - convergence_limit=convergence_limit, - kernelHS=kernelHS, - kernelT=kernelT, - kernelX=kernelX, - kernelY=kernelY, - are_phases=are_phases) - vector_frames[i][nan_channels] = np.nan + np.nan*1j + frames = interpolate_empty_sites(frames, are_phases=are_phases, + max_iters=max_padding_iterations, n_jobs=n_jobs) + + print("Calculating horn schunck steps") + vector_frames = Parallel(n_jobs=n_jobs)( + delayed(horn_schunck_step) + (frames[i], + frames[i+1], + alpha=alpha, + max_Niter=max_Niter, + convergence_limit=convergence_limit, + kernelHS=kernelHS, + kernelT=kernelT, + kernelX=kernelX, + kernelY=kernelY, + are_phases=are_phases) + for i in tqdm(range(len(frames[:-1])), ascii=True)) + + vector_frames = np.asarray(vector_frames, dtype=complex) + vector_frames[:,nan_channels[0],nan_channels[1]] = np.nan + np.nan*1j frames[:,nan_channels[0],nan_channels[1]] = np.nan return vector_frames -def interpolate_empty_sites(frames, are_phases=False): - if np.isfinite(frames).all(): - return frames - dim_y, dim_x = frames[0].shape - grid_y, grid_x = np.meshgrid([-1,0,1],[-1,0,1], indexing='ij') - - for i, frame in enumerate(frames): - new_frame = copy(frame) - while not np.isfinite(new_frame).all(): - y, x = np.where(np.bitwise_not(np.isfinite(new_frame))) +def interpolation_step(frame, grid_y, grid_x, are_phases=False, max_iters=1000): + dim_y, dim_x = frame.shape + if np.isfinite(frame).all(): + return frame + else: + for _ in range(max_iters): + new_frame = copy(frame) + y, x = np.where(np.bitwise_not(np.isfinite(frame))) # loop over nan-sites - for xi, yi in zip(x,y): + for xi, yi in zip(x, y): + neighbours = [] # collect neighbours of each site for dx, dy in zip(grid_x.flatten(), grid_y.flatten()): - xn = xi+dx - yn = yi+dy + xn = xi + dx + yn = yi + dy if (0 <= xn) & (xn < dim_x) & (0 <= yn) & (yn < dim_y): - neighbours.append(frames[i, yn, xn]) + neighbours.append(frame[yn, xn]) # average over neihbour values if np.isfinite(neighbours).any(): if are_phases: - vectors = np.exp(1j*np.array(neighbours)) - new_frame[yi,xi] = np.angle(np.nansum(vectors)) + vectors = np.exp(1j * np.array(neighbours)) + new_frame[yi, xi] = np.angle(np.nansum(vectors)) else: - new_frame[yi,xi] = np.nansum(neighbours) - frames[i] = new_frame - return frames + new_frame[yi, xi] = np.nansum(neighbours) + frame = new_frame + if np.isfinite(frame).all(): + break + return frame + + +def interpolate_empty_sites(frames, are_phases=False, max_iters=1000, n_jobs=1): + if np.isfinite(frames).all(): + return frames + else: + grid_y, grid_x = np.meshgrid([-1, 0, 1], [-1, 0, 1], indexing='ij') + print('Frames interpolation') + frames = Parallel(n_jobs=n_jobs)( + delayed(interpolation_step) + (frames[i], + grid_y=grid_y, + grid_x=grid_x, + are_phases=are_phases, + max_iters=max_iters) + for i in tqdm(range(len(frames)), ascii=True)) + frames = np.asarray(frames) + return frames def smooth_frames(frames, sigma): @@ -259,7 +286,10 @@ def is_phase_signal(signal, use_phases): kernelY=kernel.y, kernelT=kernelT, kernelHS=kernelHS, - are_phases=args.use_phases) + are_phases=args.use_phases, + n_jobs=args.n_jobs, + max_padding_iterations=args.max_padding_iterations) + if np.sum(args.gaussian_sigma): vector_frames = smooth_frames(vector_frames, sigma=args.gaussian_sigma)