Skip to content

Commit

Permalink
Reformatting alignment methods and add 2D, need to tidy up.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 17, 2024
1 parent e7435ec commit 62010d8
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 90 deletions.
30 changes: 24 additions & 6 deletions debugging/_test_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def _prep_recording(recording, plot=False):

return peaks, peak_locations

MOTION = False # True
SAVE = False
MOTION = True # True
SAVE = True
PLOT = False
BIN_UM = 5

Expand Down Expand Up @@ -224,7 +224,7 @@ def _prep_recording(recording, plot=False):
"chunked_bin_size_s": "estimate",
"log_scale": True,
"depth_smooth_um": 10,
"histogram_type": "activity_1d", # "y_only", "2Dy_x", "2Dy_amplitude"" TOOD: better names!
"histogram_type": "activity_1d", # "y_only", "locations_2d", "activity_2d"" TOOD: better names!
}
compute_alignment_kwargs = {
"num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs.
Expand All @@ -247,9 +247,7 @@ def _prep_recording(recording, plot=False):
}

if MOTION:
from session_alignment import align_sessions_after_motion_correction

corrected_recordings_list, extra_info = align_sessions_after_motion_correction(
corrected_recordings_list, extra_info = session_alignment.align_sessions_after_motion_correction(
recordings_list,
motion_info_list,
align_sessions_kwargs={
Expand All @@ -271,7 +269,19 @@ def _prep_recording(recording, plot=False):
compute_alignment_kwargs=compute_alignment_kwargs,
)

<<<<<<< HEAD

plotting_session_alignment.SessionAlignmentWidget(
recordings_list,
peaks_list,
peak_locations_list,
extra_info["session_histogram_list"],
**extra_info["corrected"],
spatial_bin_centers=extra_info["spatial_bin_centers"],
drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings.
)

=======
plotting_session_alignment.SessionAlignmentWidget(
recordings_list,
peaks_list,
Expand All @@ -282,14 +292,22 @@ def _prep_recording(recording, plot=False):
drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings.
)

>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.)
plt.show()

# TODO: estimate chunk size Hz needs to be scaled for time? is it not been done correctly?

<<<<<<< HEAD
# No, even two sessions is a mess
# TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment
# as at the moment all nonrigid to eachother is a mess
if False:
=======
if False:
# No, even two sessions is a mess
# TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment
# as at the moment all nonrigid to eachother is a mess
>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.)
A = extra_info["histogram_info_list"][2]["chunked_histograms"]

mean_ = alignment_utils.get_chunked_hist_mean(A)
Expand Down
Binary file added debugging/peak_locs_1.npy
Binary file not shown.
Binary file added debugging/peak_locs_2.npy
Binary file not shown.
Binary file added debugging/peaks_1.npy
Binary file not shown.
Binary file added debugging/peaks_2.npy
Binary file not shown.
152 changes: 103 additions & 49 deletions debugging/playing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
Expand All @@ -9,57 +8,112 @@
)
import matplotlib.pyplot as plt

import spikeinterface.full as si

# --------------------------------------------------------------------------------------
# Load / generate some recordings
# --------------------------------------------------------------------------------------
si.set_global_job_kwargs(n_jobs=10)

recordings_list, _ = generate_session_displacement_recordings(
num_units=50,
recording_durations=[50, 50, 50],
recording_shifts=((0, 0), (0, 50), (0, 75))
)

# --------------------------------------------------------------------------------------
# Compute the peaks / locations with your favourite method
# --------------------------------------------------------------------------------------
# Note if you did motion correction the peaks are on the motion object.
# There is a function 'session_alignment.align_sessions_after_motion_correction()
# you can use instead of the below.

peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment(
recordings_list,
detect_kwargs={"method": "locally_exclusive"},
localize_peaks_kwargs={"method": "grid_convolution"},
)
if __name__ == '__main__':

# --------------------------------------------------------------------------------------
# Do the estimation
# --------------------------------------------------------------------------------------
# For each session, an 'activity histogram' is generated. This can be `entire_session`
# or the session can be chunked into segments and some summary generated taken over then.
# This might be useful if periods of the recording have weird kinetics or noise.
# See `session_alignment.py` for docs on these settings.

non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs()
non_rigid_window_kwargs["rigid"] = True

corrected_recordings_list, extra_info = session_alignment.align_sessions(
recordings_list,
peaks_list,
peak_locations_list,
alignment_order="to_session_1", # "to_session_X" or "to_middle"
non_rigid_window_kwargs=non_rigid_window_kwargs,
)
# --------------------------------------------------------------------------------------
# Load / generate some recordings
# --------------------------------------------------------------------------------------

plotting_session_alignment.SessionAlignmentWidget(
recordings_list,
peaks_list,
peak_locations_list,
extra_info["session_histogram_list"],
**extra_info["corrected"],
spatial_bin_centers=extra_info["spatial_bin_centers"],
drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10}
)

plt.show()

recordings_list, _ = generate_session_displacement_recordings(
num_units=20,
recording_durations=[400, 400, 400],
recording_shifts=((0, 0), (0, 200), (0, -125)),
non_rigid_gradient=0.005,
seed=52,
)
if False:
import numpy as np


recordings_list = [
si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-38-28_VR1.zarr\M25_D18_2024-11-05_12-38-28_VR1.zarr"),
si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-08-47_OF1.zarr\M25_D18_2024-11-05_12-08-47_OF1.zarr"),
]

recordings_list = [si.astype(rec, np.float32) for rec in recordings_list]
recordings_list = [si.bandpass_filter(rec) for rec in recordings_list]
recordings_list = [si.common_reference(rec, operator="median") for rec in recordings_list]

# --------------------------------------------------------------------------------------
# Compute the peaks / locations with your favourite method
# --------------------------------------------------------------------------------------
# Note if you did motion correction the peaks are on the motion object.
# There is a function 'session_alignment.align_sessions_after_motion_correction()
# you can use instead of the below.

peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment(
recordings_list,
detect_kwargs={"method": "locally_exclusive"},
localize_peaks_kwargs={"method": "grid_convolution"},
)

if False:
np.save("peaks_1.npy", peaks_list[0])
np.save("peaks_2.npy", peaks_list[1])
np.save("peak_locs_1.npy", peak_locations_list[0])
np.save("peak_locs_2.npy", peak_locations_list[1])

if False:
peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy")]
peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy")]

# --------------------------------------------------------------------------------------
# Do the estimation
# --------------------------------------------------------------------------------------
# For each session, an 'activity histogram' is generated. This can be `entire_session`
# or the session can be chunked into segments and some summary generated taken over then.
# This might be useful if periods of the recording have weird kinetics or noise.
# See `session_alignment.py` for docs on these settings.

non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs()
non_rigid_window_kwargs["rigid"] = False
# non_rigid_window_kwargs["win_shape"] = "rect"
# non_rigid_window_kwargs["win_step_um"] = 25

estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
estimate_histogram_kwargs["method"] = "chunked_mean"
estimate_histogram_kwargs["histogram_type"] = "activity_1d"
estimate_histogram_kwargs["bin_um"] = 5

corrected_recordings_list, extra_info = session_alignment.align_sessions(
recordings_list,
peaks_list,
peak_locations_list,
alignment_order="to_session_1", # "to_session_X" or "to_middle"
non_rigid_window_kwargs=non_rigid_window_kwargs,
estimate_histogram_kwargs=estimate_histogram_kwargs,
)

# TODO: nonlinear is not working well 'to middle', investigate
# TODO: also finalise the estimation of bin number of nonrigid.

if False:
plt.plot(extra_info["histogram_info_list"][0]["chunked_histograms"].T, color="black")

M = extra_info["session_histogram_list"][0]
S = extra_info["histogram_info_list"][0]["session_histogram_variation"]

plt.plot(M, color="red")
plt.plot(M + S, color="green")
plt.plot(M - S, color="green")

plt.show()

plotting_session_alignment.SessionAlignmentWidget(
recordings_list,
peaks_list,
peak_locations_list,
extra_info["session_histogram_list"],
**extra_info["corrected"],
spatial_bin_centers=extra_info["spatial_bin_centers"],
drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10}
)

plt.show()
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from spikeinterface import BaseRecording
import numpy as np

from spikeinterface.preprocessing import center
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram
from scipy.optimize import minimize
from scipy.ndimage import gaussian_filter
Expand Down Expand Up @@ -77,7 +79,7 @@ def get_activity_histogram(
activity_histogram *= scaler

if log_scale:
activity_histogram = np.log10(1 + activity_histogram)
activity_histogram = np.log10(1 + activity_histogram) # TODO: make_2d_motion_histogram uses log2

return activity_histogram, temporal_bin_centers, spatial_bin_centers

Expand Down Expand Up @@ -405,7 +407,7 @@ def compute_histogram_crosscorrelation(
mean does not make sense over sessions.
"""
num_sessions = len(session_histogram_list)
num_bins = session_histogram_list[0].size # all hists are same length
num_bins = session_histogram_list.shape[1] # all hists are same length
num_windows = non_rigid_windows.shape[0]

shift_matrix = np.zeros((num_sessions, num_sessions, num_windows))
Expand All @@ -416,36 +418,75 @@ def compute_histogram_crosscorrelation(
for j in range(num_sessions):

# Create the (num windows, num_bins) matrix for this pair of sessions
xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_bins * 2 - 1))

import matplotlib.pyplot as plt

# TODO: plot everything

num_iter = (
num_bins * 2 - 1 if not num_shifts_block else num_shifts_block * 2
) # TODO: make sure this is clearly defined, it is either side...
xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter))

# For each window, window the session histograms (`window` is binary)
# and perform the cross correlations
for win_idx, window in enumerate(non_rigid_windows):
windowed_histogram_i = session_histogram_list[i, :] * window
windowed_histogram_j = session_histogram_list[j, :] * window

xcorr = np.correlate(
windowed_histogram_i, windowed_histogram_j, mode="full"
) # TODO: add weight option.
# breakpoint()
# TODO: track the 2d histogram through all steps to check everything is working okay

# TOOD: gaussian window with crosscorr, won't it strongly bias zero shifts by maximising the signal at 0?
# TODO: add weight option.
# TODO: damn this is slow for 2D, speed up.
if session_histogram_list.ndim == 3:
windowed_histogram_i = session_histogram_list[i, :] * window[:, np.newaxis]
windowed_histogram_j = session_histogram_list[j, :] * window[:, np.newaxis]

from scipy.signal import correlate2d

# carefully check indices
xcorr = correlate2d(
windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis],
windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis],
) # TOOD: check speed, probs don't remove mean because we want zeros for unmasked version

mid_idx = windowed_histogram_j.shape[1] - 1
xcorr = xcorr[:, mid_idx]

else:
windowed_histogram_i = session_histogram_list[i, :] * window

window_target = True # this makes less sense now that things could be very far apart
if window_target:
windowed_histogram_j = session_histogram_list[j, :] * window
else:
windowed_histogram_j = session_histogram_list[j, :]

xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full")

# plt.plot(windowed_histogram_i)
# plt.plot(windowed_histogram_j)
# plt.show()

if num_shifts_block:
window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block)
mask = np.zeros_like(xcorr)
mask[window_indices] = 1
xcorr *= mask
xcorr = xcorr[window_indices]
shift_center_bin = (
num_shifts_block # np.floor(num_shifts_block / 2) # TODO: CHECK! and move out of loop!
)
else:
shift_center_bin = center_bin

# plt.plot(xcorr)
# plt.show()

xcorr_matrix[win_idx, :] = xcorr

# TODO: check absolute value of different bins, they are quite different (log scale, zero mean histograms)
# TODO: print out a load of quality metrics from this!

# Smooth the cross-correlations across the bins
if smoothing_sigma_bin:
breakpoint()
import matplotlib.pyplot as plt

plt.plot(xcorr_matrix[0, :])
X = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
plt.plot(X[0, :])
plt.show()

xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)

# Smooth the cross-correlations across the windows
Expand All @@ -470,7 +511,9 @@ def compute_histogram_crosscorrelation(
else:
xcorr_peak = np.argmax(xcorr_matrix, axis=1)

shift = xcorr_peak - center_bin
# breakpoint()

shift = xcorr_peak - shift_center_bin # center_bin
shift_matrix[i, j, :] = shift

return shift_matrix
Expand Down Expand Up @@ -502,8 +545,16 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
"""
abs_shift = np.abs(shift)
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0)

if array.ndim == 2:
pad_tuple = (pad_tuple, (0, 0))

padded_hist = np.pad(array, pad_tuple, mode="constant")
cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]

if padded_hist.ndim == 2:
cut_padded_array = padded_hist[abs_shift:, :] if shift >= 0 else padded_hist[:-abs_shift, :] # TOOD: tidy up
else:
cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]

return cut_padded_array

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
elif isinstance(spatial_bin_centers, np.ndarray):
spatial_bin_centers = [spatial_bin_centers] * num_histograms

# TOOD: For 2D histogram, will need to subplot and just plot the individual histograms...
for i in range(num_histograms):
self.ax.plot(spatial_bin_centers[i], dp.session_histogram_list[i], color=colors[i], linewidth=linewidths[i])

Expand Down
Loading

0 comments on commit 62010d8

Please sign in to comment.