Skip to content

Commit

Permalink
Temp save, will need to remove many parts.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Aug 28, 2024
1 parent 81ff1c5 commit e7c22dc
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 11 deletions.
4 changes: 3 additions & 1 deletion debugging/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges,
"""
TODO: assumes 1-segment recording
"""
# TODO: this is weird, don't return spatial_bin_edges here... amybe assert..
entire_session_hist, temporal_bin_edges, spatial_bin_edges = \
make_2d_motion_histogram(
recording,
Expand All @@ -35,6 +36,7 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges,
hist_margin_um=None,
spatial_bin_edges=spatial_bin_edges,
)

entire_session_hist = entire_session_hist[0]

entire_session_hist /= recording.get_duration(segment_index=0)
Expand Down Expand Up @@ -290,7 +292,7 @@ def prep_recording(recording, plot=False):
:param recording:
:return:
"""
peaks = detect_peaks(recording, method="by_channel") # "locally_exclusive")
peaks = detect_peaks(recording, method="locally_exclusive")

peak_locations = localize_peaks(recording, peaks,
method="grid_convolution")
Expand Down
Binary file modified debugging/all_recordings.pickle
Binary file not shown.
66 changes: 66 additions & 0 deletions debugging/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import numpy as np
import plotting
import alignment_utils
import matplotlib.pyplot as plt
import pickle
scalings = [np.ones(10), np.r_[np.zeros(3), np.ones(7)]]

SAVE = True

if SAVE:
recordings_list, _ = generate_session_displacement_recordings(
non_rigid_gradient=None,
num_units=35,
recording_durations=(100, 100),
recording_shifts=(
(0, 0), (0, 0),
),
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
seed=None,
)

peaks_list = []
peak_locations_list = []

for recording in recordings_list:
peaks, peak_locations = alignment_utils.prep_recording(
recording, plot=False,
)
peaks_list.append(peaks)
peak_locations_list.append(peak_locations)

# something relatively easy, only 15 units
with open('all_recordings.pickle', 'wb') as handle:
pickle.dump((recordings_list, peaks_list, peak_locations_list),
handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('all_recordings.pickle', 'rb') as handle:
recordings_list, peaks_list, peak_locations_list = pickle.load(handle)

bin_um = 2

# TODO: own function
min_y = np.min([np.min(locs["y"]) for locs in peak_locations_list])
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])

spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function

session_histogram_list = []
for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list):

hist, temp, spat = alignment_utils.get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, log_scale=False)

session_histogram_list.append(
hist
)
# TODO: need to get all outputs and check are same size
plotting.SessionAlignmentWidget(
recordings_list,
peaks_list,
peak_locations_list,
session_histogram_list,
histogram_spatial_bin_centers=spatial_bin_centers,
)
plt.show()
5 changes: 3 additions & 2 deletions debugging/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax_top.set_title(f"Session {i + 1}")
ax_top.set_xlabel(None)

plot = DriftRasterMapWidget(
DriftRasterMapWidget(
dp.peaks_list[i],
dp.corrected_peak_locations_list[i],
recording=dp.recordings_list[i],
Expand All @@ -163,7 +163,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

ax = fig.add_subplot(gs[num_rows, :])

plot = SessionAlignmentHistogramWidget(
SessionAlignmentHistogramWidget(
dp.session_histogram_list,
dp.histogram_spatial_bin_centers,
ax=ax,
Expand Down Expand Up @@ -242,6 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
if isinstance(linewidths, int):
linewidths = [linewidths] * num_histograms

# TODO: this leads to quite unexpected behaviours, figure something else out.
if spatial_bin_centers is None:
num_bins = dp.session_histogram_list[0].size
spatial_bin_centers = [np.arange(num_bins)] * num_histograms
Expand Down
2 changes: 1 addition & 1 deletion debugging/session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def estimate_inter_session_displacement(
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])

spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges)
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges) # TODO: own function

# Estimate an activity histogram per-session
all_session_hists = [] # TODO: probably better as a dict
Expand Down
2 changes: 0 additions & 2 deletions debugging/test_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
# handle the case where the passed recordings are not motion correction recordings.


# 1) get all commits and PRs in order. Work on the original PR
# 2) investigate why the expected peaks do not drop when recording_amplitude_scalings (rename) is used
# 3) think about and add new neurons that are introduced when shifted

# 4) add interpolation of the histograms prior to cross correlation
Expand Down
90 changes: 85 additions & 5 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,27 @@ def generate_session_displacement_recordings(

# Fix generate template kwargs, so they are the same for every created recording.
# Also fix unit firing rates across recordings.
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs)
fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates

extend = True
if extend:
num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = (
_update_kwargs_for_extended_units(
num_units,
channel_locations,
unit_locations,
generate_unit_locations_kwargs,
generate_templates_kwargs,
generate_sorting_kwargs,
fixed_generate_templates_kwargs,
fixed_generate_sorting_kwargs,
seed,
)
)

# Start looping over parameters, creating recordings shifted
# and scaled as required
Expand Down Expand Up @@ -174,7 +191,7 @@ def generate_session_displacement_recordings(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=[duration],
**generate_sorting_kwargs,
**fixed_generate_sorting_kwargs,
extra_outputs=True,
seed=seed,
)
Expand All @@ -195,7 +212,7 @@ def generate_session_displacement_recordings(
unit_locations_moved,
sampling_frequency=sampling_frequency,
seed=seed,
**generate_templates_kwargs,
**fixed_generate_templates_kwargs,
)

if recording_amplitude_scalings is not None:
Expand All @@ -210,7 +227,7 @@ def generate_session_displacement_recordings(

# Bring it all together in a `InjectTemplatesRecording` and
# propagate all relevant metadata to the recording.
ms_before = generate_templates_kwargs["ms_before"]
ms_before = fixed_generate_templates_kwargs["ms_before"]
nbefore = int(sampling_frequency * ms_before / 1000.0)

recording = InjectTemplatesRecording(
Expand Down Expand Up @@ -389,3 +406,66 @@ def _check_generate_session_displacement_arguments(
"The entry for each recording in `recording_amplitude_scalings` "
"must have the same length as the number of units."
)


def _update_kwargs_for_extended_units(
num_units,
channel_locations,
unit_locations,
generate_unit_locations_kwargs,
generate_templates_kwargs,
generate_sorting_kwargs,
fixed_generate_templates_kwargs,
fixed_generate_sorting_kwargs,
seed,
):

seed_top = seed + 1 if seed is not None else None
seed_bottom = seed - 1 if seed is not None else None

# Set unit locations above and below the probe
channel_locations_extend_top = channel_locations.copy()
channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1])

extend_top_locations = generate_unit_locations(
num_units,
channel_locations_extend_top,
seed=seed_top, # explain
**generate_unit_locations_kwargs,
)

channel_locations_extend_bottom = channel_locations.copy()
channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1])

extend_bottom_locations = generate_unit_locations(
num_units,
channel_locations_extend_bottom,
seed=seed_bottom,
**generate_unit_locations_kwargs,
)

unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations]

# Set firing rates and params for these units

# Do the same here
template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top)
template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom)

for key in fixed_generate_templates_kwargs["unit_params"].keys():
fixed_generate_templates_kwargs["unit_params"][key] = np.r_[
template_kwargs_top["unit_params"][key],
fixed_generate_templates_kwargs["unit_params"][key],
template_kwargs_bottom["unit_params"][key],
]

firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top)
firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom)

fixed_generate_sorting_kwargs["firing_rates"] = np.r_[
firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom
]

num_units *= 3

return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs

0 comments on commit e7c22dc

Please sign in to comment.