diff --git a/brainbox/behavior/dlc.py b/brainbox/behavior/dlc.py index fff439fe5..c5451d561 100644 --- a/brainbox/behavior/dlc.py +++ b/brainbox/behavior/dlc.py @@ -1,17 +1,22 @@ """ Set of functions to deal with dlc data """ -import numpy as np -import scipy.interpolate as interpolate import logging +import pandas as pd import warnings -from one.api import ONE + +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import scipy.interpolate as interpolate +from scipy.stats import zscore + from ibllib.dsp.smooth import smooth_interpolate_savgol +from brainbox.processing import bincount2D +import brainbox.behavior.wheel as bbox_wheel logger = logging.getLogger('ibllib') -one = ONE() - SAMPLING = {'left': 60, 'right': 150, 'body': 30} @@ -19,6 +24,25 @@ 'right': 1, 'body': 1} +T_BIN = 0.02 # sec +WINDOW_LEN = 2 # sec +WINDOW_LAG = -0.5 # sec + + +# For plotting we use a window around the event the data is aligned to WINDOW_LAG before and WINDOW_LEN after the event +def plt_window(x): + return x + WINDOW_LAG, x + WINDOW_LEN + + +def insert_idx(array, values): + idx = np.searchsorted(array, values, side="left") + # Choose lower index if insertion would be after last index or if lower index is closer + idx[idx == len(array)] -= 1 + idx[np.where(abs(values - array[idx - 1]) < abs(values - array[idx]))] -= 1 + # If 0 index was reduced, revert + idx[idx == -1] = 0 + return idx + def likelihood_threshold(dlc, threshold=0.9): """ @@ -184,10 +208,10 @@ def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1): """ # set framerate of camera if camera == 'left': - fr = 60 # set by hardware + fr = SAMPLING['left'] # set by hardware window = 31 # works well empirically elif camera == 'right': - fr = 150 # set by hardware + fr = SAMPLING['right'] # set by hardware window = 75 # works well empirically else: raise NotImplementedError("camera has to be 'left' or 'right") @@ -214,3 +238,301 @@ def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1): diameter_smoothed[(b + 1):(e + 1)] = np.nan # offset by 1 due to earlier diff return diameter_smoothed + + +def plot_trace_on_frame(frame, dlc_df, cam): + """ + Plots dlc traces as scatter plots on a frame of the video. + For left and right video also plots whisker pad and eye and tongue zoom. + + :param frame: np.array, single video frame to plot on + :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace + :param cam: str, which camera to process ('left', 'right', 'body') + :returns: matplolib.axis + """ + # Define colors + colors = {'tail_start': '#636EFA', + 'nose_tip': '#636EFA', + 'paw_l': '#EF553B', + 'paw_r': '#00CC96', + 'pupil_bottom_r': '#AB63FA', + 'pupil_left_r': '#FFA15A', + 'pupil_right_r': '#19D3F3', + 'pupil_top_r': '#FF6692', + 'tongue_end_l': '#B6E880', + 'tongue_end_r': '#FF97FF'} + # Threshold the dlc traces + dlc_df = likelihood_threshold(dlc_df) + # Features without tube + features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x]) + # Normalize the number of points across cameras + dlc_df_norm = pd.DataFrame() + for feat in features: + dlc_df_norm[f'{feat}_x'] = dlc_df[f'{feat}_x'][0::int(SAMPLING[cam] / 10)] + dlc_df_norm[f'{feat}_y'] = dlc_df[f'{feat}_y'][0::int(SAMPLING[cam] / 10)] + # Scatter + plt.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=0.05, s=2, label=feat, c=colors[feat]) + + plt.axis('off') + plt.imshow(frame, cmap='gray') + plt.tight_layout() + + ax = plt.gca() + if cam == 'body': + plt.title(f'{cam.capitalize()} camera') + return ax + # For left and right cam plot whisker pad rectangle + # heuristic: square with side length half the distance between nose and pupil and anchored on midpoint + p_nose = np.array(dlc_df[['nose_tip_x', 'nose_tip_y']].mean()) + p_pupil = np.array(dlc_df[['pupil_top_r_x', 'pupil_top_r_y']].mean()) + p_anchor = np.mean([p_nose, p_pupil], axis=0) + dist = np.linalg.norm(p_nose - p_pupil) + rect = matplotlib.patches.Rectangle((int(p_anchor[0] - dist / 4), int(p_anchor[1])), int(dist / 2), int(dist / 3), + linewidth=1, edgecolor='lime', facecolor='none') + ax.add_patch(rect) + # Plot eye region zoom + inset_anchor = 0 if cam == 'right' else 0.5 + ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) + ax_ins.imshow(frame, cmap='gray', origin="lower") + for feat in features: + ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) + ax_ins.set_xlim(int(p_pupil[0] - 33 * RESOLUTION[cam] / 2), int(p_pupil[0] + 33 * RESOLUTION[cam] / 2)) + ax_ins.set_ylim(int(p_pupil[1] + 38 * RESOLUTION[cam] / 2), int(p_pupil[1] - 28 * RESOLUTION[cam] / 2)) + ax_ins.axis('off') + # Plot tongue region zoom + p1 = np.array(dlc_df[['tube_top_x', 'tube_top_y']].mean()) + p2 = np.array(dlc_df[['tube_bottom_x', 'tube_bottom_y']].mean()) + p_tongue = np.nanmean([p1, p2], axis=0) + inset_anchor = 0 if cam == 'left' else 0.5 + ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5]) + ax_ins.imshow(frame, cmap='gray', origin="upper") + for feat in features: + ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat]) + ax_ins.set_xlim(int(p_tongue[0] - 60 * RESOLUTION[cam] / 2), int(p_tongue[0] + 100 * RESOLUTION[cam] / 2)) + ax_ins.set_ylim(int(p_tongue[1] + 60 * RESOLUTION[cam] / 2), int(p_tongue[1] - 100 * RESOLUTION[cam] / 2)) + ax_ins.axis('off') + + plt.title(f'{cam.capitalize()} camera') + return ax + + +def plot_wheel_position(wheel_position, wheel_time, trials_df): + """ + Plots wheel position across trials, color by which side was chosen + + :param wheel_position: np.array, interpolated wheel position + :param wheel_time: np.array, interpolated wheel timestamps + :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial) + :returns: matplotlib.axis + """ + # Interpolate wheel data + wheel_position, wheel_time = bbox_wheel.interpolate_position(wheel_time, wheel_position, freq=1 / T_BIN) + # Create a window around the stimulus onset + start_window, end_window = plt_window(trials_df['stimOn_times']) + # Translating the time window into an index window + start_idx = insert_idx(wheel_time, start_window) + end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') + # Getting the wheel position for each window, normalize to first value of each window + trials_df['wheel_position'] = [wheel_position[start_idx[w]: end_idx[w]] - wheel_position[start_idx[w]] + for w in range(len(start_idx))] + # Plotting + times = np.arange(len(trials_df['wheel_position'][0])) * T_BIN + WINDOW_LAG + for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']): + side_df = trials_df[trials_df['choice'] == side] + for idx in side_df.index: + plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05) + plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn') + + plt.axvline(x=0, linestyle='--', c='k', label='stimOn') + plt.axhline(y=-0.26, linestyle='--', c='g', label='reward') + plt.ylim([-0.27, 0.27]) + plt.xlabel('time [sec]') + plt.ylabel('wheel position [rad]') + plt.legend(loc='center right') + plt.title('Wheel position') + plt.tight_layout() + + return plt.gca() + + +def _bin_window_licks(lick_times, trials_df): + """ + Helper function to bin and window the lick times and get them into trials df for plotting + + :param lick_times: np.array, timestamps of lick events + :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) + :returns: pd.DataFrame with binned, windowed lick times for plotting + """ + # Bin the licks + lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN) + lick_bins = np.squeeze(lick_bins) + start_window, end_window = plt_window(trials_df['feedback_times']) + # Translating the time window into an index window + start_idx = insert_idx(bin_times, start_window) + end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64') + # Get the binned licks for each window + trials_df['lick_bins'] = [lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] + # Remove windows that the exceed bins + trials_df['end_idx'] = end_idx + trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)] + return trials_df + + +def plot_lick_hist(lick_times, trials_df): + """ + Plots histogramm of lick events aligned to feedback time, separate for correct and incorrect trials + + :param lick_times: np.array, timestamps of lick events + :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and + 'feedbackType' (1 for correct, -1 for incorrect trials) + :returns: matplotlib axis + """ + licks_df = _bin_window_licks(lick_times, trials_df) + # Plot + times = np.arange(len(licks_df['lick_bins'][0])) * T_BIN + WINDOW_LAG + correct = licks_df[licks_df['feedbackType'] == 1]['lick_bins'] + incorrect = licks_df[licks_df['feedbackType'] == -1]['lick_bins'] + plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), + c='k', label='correct trial') + plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, incorrect.values))).mean(axis=1), + c='gray', label='incorrect trial') + plt.axvline(x=0, label='feedback', linestyle='--', c='purple') + plt.title('Lick events') + plt.xlabel('time [sec]') + plt.ylabel('lick events [a.u.]') + plt.legend(loc='lower right') + return plt.gca() + + +def plot_lick_raster(lick_times, trials_df): + """ + Plots lick raster for correct trials + + :param lick_times: np.array, timestamps of lick events + :param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and + feedbackType (1 for correct, -1 for incorrect trials) + :returns: matplotlib.axis + """ + licks_df = _bin_window_licks(lick_times, trials_df) + plt.imshow(list(licks_df[licks_df['feedbackType'] == 1]['lick_bins']), aspect='auto', + extent=[-0.5, 1.5, len(licks_df['lick_bins'][0]), 0], cmap='gray_r') + plt.xticks([-0.5, 0, 0.5, 1, 1.5]) + plt.ylabel('trials') + plt.xlabel('time [sec]') + plt.axvline(x=0, label='feedback', linestyle='--', c='purple') + plt.title('Lick events per correct trial') + plt.tight_layout() + return plt.gca() + + +def plot_motion_energy_hist(camera_dict, trials_df): + """ + Plots mean motion energy of given cameras, aligned to stimulus onset. + + :param camera_dict: dict, one key for each camera to be plotted (e.g. 'left'), value is another dict with items + 'motion_energy' (np.array, motion energy calculated from this camera) and + 'times' (np.array, camera timestamps) + :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) + :returns: matplotlib.axis + """ + colors = {'left': '#bd7a98', + 'right': '#2b6f39', + 'body': '#035382'} + + start_window, end_window = plt_window(trials_df['stimOn_times']) + for cam in camera_dict.keys(): + try: + motion_energy = zscore(camera_dict[cam]['motion_energy'], nan_policy='omit') + start_idx = insert_idx(camera_dict[cam]['times'], start_window) + end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') + me_all = [motion_energy[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] + times = np.arange(len(me_all[0])) / SAMPLING[cam] + WINDOW_LAG + me_mean = np.mean(me_all, axis=0) + me_std = np.std(me_all, axis=0) / np.sqrt(len(me_all)) + plt.plot(times, me_mean, label=f'{cam} cam', color=colors[cam], linewidth=2) + plt.fill_between(times, me_mean + me_std, me_mean - me_std, color=colors[cam], alpha=0.2) + except AttributeError: + logger.warning(f"Cannot load motion energy AND times data for {cam} camera") + + plt.xticks([-0.5, 0, 0.5, 1, 1.5]) + plt.ylabel('z-scored motion energy [a.u.]') + plt.xlabel('time [sec]') + plt.axvline(x=0, label='stimOn', linestyle='--', c='k') + plt.legend(loc='lower right') + plt.title('Motion Energy') + return plt.gca() + + +def plot_speed_hist(dlc_df, cam_times, trials_df, feature='paw_r', cam='left', legend=True): + """ + Plots speed histogram of a given dlc feature, aligned to stimulus onset, separate for correct and incorrect trials + + :param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace + :param cam_times: np.array, camera timestamps + :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) + :param feature: str, feature with trace in dlc_df for which to plot speed hist, default is 'paw_r' + :param cam: str, camera to use ('body', 'left', 'right') default is 'left' + :param legend: bool, whether to add legend to the plot, default is True + :returns: matplotlib.axis + """ + # Threshold the dlc traces + dlc_df = likelihood_threshold(dlc_df) + # Get speeds + speeds = get_speed(dlc_df, cam_times, camera=cam, feature=feature) + # Windows aligned to align_to + start_window, end_window = plt_window(trials_df['stimOn_times']) + start_idx = insert_idx(cam_times, start_window) + end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') + # Add speeds to trials_df + trials_df[f'speed_{feature}'] = [speeds[start_idx[i]:end_idx[i]] for i in range(len(start_idx))] + # Plot + times = np.arange(len(trials_df[f'speed_{feature}'][0])) / SAMPLING[cam] + WINDOW_LAG + # Need to expand the series of lists into a dataframe first, for the nan skipping to work + correct = trials_df[trials_df['feedbackType'] == 1][f'speed_{feature}'] + incorrect = trials_df[trials_df['feedbackType'] == -1][f'speed_{feature}'] + plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1), + c='k', label='correct trial') + plt.plot(times, pd.DataFrame.from_dict(dict(zip(incorrect.index, incorrect.values))).mean(axis=1), + c='gray', label='incorrect trial') + plt.axvline(x=0, label='stimOn', linestyle='--', c='r') + plt.title(f'{feature.split("_")[0].capitalize()} speed') + plt.xticks([-0.5, 0, 0.5, 1, 1.5]) + plt.xlabel('time [sec]') + plt.ylabel('speed [px/sec]') + if legend: + plt.legend() + + return plt.gca() + + +def plot_pupil_diameter_hist(pupil_diameter, cam_times, trials_df, cam='left'): + """ + Plots histogram of pupil diameter aligned to simulus onset and feedback time. + + :param pupil_diameter: np.array, (smoothed) pupil diameter estimate + :param cam_times: np.array, camera timestamps + :param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) and + feedback_times (time of feedback for each trial) + :param cam: str, camera to use ('body', 'left', 'right') default is 'left' + :returns: matplotlib.axis + """ + for align_to, color in zip(['stimOn_times', 'feedback_times'], ['red', 'purple']): + start_window, end_window = plt_window(trials_df[align_to]) + start_idx = insert_idx(cam_times, start_window) + end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64') + # Per trial norm + pupil_all = [zscore(list(pupil_diameter[start_idx[i]:end_idx[i]])) for i in range(len(start_idx))] + pupil_all_norm = [trial - trial[0] for trial in pupil_all] + + pupil_mean = np.nanmean(pupil_all_norm, axis=0) + pupil_std = np.nanstd(pupil_all_norm, axis=0) / np.sqrt(len(pupil_all_norm)) + times = np.arange(len(pupil_all_norm[0])) / SAMPLING[cam] + WINDOW_LAG + + plt.plot(times, pupil_mean, label=align_to.split("_")[0], color=color) + plt.fill_between(times, pupil_mean + pupil_std, pupil_mean - pupil_std, color=color, alpha=0.5) + plt.axvline(x=0, linestyle='--', c='k') + plt.title('Pupil diameter') + plt.xlabel('time [sec]') + plt.xticks([-0.5, 0, 0.5, 1, 1.5]) + plt.ylabel('pupil diameter [px]') + plt.legend(loc='lower right', title='aligned to') diff --git a/brainbox/io/one.py b/brainbox/io/one.py index 69feee567..b0dc624cf 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -12,7 +12,7 @@ from ibllib.io import spikeglx from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times -from ibllib.ephys.neuropixel import SITES_COORDINATES, TIP_SIZE_UM +from ibllib.ephys.neuropixel import SITES_COORDINATES, TIP_SIZE_UM, trace_header from ibllib.atlas import atlas from ibllib.atlas import AllenAtlas from ibllib.pipes import histology @@ -221,7 +221,7 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision= return channels -def channel_locations_interpolation(channels_aligned, channels, brain_regions=None): +def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): """ oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field @@ -238,6 +238,10 @@ def channel_locations_interpolation(channels_aligned, channels, brain_regions=No 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' :return: Bunch or dictionary of channels with brain coordinates keys """ + NEUROPIXEL_VERSION = 1 + h = trace_header(version=NEUROPIXEL_VERSION) + if channels is None: + channels = {'localCoordinates': np.c_[h['x'], h['y']]} nch = channels['localCoordinates'].shape[0] if set(['x', 'y', 'z']).issubset(set(channels_aligned.keys())): channels_aligned = _channels_bunch2alf(channels_aligned) @@ -245,9 +249,7 @@ def channel_locations_interpolation(channels_aligned, channels, brain_regions=No aligned_depths = channels_aligned['localCoordinates'][:, 1] else: # this is a edge case for a few spike sorting sessions assert channels_aligned['mlapdv'].shape[0] == 384 - NEUROPIXEL_VERSION = 1 - from ibllib.ephys.neuropixel import trace_header - aligned_depths = trace_header(version=NEUROPIXEL_VERSION)['y'] + aligned_depths = h['y'] depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) channels['mlapdv'] = np.zeros((nch, 3)) diff --git a/ibllib/dsp/voltage.py b/ibllib/dsp/voltage.py index d3e93f1cb..8c30986e6 100644 --- a/ibllib/dsp/voltage.py +++ b/ibllib/dsp/voltage.py @@ -109,6 +109,37 @@ def fk(x, si=.002, dx=1, vbounds=None, btype='highpass', ntr_pad=0, ntr_tap=None return xf / gain +def car(x, collection=None, lagc=300, butter_kwargs=None): + """ + Applies common average referencing with optional automatic gain control + :param x: the input array to be filtered. dimension, the filtering is considering + axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns) + :param collection: + :param lagc: window size for time domain automatic gain control (no agc otherwise) + :param butter_kwargs: filtering parameters: defaults: {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} + :return: + """ + if butter_kwargs is None: + butter_kwargs = {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} + if collection is not None: + xout = np.zeros_like(x) + for c in np.unique(collection): + sel = collection == c + xout[sel, :] = kfilt(x=x[sel, :], ntr_pad=0, ntr_tap=None, collection=None, + butter_kwargs=butter_kwargs) + return xout + + # apply agc and keep the gain in handy + if not lagc: + xf = np.copy(x) + gain = 1 + else: + xf, gain = agc(x, wl=lagc, si=1.0) + # apply CAR and then un-apply the gain + xf = xf - np.median(xf, axis=0) + return xf / gain + + def kfilt(x, collection=None, ntr_pad=0, ntr_tap=None, lagc=300, butter_kwargs=None): """ Applies a butterworth filter on the 0-axis with tapering / padding @@ -209,6 +240,7 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha True: deduces the bad channels from the data provided :param butter_kwargs: (optional, None) butterworth params, see the code for the defaults dict :param k_kwargs: (optional, None) K-filter params, see the code for the defaults dict + can also be set to 'car', in which case the median accross channels will be subtracted :return: x, filtered array """ if butter_kwargs is None: @@ -216,6 +248,11 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha if k_kwargs is None: k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000, 'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}} + spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa + elif isinstance(k_kwargs, dict): + spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa + else: + spatial_fcn = lambda dat: car(dat, lagc=int(0.1 * fs)) # noqa h = neuropixel.trace_header(version=neuropixel_version) if channel_labels is True: channel_labels, _ = detect_bad_channels(x, fs) @@ -231,9 +268,9 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha if channel_labels is not None: x = interpolate_bad_channels(x, channel_labels, h) inside_brain = np.where(channel_labels != 3)[0] - x[inside_brain, :] = kfilt(x[inside_brain, :], **k_kwargs) # apply the k-filter + x[inside_brain, :] = spatial_fcn(x[inside_brain, :]) # apply the k-filter else: - x = kfilt(x, **k_kwargs) + x = spatial_fcn(x) return x @@ -245,7 +282,7 @@ def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, appen Production version with optimized FFTs - requires pyfftw :param sr: seismic reader object (spikeglx.Reader) :param output_file: (optional, defaults to .bin extension of the compressed bin file) - :param h: (optional) + :param h: (optional) neuropixel trace header. Dictionary with key 'sample_shift' :param wrot: (optional) whitening matrix [nc x nc] or amplitude scalar to apply to the output :param append: (optional, False) for chronic recordings, append to end of file :param nc_out: (optional, True) saves non selected channels (synchronisation trace) in output @@ -273,7 +310,7 @@ def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, appen k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000, 'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}} h = neuropixel.trace_header(version=1) if h is None else h - ncv = h['x'].size # number of channels + ncv = h['sample_shift'].size # number of channels output_file = sr.file_bin.with_suffix('.bin') if output_file is None else output_file assert output_file != sr.file_bin taper = np.r_[0, scipy.signal.windows.cosine((SAMPLES_TAPER - 1) * 2), 0] @@ -504,7 +541,7 @@ def nxcor(x, ref): xcor = channels_similarity(raw) fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz - sos_hp = scipy.signal.butter(**{'N': 3, 'Wn': 1000 / fs / 2, 'btype': 'highpass'}, output='sos') + sos_hp = scipy.signal.butter(**{'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}, output='sos') hf = scipy.signal.sosfiltfilt(sos_hp, raw) xcorf = channels_similarity(hf) @@ -513,7 +550,7 @@ def nxcor(x, ref): 'rms_raw': rms(raw), # very similar to the rms avfter butterworth filter 'xcor_hf': detrend(xcor, 11), 'xcor_lf': xcorf - detrend(xcorf, 11) - 1, - 'psd_hf': np.mean(psd[:, fscale > 12000], axis=-1), + 'psd_hf': np.mean(psd[:, fscale > (fs / 2 * 0.8)], axis=-1), # 80% nyquists }) # make recommendation diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 7b30cc6c0..736596cc5 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -553,7 +553,7 @@ def groom_pin_state(gpio, audio, ts, tolerance=2., display=False, take='first', downs = ts[high2low] - ts[high2low][0] offsets = audio_times[1::2] - audio_times[1] assigned = attribute_times(offsets, downs, tol=tolerance, take=take) - unassigned = np.setdiff1d(np.arange(onsets.size), assigned[assigned > -1]) + unassigned = np.setdiff1d(np.arange(offsets.size), assigned[assigned > -1]) if unassigned.size > 0: _logger.debug(f'{unassigned.size} audio TTL falls were not detected by the camera') # Check that all pin state downticks could be attributed to an offset TTL diff --git a/ibllib/io/spikeglx.py b/ibllib/io/spikeglx.py index f1e11f41e..aa871c8f4 100644 --- a/ibllib/io/spikeglx.py +++ b/ibllib/io/spikeglx.py @@ -478,6 +478,11 @@ def _geometry_from_meta(meta_data): """ cm = _map_channels_from_meta(meta_data) major_version = _get_neuropixel_major_version_from_meta(meta_data) + if cm is None: + _logger.warning("Meta data doesn't have geometry (snsShankMap field), returning defaults") + th = neuropixel.trace_header(version=major_version) + th['flag'] = th['x'] * 0 + 1. + return th th = cm.copy() if major_version == 1: # the spike sorting channel maps have a flipped version of the channel map diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 0aa8539dd..87c445c57 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd + import one.alf.io as alfio from ibllib.misc import check_nvidia_driver @@ -24,6 +25,8 @@ from ibllib.qc.camera import run_all_qc as run_camera_qc from ibllib.qc.dlc import DlcQC from ibllib.dsp import rms +from ibllib.plots.figures import dlc_qc_plot +from ibllib.plots.snapshot import ReportSnapshot from brainbox.behavior.dlc import likelihood_threshold, get_licks, get_pupil_diameter, get_smooth_pupil_diameter _logger = logging.getLogger("ibllib") @@ -871,85 +874,104 @@ class EphysPostDLC(tasks.Task): ('_ibl_rightCamera.times.npy', 'alf', True), ('_ibl_leftCamera.times.npy', 'alf', True), ('_ibl_bodyCamera.times.npy', 'alf', True)], + # More files are required for all panels of the DLC QC plot to function 'output_files': [('_ibl_leftCamera.features.pqt', 'alf', True), ('_ibl_rightCamera.features.pqt', 'alf', True), ('licks.times.npy', 'alf', True)] } - def _run(self, overwrite=False, run_qc=True): + def _run(self, overwrite=False, run_qc=True, plot_qc=True): # Check if output files exist locally exist, output_files = self.assert_expected(self.signature['output_files'], silent=True) if exist and not overwrite: - _logger.warning('EphysPostDLC outputs exist and overwrite=False, skipping.') - return output_files - if exist and overwrite: - _logger.warning('EphysPostDLC outputs exist and overwrite=True, overwriting existing outputs.') - # Find all available dlc traces and dlc times - dlc_files = list(Path(self.session_path).joinpath('alf').glob('_ibl_*Camera.dlc.*')) - for dlc_file in dlc_files: - _logger.debug(dlc_file) - output_files = [] - combined_licks = [] - - for dlc_file in dlc_files: - # Catch unforeseen exceptions and move on to next cam - try: - cam = label_from_path(dlc_file) - # load dlc trace and camera times - dlc = pd.read_parquet(dlc_file) - dlc_thresh = likelihood_threshold(dlc, 0.9) - # try to load respective camera times + _logger.warning('EphysPostDLC outputs exist and overwrite=False, skipping computations of outputs.') + else: + if exist and overwrite: + _logger.warning('EphysPostDLC outputs exist and overwrite=True, overwriting existing outputs.') + # Find all available dlc traces and dlc times + dlc_files = list(Path(self.session_path).joinpath('alf').glob('_ibl_*Camera.dlc.*')) + for dlc_file in dlc_files: + _logger.debug(dlc_file) + output_files = [] + combined_licks = [] + + for dlc_file in dlc_files: + # Catch unforeseen exceptions and move on to next cam try: - dlc_t = np.load(next(Path(self.session_path).joinpath('alf').glob(f'_ibl_{cam}Camera.times.*npy'))) - times = True - except StopIteration: - _logger.error(f'No camera.times found for {cam} camera. ' - f'Computations using camera.times will be skipped') - self.status = -1 - times = False - - # These features are only computed from left and right cam - if cam in ('left', 'right'): - features = pd.DataFrame() - # If camera times are available, get the lick time stamps for combined array - if times: - _logger.info(f"Computing lick times for {cam} camera.") - combined_licks.append(get_licks(dlc_thresh, dlc_t)) + cam = label_from_path(dlc_file) + # load dlc trace and camera times + dlc = pd.read_parquet(dlc_file) + dlc_thresh = likelihood_threshold(dlc, 0.9) + # try to load respective camera times + try: + dlc_t = np.load(next(Path(self.session_path).joinpath('alf').glob(f'_ibl_{cam}Camera.times.*npy'))) + times = True + except StopIteration: + _logger.error(f'No camera.times found for {cam} camera. ' + f'Computations using camera.times will be skipped') + self.status = -1 + times = False + + # These features are only computed from left and right cam + if cam in ('left', 'right'): + features = pd.DataFrame() + # If camera times are available, get the lick time stamps for combined array + if times: + _logger.info(f"Computing lick times for {cam} camera.") + combined_licks.append(get_licks(dlc_thresh, dlc_t)) + else: + _logger.warning(f"Skipping lick times for {cam} camera as no camera.times available.") + # Compute pupil diameter, raw and smoothed + _logger.info(f"Computing raw pupil diameter for {cam} camera.") + features['pupilDiameter_raw'] = get_pupil_diameter(dlc_thresh) + _logger.info(f"Computing smooth pupil diameter for {cam} camera.") + features['pupilDiameter_smooth'] = get_smooth_pupil_diameter(features['pupilDiameter_raw'], cam) + # Safe to pqt + features_file = Path(self.session_path).joinpath('alf', f'_ibl_{cam}Camera.features.pqt') + features.to_parquet(features_file) + output_files.append(features_file) + + # For all cams, compute DLC qc if times available + if times and run_qc: + # Setting download_data to False because at this point the data should be there + qc = DlcQC(self.session_path, side=cam, one=self.one, download_data=False) + qc.run(update=True) else: - _logger.warning(f"Skipping lick times for {cam} camera as no camera.times available.") - # Compute pupil diameter, raw and smoothed - _logger.info(f"Computing raw pupil diameter for {cam} camera.") - features['pupilDiameter_raw'] = get_pupil_diameter(dlc_thresh) - _logger.info(f"Computing smooth pupil diameter for {cam} camera.") - features['pupilDiameter_smooth'] = get_smooth_pupil_diameter(features['pupilDiameter_raw'], cam) - # Safe to pqt - features_file = Path(self.session_path).joinpath('alf', f'_ibl_{cam}Camera.features.pqt') - features.to_parquet(features_file) - output_files.append(features_file) - - # For all cams, compute DLC qc if times available - if times and run_qc: - # Setting download_data to False because at this point the data should be there - qc = DlcQC(self.session_path, side=cam, one=self.one, download_data=False) - qc.run(update=True) - else: - if not times: - _logger.warning(f"Skipping QC for {cam} camera as no camera.times available") - if not run_qc: - _logger.warning(f"Skipping QC for {cam} camera as run_qc=False") + if not times: + _logger.warning(f"Skipping QC for {cam} camera as no camera.times available") + if not run_qc: + _logger.warning(f"Skipping QC for {cam} camera as run_qc=False") + except BaseException: + _logger.error(traceback.format_exc()) + self.status = -1 + continue + + # Combined lick times + if len(combined_licks) > 0: + lick_times_file = Path(self.session_path).joinpath('alf', 'licks.times.npy') + np.save(lick_times_file, sorted(np.concatenate(combined_licks))) + output_files.append(lick_times_file) + else: + _logger.warning("No lick times computed for this session.") + + if plot_qc: + _logger.info("Creating DLC QC plot") + try: + session_id = self.one.path2eid(self.session_path) + fig_path = self.session_path.joinpath('snapshot', 'dlc_qc_plot.png') + if not fig_path.parent.exists(): + fig_path.parent.mkdir(parents=True, exist_ok=True) + fig = dlc_qc_plot(self.one.path2eid(self.session_path), one=self.one) + fig.savefig(fig_path) + snp = ReportSnapshot(self.session_path, session_id, one=self.one) + snp.outputs = [fig_path] + snp.register_images(widths=['orig'], + function=str(dlc_qc_plot.__module__) + '.' + str(dlc_qc_plot.__name__)) except BaseException: + _logger.error('Could not create and/or upload DLC QC Plot') _logger.error(traceback.format_exc()) self.status = -1 - continue - - # Combined lick times - if len(combined_licks) > 0: - lick_times_file = Path(self.session_path).joinpath('alf', 'licks.times.npy') - np.save(lick_times_file, sorted(np.concatenate(combined_licks))) - output_files.append(lick_times_file) - else: - _logger.warning("No lick times computed for this session.") return output_files diff --git a/ibllib/pipes/histology.py b/ibllib/pipes/histology.py index 54761c229..cd5461aca 100644 --- a/ibllib/pipes/histology.py +++ b/ibllib/pipes/histology.py @@ -347,11 +347,11 @@ def create_channel_dict(traj, brain_locations): channel_dict = [] for i in np.arange(brain_locations.id.size): channel_dict.append({ - 'x': brain_locations.xyz[i, 0] * 1e6, - 'y': brain_locations.xyz[i, 1] * 1e6, - 'z': brain_locations.xyz[i, 2] * 1e6, - 'axial': brain_locations.axial[i], - 'lateral': brain_locations.lateral[i], + 'x': np.float64(brain_locations.xyz[i, 0] * 1e6), + 'y': np.float64(brain_locations.xyz[i, 1] * 1e6), + 'z': np.float64(brain_locations.xyz[i, 2] * 1e6), + 'axial': np.float64(brain_locations.axial[i]), + 'lateral': np.float64(brain_locations.lateral[i]), 'brain_region': int(brain_locations.id[i]), 'trajectory_estimate': traj['id'] }) diff --git a/ibllib/plots/figures.py b/ibllib/plots/figures.py index cab212d3b..77662c959 100644 --- a/ibllib/plots/figures.py +++ b/ibllib/plots/figures.py @@ -1,33 +1,114 @@ """ Module that produces figures, usually for the extraction pipeline """ +import logging from pathlib import Path +from string import ascii_uppercase import numpy as np +import pandas as pd import scipy.signal +import matplotlib.pyplot as plt +from ibllib.dsp import voltage +from ibllib.plots.snapshot import ReportSnapshot +from one.api import ONE +import one.alf.io as alfio +from one.alf.exceptions import ALFObjectNotFound +from ibllib.io.video import get_video_frame, url_from_eid +from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \ + plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist + +logger = logging.getLogger('ibllib') + + +class BadChannelsAp(ReportSnapshot): + """ + Plots raw electrophysiology AP band + :param session_path: session path + :param probe_id: str, UUID of the probe insertion for which to create the plot + :param **kwargs: keyword arguments passed to tasks.Task + """ + signature = { + 'input_files': [], # see setUp method for declaration of inputs + 'output_files': [] # see setUp method for declaration of inputs + } + + def __init__(self, session_path, probe_id, **kwargs): + self.content_type = 'probeinsertion' + self.pid = probe_id + super(BadChannelsAp, self).__init__(session_path, probe_id, content_type=self.content_type, **kwargs) + + @staticmethod + def spike_sorting_signature(pname=None): + pname = pname if pname is not None else "probe*" + input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True), + ('*ap.ch', f'raw_ephys_data/{pname}', False), + ('*ap.cbin', f'raw_ephys_data/{pname}', False)] + output_signature = [('destripe.png', f'snapshot/{pname}', True), + ('highpass.png', f'snapshot/{pname}', True)] + return input_signature, output_signature + + def _run(self): + """runs for initiated PID, streams data, destripe and check bad channels""" + assert self.pid + SNAPSHOT_LABEL = "raw_ephys_bad_channels" + eid, pname = self.one.pid2eid(self.pid) + output_directory = self.session_path.joinpath('snapshot', pname) + output_files = list(output_directory.glob(f'{SNAPSHOT_LABEL}*')) + if len(output_files) == 4: + return output_files + output_directory.mkdir(exist_ok=True, parents=True) + from brainbox.io.spikeglx import stream + T0 = 60 * 30 + sr, t0 = stream(self.pid, T0, nsecs=1, one=self.one) + raw = sr[:, :-sr.nsync].T + channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs) + _, _, output_files = ephys_bad_channels( + raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, + title=SNAPSHOT_LABEL, destripe=True, save_dir=output_directory) + return output_files + + +def ephys_bad_channels(raw, fs, channel_labels, channel_features, title="ephys_bad_channels", save_dir=None, + destripe=False, eqcs=None): + nc, ns = raw.shape + rl = ns / fs + if fs >= 2600: # AP band + ylim_rms = [0, 100] + ylim_psd_hf = [0, 0.1] + eqc_xrange = [450, 500] + butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} + eqc_gain = - 90 + else: + # we are working with the LFP + ylim_rms = [0, 1000] + ylim_psd_hf = [0, 1] + eqc_xrange = [450, 950] + butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'} + eqc_gain = - 78 -def ephys_bad_channels(raw, fs, channel_labels, channel_features, title="ephys_bad_channels", save_dir=None): - nc = raw.shape[0] inoisy = np.where(channel_labels == 2)[0] idead = np.where(channel_labels == 1)[0] ioutside = np.where(channel_labels == 3)[0] from easyqc.gui import viewseis - import matplotlib.pyplot as plt # display voltage traces - eqcs = [] - butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} + eqcs = [] if eqcs is None else eqcs # butterworth, for display only sos = scipy.signal.butter(**butter_kwargs, output='sos') butt = scipy.signal.sosfiltfilt(sos, raw) - eqcs.append(viewseis(butt.T, si=1 / fs * 1e3, title='butt', taxis=0)) + eqcs.append(viewseis(butt.T, si=1 / fs * 1e3, title='highpass', taxis=0)) + if destripe: + dest = voltage.destripe(raw, fs=fs, channel_labels=channel_labels) + eqcs.append(viewseis(dest.T, si=1 / fs * 1e3, title='destripe', taxis=0)) + eqcs.append(viewseis((butt - dest).T, si=1 / fs * 1e3, title='difference', taxis=0)) for eqc in eqcs: - y, x = np.meshgrid(ioutside, np.linspace(0, 1 * 1e3, 500)) + y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside') - y, x = np.meshgrid(inoisy, np.linspace(0, 1 * 1e3, 500)) + y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy') - y, x = np.meshgrid(idead, np.linspace(0, 1 * 1e3, 500)) + y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead') # display features fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True) @@ -35,11 +116,11 @@ def ephys_bad_channels(raw, fs, channel_labels, channel_features, title="ephys_b # fig.suptitle(f"pid:{pid}, \n eid:{eid}, \n {one.eid2path(eid).parts[-3:]}, {pname}") fig.suptitle(title) axs[0, 0].plot(channel_features['rms_raw'] * 1e6) - axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=[0, 100]) + axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms) axs[1, 0].plot(channel_features['psd_hf']) axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr') - axs[1, 0].set(title='PSD above 12kHz', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=[0, 0.1]) + axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf) axs[1, 0].legend = ['psd', 'noisy'] axs[0, 1].plot(channel_features['xcor_hf']) @@ -54,19 +135,231 @@ def ephys_bad_channels(raw, fs, channel_labels, channel_features, title="ephys_b axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto', vmin=-50, vmax=-20) axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)") - axs[1, 1].plot(idead, idead * 0 + fs / 2 - 500, 'xb') - axs[1, 1].plot(inoisy, inoisy * 0 + fs / 2 - 500, 'xr') - axs[1, 1].plot(ioutside, ioutside * 0 + fs / 2 - 500, 'xy') + axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb') + axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr') + axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy') - eqcs[0].ctrl.set_gain(-90) + eqcs[0].ctrl.set_gain(eqc_gain) eqcs[0].resize(1960, 1200) - eqcs[0].viewBox_seismic.setXRange(450, 500) + eqcs[0].viewBox_seismic.setXRange(*eqc_xrange) eqcs[0].viewBox_seismic.setYRange(0, nc) eqcs[0].ctrl.propagate() if save_dir is not None: - fig.savefig(Path(save_dir).joinpath(f"{title}.png")) + output_files = [Path(save_dir).joinpath(f"{title}.png")] + fig.savefig(output_files[0]) for eqc in eqcs: - eqc.grab().save(str(Path(save_dir).joinpath(f"{title}_data_{eqc.windowTitle()}.png"))) + output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png")) + eqc.grab().save(str(output_files[-1])) + return fig, eqcs, output_files + else: + return fig, eqcs + + +def raw_destripe(raw, fs, t0, i_plt, n_plt, + fig=None, axs=None, savedir=None, detect_badch=True, + SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384, + MIN_X=-0.00011, MAX_X=0.00011): + ''' + :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel + :param fs: sampling freq (Hz) of the raw ephys data + :param t0: time (s) of ephys sample beginning from session start + :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt) + :param n_plt: total number of subplot on figure + :param fig: figure handle + :param axs: axis handle + :param savedir: filename, including directory, to save figure to + :param detect_badch: boolean, to detect or not bad channels + :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display + :param DISPLAY_TIME: time (s) to display + :param N_CHAN: number of expected channels on the probe + :param MIN_X: max voltage for color range + :param MAX_X: min voltage for color range + :return: fig, axs + ''' + + # Import + from ibllib.dsp import voltage + from ibllib.plots import Density + + # Init fig + if fig is None or axs is None: + fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt}) + + if i_plt > len(axs) - 1: # Error + raise ValueError(f'The given increment of subplot ({i_plt+1}) ' + f'is larger than the total number of subplots ({len(axs)})') + + [nc, ns] = raw.shape + if nc == N_CHAN: + destripe = voltage.destripe(raw, fs=fs) + X = destripe[:, :int(DISPLAY_TIME * fs)].T + Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning + Tplot = Xs.shape[1] / fs + + # PLOT RAW DATA + d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X, cmap='Greys') # noqa + axs[i_plt].set_ylabel('') + axs[i_plt].set_xlim((0, Tplot * 1e3)) + axs[i_plt].set_ylim((0, nc)) + + # Init title + title_plt = f't0 = {int(t0 / 60)} min' + + if detect_badch: + # Detect and remove bad channels prior to spike detection + labels, xfeats = voltage.detect_bad_channels(raw, fs) + idx_badchan = np.where(labels != 0)[0] + # Plot bad channels on raw data + x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20)) + axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1) + # Append title + title_plt += f', n={len(idx_badchan)} bad ch' + + # Set title + axs[i_plt].title.set_text(title_plt) + + else: + axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}') + + # Amend some axis style + if i_plt > 0: + axs[i_plt].set_yticklabels('') + + # Fig layout + fig.tight_layout() + if savedir is not None: + fig.savefig(fname=savedir) + + return fig, axs + + +def dlc_qc_plot(eid, one=None): + """ + Creates DLC QC plot. + Data is searched first locally, then on Alyx. Panels that lack required data are skipped. + + Required data to create all panels + 'raw_video_data/_iblrig_bodyCamera.raw.mp4', + 'raw_video_data/_iblrig_leftCamera.raw.mp4', + 'raw_video_data/_iblrig_rightCamera.raw.mp4', + 'alf/_ibl_bodyCamera.dlc.pqt', + 'alf/_ibl_leftCamera.dlc.pqt', + 'alf/_ibl_rightCamera.dlc.pqt', + 'alf/_ibl_bodyCamera.times.npy', + 'alf/_ibl_leftCamera.times.npy', + 'alf/_ibl_rightCamera.times.npy', + 'alf/_ibl_leftCamera.features.pqt', + 'alf/rightROIMotionEnergy.position.npy', + 'alf/leftROIMotionEnergy.position.npy', + 'alf/bodyROIMotionEnergy.position.npy', + 'alf/_ibl_trials.choice.npy', + 'alf/_ibl_trials.feedbackType.npy', + 'alf/_ibl_trials.feedback_times.npy', + 'alf/_ibl_trials.stimOn_times.npy', + 'alf/_ibl_wheel.position.npy', + 'alf/_ibl_wheel.timestamps.npy', + 'alf/licks.times.npy', + + :params eid: Session ID + :params one: ONE instance, if None is given, default ONE is instantiated + :returns: Matplotlib figure + """ + + one = one or ONE() + data = {} + # Camera data + for cam in ['left', 'right', 'body']: + # Load a single frame for each video, first check if data is local, otherwise stream + video_path = one.eid2path(eid).joinpath('raw_video_data', f'_iblrig_{cam}Camera.raw.mp4') + if not video_path.exists(): + try: + video_path = url_from_eid(eid, one=one)[cam] + except KeyError: + logger.warning(f"No raw video data found for {cam} camera, some DLC QC plots have to be skipped.") + data[f'{cam}_frame'] = None + try: + data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] + except TypeError: + logger.warning(f"Could not load video frame for {cam} camera, some DLC QC plots have to be skipped.") + data[f'{cam}_frame'] = None + # Load other video associated data + for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: + # Check locally first, then try to load from alyx, if nothing works, set to None + local_file = list(one.eid2path(eid).joinpath('alf').glob(f'*{cam}Camera.{feat}*')) + alyx_file = [ds for ds in one.list_datasets(eid) if f'{cam}Camera.{feat}' in ds] + if feat == 'features' and cam in ['body', 'right']: + continue + elif len(local_file) > 0: + data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0]) + elif len(alyx_file) > 0: + data[f'{cam}_{feat}'] = one.load_dataset(eid, alyx_file[0]) + else: + logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some DLC QC plots have to be skipped.") + data[f'{cam}_{feat}'] = None + # Session data + for alf_object in ['trials', 'wheel', 'licks']: + try: + data[f'{alf_object}'] = alfio.load_object(one.eid2path(eid).joinpath('alf'), alf_object) + continue + except ALFObjectNotFound: + pass + try: + data[f'{alf_object}'] = one.load_object(eid, alf_object) + except ALFObjectNotFound: + logger.warning(f"Could not load {alf_object} object for session {eid}, some plots have to be skipped.") + data[f'{alf_object}'] = None + # Simplify to what we actually need + data['licks'] = data['licks'].times if data['licks'] else None + data['left_pupil'] = data['left_features'].pupilDiameter_smooth if data['left_features'] is not None else None + data['wheel_time'] = data['wheel'].timestamps if data['wheel'] is not None else None + data['wheel_position'] = data['wheel'].position if data['wheel'] is not None else None + if data['trials']: + data['trials'] = pd.DataFrame( + {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']}) + # Discard nan events and too long trials + data['trials'] = data['trials'].dropna() + data['trials'] = data['trials'].drop( + data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index) + # List panels: axis functions and inputs + panels = [(plot_trace_on_frame, {'frame': data['left_frame'], 'dlc_df': data['left_dlc'], 'cam': 'left'}), + (plot_trace_on_frame, {'frame': data['right_frame'], 'dlc_df': data['right_dlc'], 'cam': 'right'}), + (plot_trace_on_frame, {'frame': data['body_frame'], 'dlc_df': data['body_dlc'], 'cam': 'body'}), + (plot_wheel_position, + {'wheel_position': data['wheel_position'], 'wheel_time': data['wheel_time'], 'trials_df': data['trials']}), + (plot_motion_energy_hist, + {'camera_dict': {'left': {'motion_energy': data['left_ROIMotionEnergy'], 'times': data['left_times']}, + 'right': {'motion_energy': data['right_ROIMotionEnergy'], 'times': data['right_times']}, + 'body': {'motion_energy': data['body_ROIMotionEnergy'], 'times': data['body_times']}}, + 'trials_df': data['trials']}), + (plot_speed_hist, + {'dlc_df': data['left_dlc'], 'cam_times': data['left_times'], 'trials_df': data['trials']}), + (plot_speed_hist, + {'dlc_df': data['left_dlc'], 'cam_times': data['left_times'], 'trials_df': data['trials'], + 'feature': 'nose_tip', 'legend': False}), + (plot_lick_hist, {'lick_times': data['licks'], 'trials_df': data['trials']}), + (plot_lick_raster, {'lick_times': data['licks'], 'trials_df': data['trials']}), + (plot_pupil_diameter_hist, + {'pupil_diameter': data['left_pupil'], 'cam_times': data['left_times'], 'trials_df': data['trials']}) + ] + # Plotting + plt.rcParams.update({'font.size': 10}) + fig = plt.figure(figsize=(17, 10)) + for i, panel in enumerate(panels): + ax = plt.subplot(2, 5, i + 1) + ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold') + # Check if any of the inputs is None + if any([v is None for v in panel[1].values()]): + ax.text(.5, .5, f"Data incomplete\n{panel[0].__name__}", color='r', fontweight='bold', + fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) + plt.axis('off') + else: + try: + panel[0](**panel[1]) + except BaseException: + ax.text(.5, .5, f'Error in \n{panel[0].__name__}', color='r', fontweight='bold', + fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) + plt.axis('off') + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) - return fig, eqcs[0] + return fig diff --git a/ibllib/plots/snapshot.py b/ibllib/plots/snapshot.py index bc20a5e89..a09785f90 100644 --- a/ibllib/plots/snapshot.py +++ b/ibllib/plots/snapshot.py @@ -1,18 +1,44 @@ import logging import requests import traceback +import json from one.api import ONE - +from ibllib.pipes import tasks +from ibllib.misc import version _logger = logging.getLogger('ibllib') +class ReportSnapshot(tasks.Task): + + def __init__(self, session_path, object_id, content_type='session', **kwargs): + self.object_id = object_id + self.content_type = content_type + self.images = [] + super(ReportSnapshot, self).__init__(session_path, **kwargs) + + def _run(self, overwrite=False): + # Can be used to generate the image if desired + pass + + def register_images(self, widths=None, function=None): + report_tag = '## report ##' + snapshot = Snapshot(one=self.one, object_id=self.object_id, content_type=self.content_type) + jsons = [] + texts = [] + for f in self.outputs: + jsons.append(dict(tag=report_tag, version=version.ibllib(), + function=(function or str(self.__class__).split("'")[1]), name=f.stem)) + texts.append(f"{f.stem}") + return snapshot.register_images(self.outputs, jsons=jsons, texts=texts, widths=widths) + + class Snapshot: """ A class to register images in form of Notes, linked to an object on Alyx. :param object_id: The id of the object the image should be linked to - :param content_type: Which type of object to link to, e.g. 'session', 'probeinsertions', 'subject', + :param content_type: Which type of object to link to, e.g. 'session', 'probeinsertion', 'subject', default is 'session' :param one: An ONE instance, if None is given it will be instantiated. """ @@ -23,6 +49,13 @@ def __init__(self, object_id, content_type='session', one=None): self.content_type = content_type self.images = [] + def plot(self): + """ + Placeholder method to be overriden by child object + :return: + """ + pass + def generate_image(self, plt_func, plt_kwargs): """ Takes a plotting function and adds the output to the Snapshot.images list for registration @@ -31,25 +64,37 @@ def generate_image(self, plt_func, plt_kwargs): :param plt_kwargs: Dictionary with keyword arguments for the plotting function """ img_path = plt_func(**plt_kwargs) - self.images.append(img_path) + if isinstance(img_path, list): + self.images.extend(img_path) + else: + self.images.append(img_path) return img_path - def register_image(self, image_file, text='', width=None): + def register_image(self, image_file, text='', json_field=None, width=None): """ Registers an image as a Note, attached to the object specified by Snapshot.object_id :param image_file: Path to the image to to registered - :param text: Text to describe the image, defaults ot empty string + :param text: str, text to describe the image, defaults ot empty string + :param json_field: dict, to be added to the json field of the Note :param width: width to scale the image to, defaults to None (scale to UPLOADED_IMAGE_WIDTH in alyx.settings.py), other options are 'orig' (don't change size) or any integer (scale to width=int, aspect ratios won't be changed) :returns: dict, note as registered in database """ - fig_open = open(image_file, 'rb') + # the protocol is not compatible with byte streaming and json, so serialize the json object here note = { 'user': self.one.alyx.user, 'content_type': self.content_type, 'object_id': self.object_id, - 'text': text, 'width': width} + 'text': text, 'width': width, 'json': json.dumps(json_field)} _logger.info(f'Registering image to {self.content_type} with id {self.object_id}') + # to make sure an eventual note gets deleted with the image call the delete REST endpoint first + current_note = self.one.alyx.rest('notes', 'list', + django=f"object_id,{self.object_id},text,{text},json__name,{text}", + no_cache=True) + if len(current_note) == 1: + self.one.alyx.rest('notes', 'delete', id=current_note[0]['id']) + # Open image for upload + fig_open = open(image_file, 'rb') # Catch error that results from object_id - content_type mismatch try: note_db = self.one.alyx.rest('notes', 'create', data=note, files={'image': fig_open}) @@ -64,7 +109,7 @@ def register_image(self, image_file, text='', width=None): fig_open.close() raise - def register_images(self, image_list=None, texts=[''], widths=[None]): + def register_images(self, image_list=None, texts=None, widths=None, jsons=None): """ Registers a list of images as Notes, attached to the object specified by Snapshot.object_id. The images can be passed as image_list. If None are passed, will try to register the images in Snapshot.images. @@ -74,7 +119,8 @@ def register_images(self, image_list=None, texts=[''], widths=[None]): :param texts: List of text to describe the images. If len(texts)==1, the same text will be used for all images :param widths: List of width to scale the figure to (see Snapshot.register_image). If len(widths)==1, the same width will be used for all images - + :param jsons: List of dictionaries to populate the json field of the note in Alyx. If len(jsons)==1, + the same dict will be used for all images :returns: list of dicts, notes as registered in database """ if not image_list or len(image_list) == 0: @@ -84,11 +130,17 @@ def register_images(self, image_list=None, texts=[''], widths=[None]): return else: image_list = self.images + widths = widths or [None] + texts = texts or [''] + jsons = jsons or [None] + if len(texts) == 1: texts = len(image_list) * texts if len(widths) == 1: widths = len(image_list) * widths + if len(jsons) == 1: + jsons = len(image_list) * jsons note_dbs = [] - for figure, text, width in zip(image_list, texts, widths): - note_dbs.append(self.register_image(figure, text=text, width=width)) + for figure, text, width, json_field in zip(image_list, texts, widths, jsons): + note_dbs.append(self.register_image(figure, text=text, width=width, json_field=json_field)) return note_dbs diff --git a/ibllib/tests/test_plots.py b/ibllib/tests/test_plots.py index 278b35e86..5e21055b3 100644 --- a/ibllib/tests/test_plots.py +++ b/ibllib/tests/test_plots.py @@ -11,23 +11,25 @@ from ibllib.tests import TEST_DB from ibllib.plots.snapshot import Snapshot +from ibllib.plots.figures import dlc_qc_plot WIDTH, HEIGHT = 1000, 100 class TestSnapshot(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): # Make a small image an store in tmp file - self.tmp_dir = tempfile.TemporaryDirectory() - self.img_file = Path(self.tmp_dir.name).joinpath('test.png') + cls.tmp_dir = tempfile.TemporaryDirectory() + cls.img_file = Path(cls.tmp_dir.name).joinpath('test.png') image = Image.new('RGBA', size=(WIDTH, HEIGHT), color=(155, 0, 0)) - image.save(self.img_file, 'png') + image.save(cls.img_file, 'png') image.close() # set up ONE - self.one = ONE(**TEST_DB) + cls.one = ONE(**TEST_DB) # Collect all notes to delete them later - self.notes = [] + cls.notes = [] def _get_image(self, url): # This is a bit of a hack because when running a the server locally, the request to the media folder fail @@ -91,8 +93,8 @@ def test_register_multiple(self): self.notes.extend(snp.register_images([self.img_file, self.img_file, self.img_file], texts=['first', 'second', 'third'], widths=[None, 'orig', 200])) for i in range(3): - self.assertEqual(self.notes[i]['text'], expected_texts[i]) - img_db = self._get_image(self.notes[i]['image']) + self.assertEqual(self.notes[i - 3]['text'], expected_texts[i]) + img_db = self._get_image(self.notes[i - 3]['image']) with Image.open(img_db) as im: self.assertEqual(im.size, expected_sizes[i]) # Registering multiple figures by adding to self.figures @@ -103,8 +105,8 @@ def test_register_multiple(self): snp.images.extend([self.img_file, self.img_file, self.img_file]) self.notes.extend(snp.register_images(texts=['always the same'], widths=[200])) for i in range(3): - self.assertEqual(self.notes[i + 3]['text'], 'always the same') - img_db = self._get_image(self.notes[i + 3]['image']) + self.assertEqual(self.notes[i - 3]['text'], 'always the same') + img_db = self._get_image(self.notes[i - 3]['image']) with Image.open(img_db) as im: self.assertEqual(im.size, expected_sizes[2]) @@ -123,9 +125,32 @@ def make_img(size, out_path): with Image.open(out_path) as im: self.assertEqual(im.size, (WIDTH, HEIGHT)) - def tearDown(self): + @classmethod + def tearDownClass(cls): # Clean up tmp dir - self.tmp_dir.cleanup() + cls.tmp_dir.cleanup() # Delete all notes - for note in self.notes: - self.one.alyx.rest('notes', 'delete', id=note['id']) + for note in cls.notes: + cls.one.alyx.rest('notes', 'delete', id=note['id']) + + +class TestDlcQcPlot(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.tmp_dir = tempfile.TemporaryDirectory() + cls.one = ONE(**TEST_DB) + + @classmethod + def tearDownClass(cls): + # Clean up tmp dir + cls.tmp_dir.cleanup() + + def test_without_inputs(self): + eid = '3473f9d2-aa5d-41a6-9048-c65d0b7ab97c' + with self.assertLogs('ibllib', 'WARNING'): + fig = dlc_qc_plot(eid, self.one) + fig_path = (Path(self.tmp_dir.name).joinpath('dlc_qc_plot.png')) + fig.savefig(fig_path) + with Image.open(fig_path) as im: + self.assertEqual(im.size, (1700, 1000)) diff --git a/release_notes.md b/release_notes.md index 006767fd6..7fdab6571 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,10 @@ +## Release Notes 2.6 +### Release Notes 2.6.0 2021-12-08 +- New ReportSnapshot class +- DLC QC plots, as part of EphysPostDLC task +- BadChannelsAP plots for ephys QC +- Fix typo in camera extractor + ## Release Notes 2.5 ### Release Notes 2.5.1 2021-11-25 - SpikeSorting task overwrites old tar file on rerun diff --git a/setup.py b/setup.py index 502a290a5..e43e60736 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( name='ibllib', - version='2.5.1', + version='2.6.0', python_requires='>={}.{}'.format(*REQUIRED_PYTHON), description='IBL libraries', license="MIT",