Skip to content

Commit

Permalink
Merge pull request #3138 from chrishalcrow/fix-nn-calculations
Browse files Browse the repository at this point in the history
Fix nn pca_metric computation and update tests
  • Loading branch information
alejoe91 authored Jul 15, 2024
2 parents eba9f68 + edb8003 commit ad3e924
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 142 deletions.
99 changes: 59 additions & 40 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ..core import get_random_data_chunks, compute_sparsity
from ..core.template_tools import get_template_extremum_channel


_possible_pc_metric_names = [
"isolation_distance",
"l_ratio",
Expand Down Expand Up @@ -90,7 +89,7 @@ def compute_pc_metrics(
sorting = sorting_analyzer.sorting

if metric_names is None:
metric_names = _possible_pc_metric_names
metric_names = _possible_pc_metric_names.copy()
if qm_params is None:
qm_params = _default_params

Expand All @@ -110,8 +109,13 @@ def compute_pc_metrics(
if "nn_isolation" in metric_names:
pc_metrics["nn_unit_id"] = {}

possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"]

nn_metrics = list(set(metric_names).intersection(possible_nn_metrics))
non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics))

# Compute nspikes and firing rate outside of main loop for speed
if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]):
if nn_metrics:
n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids)
fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids)
else:
Expand All @@ -120,9 +124,6 @@ def compute_pc_metrics(

run_in_parallel = n_jobs > 1

if run_in_parallel:
parallel_functions = []

# this get dense projection for selected unit_ids
dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids)
all_labels = sorting.unit_ids[spike_unit_indices]
Expand All @@ -146,7 +147,7 @@ def compute_pc_metrics(
func_args = (
pcs_flat,
labels,
metric_names,
non_nn_metrics,
unit_id,
unit_ids,
qm_params,
Expand All @@ -156,16 +157,16 @@ def compute_pc_metrics(
)
items.append(func_args)

if not run_in_parallel:
if not run_in_parallel and non_nn_metrics:
units_loop = enumerate(unit_ids)
if progress_bar:
units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids))
units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids))

for unit_ind, unit_id in units_loop:
pca_metrics_unit = pca_metrics_one_unit(items[unit_ind])
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
else:
elif run_in_parallel and non_nn_metrics:
with ProcessPoolExecutor(n_jobs) as executor:
results = executor.map(pca_metrics_one_unit, items)
if progress_bar:
Expand All @@ -176,6 +177,37 @@ def compute_pc_metrics(
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric

for metric_name in nn_metrics:
units_loop = enumerate(unit_ids)
if progress_bar:
units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids))

func = _nn_metric_name_to_func[metric_name]
metric_params = qm_params[metric_name] if metric_name in qm_params else {}

for _, unit_id in units_loop:
try:
res = func(
sorting_analyzer,
unit_id,
seed=seed,
n_spikes_all_units=n_spikes_all_units,
fr_all_units=fr_all_units,
**metric_params,
)
except:
if metric_name == "nn_isolation":
res = (np.nan, np.nan)
elif metric_name == "nn_noise_overlap":
res = np.nan

if metric_name == "nn_isolation":
nn_isolation, nn_unit_id = res
pc_metrics["nn_isolation"][unit_id] = nn_isolation
pc_metrics["nn_unit_id"][unit_id] = nn_unit_id
elif metric_name == "nn_noise_overlap":
pc_metrics["nn_noise_overlap"][unit_id] = res

return pc_metrics


Expand Down Expand Up @@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap(
templates_ext = sorting_analyzer.get_extension("templates")
assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'"

try:
sorting_analyzer.get_extension("templates").get_data(operator="median")
except KeyError:
warnings.warn(
"nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator."
"You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes."
)

if n_spikes_all_units is None:
n_spikes_all_units = compute_num_spikes(sorting_analyzer)
if fr_all_units is None:
Expand Down Expand Up @@ -955,11 +995,13 @@ def pca_metrics_one_unit(args):
pc_metrics = {}
# metrics
if "isolation_distance" in metric_names or "l_ratio" in metric_names:

try:
isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id)
except:
isolation_distance = np.nan
l_ratio = np.nan

if "isolation_distance" in metric_names:
pc_metrics["isolation_distance"] = isolation_distance
if "l_ratio" in metric_names:
Expand All @@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args):
d_prime = lda_metrics(pcs_flat, labels, unit_id)
except:
d_prime = np.nan

pc_metrics["d_prime"] = d_prime

if "nearest_neighbor" in metric_names:
Expand All @@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args):
pc_metrics["nn_hit_rate"] = nn_hit_rate
pc_metrics["nn_miss_rate"] = nn_miss_rate

if "nn_isolation" in metric_names:
try:
nn_isolation, nn_unit_id = nearest_neighbors_isolation(
we,
unit_id,
seed=seed,
n_spikes_all_units=n_spikes_all_units,
fr_all_units=fr_all_units,
**qm_params["nn_isolation"],
)
except:
nn_isolation = np.nan
nn_unit_id = np.nan
pc_metrics["nn_isolation"] = nn_isolation
pc_metrics["nn_unit_id"] = nn_unit_id

if "nn_noise_overlap" in metric_names:
try:
nn_noise_overlap = nearest_neighbors_noise_overlap(
we,
unit_id,
n_spikes_all_units=n_spikes_all_units,
fr_all_units=fr_all_units,
seed=seed,
**qm_params["nn_noise_overlap"],
)
except:
nn_noise_overlap = np.nan
pc_metrics["nn_noise_overlap"] = nn_noise_overlap

if "silhouette" in metric_names:
silhouette_method = qm_params["silhouette"]["method"]
if "simplified" in silhouette_method:
Expand All @@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args):
pc_metrics["silhouette_full"] = unit_silhouette_score

return pc_metrics


_nn_metric_name_to_func = {
"nn_isolation": nearest_neighbors_isolation,
"nn_noise_overlap": nearest_neighbors_noise_overlap,
}
37 changes: 37 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
)


def _small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
seed=1205,
)

sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {"operators": ["average", "median"]},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
39 changes: 4 additions & 35 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,6 @@
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def _small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=4,
seed=1205,
)

sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()


def test_unit_structure_in_output(small_sorting_analyzer):

qm_params = {
Expand Down Expand Up @@ -126,7 +95,7 @@ def test_unit_id_order_independence(small_sorting_analyzer):
"""

recording = small_sorting_analyzer.recording
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3])
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2])

small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

Expand Down Expand Up @@ -161,9 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer):
)

for metric, metric_1_data in quality_metrics_1.items():
assert quality_metrics_2[metric][3] == metric_1_data["#3"]
assert quality_metrics_2[metric][2] == metric_1_data["#9"]
assert quality_metrics_2[metric][0] == metric_1_data["#4"]
assert quality_metrics_2[metric][2] == metric_1_data["#3"]
assert quality_metrics_2[metric][7] == metric_1_data["#9"]
assert quality_metrics_2[metric][1] == metric_1_data["#4"]


def _sorting_analyzer_simple():
Expand Down
76 changes: 9 additions & 67 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,24 @@
import pytest
from pathlib import Path
import numpy as np
from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
)


from spikeinterface.qualitymetrics import (
compute_pc_metrics,
nearest_neighbors_isolation,
nearest_neighbors_noise_overlap,
)


job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def _sorting_analyzer_simple():
recording, sorting = generate_ground_truth_recording(
durations=[
50.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates", operators=["average", "std", "median"])
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
return _sorting_analyzer_simple()


def test_calculate_pc_metrics(sorting_analyzer_simple):
def test_calculate_pc_metrics(small_sorting_analyzer):
import pandas as pd

sorting_analyzer = sorting_analyzer_simple
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
sorting_analyzer = small_sorting_analyzer
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205)
res1 = pd.DataFrame(res1)

res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205)
res2 = pd.DataFrame(res2)

for k in res1.columns:
mask = ~np.isnan(res1[k].values)
if np.any(mask):
assert np.array_equal(res1[k].values[mask], res2[k].values[mask])


def test_nearest_neighbors_isolation(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
this_unit_id = sorting_analyzer.unit_ids[0]
nearest_neighbors_isolation(sorting_analyzer, this_unit_id)


def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
this_unit_id = sorting_analyzer.unit_ids[0]
nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id)

for metric_name in res1.columns:
if metric_name != "nn_unit_id":
assert not np.all(np.isnan(res1[metric_name].values))
assert not np.all(np.isnan(res2[metric_name].values))

if __name__ == "__main__":
sorting_analyzer = _sorting_analyzer_simple()
test_calculate_pc_metrics(sorting_analyzer)
test_nearest_neighbors_isolation(sorting_analyzer)
test_nearest_neighbors_noise_overlap(sorting_analyzer)
assert np.array_equal(res1[metric_name].values, res2[metric_name].values)

0 comments on commit ad3e924

Please sign in to comment.