Skip to content

Commit

Permalink
tdc update and clean
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed May 24, 2024
2 parents 210d913 + 0df2536 commit 3ab27d4
Show file tree
Hide file tree
Showing 38 changed files with 334 additions and 163 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863
"matplotlib>=3.6", # matplotlib.colormaps
"cuda-python; platform_system != 'Darwin'",
"numba",
]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
self.params["operators"] += [(operator, percentile)]
templates_array = self.data[key]

if save:
if not self.sorting_analyzer.is_read_only():
self.save()
if save:
if not self.sorting_analyzer.is_read_only():
self.save()

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"chunk_duration",
"progress_bar",
"mp_context",
"verbose",
"max_threads_per_process",
)

Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def write_binary_recording(
add_file_extension: bool = True,
byte_offset: int = 0,
auto_cast_uint: bool = True,
verbose: bool = True,
**job_kwargs,
):
"""
Expand All @@ -98,6 +99,8 @@ def write_binary_recording(
auto_cast_uint: bool, default: True
If True, unsigned integers are automatically cast to int if the specified dtype is signed
.. deprecated:: 0.103, use the `unsigned_to_signed` function instead.
verbose: bool
If True, output is verbose
{}
"""
job_kwargs = fix_job_kwargs(job_kwargs)
Expand Down Expand Up @@ -138,7 +141,7 @@ def write_binary_recording(
init_func = _init_binary_worker
init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned)
executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs
recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs
)
executor.run()

Expand Down
14 changes: 6 additions & 8 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,10 @@ def load_from_zarr(cls, folder, recording=None):
rec_attributes["probegroup"] = None

# sparsity
if "sparsity_mask" in zarr_root.attrs:
# sparsity = zarr_root.attrs["sparsity"]
sparsity = ChannelSparsity(zarr_root["sparsity_mask"], cls.unit_ids, rec_attributes["channel_ids"])
if "sparsity_mask" in zarr_root:
sparsity = ChannelSparsity(
np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"]
)
else:
sparsity = None

Expand Down Expand Up @@ -1596,10 +1597,6 @@ def load_data(self):
self.data[ext_data_name] = ext_data

elif self.format == "zarr":
# Alessio
# TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap
# but this make the garbage complicated when a data is hold by a plot but the o SortingAnalyzer is delete
# lets talk
extension_group = self._get_zarr_extension_group(mode="r")
for ext_data_name in extension_group.keys():
ext_data_ = extension_group[ext_data_name]
Expand All @@ -1615,7 +1612,8 @@ def load_data(self):
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
ext_data = ext_data_
# this load in memmory
ext_data = np.array(ext_data_)
self.data[ext_data_name] = ext_data

def copy(self, new_sorting_analyzer, unit_ids=None):
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def __repr__(self):
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

def __eq__(self, other):
return (
isinstance(other, ChannelSparsity)
and np.array_equal(self.channel_ids, other.channel_ids)
and np.array_equal(self.unit_ids, other.unit_ids)
and np.array_equal(self.mask, other.mask)
)

@property
def unit_id_to_channel_ids(self):
if self._unit_id_to_channel_ids is None:
Expand Down
18 changes: 10 additions & 8 deletions src/spikeinterface/core/tests/test_recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_write_binary_recording(tmp_path):
file_paths = [tmp_path / "binary01.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=1)
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=1)
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand All @@ -64,9 +64,11 @@ def test_write_binary_recording_offset(tmp_path):
file_paths = [tmp_path / "binary01.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=1)
job_kwargs = dict(n_jobs=1)
byte_offset = 125
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, **job_kwargs)
write_binary_recording(
recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs
)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down Expand Up @@ -97,8 +99,8 @@ def test_write_binary_recording_parallel(tmp_path):
file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down Expand Up @@ -127,8 +129,8 @@ def test_write_binary_recording_multiple_segment(tmp_path):
file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,15 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

data = sorting_analyzer2.get_extension("dummy").data
assert "result_one" in data
assert isinstance(data["result_one"], str)
assert isinstance(data["result_two"], np.ndarray)
assert data["result_two"].size == original_sorting.to_spike_vector().size
assert np.array_equal(data["result_two"], sorting_analyzer.get_extension("dummy").data["result_two"])

assert sorting_analyzer2.return_scaled == sorting_analyzer.return_scaled

assert sorting_analyzer2.sparsity == sorting_analyzer.sparsity

# select unit_ids to several format
for format in ("memory", "binary_folder", "zarr"):
if format != "memory":
Expand Down
19 changes: 18 additions & 1 deletion src/spikeinterface/sorters/internal/simplesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ class SimpleSorter(ComponentsBasedSorter):
"job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"},
}

_params_description = {
"apply_preprocessing": "whether to apply the preprocessing steps, default: False",
"waveforms": "A dictonary containing waveforms params: 'ms_before' (peak of spike) default: 1.0, 'ms_after' (peak of spike) deafult: 1.5",
"filtering": "A dictionary containing bandpass filter conditions, 'freq_min' default: 300 and 'freq_max' default:8000.0",
"detection": (
"A dictionary for specifying the detection conditions of 'peak_sign' (pos or neg) default: 'neg', "
"'detect_threshold' (snr) default: 5.0, 'exclude_sweep_ms' default: 1.5, 'radius_um' default: 150.0"
),
"features": "A dictionary for the PCA specifying the 'n_components, default: 3",
"clustering": (
"A dictionary for specifying the clustering parameters: 'method' (to cluster) default: 'hdbscan', "
"'min_cluster_size' (min number of spikes per cluster) default: 25, 'allow_single_cluster' default: True, "
" 'core_dist_n_jobs' (parallelization) default: -1, cluster_selection_method (for hdbscan) default: leaf"
),
"job_kwargs": "Spikeinterface job_kwargs (see job_kwargs documentation) default 'n_jobs': -1, 'chunk_duration': '1s'",
}

@classmethod
def get_sorter_version(cls):
return "1.0"
Expand All @@ -54,7 +71,7 @@ def get_sorter_version(cls):
def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs.update({"verbose": verbose, "progress_bar": verbose})
job_kwargs.update({"progress_bar": verbose})

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs.update({"verbose": verbose, "progress_bar": verbose})
job_kwargs.update({"progress_bar": verbose})

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

Expand Down
31 changes: 7 additions & 24 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from spikeinterface.core import (
get_noise_levels,
NumpySorting,
get_channel_distances,
estimate_templates_with_accumulator,
Templates,
compute_sparsity,
Expand All @@ -18,14 +17,11 @@
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing
from spikeinterface.sortingcomponents.tools import cache_preprocessing

# from spikeinterface.qualitymetrics import compute_snrs

import numpy as np

import pickle
import json


class Tridesclous2Sorter(ComponentsBasedSorter):
Expand Down Expand Up @@ -87,31 +83,20 @@ def get_sorter_version(cls):

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs = params["job_kwargs"].copy()
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose

from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection

from spikeinterface.sortingcomponents.clustering.split import split_clusters
from spikeinterface.sortingcomponents.clustering.merge import merge_clusters
from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse
from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks
from spikeinterface.sortingcomponents.tools import remove_empty_templates

from spikeinterface.preprocessing import correct_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording

job_kwargs = params["job_kwargs"].copy()
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose


recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

num_chans = recording_raw.get_num_channels()
Expand All @@ -135,8 +120,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording = common_reference(recording)

if params["apply_motion_correction"]:
# interpolate_motion_kwargs = motion_info["parameters"]["interpolate_motion_kwargs"]

interpolate_motion_kwargs = dict(
direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5))

# n = self.motion.shape[1]
# step = int(np.ceil(max(1, n / show_only)))
# colors = plt.cm.get_cmap("jet", n)
# colors = plt.colormaps["jet"].resampled(n)
# for i in range(0, n, step):
# ax = axs[0]
# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def create_benchmark(self, key):

# import matplotlib

# my_cmap = plt.get_cmap(cmap)
# my_cmap = plt.colormaps[cmap]
# cNorm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1])
# scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CircusClustering:
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
"verbose": True,
}

@classmethod
Expand All @@ -72,7 +73,7 @@ def main_function(cls, recording, peaks, params):
job_kwargs = fix_job_kwargs(params["job_kwargs"])

d = params
verbose = job_kwargs.get("verbose", True)
verbose = d["verbose"]

fs = recording.get_sampling_frequency()
ms_before = params["ms_before"]
Expand Down Expand Up @@ -250,7 +251,6 @@ def main_function(cls, recording, peaks, params):
cleaning_matching_params.pop(value)
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["verbose"] = False
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _split_waveforms(
local_feature_plot = local_feature

unique_lab = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
active_ind = np.arange(local_feature.shape[0])
Expand Down Expand Up @@ -145,7 +145,7 @@ def _split_waveforms_nested(
local_feature_plot = reducer.fit_transform(local_feature)

unique_lab = np.unique(active_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down Expand Up @@ -276,7 +276,7 @@ def auto_split_clustering(

fig, ax = plt.subplots()
plot_labels_set = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", plot_labels_set.size)
cmap = plt.colormaps["jet"].resampled(plot_labels_set.size)
cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PositionAndFeaturesClustering:
"ms_before": 1.5,
"ms_after": 1.5,
"cleaning_method": "dip",
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PositionAndPCAClustering:
"ms_after": 2.5,
"n_components_by_channel": 3,
"n_components": 5,
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
"hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1},
"hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1},
"waveform_mode": "shared_memory",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PositionPTPScaledClustering:
"ptps": None,
"scales": (1, 1, 10),
"peak_localization_kwargs": {"method": "center_of_mass"},
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True},
"hdbscan_kwargs": {
"min_cluster_size": 20,
"min_samples": 20,
Expand Down
Loading

0 comments on commit 3ab27d4

Please sign in to comment.