Skip to content

Commit

Permalink
change to strings
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Dec 17, 2024
1 parent 0bf2b08 commit 212a974
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_total_duration(self) -> float:

def get_unit_spike_train(
self,
unit_id,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
7 changes: 5 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
import numpy as np
from typing import Literal
from typing import Literal, Optional
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -1138,6 +1138,7 @@ def __init__(
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
unit_ids=unit_ids,
t_start=None,
)
self.add_sorting_segment(segment)
Expand All @@ -1161,6 +1162,7 @@ def __init__(
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
unit_ids: list[str],
t_start: Optional[float] = None,
):
self.num_units = num_units
Expand All @@ -1177,7 +1179,8 @@ def __init__(
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids}

self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/tests/test_basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder):
assert snippets.get_num_segments() == len(duration)
assert snippets.get_num_channels() == num_channels

assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None))
assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None))

# annotations / properties
snippets.annotate(gre="ta")
Expand All @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder):
)

# missing property
snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1])
snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"])
values = snippets.get_property("string_property")
assert values[2] == ""

Expand All @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder):
snippets.set_property,
key="string_property_nan",
values=["hola", "chabon"],
ids=[0, 1],
ids=["0", "1"],
missing_value=np.nan,
)

# int properties without missing values raise an error
assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2])

snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200)
snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200)
values = snippets.get_property("int_property")
assert values.dtype.kind == "i"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def test_channelsaggregationrecording():

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
Expand Down
14 changes: 7 additions & 7 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):

# test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041
# this bug requires that we have an info.json file so we calculate templates above
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1])
assert len(select_units_sorting_analyer.unit_ids) == 1
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"])
assert len(select_units_sorting_analyer.unit_ids) == "1"

folder = tmp_path / "test_SortingAnalyzer_binary_folder"
if folder.exists():
Expand Down Expand Up @@ -121,11 +121,11 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):

# test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041
# this bug requires that we have an info.json file so we calculate templates above
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1])
assert len(select_units_sorting_analyer.unit_ids) == 1
remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=[1])
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"])
assert len(select_units_sorting_analyer.unit_ids) == "1"
remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=["1"])
assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1
assert 1 not in remove_units_sorting_analyer.unit_ids
assert "1" not in remove_units_sorting_analyer.unit_ids

# test no compression
sorting_analyzer_no_compression = create_sorting_analyzer(
Expand Down Expand Up @@ -358,7 +358,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
shutil.rmtree(folder)
else:
folder = None
sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder)
sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[["0", "1"]], format=format, folder=folder)

if format != "memory":
if format == "zarr":
Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,43 @@
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
assert np.array_equal(sorting2.unit_ids, [0, 2])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"])
assert np.array_equal(sorting2.unit_ids, ["0", "2"])
assert sorting2.get_parent() == sorting

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"])
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting2.get_unit_spike_train(unit_id="0", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting3.get_unit_spike_train(unit_id="a", segment_index=0),
)

assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting2.get_unit_spike_train(unit_id="2", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting3.get_unit_spike_train(unit_id="b", segment_index=0),
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"])


def test_custom_cache_spike_vector():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"])
sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"])
cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True)
computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False)
assert np.all(cached_spike_vector == computed_spike_vector)
Expand Down

0 comments on commit 212a974

Please sign in to comment.