diff --git a/debugging/alignment_utils.py b/debugging/alignment_utils.py index d6d26ed6ef..939d8a0623 100644 --- a/debugging/alignment_utils.py +++ b/debugging/alignment_utils.py @@ -23,6 +23,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, """ TODO: assumes 1-segment recording """ + # TODO: this is weird, don't return spatial_bin_edges here... amybe assert.. entire_session_hist, temporal_bin_edges, spatial_bin_edges = \ make_2d_motion_histogram( recording, @@ -35,6 +36,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, hist_margin_um=None, spatial_bin_edges=spatial_bin_edges, ) + entire_session_hist = entire_session_hist[0] entire_session_hist /= recording.get_duration(segment_index=0) @@ -290,7 +292,7 @@ def prep_recording(recording, plot=False): :param recording: :return: """ - peaks = detect_peaks(recording, method="by_channel") # "locally_exclusive") + peaks = detect_peaks(recording, method="locally_exclusive") peak_locations = localize_peaks(recording, peaks, method="grid_convolution") diff --git a/debugging/all_recordings.pickle b/debugging/all_recordings.pickle index 38e170a6e4..aaf1be330f 100644 Binary files a/debugging/all_recordings.pickle and b/debugging/all_recordings.pickle differ diff --git a/debugging/main.py b/debugging/main.py new file mode 100644 index 0000000000..8d2698eefd --- /dev/null +++ b/debugging/main.py @@ -0,0 +1,66 @@ +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +import numpy as np +import plotting +import alignment_utils +import matplotlib.pyplot as plt +import pickle +scalings = [np.ones(10), np.r_[np.zeros(3), np.ones(7)]] + +SAVE = True + +if SAVE: + recordings_list, _ = generate_session_displacement_recordings( + non_rigid_gradient=None, + num_units=35, + recording_durations=(100, 100), + recording_shifts=( + (0, 0), (0, 0), + ), + recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings}, + seed=None, + ) + + peaks_list = [] + peak_locations_list = [] + + for recording in recordings_list: + peaks, peak_locations = alignment_utils.prep_recording( + recording, plot=False, + ) + peaks_list.append(peaks) + peak_locations_list.append(peak_locations) + + # something relatively easy, only 15 units + with open('all_recordings.pickle', 'wb') as handle: + pickle.dump((recordings_list, peaks_list, peak_locations_list), + handle, protocol=pickle.HIGHEST_PROTOCOL) + +with open('all_recordings.pickle', 'rb') as handle: + recordings_list, peaks_list, peak_locations_list = pickle.load(handle) + +bin_um = 2 + +# TODO: own function +min_y = np.min([np.min(locs["y"]) for locs in peak_locations_list]) +max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list]) + +spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin... +spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function + +session_histogram_list = [] +for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list): + + hist, temp, spat = alignment_utils.get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, log_scale=False) + + session_histogram_list.append( + hist + ) +# TODO: need to get all outputs and check are same size +plotting.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + session_histogram_list, + histogram_spatial_bin_centers=spatial_bin_centers, +) +plt.show() diff --git a/debugging/plotting.py b/debugging/plotting.py index 3569ce9deb..356cc924ba 100644 --- a/debugging/plotting.py +++ b/debugging/plotting.py @@ -143,7 +143,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax_top.set_title(f"Session {i + 1}") ax_top.set_xlabel(None) - plot = DriftRasterMapWidget( + DriftRasterMapWidget( dp.peaks_list[i], dp.corrected_peak_locations_list[i], recording=dp.recordings_list[i], @@ -163,7 +163,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = fig.add_subplot(gs[num_rows, :]) - plot = SessionAlignmentHistogramWidget( + SessionAlignmentHistogramWidget( dp.session_histogram_list, dp.histogram_spatial_bin_centers, ax=ax, @@ -242,6 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if isinstance(linewidths, int): linewidths = [linewidths] * num_histograms + # TODO: this leads to quite unexpected behaviours, figure something else out. if spatial_bin_centers is None: num_bins = dp.session_histogram_list[0].size spatial_bin_centers = [np.arange(num_bins)] * num_histograms diff --git a/debugging/session_alignment.py b/debugging/session_alignment.py index 89591d8d27..989093ce5e 100644 --- a/debugging/session_alignment.py +++ b/debugging/session_alignment.py @@ -125,7 +125,7 @@ def estimate_inter_session_displacement( max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list]) spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin... - spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) + spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function # Estimate an activity histogram per-session all_session_hists = [] # TODO: probably better as a dict diff --git a/debugging/test_session_alignment.py b/debugging/test_session_alignment.py index 385c7c99a9..1269490ffe 100644 --- a/debugging/test_session_alignment.py +++ b/debugging/test_session_alignment.py @@ -46,8 +46,6 @@ # handle the case where the passed recordings are not motion correction recordings. -# 1) get all commits and PRs in order. Work on the original PR -# 2) investigate why the expected peaks do not drop when recording_amplitude_scalings (rename) is used # 3) think about and add new neurons that are introduced when shifted # 4) add interpolation of the histograms prior to cross correlation diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py index d769983d69..175e7f5c6a 100644 --- a/src/spikeinterface/generation/session_displacement_generator.py +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -141,10 +141,27 @@ def generate_session_displacement_recordings( # Fix generate template kwargs, so they are the same for every created recording. # Also fix unit firing rates across recordings. - generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) + fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed) - generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + + extend = True + if extend: + num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = ( + _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, + ) + ) # Start looping over parameters, creating recordings shifted # and scaled as required @@ -174,7 +191,7 @@ def generate_session_displacement_recordings( num_units=num_units, sampling_frequency=sampling_frequency, durations=[duration], - **generate_sorting_kwargs, + **fixed_generate_sorting_kwargs, extra_outputs=True, seed=seed, ) @@ -195,7 +212,7 @@ def generate_session_displacement_recordings( unit_locations_moved, sampling_frequency=sampling_frequency, seed=seed, - **generate_templates_kwargs, + **fixed_generate_templates_kwargs, ) if recording_amplitude_scalings is not None: @@ -210,7 +227,7 @@ def generate_session_displacement_recordings( # Bring it all together in a `InjectTemplatesRecording` and # propagate all relevant metadata to the recording. - ms_before = generate_templates_kwargs["ms_before"] + ms_before = fixed_generate_templates_kwargs["ms_before"] nbefore = int(sampling_frequency * ms_before / 1000.0) recording = InjectTemplatesRecording( @@ -389,3 +406,66 @@ def _check_generate_session_displacement_arguments( "The entry for each recording in `recording_amplitude_scalings` " "must have the same length as the number of units." ) + + +def _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, +): + + seed_top = seed + 1 if seed is not None else None + seed_bottom = seed - 1 if seed is not None else None + + # Set unit locations above and below the probe + channel_locations_extend_top = channel_locations.copy() + channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1]) + + extend_top_locations = generate_unit_locations( + num_units, + channel_locations_extend_top, + seed=seed_top, # explain + **generate_unit_locations_kwargs, + ) + + channel_locations_extend_bottom = channel_locations.copy() + channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1]) + + extend_bottom_locations = generate_unit_locations( + num_units, + channel_locations_extend_bottom, + seed=seed_bottom, + **generate_unit_locations_kwargs, + ) + + unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations] + + # Set firing rates and params for these units + + # Do the same here + template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top) + template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom) + + for key in fixed_generate_templates_kwargs["unit_params"].keys(): + fixed_generate_templates_kwargs["unit_params"][key] = np.r_[ + template_kwargs_top["unit_params"][key], + fixed_generate_templates_kwargs["unit_params"][key], + template_kwargs_bottom["unit_params"][key], + ] + + firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top) + firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom) + + fixed_generate_sorting_kwargs["firing_rates"] = np.r_[ + firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom + ] + + num_units *= 3 + + return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs