Skip to content

Commit

Permalink
Cleaning imports in circus files
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Mar 15, 2024
1 parent 1fd6973 commit d721304
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 26 deletions.
20 changes: 4 additions & 16 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# """Sorting components: clustering"""
from pathlib import Path

import shutil
import numpy as np

try:
Expand All @@ -13,16 +12,13 @@
except:
HAVE_HDBSCAN = False

import random, string, os
from spikeinterface.core import get_global_tmp_folder, get_channel_distances
import random, string
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.core.basesorting import minimum_spike_dtype
from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler
from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
from sklearn.decomposition import TruncatedSVD
Expand All @@ -32,7 +28,6 @@
import pickle, json
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)
Expand All @@ -59,7 +54,6 @@ class CircusClustering:
"n_svd": [5, 10],
"ms_before": 0.5,
"ms_after": 0.5,
"random_seed": 42,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
Expand All @@ -74,16 +68,11 @@ def main_function(cls, recording, peaks, params):
d = params
verbose = job_kwargs.get("verbose", False)

peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]

fs = recording.get_sampling_frequency()
ms_before = params["ms_before"]
ms_after = params["ms_after"]
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)
num_samples = nbefore + nafter
num_chans = recording.get_num_channels()
np.random.seed(d["random_seed"])

if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
Expand Down Expand Up @@ -119,7 +108,6 @@ def main_function(cls, recording, peaks, params):
json.dump(model_params, f)

# features
features_folder = model_folder / "features"
node0 = PeakRetriever(recording, peaks)

radius_um = params["radius_um"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@
except:
HAVE_HDBSCAN = False

import random, string, os
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks
from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler
from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
Expand Down Expand Up @@ -99,7 +94,6 @@ def main_function(cls, recording, peaks, params):

nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

# if params["feature"] == "ptp":
# noise_values = np.ptp(rng.randn(1000, nsamples), axis=1)
Expand All @@ -124,8 +118,6 @@ def main_function(cls, recording, peaks, params):
recording, pipeline_nodes, job_kwargs=job_kwargs, job_name="extracting features"
)

import sklearn

clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
peak_labels = clustering[0]

Expand Down

0 comments on commit d721304

Please sign in to comment.