Skip to content

Commit

Permalink
Merge branch 'main' into ibl_move_import_inside
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 29, 2023
2 parents e0bcb28 + 427d7b5 commit c2d369a
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 115 deletions.
15 changes: 8 additions & 7 deletions doc/how_to/load_matlab_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Here, we present a MATLAB code that creates a random dataset and writes it to a
Loading Data in SpikeInterface
------------------------------

After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path.
After executing the above MATLAB code, a binary file named :code:`your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path.

Use the following Python script to load the binary data into SpikeInterface:

Expand All @@ -55,7 +55,7 @@ Use the following Python script to load the binary data into SpikeInterface:
# Load data using SpikeInterface
recording = si.read_binary(file_path, sampling_frequency=sampling_frequency,
num_channels=num_channels, dtype=dtype)
num_channels=num_channels, dtype=dtype)
# Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data
print(recording.get_num_frames(), recording.get_num_channels())
Expand All @@ -65,18 +65,18 @@ Follow the steps above to seamlessly import your MATLAB data into SpikeInterface
Common Pitfalls & Tips
----------------------

1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`.
1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use :code:`time_axis=1` in :code:`si.read_binary()`.
2. **File Path**: Always double-check the Python file path.
3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`.
4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface.
5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide.
5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's `Numpy for MATLAB Users <https://numpy.org/doc/stable/user/numpy-for-matlab-users.html>`_ guide.

Using gains and offsets for integer data
----------------------------------------

Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset.
In SpikeInterface, you can use the `gain_to_uV` and `offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the `read_binary` function.
If your data in MATLAB is stored as `int16`, and you know the gain and offset, you can use the following code to load the data:
In SpikeInterface, you can use the :code:`gain_to_uV` and :code:`offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the :code:`read_binary` function.
If your data in MATLAB is stored as :code:`int16`, and you know the gain and offset, you can use the following code to load the data:

.. code-block:: python
Expand All @@ -90,7 +90,8 @@ If your data in MATLAB is stored as `int16`, and you know the gain and offset, y
num_channels=num_channels, dtype=dtype_int,
gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
recording.get_traces(return_scaled=True) # Return traces in micro volts (uV)
recording.get_traces() # Return traces in original units [type: int]
recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) [type: float]
This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`.
Expand Down
21 changes: 0 additions & 21 deletions src/spikeinterface/extractors/cellexplorersortingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
sampling_frequency: float | None = None,
session_info_file_path: str | Path | None = None,
spikes_matfile_path: str | Path | None = None,
session_info_matfile_path: str | Path | None = None,
):
try:
from pymatreader import read_mat
Expand All @@ -67,26 +66,6 @@ def __init__(
)
file_path = spikes_matfile_path if file_path is None else file_path

if session_info_matfile_path is not None:
# Raise an error if the warning period has expired
deprecation_issued = datetime.datetime(2023, 4, 1)
deprecation_deadline = deprecation_issued + datetime.timedelta(days=180)
if datetime.datetime.now() > deprecation_deadline:
raise ValueError(
"The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead."
)

# Otherwise, issue a DeprecationWarning
else:
warnings.warn(
"The session_info_matfile_path argument is deprecated and will be removed in six months. "
"Use session_info_file_path instead.",
DeprecationWarning,
)
session_info_file_path = (
session_info_matfile_path if session_info_file_path is None else session_info_file_path
)

self.spikes_cellinfo_path = Path(file_path)
self.session_path = self.spikes_cellinfo_path.parent
self.session_id = self.spikes_cellinfo_path.stem.split(".")[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CellExplorerSortingTest(SortingCommonTestSuite, unittest.TestCase):
(
"cellexplorer/dataset_2/20170504_396um_0um_merge.spikes.cellinfo.mat",
{
"session_info_matfile_path": local_folder
"session_info_file_path": local_folder
/ "cellexplorer/dataset_2/20170504_396um_0um_merge.sessionInfo.mat"
},
),
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
Expand Down Expand Up @@ -151,7 +151,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params
recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params
)

if verbose:
Expand Down
42 changes: 25 additions & 17 deletions src/spikeinterface/sortingcomponents/clustering/clustering_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,14 +539,14 @@ def remove_duplicates_via_matching(
method_kwargs={},
job_kwargs={},
tmp_folder=None,
method="circus-omp-svd",
):
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface import get_noise_levels
from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape
import string, random, shutil, os
from pathlib import Path

Expand Down Expand Up @@ -591,19 +591,12 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

dummy_filter = np.empty((num_chans, duration), dtype=np.float32)
dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32)

fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1)

method_kwargs.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"templates": None,
"overlaps": None,
}
)

Expand All @@ -618,16 +611,31 @@ def remove_duplicates_via_matching(

method_kwargs.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"temporal": computed["temporal"],
"spatial": computed["spatial"],
"singular": computed["singular"],
"units_overlaps": computed["units_overlaps"],
"unit_overlaps_indices": computed["unit_overlaps_indices"],
"sparsity_mask": computed["sparsity_mask"],
}
)
elif method == "circus-omp":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
)
valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging)
if np.sum(valid) > 0:
if np.sum(valid) == 1:
Expand Down
114 changes: 63 additions & 51 deletions src/spikeinterface/sortingcomponents/clustering/random_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever


class RandomProjectionClustering:
Expand All @@ -34,17 +36,17 @@ class RandomProjectionClustering:
"cluster_selection_method": "leaf",
},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100},
"radius_um": 100,
"max_spikes_per_unit": 200,
"selection_method": "closest_to_centroid",
"nb_projections": {"ptp": 8, "energy": 2},
"ms_before": 1.5,
"ms_after": 1.5,
"nb_projections": 10,
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"shared_memory": False,
"min_values": {"ptp": 0, "energy": 0},
"smoothing_kwargs": {"window_length_ms": 1},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
}

@classmethod
Expand Down Expand Up @@ -74,50 +76,60 @@ def main_function(cls, recording, peaks, params):

np.random.seed(d["random_seed"])

features_params = {}
features_list = []

noise_snippets = None

for proj_type in ["ptp", "energy"]:
if d["nb_projections"][proj_type] > 0:
features_list += [f"random_projections_{proj_type}"]

if d["min_values"][proj_type] == "auto":
if noise_snippets is None:
num_segments = recording.get_num_segments()
num_chunks = 3 * d["max_spikes_per_unit"] // num_segments
noise_snippets = get_random_data_chunks(
recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42
)
noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans)

if proj_type == "energy":
data = np.linalg.norm(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif proj_type == "ptp":
data = np.ptp(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif d["min_values"][proj_type] > 0:
min_values = d["min_values"][proj_type]
else:
min_values = None

projections = np.random.randn(num_chans, d["nb_projections"][proj_type])
features_params[f"random_projections_{proj_type}"] = {
"radius_um": params["radius_um"],
"projections": projections,
"min_values": min_values,
}

features_data = compute_features_from_peaks(
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"]
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
)

if len(features_data) > 1:
hdbscan_data = np.hstack((features_data[0], features_data[1]))
else:
hdbscan_data = features_data[0]
node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])

projections = np.random.randn(num_chans, d["nb_projections"])
projections -= projections.mean(0)
projections /= projections.std(0)

nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

import scipy

x = np.random.randn(100, nsamples, num_chans).astype(np.float32)
x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1)

ptps = np.ptp(x, axis=1)
a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000))
ydata = np.cumsum(a) / a.sum()
xdata = b[1:]

from scipy.optimize import curve_fit

def sigmoid(x, L, x0, k, b):
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)

node3 = RandomProjectionsFeature(
recording,
parents=[node0, node2],
return_output=True,
projections=projections,
radius_um=params["radius_um"],
)

pipeline_nodes = [node0, node1, node2, node3]

hdbscan_data = run_node_pipeline(
recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features"
)

import sklearn

Expand All @@ -132,7 +144,7 @@ def main_function(cls, recording, peaks, params):

all_indices = np.arange(0, peak_labels.size)

max_spikes = params["max_spikes_per_unit"]
max_spikes = params["waveforms"]["max_spikes_per_unit"]
selection_method = params["selection_method"]

for unit_ind in labels:
Expand Down
27 changes: 15 additions & 12 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,44 @@ def __init__(
return_output=True,
parents=None,
projections=None,
radius_um=150.0,
min_values=None,
sigmoid=None,
radius_um=None,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

self.projections = projections
self.radius_um = radius_um
self.min_values = min_values

self.sigmoid = sigmoid
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um

self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values))

self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self._dtype = recording.get_dtype()

def get_dtype(self):
return self._dtype

def _sigmoid(self, x):
L, x0, k, b = self.sigmoid
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

def compute(self, traces, peaks, waveforms):
all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype)

for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1)
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.min_values is not None:
wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4
if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)

denom = np.sum(wf_ptp, axis=1)
mask = denom != 0

all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis])

return all_projections


Expand Down
Loading

0 comments on commit c2d369a

Please sign in to comment.