Skip to content

Commit

Permalink
Merge pull request #2784 from h-mayorquin/add_name_to_repr
Browse files Browse the repository at this point in the history
Add name as an extractor attribute for `__repr__` purposes
  • Loading branch information
alejoe91 authored Jul 9, 2024
2 parents a1958ce + 4e0f587 commit 588ff5f
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 15 deletions.
15 changes: 14 additions & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BaseExtractor:
# This replaces the old key_properties
# These are annotations/properties that always need to be
# dumped (for instance locations, groups, is_fileterd, etc.)
_main_annotations = []
_main_annotations = ["name"]
_main_properties = []

# these properties are skipped by default in copy_metadata
Expand Down Expand Up @@ -79,6 +79,19 @@ def __init__(self, main_ids: Sequence) -> None:
# preferred context for multiprocessing
self._preferred_mp_context = None

@property
def name(self):
name = self._annotations.get("name", None)
return name if name is not None else self.__class__.__name__

@name.setter
def name(self, value):
if value is not None:
self.annotate(name=value)
else:
# we remove the annotation if it exists
_ = self._annotations.pop("name", None)

def get_num_segments(self) -> int:
# This is implemented in BaseRecording or BaseSorting
raise NotImplementedError
Expand Down
17 changes: 7 additions & 10 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets):
Internally handle list of RecordingSegment
"""

_main_annotations = ["is_filtered"]
_main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"]
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = [] # recording do not handle features

Expand All @@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
self.annotate(is_filtered=False)

def __repr__(self):

class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()

txt = self._repr_header()
Expand All @@ -57,7 +54,7 @@ def __repr__(self):
split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100
if split_index != -1:
first_line = txt[:split_index]
recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": "
recording_string_space = len(self.name) + 2 # Length of self.name plus ": "
white_space_to_align_with_first_line = " " * recording_string_space
second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip()
txt = first_line + "\n" + second_line
Expand Down Expand Up @@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6):
return txt

def _repr_header(self):
class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()
num_channels = self.get_num_channels()
sf_khz = self.get_sampling_frequency() / 1000.0
sf_hz = self.get_sampling_frequency()
sf_khz = sf_hz / 1000
dtype = self.get_dtype()

total_samples = self.get_total_samples()
total_duration = self.get_total_duration()
total_memory_size = self.get_total_memory_size()
sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"

txt = (
f"{name_to_display}: "
f"{self.name}: "
f"{num_channels} channels - "
f"{sf_khz:0.1f}kHz - "
f"{sampling_frequency_repr} - "
f"{num_segments} segments - "
f"{total_samples:,} samples - "
f"{convert_seconds_to_str(total_duration)} - "
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets):
Abstract class representing several multichannel snippets.
"""

_main_annotations = []
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = []

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List):
self._cached_spike_trains = {}

def __repr__(self):
clsname = self.__class__.__name__
nseg = self.get_num_segments()
nunits = self.get_num_units()
sf_khz = self.get_sampling_frequency() / 1000.0
txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
if "file_path" in self._kwargs:
txt += "\n file_path: {}".format(self._kwargs["file_path"])
return txt
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def generate_recording(
probe.set_device_channel_indices(np.arange(num_channels))
recording.set_probe(probe, in_place=True)

recording.name = "SyntheticRecording"

return recording


Expand Down Expand Up @@ -2122,4 +2124,7 @@ def generate_ground_truth_recording(
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

return recording, sorting
31 changes: 30 additions & 1 deletion src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from typing import Sequence
import numpy as np
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core import generate_recording, concatenate_recordings
from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings


class DummyDictExtractor(BaseExtractor):
Expand Down Expand Up @@ -65,6 +66,34 @@ def test_check_if_serializable():
assert not extractor.check_serializability("json")


def test_name_and_repr():
test_recording, test_sorting = generate_ground_truth_recording(seed=0, durations=[2])
assert test_recording.name == "GroundTruthRecording"
assert test_sorting.name == "GroundTruthSorting"

# set a different name
test_recording.name = "MyRecording"
assert test_recording.name == "MyRecording"

# to/from dict
test_recording_dict = test_recording.to_dict()
test_recording2 = BaseExtractor.from_dict(test_recording_dict)
assert test_recording2.name == "MyRecording"

# repr
rec_str = str(test_recording2)
assert "MyRecording" in rec_str
test_recording2.name = None
assert "MyRecording" not in str(test_recording2)
assert test_recording2.__class__.__name__ in str(test_recording2)
# above 10khz, sampling frequency is printed in kHz
assert f"kHz" in rec_str
# below 10khz sampling frequency is printed in Hz
test_rec_low_fs = generate_recording(seed=0, durations=[2], sampling_frequency=5000)
rec_str = str(test_rec_low_fs)
assert "Hz" in rec_str


if __name__ == "__main__":
test_check_if_memory_serializable()
test_check_if_serializable()

0 comments on commit 588ff5f

Please sign in to comment.