Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 updates #2086

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")
noise_levels = np.ones(num_channels, dtype=np.float32)

## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
Expand All @@ -87,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

noise_levels = np.ones(num_channels, dtype=np.float32)
selection_params.update({"noise_levels": noise_levels})
selected_peaks = select_peaks(
peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params
Expand All @@ -107,6 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params.update({"noise_levels": noise_levels})

labels, peak_labels = find_cluster_from_peaks(
recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)


class RandomProjectionClustering:
Expand All @@ -43,7 +48,8 @@ class RandomProjectionClustering:
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"smoothing_kwargs": {"window_length_ms": 1},
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
Expand Down Expand Up @@ -72,7 +78,10 @@ def main_function(cls, recording, peaks, params):
num_samples = nbefore + nafter
num_chans = recording.get_num_channels()

noise_levels = get_noise_levels(recording, return_scaled=False)
if d["noise_levels"] is None:
noise_levels = get_noise_levels(recording, return_scaled=False)
else:
noise_levels = d["noise_levels"]

np.random.seed(d["random_seed"])

Expand All @@ -82,10 +91,16 @@ def main_function(cls, recording, peaks, params):
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
tmp_folder.mkdir(parents=True, exist_ok=True)

node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=False,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
radius_um=params["radius_um"],
)

node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])
Expand Down Expand Up @@ -123,6 +138,8 @@ def sigmoid(x, L, x0, k, b):
return_output=True,
projections=projections,
radius_um=params["radius_um"],
sigmoid=None,
sparse=True,
)

pipeline_nodes = [node0, node1, node2, node3]
Expand All @@ -136,6 +153,18 @@ def sigmoid(x, L, x0, k, b):
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
peak_labels = clustering[0]

# peak_labels = -1 * np.ones(len(peaks), dtype=int)
# nb_clusters = 0
# for c in np.unique(peaks['channel_index']):
# mask = peaks['channel_index'] == c
# clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs'])
# local_labels = clustering[0]
# valid_clusters = local_labels > -1
# if np.sum(valid_clusters) > 0:
# local_labels[valid_clusters] += nb_clusters
# peak_labels[mask] = local_labels
# nb_clusters += len(np.unique(local_labels[valid_clusters]))

labels = np.unique(peak_labels)
labels = labels[labels >= 0]

Expand Down Expand Up @@ -174,15 +203,6 @@ def sigmoid(x, L, x0, k, b):
if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(labels)))

# create a tmp folder
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"])

tmp_folder.mkdir(parents=True, exist_ok=True)

sorting_folder = tmp_folder / "sorting"
unit_ids = np.arange(len(np.unique(spikes["unit_index"])))
sorting = NumpySorting(spikes, fs, unit_ids=unit_ids)
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
projections=None,
sigmoid=None,
radius_um=None,
sparse=True,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

Expand All @@ -195,7 +196,8 @@ def __init__(
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self.sparse = sparse
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse))
self._dtype = recording.get_dtype()

def get_dtype(self):
Expand All @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)
if self.sparse:
wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1)
else:
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)
Expand Down