Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 27, 2023
1 parent 73b065a commit 3cbf8f8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ class RandomProjectionClustering:
"cluster_selection_method": "leaf",
},
"cleaning_kwargs": {},
"waveforms" : {"ms_before" : 2, "ms_after" : 2, "max_spikes_per_unit": 100},
"waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100},
"radius_um": 100,
"selection_method": "closest_to_centroid",
"nb_projections": 10,
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"smoothing_kwargs" : {"window_length_ms" : 1},
"smoothing_kwargs": {"window_length_ms": 1},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
Expand Down Expand Up @@ -84,40 +84,46 @@ def main_function(cls, recording, peaks, params):

### Then we extract the SVD features
node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(recording, parents=[node0], return_output=False,
ms_before=params['ms_before'],
ms_after=params['ms_after']
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
)

node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params['smoothing_kwargs'])
node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])

projections = np.random.randn(num_chans, d["nb_projections"])
projections -= projections.mean(0)
projections /= projections.std(0)

nbefore = int(params['ms_before'] * fs / 1000)
nafter = int(params['ms_after'] * fs / 1000)
nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

import scipy

x = np.random.randn(100, nsamples, num_chans).astype(np.float32)
x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1)

ptps = np.ptp(x, axis=1)
a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000))
ydata = np.cumsum(a)/a.sum()
ydata = np.cumsum(a) / a.sum()
xdata = b[1:]

from scipy.optimize import curve_fit
def sigmoid(x, L ,x0, k, b):
y = L / (1 + np.exp(-k*(x-x0))) + b
return (y)

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
def sigmoid(x, L, x0, k, b):
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)

node3 = RandomProjectionsFeature(recording, parents=[node0, node2], return_output=True,
projections=projections, radius_um=params['radius_um'])
node3 = RandomProjectionsFeature(
recording,
parents=[node0, node2],
return_output=True,
projections=projections,
radius_um=params["radius_um"],
)

pipeline_nodes = [node0, node1, node2, node3]

Expand All @@ -136,7 +142,7 @@ def sigmoid(x, L ,x0, k, b):

all_indices = np.arange(0, peak_labels.size)

max_spikes = params['waveforms']["max_spikes_per_unit"]
max_spikes = params["waveforms"]["max_spikes_per_unit"]
selection_method = params["selection_method"]

for unit_ind in labels:
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
parents=None,
projections=None,
sigmoid=None,
radius_um=None
radius_um=None,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

Expand All @@ -203,12 +203,12 @@ def get_dtype(self):

def _sigmoid(self, x):
L, x0, k, b = self.sigmoid
y = L / (1 + np.exp(-k*(x-x0))) + b
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

def compute(self, traces, peaks, waveforms):
all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype)

for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
Expand All @@ -221,7 +221,7 @@ def compute(self, traces, peaks, waveforms):
denom = np.sum(wf_ptp, axis=1)
mask = denom != 0
all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis])

return all_projections


Expand Down

0 comments on commit 3cbf8f8

Please sign in to comment.