Skip to content

Commit

Permalink
Merge branch 'dev_spikeinterface_v101' into nei_nienborg
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Jun 4, 2024
2 parents 214708c + 1a1b18f commit 5f69808
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 214 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.4.0] - 2024-05-28

+ Add - support for SpikeInterface version >= 0.101.0 (updated API)
+ Add - feature for memoization of spike sorting results (prevent duplicated runs)


## [0.3.4] - 2024-03-22

+ Add - pytest
Expand Down
253 changes: 120 additions & 133 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import datajoint as dj
import numpy as np
import pandas as pd

from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory

from . import ephys_report, probe
from .readers import kilosort, openephys, spikeglx

log = dj.logger
logger = dj.logger

schema = dj.schema()

Expand Down Expand Up @@ -822,7 +821,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False):

if mkdir:
output_dir.mkdir(parents=True, exist_ok=True)
log.info(f"{output_dir} created!")
logger.info(f"{output_dir} created!")

return output_dir.relative_to(processed_dir) if relative else output_dir

Expand Down Expand Up @@ -1028,108 +1027,81 @@ def make(self, key):

# Get channel and electrode-site mapping
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
channel2electrode_map = electrode_query.fetch(as_dict=True)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in channel2electrode_map
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
}

# Get sorter method and create output directory.
sorter_name = clustering_method.replace(".", "_")
si_waveform_dir = output_dir / sorter_name / "waveform"
si_sorting_dir = output_dir / sorter_name / "spike_sorting"
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"

if si_waveform_dir.exists(): # Read from spikeinterface outputs
if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
import spikeinterface as si
from spikeinterface import sorters

we: si.WaveformExtractor = si.load_waveforms(
si_waveform_dir, with_recording=False
)
si_sorting: si.sorters.BaseSorter = si.load_extractor(
si_sorting_dir / "si_sorting.pkl", base_folder=output_dir
)
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
si_sorting = sorting_analyzer.sorting

unit_peak_channel: dict[int, int] = si.get_template_extremum_channel(
we, outputs="index"
) # {unit: peak_channel_id}
# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer,
1,
).unit_id_to_channel_indices
)
unit_peak_channel: dict[int, int] = {
u: chn[0] for u, chn in unit_peak_channel.items()
}

spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
# {unit: spike_count}

spikes = si_sorting.to_spike_vector()

# reorder channel2electrode_map according to recording channel ids
# update channel2electrode_map to match with probe's channel index
channel2electrode_map = {
chn_idx: channel2electrode_map[chn_idx]
for chn_idx in we.channel_ids_to_indices(we.channel_ids)
idx: channel2electrode_map[int(chn_idx)]
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
}

# Get unit id to quality label mapping
try:
cluster_quality_label_map = pd.read_csv(
si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv",
delimiter="\t",
cluster_quality_label_map = {
int(unit_id): (
si_sorting.get_unit_property(unit_id, "KSLabel")
if "KSLabel" in si_sorting.get_property_keys()
else "n.a."
)
except FileNotFoundError:
cluster_quality_label_map = {}
else:
cluster_quality_label_map: dict[
int, str
] = cluster_quality_label_map.set_index("cluster_id")[
"KSLabel"
].to_dict() # {unit: quality_label}

# Get electrode where peak unit activity is recorded
peak_electrode_ind = np.array(
[
channel2electrode_map[unit_peak_channel[unit_id]]["electrode"]
for unit_id in si_sorting.unit_ids
]
)

# Get channel depth
channel_depth_ind = np.array(
[
we.get_probe().contact_positions[unit_peak_channel[unit_id]][1]
for unit_id in si_sorting.unit_ids
]
)

# Assign electrode and depth for each spike
new_spikes = np.empty(
spikes.shape,
spikes.dtype.descr + [("electrode", "<i8"), ("depth", "<i8")],
)

for field in spikes.dtype.names:
new_spikes[field] = spikes[field]
del spikes
for unit_id in si_sorting.unit_ids
}

new_spikes["electrode"] = peak_electrode_ind[new_spikes["unit_index"]]
new_spikes["depth"] = channel_depth_ind[new_spikes["unit_index"]]
spike_locations = sorting_analyzer.get_extension("spike_locations")
spikes_df = pd.DataFrame(spike_locations.spikes)

units = []

for unit_id in si_sorting.unit_ids:
for unit_idx, unit_id in enumerate(si_sorting.unit_ids):
unit_id = int(unit_id)
unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx]
spike_sites = np.array(
[
channel2electrode_map[chn_idx]["electrode"]
for chn_idx in unit_spikes_df.channel_index
]
)
unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
_, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
spike_times = si_sorting.get_unit_spike_train(
unit_id, return_times=True
)

assert len(spike_times) == len(spike_sites) == len(spike_depths)

units.append(
{
**key,
**channel2electrode_map[unit_peak_channel[unit_id]],
"unit": unit_id,
"cluster_quality_label": cluster_quality_label_map.get(
unit_id, "n.a."
),
"spike_times": si_sorting.get_unit_spike_train(
unit_id, return_times=True
),
"cluster_quality_label": cluster_quality_label_map[unit_id],
"spike_times": spike_times,
"spike_count": spike_count_dict[unit_id],
"spike_sites": new_spikes["electrode"][
new_spikes["unit_index"] == unit_id
],
"spike_depths": new_spikes["depth"][
new_spikes["unit_index"] == unit_id
],
"spike_sites": spike_sites,
"spike_depths": spike_depths,
}
)
else: # read from kilosort outputs
Expand Down Expand Up @@ -1268,43 +1240,45 @@ def make(self, key):

# Get channel and electrode-site mapping
electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
channel2electrode_map = electrode_query.fetch(as_dict=True)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in channel2electrode_map
chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
}

si_waveform_dir = output_dir / sorter_name / "waveform"
if si_waveform_dir.exists(): # read from spikeinterface outputs
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
import spikeinterface as si
we: si.WaveformExtractor = si.load_waveforms(
si_waveform_dir, with_recording=False
)
unit_id_to_peak_channel_map: dict[
int, np.ndarray
] = si.ChannelSparsity.from_best_channels(
we, 1, peak_sign="neg"
).unit_id_to_channel_indices # {unit: peak_channel_index}

# reorder channel2electrode_map according to recording channel ids
channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist()

sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)

# Find representative channel for each unit
unit_peak_channel: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
sorting_analyzer, 1
).unit_id_to_channel_indices
) # {unit: peak_channel_index}
unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}

# update channel2electrode_map to match with probe's channel index
channel2electrode_map = {
chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices
idx: channel2electrode_map[int(chn_idx)]
for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
}

templates = sorting_analyzer.get_extension("templates")

def yield_unit_waveforms():
for unit in (CuratedClustering.Unit & key).fetch(
"KEY", order_by="unit"
):
# Get mean waveform for this unit from all channels - (sample x channel)
unit_waveforms = we.get_template(
unit_id=unit["unit"], mode="average", force_dense=True
)
peak_chn_idx = channel_indices.index(
unit_id_to_peak_channel_map[unit["unit"]][0]
unit_waveforms = templates.get_unit_template(
unit_id=unit["unit"], operator="average"
)
unit_peak_waveform = {
**unit,
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
"peak_electrode_waveform": unit_waveforms[
:, unit_peak_channel[unit["unit"]]
],
}

unit_electrode_waveforms = [
Expand All @@ -1313,12 +1287,12 @@ def yield_unit_waveforms():
**channel2electrode_map[chn_idx],
"waveform_mean": unit_waveforms[:, chn_idx],
}
for chn_idx in channel_indices
for chn_idx in channel2electrode_map
]

yield unit_peak_waveform, unit_electrode_waveforms

else: # read from kilosort outputs
else: # read from kilosort outputs (ecephys pipeline)
kilosort_dataset = kilosort.Kilosort(output_dir)

acq_software, probe_serial_number = (
Expand Down Expand Up @@ -1524,43 +1498,56 @@ def make(self, key):
output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
sorter_name = clustering_method.replace(".", "_")

# find metric_fp
for metric_fp in [
output_dir / "metrics.csv",
output_dir / sorter_name / "metrics" / "metrics.csv",
]:
if metric_fp.exists():
break
else:
raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
import spikeinterface as si

metrics_df = pd.read_csv(metric_fp)
sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
template_metrics = sorting_analyzer.get_extension(
"template_metrics"
).get_data()
metrics_df = pd.concat([qc_metrics, template_metrics], axis=1)

# Conform the dataframe to match the table definition
if "cluster_id" in metrics_df.columns:
metrics_df.set_index("cluster_id", inplace=True)
else:
metrics_df.rename(
columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
columns={
"amplitude_median": "amplitude",
"isi_violations_ratio": "isi_violation",
"isi_violations_count": "number_violation",
"silhouette": "silhouette_score",
"rp_contamination": "contamination_rate",
"drift_ptp": "max_drift",
"drift_mad": "cumulative_drift",
"half_width": "halfwidth",
"peak_trough_ratio": "pt_ratio",
"peak_to_valley": "duration",
},
inplace=True,
)
metrics_df.set_index("cluster_id", inplace=True)
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
metrics_df.columns = metrics_df.columns.str.lower()

metrics_df.rename(
columns={
"isi_violations_ratio": "isi_violation",
"isi_violations_count": "number_violation",
"silhouette": "silhouette_score",
"rp_contamination": "contamination_rate",
"drift_ptp": "max_drift",
"drift_mad": "cumulative_drift",
"half_width": "halfwidth",
"peak_trough_ratio": "pt_ratio",
},
inplace=True,
)
else: # read from kilosort outputs (ecephys pipeline)
# find metric_fp
for metric_fp in [
output_dir / "metrics.csv",
]:
if metric_fp.exists():
break
else:
raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")

metrics_df = pd.read_csv(metric_fp)

# Conform the dataframe to match the table definition
if "cluster_id" in metrics_df.columns:
metrics_df.set_index("cluster_id", inplace=True)
else:
metrics_df.rename(
columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
)
metrics_df.set_index("cluster_id", inplace=True)

metrics_df.columns = metrics_df.columns.str.lower()

metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
metrics_list = [
dict(metrics_df.loc[unit_key["unit"]], **unit_key)
for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
Expand Down
Loading

0 comments on commit 5f69808

Please sign in to comment.