Skip to content

Commit

Permalink
Enhance the clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Sep 27, 2023
1 parent 0a2c0f6 commit 9f45f2e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
Expand Down
106 changes: 55 additions & 51 deletions src/spikeinterface/sortingcomponents/clustering/random_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever


class RandomProjectionClustering:
Expand All @@ -34,17 +36,17 @@ class RandomProjectionClustering:
"cluster_selection_method": "leaf",
},
"cleaning_kwargs": {},
"waveforms" : {"ms_before" : 2, "ms_after" : 2, "max_spikes_per_unit": 100},
"radius_um": 100,
"max_spikes_per_unit": 200,
"selection_method": "closest_to_centroid",
"nb_projections": {"ptp": 8, "energy": 2},
"ms_before": 1.5,
"ms_after": 1.5,
"nb_projections": 10,
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"shared_memory": False,
"min_values": {"ptp": 0, "energy": 0},
"smoothing_kwargs" : {"window_length_ms" : 1},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
}

@classmethod
Expand Down Expand Up @@ -74,50 +76,52 @@ def main_function(cls, recording, peaks, params):

np.random.seed(d["random_seed"])

features_params = {}
features_list = []

noise_snippets = None

for proj_type in ["ptp", "energy"]:
if d["nb_projections"][proj_type] > 0:
features_list += [f"random_projections_{proj_type}"]

if d["min_values"][proj_type] == "auto":
if noise_snippets is None:
num_segments = recording.get_num_segments()
num_chunks = 3 * d["max_spikes_per_unit"] // num_segments
noise_snippets = get_random_data_chunks(
recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42
)
noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans)

if proj_type == "energy":
data = np.linalg.norm(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif proj_type == "ptp":
data = np.ptp(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif d["min_values"][proj_type] > 0:
min_values = d["min_values"][proj_type]
else:
min_values = None

projections = np.random.randn(num_chans, d["nb_projections"][proj_type])
features_params[f"random_projections_{proj_type}"] = {
"radius_um": params["radius_um"],
"projections": projections,
"min_values": min_values,
}

features_data = compute_features_from_peaks(
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"]
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(recording, parents=[node0], return_output=False,
ms_before=params['ms_before'],
ms_after=params['ms_after']
)

if len(features_data) > 1:
hdbscan_data = np.hstack((features_data[0], features_data[1]))
else:
hdbscan_data = features_data[0]
node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params['smoothing_kwargs'])

projections = np.random.randn(num_chans, d["nb_projections"])
projections -= projections.mean(0)
projections /= projections.std(0)

nbefore = int(params['ms_before'] * fs / 1000)
nafter = int(params['ms_after'] * fs / 1000)
nsamples = nbefore + nafter

import scipy
x = np.random.randn(100, nsamples, num_chans).astype(np.float32)
x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1)

ptps = np.ptp(x, axis=1)
a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000))
ydata = np.cumsum(a)/a.sum()
xdata = b[1:]

from scipy.optimize import curve_fit
def sigmoid(x, L ,x0, k, b):
y = L / (1 + np.exp(-k*(x-x0))) + b
return (y)

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)

node3 = RandomProjectionsFeature(recording, parents=[node0, node2], return_output=True,
projections=projections, radius_um=params['radius_um'])

pipeline_nodes = [node0, node1, node2, node3]

hdbscan_data = run_node_pipeline(recording, pipeline_nodes, params["job_kwargs"])

import sklearn

Expand All @@ -132,7 +136,7 @@ def main_function(cls, recording, peaks, params):

all_indices = np.arange(0, peak_labels.size)

max_spikes = params["max_spikes_per_unit"]
max_spikes = params['waveforms']["max_spikes_per_unit"]
selection_method = params["selection_method"]

for unit_ind in labels:
Expand Down
27 changes: 15 additions & 12 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,44 @@ def __init__(
return_output=True,
parents=None,
projections=None,
radius_um=150.0,
min_values=None,
sigmoid=None,
radius_um=None
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

self.projections = projections
self.radius_um = radius_um
self.min_values = min_values

self.sigmoid = sigmoid
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um

self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values))

self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self._dtype = recording.get_dtype()

def get_dtype(self):
return self._dtype

def _sigmoid(self, x):
L, x0, k, b = self.sigmoid
y = L / (1 + np.exp(-k*(x-x0))) + b
return y

def compute(self, traces, peaks, waveforms):
all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype)

for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1)
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.min_values is not None:
wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4
if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)

denom = np.sum(wf_ptp, axis=1)
mask = denom != 0

all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis])

return all_projections


Expand Down

0 comments on commit 9f45f2e

Please sign in to comment.