Skip to content

Commit

Permalink
Trying different alg
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 19, 2024
1 parent 300b781 commit 8e5ba68
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 28 deletions.
42 changes: 25 additions & 17 deletions debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import matplotlib.pyplot as plt

import spikeinterface.full as si
import numpy as np


si.set_global_job_kwargs(n_jobs=10)

Expand All @@ -19,8 +21,6 @@
# Load / generate some recordings
# --------------------------------------------------------------------------------------



recordings_list, _ = generate_session_displacement_recordings(
num_units=20,
recording_durations=[400, 400, 400],
Expand Down Expand Up @@ -48,21 +48,23 @@
# There is a function 'session_alignment.align_sessions_after_motion_correction()
# you can use instead of the below.

peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment(
recordings_list,
detect_kwargs={"method": "locally_exclusive"},
localize_peaks_kwargs={"method": "grid_convolution"},
)

if False:
peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment(
recordings_list,
detect_kwargs={"method": "locally_exclusive"},
localize_peaks_kwargs={"method": "grid_convolution"},
)

np.save("peaks_1.npy", peaks_list[0])
np.save("peaks_2.npy", peaks_list[1])
np.save("peaks_3.npy", peaks_list[2])
np.save("peak_locs_1.npy", peak_locations_list[0])
np.save("peak_locs_2.npy", peak_locations_list[1])
np.save("peak_locs_3.npy", peak_locations_list[2])

if False:
peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy")]
peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy")]
# if False:
peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")]
peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")]

# --------------------------------------------------------------------------------------
# Do the estimation
Expand All @@ -73,14 +75,20 @@
# See `session_alignment.py` for docs on these settings.

non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs()
non_rigid_window_kwargs["rigid"] = False
# non_rigid_window_kwargs["win_shape"] = "rect"
# non_rigid_window_kwargs["win_step_um"] = 25
non_rigid_window_kwargs["rigid_mode"] = "nonrigid"
non_rigid_window_kwargs["win_shape"] = "rect"
non_rigid_window_kwargs["win_step_um"] = 100.0
non_rigid_window_kwargs["win_scale_um"] = 200.0

estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
estimate_histogram_kwargs["method"] = "chunked_mean"
estimate_histogram_kwargs["histogram_type"] = "activity_1d"
estimate_histogram_kwargs["bin_um"] = 5
estimate_histogram_kwargs["method"] = "chunked_median"
estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly
estimate_histogram_kwargs["bin_um"] = 2
estimate_histogram_kwargs["log_scale"] = True
estimate_histogram_kwargs["weight_with_amplitude"] = False

compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs()
compute_alignment_kwargs["num_shifts_block"] = 300

corrected_recordings_list, extra_info = session_alignment.align_sessions(
recordings_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def get_activity_histogram(
bin_s: float | None,
depth_smooth_um: float | None,
scale_to_hz: bool = False,
weight_with_amplitude: bool = False,
):
"""
Generate a 2D activity histogram for the session. Wraps the underlying
Expand Down Expand Up @@ -57,7 +58,7 @@ def get_activity_histogram(
recording,
peaks,
peak_locations,
weight_with_amplitude=False,
weight_with_amplitude=weight_with_amplitude,
direction="y",
bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)),
bin_um=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
elif isinstance(spatial_bin_centers, np.ndarray):
spatial_bin_centers = [spatial_bin_centers] * num_histograms

# TOOD: For 2D histogram, will need to subplot and just plot the individual histograms...
if dp.session_histogram_list[0].ndim == 2:
histogram_list = [np.sum(hist_, axis=1) for hist_ in dp.session_histogram_list]
print("2D histogram passed, will be summed across first (i.e. amplitude) axis.")
else:
histogram_list = dp.session_histogram_list

for i in range(num_histograms):
self.ax.plot(spatial_bin_centers[i], dp.session_histogram_list[i], color=colors[i], linewidth=linewidths[i])
self.ax.plot(spatial_bin_centers[i], histogram_list[i], color=colors[i], linewidth=linewidths[i])

if legend is not None:
self.ax.legend(legend)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_estimate_histogram_kwargs() -> dict:
"log_scale": False,
"depth_smooth_um": None,
"histogram_type": "activity_1d",
"weight_with_amplitude": True,
}


Expand All @@ -66,7 +67,7 @@ def get_compute_alignment_kwargs() -> dict:
windows along the probe depth. See `get_spatial_windows`.
"""
return {
"num_shifts_block": 100, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction
"num_shifts_block": 50, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction
"interpolate": False,
"interp_factor": 10,
"kriging_sigma": 1,
Expand All @@ -88,7 +89,7 @@ def get_non_rigid_window_kwargs():
defined in the function signature.
"""
return {
"rigid": True,
"rigid_mode": "rigid", # "rigid", "rigid_nonrigid", "nonrigid"
"win_shape": "gaussian",
"win_step_um": 50,
"win_scale_um": 50,
Expand Down Expand Up @@ -314,7 +315,7 @@ def align_sessions_after_motion_correction(
# shifts together.
if (
"non_rigid_window_kwargs" in align_sessions_kwargs
and align_sessions_kwargs["non_rigid_window_kwargs"]["rigid"] is False
and "nonrigid" in align_sessions_kwargs["non_rigid_window_kwargs"]["rigid_mode"]
):

if motion_window_kwargs["rigid"] is False:
Expand Down Expand Up @@ -392,6 +393,7 @@ def _compute_session_histograms(
chunked_bin_size_s: float | "estimate",
depth_smooth_um: float,
log_scale: bool,
weight_with_amplitude: bool,
) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]:
"""
Compute a 1d activity histogram for the session. As
Expand Down Expand Up @@ -455,6 +457,7 @@ def _compute_session_histograms(
log_scale,
chunked_bin_size_s,
depth_smooth_um,
weight_with_amplitude,
)
temporal_bin_centers_list.append(temporal_bin_centers)
session_histogram_list.append(session_hist)
Expand All @@ -479,6 +482,7 @@ def _get_single_session_activity_histogram(
log_scale: bool,
chunked_bin_size_s: float | "estimate",
depth_smooth_um: float,
weight_with_amplitude: bool,
) -> tuple[np.ndarray, np.ndarray, dict]:
"""
Compute an activity histogram for a single session.
Expand Down Expand Up @@ -534,6 +538,7 @@ def _get_single_session_activity_histogram(
bin_s=None,
depth_smooth_um=None,
scale_to_hz=False,
weight_with_amplitude=weight_with_amplitude
)

# It is important that the passed histogram is scaled to firing rate in Hz
Expand Down Expand Up @@ -824,12 +829,69 @@ def _correct_session_displacement(
estimate_histogram_kwargs["log_scale"],
estimate_histogram_kwargs["chunked_bin_size_s"],
estimate_histogram_kwargs["depth_smooth_um"],
estimate_histogram_kwargs["weight_with_amplitude"],
)
corrected_session_histogram_list.append(session_hist)

return corrected_peak_locations_list, corrected_session_histogram_list


def get_shifts(signal1, signal2, windows):

import matplotlib.pyplot as plt

signa11_blanked = signal1.copy()
signal2_blanked = signal2.copy()

if (first_idx := windows[0][0]) != 0:
print("first idx", first_idx)
signa11_blanked[:first_idx] = 0
signal2_blanked[:first_idx] = 0

if (last_idx := windows[-1][-1]) != signal1.size - 1: #double check
print("last idx", last_idx)
signa11_blanked[last_idx:] = 0
signal2_blanked[last_idx:] = 0

segment_shifts = np.empty(windows.shape[0])
cum_shifts = []

for round in range(windows.shape[0]):

xcorr = np.correlate(signa11_blanked, signal2_blanked, mode="full")

if np.max(xcorr) < 0.01:
shift = 0
else:
shift = np.argmax(xcorr) - xcorr.size // 2
cum_shifts.append(shift)
print(shift)

# shift the signal1, or use indexing
signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, shift)

# plt.plot(signa11_blanked)
# plt.plot(signal2_blanked)
# plt.show()

window_corrs = np.empty(windows.shape[0])
for i, idx in enumerate(windows):

window_corrs[i] = np.correlate(signa11_blanked[idx] - np.mean(signa11_blanked[idx]), signal2_blanked[idx] - np.mean(signal2_blanked[idx]))

max_window = np.argmax(window_corrs)

segment_shifts[max_window] = np.sum(cum_shifts)

# print(segment_shifts[max_window])

# TODO: this is interacting with the shift to make spikes!
signa11_blanked[windows[max_window]] = 0
signal2_blanked[windows[max_window]] = 0

return segment_shifts


def _compute_session_alignment(
session_histogram_list: list[np.ndarray],
contact_depths: np.ndarray,
Expand Down Expand Up @@ -868,6 +930,9 @@ def _compute_session_alignment(

akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid")

rigid_mode = non_rigid_window_kwargs.pop("rigid_mode") # TODO: carefully check all popped kwargs
non_rigid_window_kwargs["rigid"] = rigid_mode == "rigid"

non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
contact_depths, spatial_bin_centers, **non_rigid_window_kwargs
)
Expand All @@ -878,7 +943,7 @@ def _compute_session_alignment(
compute_alignment_kwargs,
)

if non_rigid_window_kwargs["rigid"]:
if rigid_mode == "rigid":
return rigid_shifts, non_rigid_windows, non_rigid_window_centers

# For non-rigid, first shift the histograms according to the rigid shift
Expand All @@ -888,7 +953,7 @@ def _compute_session_alignment(
# for non-rigid, it makes sense to start without rigid alignment
shifted_histograms = session_histogram_array.copy()

if False:
if rigid_mode == "rigid_nonrigid": # TOOD: add to docs
shifted_histograms = np.zeros_like(session_histogram_array)
for ses_idx, orig_histogram in enumerate(session_histogram_array):

Expand All @@ -897,10 +962,45 @@ def _compute_session_alignment(
)
shifted_histograms[ses_idx, :] = shifted_histogram


nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0]))

windows = []
for i in range(non_rigid_windows.shape[0]):
idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)]
windows.append(idxs)
# TODO: check assumptions these are always the same size

windows = np.vstack(windows)

# import matplotlib.pyplot as plt
# plt.plot(non_rigid_windows.T)
# plt.show()

windows1 = windows[::2, :]
windows2 = windows[1::2, :]

nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0], non_rigid_windows.shape[0]))

for i in range(shifted_histograms.shape[0]):
for j in range(shifted_histograms.shape[0]):

shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1)
shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2)
shifts = np.empty(shifts1.size + shifts2.size)
# breakpoint()
shifts[::2] = shifts1
shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2# np.shifts2
# breakpoint()
nonrigid_session_offsets_matrix[i, j, :] = shifts

# TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S

# breakpoint()
# Then compute the nonrigid shifts
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
)
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
# )
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)

# Akima interpolate the nonrigid bins if required.
Expand All @@ -911,8 +1011,12 @@ def _compute_session_alignment(
shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts
non_rigid_window_centers = spatial_bin_centers
else:
# TODO: so check + add a test, the interpolator will handle this?
shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts

if rigid_mode == "rigid_nonrigid":
shifts += rigid_shifts

return shifts, non_rigid_windows, non_rigid_window_centers


Expand Down

0 comments on commit 8e5ba68

Please sign in to comment.