Skip to content

Commit

Permalink
Update test and fix the failings: rp_violation and drift
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Jul 2, 2024
1 parent 11ae9aa commit 6529bd3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
12 changes: 7 additions & 5 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,11 @@ def compute_refrac_period_violations(
nb_violations = {}
rp_contamination = {}

for i, unit_id in enumerate(unit_ids):
for unit_index, unit_id in enumerate(sorting.unit_ids):
if unit_id not in unit_ids:
continue

nb_violations[unit_id] = n_v = nb_rp_violations[i]
nb_violations[unit_id] = n_v = nb_rp_violations[unit_index]
N = num_spikes[unit_id]
if N == 0:
rp_contamination[unit_id] = np.nan
Expand Down Expand Up @@ -1083,10 +1085,10 @@ def compute_drift_metrics(
spikes_in_bin = spikes_in_segment[i0:i1]
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]

for unit_ind in np.arange(len(unit_ids)):
mask = spikes_in_bin["unit_index"] == unit_ind
for unit_index, unit_id in enumerate(unit_ids):
mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id)
if np.sum(mask) >= min_spikes_per_interval:
median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask])
median_positions[unit_index, bin_index] = np.median(spike_locations_in_bin[mask])
if median_position_segments is None:
median_position_segments = median_positions
else:
Expand Down
55 changes: 36 additions & 19 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,44 @@ def small_sorting_analyzer():


def test_unit_structure_in_output(small_sorting_analyzer):
for metric_name in get_quality_metric_list():
result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer)

if isinstance(result, dict):
assert list(result.keys()) == ["#3", "#9", "#4"]
else:
for one_result in result:
assert list(one_result.keys()) == ["#3", "#9", "#4"]
qm_params = {
"presence_ratio": {"bin_duration_s": 0.1},
"amplitude_cutoff": {"num_histogram_bins": 3},
"amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3},
"firing_range": {"bin_size_s": 1},
"isi_violation": {"isi_threshold_ms": 10},
"drift": {"interval_s": 1, "min_spikes_per_interval": 5},
"sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15},
"rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0},
}

for metric_name in get_quality_metric_list():
result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, unit_ids=["#9", "#3"])

if isinstance(result, dict):
assert list(result.keys()) == ["#9", "#3"]
try:
qm_param = qm_params[metric_name]
except:
qm_param = {}

result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param)
result_sub = _misc_metric_name_to_func[metric_name](
sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param
)

if isinstance(result_all, dict):
assert list(result_all.keys()) == ["#3", "#9", "#4"]
assert list(result_sub.keys()) == ["#4", "#9"]
assert result_sub["#9"] == result_all["#9"]
assert result_sub["#4"] == result_all["#4"]

else:
for one_result in result:
print(metric_name)
assert list(one_result.keys()) == ["#9", "#3"]
for result_ind, result in enumerate(result_sub):

assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"]
assert result_sub[result_ind].keys() == set(["#4", "#9"])

assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"]
assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"]


def test_unit_id_order_independence(small_sorting_analyzer):
Expand All @@ -110,12 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer):
and checks that their calculated quality metrics are independent of the ordering and labelling.
"""

recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=4,
seed=1205,
)
sorting = sorting.select_units([0, 2, 3])
recording = small_sorting_analyzer.recording
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3])

small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
Expand Down

0 comments on commit 6529bd3

Please sign in to comment.