Skip to content

Commit

Permalink
Fix firing range when bin size is to small (#2054)
Browse files Browse the repository at this point in the history
* Fix firing range when bin size is to small

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
alejoe91 and pre-commit-ci[bot] authored Sep 29, 2023
1 parent 2f354c4 commit c8be1a0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 8 deletions.
9 changes: 9 additions & 0 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,15 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95),
if unit_ids is None:
unit_ids = sorting.unit_ids

if all(
[
waveform_extractor.get_num_samples(segment_index) < bin_size_samples
for segment_index in range(waveform_extractor.get_num_segments())
]
):
warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.")
return {unit_id: np.nan for unit_id in unit_ids}

# for each segment, we compute the firing rate histogram and we concatenate them
firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids}
for segment_index in range(waveform_extractor.get_num_segments()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def test_calculate_firing_range(waveform_extractor_simple):
firing_ranges = compute_firing_ranges(we)
print(firing_ranges)

with pytest.warns(UserWarning) as w:
firing_ranges_nan = compute_firing_ranges(we, bin_size_s=we.get_total_duration() + 1)
assert np.all([np.isnan(f) for f in firing_ranges_nan.values()])


def test_calculate_amplitude_cutoff(waveform_extractor_simple):
we = waveform_extractor_simple
Expand Down Expand Up @@ -378,7 +382,7 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
if __name__ == "__main__":
sim_data = _simulated_data()
we = _waveform_extractor_simple()
we_violations = _waveform_extractor_violations(sim_data)
# we_violations = _waveform_extractor_violations(sim_data)
# test_calculate_amplitude_cutoff(we)
# test_calculate_presence_ratio(we)
# test_calculate_amplitude_median(we)
Expand All @@ -387,4 +391,4 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# test_calculate_drift_metrics(we)
# test_synchrony_metrics(we)
test_calculate_firing_range(we)
test_calculate_amplitude_cv_metrics(we)
# test_calculate_amplitude_cv_metrics(we)
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,13 @@ def plot_motions_several_benchmarks(benchmarks):
_simpleaxis(ax)


def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None):
def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None):
if ax is None:
fig, ax = plt.subplots(figsize=(5, 5))

for count, benchmark in enumerate(benchmarks):
color = colors[count] if colors is not None else None

if detailed:
bottom = 0
i = 0
Expand All @@ -606,8 +606,6 @@ def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=No
else:
total_run_time = np.sum([value for key, value in benchmark.run_times.items()])
ax.bar([count], [total_run_time], color=color, edgecolor="black")



# ax.legend()
ax.set_ylabel("speed (s)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,15 @@ def run_sorters(self, skip_already_done=True):
recording = self.recordings[case["recording"]]
output_folder = self.folder / f"tmp_sortings_{label}"
if output_folder.exists() and skip_already_done:
print('already done')
print("already done")
sorting = read_sorter_folder(output_folder)
else:
sorting = run_sorter(
sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder
sorter_name,
recording,
output_folder,
**sorter_params,
delete_output_folder=self.delete_output_folder,
)
self.sortings[label] = sorting

Expand Down

0 comments on commit c8be1a0

Please sign in to comment.