diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 81cda212b2..42e9a20f3c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -200,18 +200,11 @@ def __init__( if peak_amplitudes is not None: peak_amplitudes = peak_amplitudes[peak_mask] - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - plot_data = dict( peaks=peaks, peak_locations=peak_locations, peak_amplitudes=peak_amplitudes, direction=direction, - times=times, sampling_frequency=sampling_frequency, segment_index=segment_index, depth_lim=depth_lim, @@ -238,10 +231,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.times is None: + if dp.recording is None: peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - peak_times = dp.times[dp.peaks["sample_index"]] + peak_times = dp.recording.sample_index_to_time(dp.peaks["sample_index"], segment_index=dp.segment_index) peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: @@ -340,12 +333,12 @@ def __init__( raise ValueError( "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" ) - - times = recording.get_times() if recording is not None else None + assert recording.get_num_segments() == len( + motion.displacement + ), "The number of segments in the recording must be the same as the number of segments in the motion object" plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], - times=times, segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim,