-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't let decimate mess with times and skim tests #3519
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,18 +63,15 @@ def __init__( | |
f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." | ||
) | ||
self._decimation_offset = decimation_offset | ||
resample_rate = self._orig_samp_freq / self._decimation_factor | ||
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor | ||
|
||
BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) | ||
BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) | ||
|
||
# in case there was a time_vector, it will be dropped for sanity. | ||
# This is not necessary but consistent with ResampleRecording | ||
for parent_segment in recording._recording_segments: | ||
parent_segment.time_vector = None | ||
self.add_recording_segment( | ||
DecimateRecordingSegment( | ||
parent_segment, | ||
resample_rate, | ||
decimated_sampling_frequency, | ||
self._orig_samp_freq, | ||
decimation_factor, | ||
decimation_offset, | ||
|
@@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment): | |
def __init__( | ||
self, | ||
parent_recording_segment, | ||
resample_rate, | ||
decimated_sampling_frequency, | ||
parent_rate, | ||
decimation_factor, | ||
decimation_offset, | ||
dtype, | ||
): | ||
if parent_recording_segment.t_start is None: | ||
new_t_start = None | ||
if parent_recording_segment.time_vector is not None: | ||
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] | ||
decimated_sampling_frequency = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this None for? I think Joe has thought about this more so It's tricky for me to think about time vector vs t_start and when we want a frequency of None vs a value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay cool. Maybe we can have a chat about this at some point. I don't want to take up developer meeting time for this necessarily, but since I don't use the time api I don't know it well enough. :) |
||
t_start = None | ||
else: | ||
new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate | ||
time_vector = None | ||
if parent_recording_segment.t_start is None: | ||
t_start = None | ||
else: | ||
t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) | ||
|
||
# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! | ||
BaseRecordingSegment.__init__( | ||
self, | ||
sampling_frequency=resample_rate, | ||
t_start=new_t_start, | ||
self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector | ||
) | ||
self._parent_segment = parent_recording_segment | ||
self._decimation_factor = decimation_factor | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,39 +8,74 @@ | |
import numpy as np | ||
|
||
|
||
@pytest.mark.parametrize("N_segments", [1, 2]) | ||
@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) | ||
@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) | ||
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) | ||
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) | ||
def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): | ||
rec = generate_recording() | ||
|
||
segment_num_samps = [101 + i for i in range(N_segments)] | ||
|
||
@pytest.mark.parametrize("num_segments", [1, 2]) | ||
@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) | ||
@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) | ||
def test_decimate(num_segments, decimation_offset, decimation_factor): | ||
segment_num_samps = [20000, 40000] | ||
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) | ||
|
||
parent_traces = [rec.get_traces(i) for i in range(N_segments)] | ||
parent_traces = [rec.get_traces(i) for i in range(num_segments)] | ||
|
||
if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: | ||
with pytest.raises(ValueError): | ||
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) | ||
return | ||
|
||
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) | ||
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] | ||
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] | ||
|
||
if start_frame is None: | ||
start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) | ||
if end_frame is None: | ||
end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) | ||
for start_frame in [0, 1, 5, None, 1000]: | ||
for end_frame in [0, 1, 5, None, 1000]: | ||
if start_frame is None: | ||
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) | ||
if end_frame is None: | ||
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) | ||
|
||
for i in range(N_segments): | ||
for i in range(num_segments): | ||
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] | ||
assert np.all( | ||
decimated_rec.get_traces(i, start_frame, end_frame) | ||
== decimated_parent_traces[i][start_frame:end_frame] | ||
) | ||
|
||
for i in range(num_segments): | ||
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] | ||
assert np.all( | ||
decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] | ||
) | ||
|
||
|
||
def test_decimate_with_times(): | ||
rec = generate_recording(durations=[5, 10]) | ||
|
||
# test with times | ||
times = [rec.get_times(0) + 10, rec.get_times(1) + 20] | ||
for i, t in enumerate(times): | ||
rec.set_times(t, i) | ||
|
||
decimation_factor = 2 | ||
decimation_offset = 1 | ||
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) | ||
|
||
for segment_index in range(rec.get_num_segments()): | ||
assert np.allclose( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we explicitly set the tolerance that we tolerate? We often get flakiness due to floating point rounding. does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. I think here all equal will work too. I'll push an update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zm711 use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You think 10 decimal places? It looks like assert_array_almost_equal also deals with nan's which is nice. I vaguely remembering Heberto having to slowly relax one of these style of tests because it keep failing. 10 decimals seems super exact. Why not the default? (which is 7 based on most recent docs). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my tests the absolute error was e-15 :) |
||
decimated_rec.get_times(segment_index), | ||
rec.get_times(segment_index)[decimation_offset::decimation_factor], | ||
) | ||
|
||
# test with t_start | ||
rec = generate_recording(durations=[5, 10]) | ||
t_starts = [10, 20] | ||
for t_start, rec_segment in zip(t_starts, rec._recording_segments): | ||
rec_segment.t_start = t_start | ||
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) | ||
for segment_index in range(rec.get_num_segments()): | ||
assert np.allclose( | ||
decimated_rec.get_times(segment_index), | ||
rec.get_times(segment_index)[decimation_offset::decimation_factor], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_decimate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be a breaking change no? Should we deprecate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but it's soo deep in the API that I'm 100% sure it wouldn't affect anyone's workflow.
If it were at the
DecimateRecording
level, than we should have worried about back-compatibility because of saved objects/JSON files. But since it's the segment which is instantiated on the fly we don't have to worry about it (and I think the naming is much more in line with the overll API)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely agree with the naming. And makes sense if it's deep. I wasn't sure if this was more on the private or public side, but makes sense that the Segment level is basically private.