Skip to content

Commit

Permalink
Adding a few more options.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Aug 28, 2024
1 parent ff2de84 commit c0eb2f5
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 72 deletions.
120 changes: 77 additions & 43 deletions debugging/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# -----------------------------------------------------------------------------

# TODO: this function might be pointless
def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges):
def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges, log_scale):
"""
TODO: assumes 1-segment recording
"""
Expand All @@ -41,11 +41,14 @@ def get_entire_session_hist(recording, peaks, peak_locations, spatial_bin_edges)

spatial_centers = get_bin_centers(spatial_bin_edges)

if log_scale:
entire_session_hist = np.log10(1 + entire_session_hist)

return entire_session_hist, temporal_bin_edges, spatial_centers


def get_chunked_histogram( # TODO: this function might be pointless
recording, peaks, peak_locations, bin_s, spatial_bin_edges, weight_with_amplitude=False
recording, peaks, peak_locations, bin_s, spatial_bin_edges, log_scale, weight_with_amplitude=False
):
chunked_session_hist, temporal_bin_edges, _ = \
make_2d_motion_histogram(
Expand All @@ -66,6 +69,9 @@ def get_chunked_histogram( # TODO: this function might be pointless
bin_times = np.diff(temporal_bin_edges)[:, np.newaxis]
chunked_session_hist /= bin_times

if log_scale:
chunked_session_hist = np.log10(1 + chunked_session_hist)

return chunked_session_hist, temporal_centers, spatial_centers

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -354,72 +360,81 @@ def run_kilosort_like_rigid_registration(all_hists, non_rigid_windows):
return -optimal_shift_indices # TODO: these are reversed at this stage


# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?

# TOOD: the iterative_template seems a little different to the interpolation
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
# implementation is different to that described in the paper/ where is the
# Akima spline interpolation?

# TODO: make sure that the num bins will always align.
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code

# TODO: this is basically a re-implimentation of the nonrigid part
# of iterative template. Want to leave separate for now for prototyping
# but should combine the shared parts later.

# TOOD: important differenence, this does not roll, will need to test when new spikes are added...

# TODO: try out logarithmic scaling as some neurons fire too much...



def run_alignment_estimation(
all_session_hists, spatial_bin_centers, rigid, robust=False
all_session_hists, spatial_bin_centers, rigid, num_nonrigid_bins, robust=False
):
"""
"""
# TODO: figure out best way to represent this, should probably be
# suffix _list instead of prefixed all_ for consistency
if isinstance(all_session_hists, list):
all_session_hists = np.array(all_session_hists) # TODO: figure out best way to represent this, should probably be suffix _list instead of prefixed all_ for consistency
all_session_hists = np.array(all_session_hists)

num_bins = spatial_bin_centers.size
num_sessions = all_session_hists.shape[0]

# TODO: rename
hist_array = _compute_rigid_hist_crosscorr(
num_sessions, num_bins, all_session_hists, robust
) # TODO: rename
)

optimal_shift_indices = -np.mean(hist_array, axis=0)[:, np.newaxis]
# (2, 1)

# First, perform the rigid alignment.

if rigid:
# TODO: this just shifts everything to the center. It would be (better?)
# to optmize so all shifts are about the same.
# TODO: used to get window center, for now just get them from the spatial bin
# centers and use no margin, which was applied earlier. Same below.
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
spatial_bin_centers,
# TODO: used to get window center, for now just get them from the spatial bin centers and use no margin, which was applied earlier
spatial_bin_centers,
rigid=True,
win_shape="gaussian", # rect
win_step_um=None, # TODO: expose! CHECK THIS!
# win_scale_um=win_scale_um,
win_shape="gaussian",
win_step_um=None,
win_margin_um=0,
# zero_threshold=None,
# win_scale_um=win_scale_um,
# zero_threshold=None,
)

return optimal_shift_indices, non_rigid_window_centers # TODO: rename rigid, also this is weird to pass back bins in the rigid case

# TODO: this is basically a re-implimentation of the nonrigid part
# of iterative template. Want to leave separate for now for prototyping
# but should combine the shared parts later.
# TODO: rename rigid, also this is weird to pass back bins in the rigid case
return optimal_shift_indices, non_rigid_window_centers

num_steps = 7
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_steps
win_step_um = (np.max(spatial_bin_centers) - np.min(spatial_bin_centers)) / num_nonrigid_bins

non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
spatial_bin_centers, # TODO: used to get window center, for now just get them from the spatial bin centers and use no margin, which was applied earlier
spatial_bin_centers,
spatial_bin_centers,
rigid=False,
win_shape="gaussian", # rect
win_shape="gaussian",
win_step_um=win_step_um, # TODO: expose!
# win_scale_um=win_scale_um,
win_margin_um=0,
# zero_threshold=None,
# win_scale_um=win_scale_um,
# zero_threshold=None,
)
# TODO: I wonder if it is better to estimate the hitsogram with finer bin size
# than try and interpolate the xcorr. What about smoothing the activity histograms directly?

# TOOD: the iterative_template seems a little different to the interpolation
# of nonrigid segments that is described in the NP2.0 paper. Oh, the KS
# implementation is different to that described in the paper/ where is the
# Akima spline interpolation?

# TODO: make sure that the num bins will always align.
# Apply the linear shifts, don't roll, as we don't want circular (why would the top of the probe appear at the bottom?)
# They roll the windowed version that is zeros, but here we want all done up front to simplify later code

import matplotlib.pyplot as plt

# TODO: for recursive version, shift cannot be larger than previous shift!
# Shift the histograms according to the rigid shift
shifted_histograms = np.zeros_like(all_session_hists)
for i in range(all_session_hists.shape[0]):

Expand All @@ -431,20 +446,39 @@ def run_alignment_estimation(
cut_padded_hist = padded_hist[abs_shift:] if shift > 0 else padded_hist[:-abs_shift]
shifted_histograms[i, :] = cut_padded_hist

# For each nonrigid window, compute the shift
non_rigid_shifts = np.zeros((num_sessions, non_rigid_windows.shape[0]))
for i, window in enumerate(non_rigid_windows): # TODO: use same name
for i, window in enumerate(non_rigid_windows):

windowed_histogram = shifted_histograms * window

# NOTE: this method just xcorr the entire window,
# does not provide subset of samples like kilosort_like
window_hist_array = _compute_rigid_hist_crosscorr(
num_sessions, num_bins, windowed_histogram, robust=False # this method just xcorr the entire window does not provide subset of samples like kilosort_like
num_sessions, num_bins, windowed_histogram, robust=False
)
non_rigid_shifts[:, i] = -np.mean(window_hist_array, axis=0)

return optimal_shift_indices + non_rigid_shifts, non_rigid_window_centers # TODO: tidy up
akima = False # TODO: decide whether to keep, factor to own function
if akima:
from scipy.interpolate import Akima1DInterpolator
x = win_step_um * np.arange(non_rigid_windows.shape[0])
xs = spatial_bin_centers

new_nonrigid_shifts = np.zeros((non_rigid_shifts.shape[0], num_bins))
for ses_idx in range(non_rigid_shifts.shape[0]):

y = non_rigid_shifts[ses_idx]
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) # requires scipy 14
new_nonrigid_shifts[ses_idx, :] = y_new

shifts = optimal_shift_indices + new_nonrigid_shifts
non_rigid_window_centers = spatial_bin_centers
else:
shifts = optimal_shift_indices + non_rigid_shifts

return shifts, non_rigid_window_centers

# TODO: what about the Akima Spline
# TODO: try out logarithmic scaling as some neurons fire too much...

def _compute_rigid_hist_crosscorr(num_sessions, num_bins, all_session_hists, robust=False):
""""""
Expand Down
Binary file modified debugging/all_recordings.pickle
Binary file not shown.
40 changes: 16 additions & 24 deletions debugging/session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@
- how to measure 'confidence' (peak height? std?) larger peaks may have
higher std, but we care about them more, so I think this is largely pointless.
weight_on_confidence = True
# TODO: better handle single-time point estimation.
if weight_on_confidence and np.any(std_devs): # TODO: there is no reason this can be done just for poisson, can be done for all... maybe POisson has better variance estimate, do properly!
# do exponential
# this is a bad idea, we literally want to weight on height!
stds = np.array(std_devs)
stds = stds[~(stds==0)]
stds = (stds - np.min(stds)) / (np.max(stds) - np.min(stds))
# TODO: or weight by confidence? this is basically the same as weighting by signal due to poisson variation
stds = stds * (2 - np.exp(2 * stds)) # TODO: expose param, does this even make sense? does it scale?
stds[np.where(stds<0)] = 0
trimmed_percentiles = (20, 80) # TODO: this is originally in the context of Poisson estimation
if trimmed_percentiles is not False:
Expand All @@ -66,12 +54,12 @@
# entire session. Otherwise, we will want to add chunking as part of above.

def run_inter_session_displacement_correction(
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid=True
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, log_scale=True, rigid=True, num_nonrigid_bins=7
): # TOOD: rename
"""
"""
motion_estimates_list, all_temporal_bin_centers, spatial_bin_centers, non_rigid_bin_centers, histogram_info = estimate_inter_session_displacement(
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid, log_scale, num_nonrigid_bins
)

# _, non_ridgid_spatial_windows = alignment_utils.get_spatial_windows_alignment(
Expand All @@ -83,7 +71,7 @@ def run_inter_session_displacement_correction(
)

corrected_peak_locations_list, corrected_session_histogram_list = _session_displacement_correct_peaks_and_generate_histogram(
corrected_recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers
corrected_recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers, log_scale
)

extra_outputs_dict = {
Expand All @@ -101,7 +89,7 @@ def run_inter_session_displacement_correction(


def _session_displacement_correct_peaks_and_generate_histogram(
recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers
recordings_list, peaks_list, peak_locations_list, motion_objects_list, spatial_bin_centers, log_scale
):
"""
"""
Expand All @@ -121,13 +109,13 @@ def _session_displacement_correct_peaks_and_generate_histogram(

for i in range(len(corrected_peak_locations_list)): # TODO: unwrap a bit
corrected_session_histogram_list.append(
alignment_utils.get_entire_session_hist(recordings_list[i], peaks_list[i], corrected_peak_locations_list[i], spatial_bin_centers)[0]
alignment_utils.get_entire_session_hist(recordings_list[i], peaks_list[i], corrected_peak_locations_list[i], spatial_bin_centers, log_scale)[0]
)

return corrected_peak_locations_list, corrected_session_histogram_list

def estimate_inter_session_displacement(
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid
recordings_list, peaks_list, peak_locations_list, bin_um, histogram_estimation_method, alignment_method, rigid, log_scale, num_nonrigid_bins
):
"""
"""
Expand All @@ -148,7 +136,7 @@ def estimate_inter_session_displacement(
for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list):

session_hist, temporal_bin_centers, session_chunked_hists, chunked_hist_stdevs = _get_single_session_activity_histogram(
recording, peaks, peak_locations, histogram_estimation_method, spatial_bin_edges
recording, peaks, peak_locations, histogram_estimation_method, spatial_bin_edges, log_scale
)

all_session_hists.append(session_hist)
Expand All @@ -167,7 +155,7 @@ def estimate_inter_session_displacement(
) * bin_um
else:
all_motion_arrays, non_rigid_bin_centers = alignment_utils.run_alignment_estimation( # TODO: rename because some times rigid!
all_session_hists, spatial_bin_centers, rigid
all_session_hists, spatial_bin_centers, rigid, num_nonrigid_bins
) # TODO: here the motion arrays are made negative initially. In motion correction they are done later. Discuss with others and make consistent.
all_motion_arrays *= bin_um

Expand All @@ -180,7 +168,7 @@ def estimate_inter_session_displacement(
return all_motion_arrays, all_temporal_bin_centers, spatial_bin_centers, non_rigid_bin_centers, extra_outputs_dict


def _get_single_session_activity_histogram(recording, peaks, peak_locations, method, spatial_bin_edges):
def _get_single_session_activity_histogram(recording, peaks, peak_locations, method, spatial_bin_edges, log_scale):
"""
"""
accepted_methods = ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"]
Expand All @@ -190,18 +178,22 @@ def _get_single_session_activity_histogram(recording, peaks, peak_locations, met
)
# First, get the histogram across the entire session
entire_session_hist, temporal_bin_centers, _ = alignment_utils.get_entire_session_hist( # TODO: assert spatial bin edges
recording, peaks, peak_locations, spatial_bin_edges
recording, peaks, peak_locations, spatial_bin_edges, log_scale=False
)

if method == "entire_session":

if log_scale:
entire_session_hist = np.log10(1 + entire_session_hist)

return entire_session_hist, temporal_bin_centers, None, None

# If method is not "entire_session", estimate the session
# histogram based on histograms calculated from chunks.
bin_s, percentile_lambda = alignment_utils.estimate_chunk_size(entire_session_hist, recording)
bin_s, percentile_lambda = alignment_utils.estimate_chunk_size(entire_session_hist, recording) # TODO: handle with log properly

chunked_session_hist, chunked_temporal_bin_centers, _ = alignment_utils.get_chunked_histogram( # TODO: do the centering higher levle as duplciating
recording, peaks, peak_locations, bin_s, spatial_bin_edges
recording, peaks, peak_locations, bin_s, spatial_bin_edges, log_scale
)
session_std = np.sum(np.std(chunked_session_hist, axis=0)) / chunked_session_hist.shape[1]

Expand Down
16 changes: 11 additions & 5 deletions debugging/test_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,24 @@

# What we really want to find is maximal subset of the data that matches

# TOOD: here use natural log for scaling, should prob go to base 10

# TODO: major check, refactor and tidy up
# list out carefully all notes
# handle the case where the passed recordings are not motion correction recordings.

SAVE = False
PLOT = False
BIN_UM = 2

if SAVE:
scalings = [np.ones(25), np.r_[np.zeros(10), np.ones(15)]] # TODO: there is something wrong here, because why are the maximum histograms not removed?
recordings_list, _ = generate_session_displacement_recordings(
non_rigid_gradient=0.1, # None,
num_units=25,
recording_durations=(100, 100),
non_rigid_gradient=None, # 0.05, # None,
num_units=15,
recording_durations=(100, 100, 100, 100),
recording_shifts=(
(0, 0), (0, 75),
(0, 0), (0, 75), (0, -125), (0, 25),
),
recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings},
seed=None,
Expand All @@ -76,7 +82,7 @@


corrected_recordings_list, motion_objects_list, extra_info = session_alignment.run_inter_session_displacement_correction(
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=True
recordings_list, peaks_list, peak_locations_list, bin_um=BIN_UM, histogram_estimation_method="entire_session", alignment_method="mean_crosscorr", rigid=False, log_scale=True, num_nonrigid_bins=7
)

plotting.SessionAlignmentWidget(
Expand Down

0 comments on commit c0eb2f5

Please sign in to comment.