Skip to content

Commit

Permalink
Revert "Try a recursive nonrigid alignment, doesn't really work."
Browse files Browse the repository at this point in the history
This reverts commit f27328c.
  • Loading branch information
JoeZiminski committed Aug 28, 2024
1 parent 818cf86 commit ff2de84
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 133 deletions.
144 changes: 17 additions & 127 deletions debugging/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,132 +392,6 @@ def run_alignment_estimation(
# of iterative template. Want to leave separate for now for prototyping
# but should combine the shared parts later.

# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?

# TOOD: the iterative_template seems a little different to the interpolation
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
# implementation is different to that described in the paper/ where is the
# Akima spline interpolation?

# TODO: make sure that the num bins will always align.
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code

# TODO: what about the Akima Spline, this would be cooler
# TODO: try out logarithmic scaling as some neurons fire too much...

num_bins = spatial_bin_centers.shape[0]
# assert spatial_bin_centers.shape[0] %2 == 0, "num channels must be even"

min_num_bins = 10
divs = 2**np.arange(10)
divs = divs[np.where(num_bins / divs > min_num_bins)]

accumulated_shifts = []

step_shifts = []
for step_idx, num_steps in enumerate(divs): # TOOD: use this compeltely dynamically, for rigid, kilosort-like and recursive


bin_edges = np.arange(num_steps + 1) * (num_bins/num_steps)
print(bin_edges)
bin_edges = bin_edges.astype(int)
# bin_edges = np.arange(1, num_steps + 1)[::-1]
# bin_edges = np.r_[0, (num_bins / bin_edges).astype(int)]

non_rigid_windows = np.zeros((num_steps, num_bins))
non_rigid_window_centers = np.zeros(num_steps)

for i in range(num_steps):
non_rigid_windows[i, bin_edges[i]:bin_edges[i+1]] = 1
non_rigid_window_centers[i] = np.mean(spatial_bin_centers[non_rigid_windows[i].astype(bool)]) # TODO: maybe not mean, maybe (max - min)/ 2 ... :S

if num_steps == 1:
shifted_histograms = all_session_hists[:, :, np.newaxis]
else:
shifted_histograms = np.repeat(shifted_histograms, 2, axis=2)

# shifted_histograms *= non_rigid_windows # TODO

window_shifts = []
for win_idx in range(shifted_histograms.shape[2]):

window_hist = shifted_histograms[:, :, win_idx] * non_rigid_windows[win_idx]

if np.any(window_hist):
# this method just xcorr the entire window does not provide subset of samples like kilosort_like
window_hist_array = _compute_rigid_hist_crosscorr(num_sessions, num_bins, window_hist, robust=False)

all_ses_shifts = -np.mean(window_hist_array, axis=0)

if np.any(all_ses_shifts > 400):
breakpoint()

else:
all_ses_shifts = np.zeros(window_hist.shape[0])

window_shifts.append(all_ses_shifts)

# perform shift
for ses_idx, shift in enumerate(all_ses_shifts):

abs_shift = np.abs(shift).astype(int)

if abs_shift == 0:
cut_padded_hist = window_hist[ses_idx]
else:
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) # TODO: check direction!

padded_hist = np.pad(window_hist[ses_idx], pad_tuple, mode="constant")
cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]

shifted_histograms[ses_idx, :, win_idx] = cut_padded_hist

step_shifts.append(window_shifts)

breakpoint()
"""
try:
splitto_binno = np.split(y, divo)
except:
breakpoint()
hist_idx = np.where(i in binno for binno in splitto_binno)[0]
for ses_idx in range(num_sessions):
shift = int(accumulated_shifts[seg_idx][ses_idx, hist_idx])
abs_shift = np.abs(shift)
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) # TODO: check direction!
padded_hist = np.pad(all_session_hists[ses_idx, :], pad_tuple, mode="constant")
cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]
try:
shifted_histograms[ses_idx, hist_idx, :] = cut_padded_hist
except:
breakpoint()
non_rigid_shifts = np.zeros((num_sessions, non_rigid_windows.shape[0]))
for i, window in enumerate(non_rigid_windows): # TODO: use same name
windowed_histogram = shifted_histograms[:, i, :] * window # these are shifted, but not windows. Maybe better to window then shift like kilosort.
window_hist_array = _compute_rigid_hist_crosscorr(
num_sessions, num_bins, windowed_histogram, robust=False
# this method just xcorr the entire window does not provide subset of samples like kilosort_like
)
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)
accumulated_shifts.append(non_rigid_shifts)
"""

return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up


"""
num_steps = 7
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_steps

Expand All @@ -531,7 +405,21 @@ def run_alignment_estimation(
win_margin_um=0,
# zero_threshold=None,
)
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?

# TOOD: the iterative_template seems a little different to the interpolation
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
# implementation is different to that described in the paper/ where is the
# Akima spline interpolation?

# TODO: make sure that the num bins will always align.
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code

import matplotlib.pyplot as plt

# TODO: for recursive version, shift cannot be larger than previous shift!
shifted_histograms = np.zeros_like(all_session_hists)
for i in range(all_session_hists.shape[0]):

Expand All @@ -554,7 +442,9 @@ def run_alignment_estimation(
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)

return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up
"""

# TODO: what about the Akima Spline
# TODO: try out logarithmic scaling as some neurons fire too much...

def _compute_rigid_hist_crosscorr(num_sessions, num_bins, all_session_hists, robust=False):
""""""
Expand Down
Binary file modified debugging/all_recordings.pickle
Binary file not shown.
3 changes: 1 addition & 2 deletions debugging/session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def estimate_inter_session_displacement(
min_y = np.min([np.min(locs["y"]) for locs in peak_locations_list])
max_y = np.max([np.max(locs["y"]) for locs in peak_locations_list])

# TOOD: specifically chosen to get num bins to work!!!!!!!!!!!!!!!!!! #######################################################################################################
spatial_bin_edges = np.linspace(min_y, max_y, 1024 + 1) # np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
spatial_bin_edges = np.arange(min_y, max_y + bin_um, bin_um) # TODO: expose a margin...
spatial_bin_centers = alignment_utils.get_bin_centers(spatial_bin_edges)

# Estimate an activity histogram per-session
Expand Down
7 changes: 3 additions & 4 deletions debugging/test_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# What we really want to find is maximal subset of the data that matches

SAVE = False
PLOT = True
PLOT = False
BIN_UM = 2

if SAVE:
Expand All @@ -50,7 +50,7 @@
num_units=25,
recording_durations=(100, 100),
recording_shifts=(
(0, 0), (0, 75), # TODO: check the histogram, why is this shift not actually 75 um!??!?!? could be an x-axis plotting issue...
(0, 0), (0, 75),
),
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
seed=None,
Expand All @@ -76,10 +76,9 @@


corrected_recordings_list, motion_objects_list, extra_info = session_alignment.run_inter_session_displacement_correction(
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=False
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=True
)

# TODO: make sure raster plot y-axis are aligned
plotting.SessionAlignmentWidget(
recordings_list,
peaks_list,
Expand Down

0 comments on commit ff2de84

Please sign in to comment.