Skip to content

Commit

Permalink
Merge branch 'main' into improve_relative_to
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Dec 5, 2023
2 parents 59cd083 + bf0b055 commit ba0d29b
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 102 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BaseRecordingSnippets(BaseExtractor):

def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype):
BaseExtractor.__init__(self, channel_ids)
self._sampling_frequency = sampling_frequency
self._sampling_frequency = float(sampling_frequency)
self._dtype = np.dtype(dtype)

@property
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseSorting(BaseExtractor):

def __init__(self, sampling_frequency: float, unit_ids: List):
BaseExtractor.__init__(self, unit_ids)
self._sampling_frequency = sampling_frequency
self._sampling_frequency = float(sampling_frequency)
self._sorting_segments: List[BaseSortingSegment] = []
# this weak link is to handle times from a recording object
self._recording = None
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.neighbours_mask = channel_distance <= radius_um
self.peak_sign = peak_sign

# precompute segment slice
Expand Down Expand Up @@ -367,7 +367,7 @@ def __init__(
self.radius_um = radius_um
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1))

def get_trace_margin(self):
Expand Down
31 changes: 15 additions & 16 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class NwbRecordingExtractor(BaseRecording):
samples_for_rate_estimation: int, default: 100000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
Expand Down Expand Up @@ -376,6 +376,9 @@ def __init__(
for column in electrodes_table.colnames:
if isinstance(electrodes_table[column][electrode_table_index], ElectrodeGroup):
continue
elif column == "channel_name":
# channel_names are already set as channel ids!
continue
elif column == "group_name":
group = unique_electrode_group_names.index(electrodes_table[column][electrode_table_index])
if "group" not in properties:
Expand Down Expand Up @@ -412,12 +415,11 @@ def __init__(
else:
self.set_property(property_name, values)

if stream_mode not in ["fsspec", "ros3", "remfile"]:
if file_path is not None:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())
if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

if stream_mode == "fsspec" and stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())

self.extra_requirements.extend(["pandas", "pynwb", "hdmf"])
self._electrical_series = electrical_series
Expand Down Expand Up @@ -493,8 +495,8 @@ class NwbSortingExtractor(BaseSorting):
samples_for_rate_estimation: int, default: 100000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
Expand Down Expand Up @@ -591,12 +593,9 @@ def __init__(
for prop_name, values in properties.items():
self.set_property(prop_name, np.array(values))

if stream_mode not in ["fsspec", "ros3"]:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
# only add stream_cache_path to kwargs if it was passed as an argument
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())
if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

self._kwargs = {
"file_path": file_path,
"electrical_series_name": self._electrical_series_name,
Expand Down
32 changes: 32 additions & 0 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,38 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache):
check_sortings_equal(reloaded_sorting, sorting)


@pytest.mark.streaming_extractors
def test_sorting_s3_nwb_remfile(tmp_path):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# We provide the 'sampling_frequency' because the NWB file does not have the electrical series
sorting = NwbSortingExtractor(
file_path,
sampling_frequency=30000.0,
stream_mode="remfile",
)

num_seg = sorting.get_num_segments()
assert num_seg == 1
num_units = len(sorting.unit_ids)
assert num_units == 64

for segment_index in range(num_seg):
for unit in sorting.unit_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
assert len(spike_train) > 0
assert spike_train.dtype == "int64"
assert np.all(spike_train >= 0)

tmp_file = tmp_path / "test_remfile_sorting.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(sorting, f)

with open(tmp_file, "rb") as f:
reloaded_sorting = pickle.load(f)

check_sortings_equal(reloaded_sorting, sorting)


if __name__ == "__main__":
test_recording_s3_nwb_ros3()
test_recording_s3_nwb_fsspec()
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_extension_function():


def _make_bins(sorting, window_ms, bin_ms):
fs = sorting.get_sampling_frequency()
fs = sorting.sampling_frequency

window_size = int(round(fs * window_ms / 2 * 1e-3))
bin_size = int(round(fs * bin_ms * 1e-3))
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights(

# mask to get nearest template given a channel
dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions)
nearest_template_mask = dist < radius_um
nearest_template_mask = dist <= radius_um

weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32)
for count, sigma in enumerate(sigma_um):
Expand Down
75 changes: 51 additions & 24 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
from __future__ import annotations
import warnings

import numpy as np
from typing import Literal

from .filter import highpass_filter
from ..core import get_random_data_chunks, order_channels_by_depth
from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording


def detect_bad_channels(
recording,
method="coherence+psd",
std_mad_threshold=5,
psd_hf_threshold=0.02,
dead_channel_threshold=-0.5,
noisy_channel_threshold=1.0,
outside_channel_threshold=-0.75,
n_neighbors=11,
nyquist_threshold=0.8,
direction="y",
chunk_duration_s=0.3,
num_random_chunks=100,
welch_window_ms=10.0,
highpass_filter_cutoff=300,
neighborhood_r2_threshold=0.9,
neighborhood_r2_radius_um=30.0,
seed=None,
recording: BaseRecording,
method: str = "coherence+psd",
std_mad_threshold: float = 5,
psd_hf_threshold: float = 0.02,
dead_channel_threshold: float = -0.5,
noisy_channel_threshold: float = 1.0,
outside_channel_threshold: float = -0.75,
outside_channels_location: Literal["top", "bottom", "both"] = "top",
n_neighbors: int = 11,
nyquist_threshold: float = 0.8,
direction: Literal["x", "y", "z"] = "y",
chunk_duration_s: float = 0.3,
num_random_chunks: int = 100,
welch_window_ms: float = 10.0,
highpass_filter_cutoff: float = 300,
neighborhood_r2_threshold: float = 0.9,
neighborhood_r2_radius_um: float = 30.0,
seed: int | None = None,
):
"""
Perform bad channel detection.
Expand Down Expand Up @@ -65,6 +68,11 @@ def detect_bad_channels(
outside_channel_threshold (coeherence+psd) : float, default: -0.75
Threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top"
Location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors (coeherence+psd) : int, default: 11
Number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold (coeherence+psd) : float, default: 0.8
Expand Down Expand Up @@ -190,6 +198,7 @@ def detect_bad_channels(
n_neighbors=n_neighbors,
nyquist_threshold=nyquist_threshold,
welch_window_ms=welch_window_ms,
outside_channels_location=outside_channels_location,
)
chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels

Expand Down Expand Up @@ -275,6 +284,7 @@ def detect_bad_channels_ibl(
n_neighbors=11,
nyquist_threshold=0.8,
welch_window_ms=0.3,
outside_channels_location="top",
):
"""
Bad channels detection for Neuropixel probes developed by IBL
Expand All @@ -300,6 +310,11 @@ def detect_bad_channels_ibl(
Threshold on Nyquist frequency to calculate HF noise band
welch_window_ms: float, default: 0.3
Window size for the scipy.signal.welch that will be converted to nperseg
outside_channels_location : "top" | "bottom" | "both", default: "top"
Location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
Returns
-------
Expand Down Expand Up @@ -332,12 +347,24 @@ def detect_bad_channels_ibl(
ichannels[inoisy] = 2

# the channels outside of the brains are the contiguous channels below the threshold on the trend coherency
# the chanels outide need to be at either extremes of the probe
ioutside = np.where(xcorr_distant < outside_channel_thr)[0]
if ioutside.size > 0 and (ioutside[-1] == (nc - 1) or ioutside[0] == 0):
a = np.cumsum(np.r_[0, np.diff(ioutside) - 1])
ioutside = ioutside[a == np.max(a)]
ichannels[ioutside] = 3
# the chanels outside need to be at the extreme of the probe
(ioutside,) = np.where(xcorr_distant < outside_channel_thr)
a = np.cumsum(np.r_[0, np.diff(ioutside) - 1])
if ioutside.size > 0:
if outside_channels_location == "top":
# channels are sorted bottom to top, so the last channel needs to be (nc - 1)
if ioutside[-1] == (nc - 1):
ioutside = ioutside[(a == np.max(a)) & (a > 0)]
ichannels[ioutside] = 3
elif outside_channels_location == "bottom":
# outside channels are at the bottom of the probe, so the first channel needs to be 0
if ioutside[0] == 0:
ioutside = ioutside[(a == np.min(a)) & (a < np.max(a))]
ichannels[ioutside] = 3
else: # both extremes are considered
if ioutside[-1] == (nc - 1) or ioutside[0] == 0:
ioutside = ioutside[(a == np.max(a)) | (a == np.min(a))]
ichannels[ioutside] = 3

return ichannels

Expand Down
Loading

0 comments on commit ba0d29b

Please sign in to comment.