Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Svd convolutions #2045

Merged
merged 32 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
68fe2ba
OMP with SVD decomposition
yger Sep 26, 2023
cc47204
Increase default rank
yger Sep 26, 2023
10c33c1
To be tried
yger Sep 26, 2023
b2a9b70
WIP
yger Sep 26, 2023
3c94594
Working with circus2
yger Sep 26, 2023
46149ef
Put OMP with SVD as default
yger Sep 26, 2023
f21d80b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
a275bca
Patch
yger Sep 26, 2023
b5aa0da
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 26, 2023
85eb432
Cleaning useless functions
yger Sep 26, 2023
15ae432
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
41155a1
Changing the internal representation of overlaps
yger Sep 27, 2023
b5c3538
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 27, 2023
97aff7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
8da6b79
Keeping the two matching engines for more tests before merging and fi…
yger Sep 27, 2023
e4189a9
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 27, 2023
a6b4774
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
257c74c
Slight misalignement
yger Sep 27, 2023
ff282ad
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 27, 2023
0a2c0f6
Default SVD Peeler is now good to go
yger Sep 27, 2023
5fbc88d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
67f1306
Merge branch 'SpikeInterface:main' into svd_convolutions
yger Sep 27, 2023
9f45f2e
Enhance the clustering
yger Sep 27, 2023
73b065a
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 27, 2023
3cbf8f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
a0fabe1
Merge branch 'main' into svd_convolutions
yger Sep 28, 2023
daddd8c
Adding a lookup table
yger Sep 28, 2023
d7dcbe0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
d623da3
typos for cleaning via matching
yger Sep 28, 2023
b2bcb00
Merge branch 'svd_convolutions' of github.com:yger/spikeinterface int…
yger Sep 28, 2023
fdb8466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
c55a3cc
Merge branch 'SpikeInterface:main' into svd_convolutions
yger Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
Expand Down Expand Up @@ -151,7 +151,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params
recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params
)

if verbose:
Expand Down
42 changes: 25 additions & 17 deletions src/spikeinterface/sortingcomponents/clustering/clustering_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,14 +539,14 @@ def remove_duplicates_via_matching(
method_kwargs={},
job_kwargs={},
tmp_folder=None,
method="circus-omp-svd",
):
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface import get_noise_levels
from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape
import string, random, shutil, os
from pathlib import Path

Expand Down Expand Up @@ -591,19 +591,12 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

dummy_filter = np.empty((num_chans, duration), dtype=np.float32)
dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32)

fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1)

method_kwargs.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"templates": None,
"overlaps": None,
}
)

Expand All @@ -618,16 +611,31 @@ def remove_duplicates_via_matching(

method_kwargs.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"temporal": computed["temporal"],
"spatial": computed["spatial"],
"singular": computed["singular"],
"units_overlaps": computed["units_overlaps"],
"unit_overlaps_indices": computed["unit_overlaps_indices"],
"sparsity_mask": computed["sparsity_mask"],
}
)
elif method == "circus-omp":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
)
valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging)
if np.sum(valid) > 0:
if np.sum(valid) == 1:
Expand Down
114 changes: 63 additions & 51 deletions src/spikeinterface/sortingcomponents/clustering/random_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever


class RandomProjectionClustering:
Expand All @@ -34,17 +36,17 @@ class RandomProjectionClustering:
"cluster_selection_method": "leaf",
},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100},
"radius_um": 100,
"max_spikes_per_unit": 200,
"selection_method": "closest_to_centroid",
"nb_projections": {"ptp": 8, "energy": 2},
"ms_before": 1.5,
"ms_after": 1.5,
"nb_projections": 10,
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"shared_memory": False,
"min_values": {"ptp": 0, "energy": 0},
"smoothing_kwargs": {"window_length_ms": 1},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
}

@classmethod
Expand Down Expand Up @@ -74,50 +76,60 @@ def main_function(cls, recording, peaks, params):

np.random.seed(d["random_seed"])

features_params = {}
features_list = []

noise_snippets = None

for proj_type in ["ptp", "energy"]:
if d["nb_projections"][proj_type] > 0:
features_list += [f"random_projections_{proj_type}"]

if d["min_values"][proj_type] == "auto":
if noise_snippets is None:
num_segments = recording.get_num_segments()
num_chunks = 3 * d["max_spikes_per_unit"] // num_segments
noise_snippets = get_random_data_chunks(
recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42
)
noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans)

if proj_type == "energy":
data = np.linalg.norm(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif proj_type == "ptp":
data = np.ptp(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif d["min_values"][proj_type] > 0:
min_values = d["min_values"][proj_type]
else:
min_values = None

projections = np.random.randn(num_chans, d["nb_projections"][proj_type])
features_params[f"random_projections_{proj_type}"] = {
"radius_um": params["radius_um"],
"projections": projections,
"min_values": min_values,
}

features_data = compute_features_from_peaks(
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"]
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
)

if len(features_data) > 1:
hdbscan_data = np.hstack((features_data[0], features_data[1]))
else:
hdbscan_data = features_data[0]
node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])

projections = np.random.randn(num_chans, d["nb_projections"])
projections -= projections.mean(0)
projections /= projections.std(0)

nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

import scipy

x = np.random.randn(100, nsamples, num_chans).astype(np.float32)
x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1)

ptps = np.ptp(x, axis=1)
a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000))
ydata = np.cumsum(a) / a.sum()
xdata = b[1:]

from scipy.optimize import curve_fit

def sigmoid(x, L, x0, k, b):
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)

node3 = RandomProjectionsFeature(
recording,
parents=[node0, node2],
return_output=True,
projections=projections,
radius_um=params["radius_um"],
)

pipeline_nodes = [node0, node1, node2, node3]

hdbscan_data = run_node_pipeline(
recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features"
)

import sklearn

Expand All @@ -132,7 +144,7 @@ def main_function(cls, recording, peaks, params):

all_indices = np.arange(0, peak_labels.size)

max_spikes = params["max_spikes_per_unit"]
max_spikes = params["waveforms"]["max_spikes_per_unit"]
selection_method = params["selection_method"]

for unit_ind in labels:
Expand Down
27 changes: 15 additions & 12 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,44 @@ def __init__(
return_output=True,
parents=None,
projections=None,
radius_um=150.0,
min_values=None,
sigmoid=None,
radius_um=None,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

self.projections = projections
self.radius_um = radius_um
self.min_values = min_values

self.sigmoid = sigmoid
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um

self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values))

self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self._dtype = recording.get_dtype()

def get_dtype(self):
return self._dtype

def _sigmoid(self, x):
L, x0, k, b = self.sigmoid
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

def compute(self, traces, peaks, waveforms):
all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype)

for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1)
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.min_values is not None:
wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4
if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)

denom = np.sum(wf_ptp, axis=1)
mask = denom != 0

all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis])

return all_projections


Expand Down
Loading