Skip to content

Commit

Permalink
Testing on the session_alignment side.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 19, 2024
1 parent ede589b commit 45502ce
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 144 deletions.
2 changes: 1 addition & 1 deletion debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
recordings_list,
peaks_list,
peak_locations_list,
alignment_order="to_session_1", # "to_session_X" or "to_middle"
alignment_order="to_session_2", # "to_session_X" or "to_middle"
non_rigid_window_kwargs=non_rigid_window_kwargs,
estimate_histogram_kwargs=estimate_histogram_kwargs,
)
Expand Down
139 changes: 91 additions & 48 deletions debugging/playing2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,66 +33,84 @@ def cross_correlate(sig1, sig2, thr= None):

return shift

def cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=100, plot=True):
def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True):
"""
"""
best_correlation = 0
best_displacements = np.zeros_like(signa11_blanked)

# TODO: use kriging interp

xcorr = []
for s in np.arange(-thr, thr): # TODO: we are off by one here

shift_signal1_blanked = shift_array_fill_zeros(signa11_blanked, s)
for scale in np.linspace(0.85, 1.15, 10):

nonzero = np.where(signa11_blanked > 0)[0]
if not np.any(nonzero):
continue

midpoint = nonzero[0] + np.ptp(nonzero) / 2
x_scale = (x - midpoint) * scale + midpoint

interp_f = scipy.interpolate.interp1d(x_scale, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging

scaled_func = interp_f(x)

# plt.plot(signa11_blanked)
# plt.plot(scaled_func)
# plt.show()

x = np.arange(shift_signal1_blanked.size)
# breakpoint()

xcorr_scale = []
for scale in np.linspace(0.75, 1.25, 10):
for sh in np.arange(-thr, thr): # TODO: we are off by one here

midpoint = np.argmax(shift_signal1_blanked) # assumes x is 0 .. n TODO: IMPROVE
xs = (x - midpoint) * scale + midpoint
shift_signal1_blanked = shift_array_fill_zeros(scaled_func, sh)

x_shift = x_scale - sh # TODO: rename

# is this pull back?
interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging
# interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging

scaled_func = interp_f(x)
# scaled_func = interp_f(x_shift)

corr_value = np.correlate(
scaled_func - np.mean(scaled_func),
shift_signal1_blanked - np.mean(shift_signal1_blanked),
signal2_blanked - np.mean(signal2_blanked),
) / signa11_blanked.size

xcorr_scale.append(
corr_value
)
if corr_value > best_correlation:
best_displacements = x_shift
best_correlation = corr_value

if plot and corr_value > 0.0045: # and np.abs(s) < 10:
if False and np.abs(sh) == 1:
print(corr_value)

plt.plot(shift_signal1_blanked)
plt.plot(signal2_blanked)
plt.show()
# plt.draw() # Draw the updated figure
# plt.pause(0.1) # Pause for 0.5 seconds before updating
# plt.clf()

plt.plot(scaled_func)
plt.plot(signal2_blanked)
plt.show()
# plt.title(f"corr value: {corr_value}")
# plt.draw() # Draw the updated figure
# plt.pause(0.1) # Pause for 0.5 seconds before updating
# plt.clf()
# breakpoint()

xcorr.append(np.max(np.r_[xcorr_scale]))

xcorr = np.r_[xcorr]
# shift = np.argmax(xcorr) - thr
# xcorr.append(np.max(np.r_[xcorr_scale]))

print("MAX", np.max(xcorr))
if False:
xcorr = np.r_[xcorr]
# shift = np.argmax(xcorr) - thr

if np.max(xcorr) < 0.0001:
shift = 0
else:
shift = np.argmax(xcorr) - thr
print("MAX", np.max(xcorr))

if np.max(xcorr) < 0.0001:
shift = 0
else:
shift = np.argmax(xcorr) - thr

print("output shift", shift)
print("output shift", shift)

return shift
return best_displacements

# plt.plot(signal1)
# plt.plot(signal2)
Expand All @@ -104,6 +122,8 @@ def get_shifts(signal1, signal2, windows, plot=True):
signa11_blanked = signal1.copy()
signal2_blanked = signal2.copy()

best_displacements = np.zeros_like(signal1)

if (first_idx := windows[0][0]) != 0:
print("first idx", first_idx)
signa11_blanked[:first_idx] = 0
Expand All @@ -115,29 +135,39 @@ def get_shifts(signal1, signal2, windows, plot=True):
signal2_blanked[last_idx:] = 0

segment_shifts = np.empty(len(windows))
cum_shifts = []


x = np.arange(signa11_blanked.size)
x_orig = x.copy()

for round in range(len(windows)):

if round == 0:
shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger!
else:
shift = cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=100, plot=False)
#if round == 0:
# shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger!
#else:
displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False)



# breakpoint()

cum_shifts.append(shift)
print("shift", shift)
interpf = scipy.interpolate.interp1d(displacements, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme
signa11_blanked = interpf(x)



# cum_shifts.append(shift)
# print("shift", shift)

# shift the signal1, or use indexing

signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift)
# signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors...

if plot:
print("round", round)
plt.plot(signa11_blanked)
plt.plot(signal2_blanked)
plt.show()
# if plot:
# print("round", round)
# plt.plot(signa11_blanked)
# plt.plot(signal2_blanked)
# plt.show()

window_corrs = np.empty(len(windows))
for i, idx in enumerate(windows):
Expand All @@ -148,15 +178,28 @@ def get_shifts(signal1, signal2, windows, plot=True):

max_window = np.argmax(window_corrs) # TODO: cutoff!

small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2)
if False:
small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2)
signa11_blanked = shift_array_fill_zeros(signa11_blanked, small_shift)
segment_shifts[max_window] = np.sum(cum_shifts) + small_shift

signa11_blanked = shift_array_fill_zeros(signa11_blanked, small_shift)
best_displacements[windows[max_window]] = displacements[windows[max_window]]

segment_shifts[max_window] = np.sum(cum_shifts) + small_shift
x = displacements

signa11_blanked[windows[max_window]] = 0
signal2_blanked[windows[max_window]] = 0

# TODO: need to carry over displacements!

print(best_displacements)
interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme
final = interpf(x_orig)

plt.plot(final)
plt.plot(signal2)
plt.show()

return segment_shifts


Expand Down
Loading

0 comments on commit 45502ce

Please sign in to comment.