diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index fca521c90f..79e6d1c8e3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -388,9 +388,11 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} - for i, unit_id in enumerate(unit_ids): + for unit_index, unit_id in enumerate(sorting.unit_ids): + if unit_id not in unit_ids: + continue - nb_violations[unit_id] = n_v = nb_rp_violations[i] + nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] if N == 0: rp_contamination[unit_id] = np.nan @@ -1083,10 +1085,10 @@ def compute_drift_metrics( spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - for unit_ind in np.arange(len(unit_ids)): - mask = spikes_in_bin["unit_index"] == unit_ind + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id) if np.sum(mask) >= min_spikes_per_interval: - median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask]) + median_positions[unit_index, bin_index] = np.median(spike_locations_in_bin[mask]) if median_position_segments is None: median_position_segments = median_positions else: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 6652ea6654..9175748a4a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -84,24 +84,44 @@ def small_sorting_analyzer(): def test_unit_structure_in_output(small_sorting_analyzer): - for metric_name in get_quality_metric_list(): - result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer) - if isinstance(result, dict): - assert list(result.keys()) == ["#3", "#9", "#4"] - else: - for one_result in result: - assert list(one_result.keys()) == ["#3", "#9", "#4"] + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } for metric_name in get_quality_metric_list(): - result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, unit_ids=["#9", "#3"]) - if isinstance(result, dict): - assert list(result.keys()) == ["#9", "#3"] + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = _misc_metric_name_to_func[metric_name]( + sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param + ) + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"] + assert list(result_sub.keys()) == ["#4", "#9"] + assert result_sub["#9"] == result_all["#9"] + assert result_sub["#4"] == result_all["#4"] + else: - for one_result in result: - print(metric_name) - assert list(one_result.keys()) == ["#9", "#3"] + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"] + assert result_sub[result_ind].keys() == set(["#4", "#9"]) + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"] + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"] def test_unit_id_order_independence(small_sorting_analyzer): @@ -110,12 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer): and checks that their calculated quality metrics are independent of the ordering and labelling. """ - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=4, - seed=1205, - ) - sorting = sorting.select_units([0, 2, 3]) + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") extensions_to_compute = {