diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2af48407a3..9a0e242d62 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -135,7 +135,7 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, - unit_id, + unit_id: str | int, segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d03c08b480..5824a75ab8 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Literal +from typing import Literal, Optional from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -1138,6 +1138,7 @@ def __init__( firing_rates=firing_rates, refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, + unit_ids=unit_ids, t_start=None, ) self.add_sorting_segment(segment) @@ -1161,6 +1162,7 @@ def __init__( firing_rates: float | np.ndarray, refractory_period_seconds: float | np.ndarray, seed: int, + unit_ids: list[str], t_start: Optional[float] = None, ): self.num_units = num_units @@ -1177,7 +1179,8 @@ def __init__( self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 64f7f76819..f243dd9d9f 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder): assert snippets.get_num_segments() == len(duration) assert snippets.get_num_channels() == num_channels - assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2]) - assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) + assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2]) + assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None)) # annotations / properties snippets.annotate(gre="ta") @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder): ) # missing property - snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1]) + snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"]) values = snippets.get_property("string_property") assert values[2] == "" @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder): snippets.set_property, key="string_property_nan", values=["hola", "chabon"], - ids=[0, 1], + ids=["0", "1"], missing_value=np.nan, ) # int properties without missing values raise an error assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2]) - snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200) + snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200) values = snippets.get_property("int_property") assert values.dtype.kind == "i" diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 118b6092a9..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -38,10 +38,12 @@ def test_channelsaggregationrecording(): assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), ) # all traces traces1 = recording1.get_traces(segment_index=seg) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 35ab18b5f2..899993d840 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -76,8 +76,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) - assert len(select_units_sorting_analyer.unit_ids) == 1 + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) + assert len(select_units_sorting_analyer.unit_ids) == "1" folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -121,11 +121,11 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) - assert len(select_units_sorting_analyer.unit_ids) == 1 - remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=[1]) + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) + assert len(select_units_sorting_analyer.unit_ids) == "1" + remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=["1"]) assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 - assert 1 not in remove_units_sorting_analyer.unit_ids + assert "1" not in remove_units_sorting_analyer.unit_ids # test no compression sorting_analyzer_no_compression = create_sorting_analyzer( @@ -358,7 +358,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): shutil.rmtree(folder) else: folder = None - sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder) + sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[["0", "1"]], format=format, folder=folder) if format != "memory": if format == "zarr": diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 1e72b0ab28..3ecb702aa2 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -10,25 +10,29 @@ def test_basic_functions(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2]) - assert np.array_equal(sorting2.unit_ids, [0, 2]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"]) + assert np.array_equal(sorting2.unit_ids, ["0", "2"]) assert sorting2.get_parent() == sorting - sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"]) + sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"]) assert np.array_equal(sorting3.unit_ids, ["a", "b"]) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting2.get_unit_spike_train(unit_id="0", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting3.get_unit_spike_train(unit_id="a", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting2.get_unit_spike_train(unit_id="2", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting3.get_unit_spike_train(unit_id="b", segment_index=0), ) @@ -36,13 +40,13 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_custom_cache_spike_vector(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) assert np.all(cached_spike_vector == computed_spike_vector)