Skip to content

Commit

Permalink
Add recording attributes check, docs, and warning
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed May 14, 2024
1 parent e7fa1e4 commit dae7813
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 16 deletions.
32 changes: 32 additions & 0 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,35 @@ def get_rec_attributes(recording):
dtype=recording.get_dtype(),
)
return rec_attributes


def check_recording_attributes_match(recording1, recording2_attributes, skip_properties=True):
"""
Check if two recordings have the same attributes
Parameters
----------
recording1 : BaseRecording
The first recording object
recording2 : BaseRecording
The second recording object
Returns
-------
bool
True if the recordings have the same attributes
"""
recording1_attributes = get_rec_attributes(recording1)
recording1_attributes["probegroup"] = recording1.get_probegroup()
recording2_attributes = deepcopy(recording2_attributes)
if skip_properties:
recording1_attributes.pop("properties")
recording2_attributes.pop("properties")
return (
np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"])
and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]
and recording1_attributes["num_channels"] == recording2_attributes["num_channels"]
and recording1_attributes["num_samples"] == recording2_attributes["num_samples"]
and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]
and recording1_attributes["dtype"] == recording2_attributes["dtype"]
)
30 changes: 27 additions & 3 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .basesorting import BaseSorting

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, check_recording_attributes_match
from .core_tools import check_json, retrieve_importing_provenance
from .job_tools import split_job_kwargs
from .numpyextractors import NumpySorting
Expand Down Expand Up @@ -588,9 +588,33 @@ def load_from_zarr(cls, folder, recording=None):

return sorting_analyzer

def set_recording(self, recording):
def set_temporary_recording(self, recording: BaseRecording):
"""
Sets a temporary recording object. This function can be useful to temporarily set
a "cached" recording object that is not saved in the SortingAnalyzer object to speed up
computations. Upon reloading, the SortingAnalyzer object will try to reload the recording
from the original location in a lazy way.
Parameters
----------
recording : BaseRecording
The recording object to set as temporary recording.
Raises
------
ValueError
_description_
"""
# check that recording is compatible
assert check_recording_attributes_match(
recording, self.rec_attributes, skip_properties=True
), "Recording attributes do not match."
assert np.array_equal(
recording.get_channel_locations(), self.get_channel_locations()
), "Recording channel locations do not match."
if self._recording is not None:
raise ValueError("Recording is already set")
warnings.warn("SortingAnalyzer recording is already set. This will overwrite the current recording.")
self._recording = recording

def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer":
Expand Down
44 changes: 31 additions & 13 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np


def get_dataset():
def _get_dataset():
recording, sorting = generate_ground_truth_recording(
durations=[30.0],
sampling_frequency=16000.0,
Expand All @@ -28,8 +28,13 @@ def get_dataset():
return recording, sorting


def test_SortingAnalyzer_memory(tmp_path):
recording, sorting = get_dataset()
@pytest.fixture(scope="module")
def get_dataset():
return _get_dataset()


def test_SortingAnalyzer_memory(tmp_path, get_dataset):
recording, sorting = get_dataset
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None)
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)

Expand All @@ -48,8 +53,8 @@ def test_SortingAnalyzer_memory(tmp_path):
assert not sorting_analyzer.return_scaled


def test_SortingAnalyzer_binary_folder(tmp_path):
recording, sorting = get_dataset()
def test_SortingAnalyzer_binary_folder(tmp_path, get_dataset):
recording, sorting = get_dataset

folder = tmp_path / "test_SortingAnalyzer_binary_folder"
if folder.exists():
Expand Down Expand Up @@ -78,8 +83,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path):
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)


def test_SortingAnalyzer_zarr(tmp_path):
recording, sorting = get_dataset()
def test_SortingAnalyzer_zarr(tmp_path, get_dataset):
recording, sorting = get_dataset

folder = tmp_path / "test_SortingAnalyzer_zarr.zarr"
if folder.exists():
Expand All @@ -99,10 +104,21 @@ def test_SortingAnalyzer_zarr(tmp_path):
)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
def test_SortingAnalyzer_tmp_recording(get_dataset):
recording, sorting = get_dataset
recording_cached = recording.save(mode="memory")

print()
print(sorting_analyzer)
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None)
sorting_analyzer.set_temporary_recording(recording_cached)

recording_sliced = recording.channel_slice(recording.channel_ids[:-1])

# wrong channels
with pytest.raises(AssertionError):
sorting_analyzer.set_temporary_recording(recording_sliced)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

register_result_extension(DummyAnalyzerExtension)

Expand Down Expand Up @@ -257,8 +273,10 @@ def test_extension():

if __name__ == "__main__":
tmp_path = Path("test_SortingAnalyzer")
test_SortingAnalyzer_memory(tmp_path)
test_SortingAnalyzer_binary_folder(tmp_path)
test_SortingAnalyzer_zarr(tmp_path)
dataset = _get_dataset()
test_SortingAnalyzer_memory(tmp_path, dataset)
test_SortingAnalyzer_binary_folder(tmp_path, dataset)
test_SortingAnalyzer_zarr(tmp_path, dataset)
test_SortingAnalyzer_tmp_recording(dataset)
test_extension()
test_extension_params()

0 comments on commit dae7813

Please sign in to comment.