Skip to content

Commit

Permalink
Merge pull request #3064 from zm711/bool-magic
Browse files Browse the repository at this point in the history
Add `bool` type hint to functions in core module
  • Loading branch information
alejoe91 authored Jun 24, 2024
2 parents 2dc0b74 + a71eff4 commit f366960
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 27 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def check_serializability(self, type):
return False
return self._serializability[type]

def check_if_memory_serializable(self):
def check_if_memory_serializable(self) -> bool:
"""
Check if the object is serializable to memory with pickle, including nested objects.
Expand All @@ -561,7 +561,7 @@ def check_if_memory_serializable(self):
"""
return self.check_serializability("memory")

def check_if_json_serializable(self):
def check_if_json_serializable(self) -> bool:
"""
Check if the object is json serializable, including nested objects.
Expand All @@ -574,7 +574,7 @@ def check_if_json_serializable(self):
# is this needed ??? I think no.
return self.check_serializability("json")

def check_if_pickle_serializable(self):
def check_if_pickle_serializable(self) -> bool:
# is this needed ??? I think no.
return self.check_serializability("pickle")

Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_num_channels(self):
def get_dtype(self):
return self._dtype

def has_scaleable_traces(self):
def has_scaleable_traces(self) -> bool:
if self.get_property("gain_to_uV") is None or self.get_property("offset_to_uV") is None:
return False
else:
Expand All @@ -62,10 +62,10 @@ def has_scaled(self):
)
return self.has_scaleable_traces()

def has_probe(self):
def has_probe(self) -> bool:
return "contact_vector" in self.get_property_keys()

def has_channel_location(self):
def has_channel_location(self) -> bool:
return self.has_probe() or "location" in self.get_property_keys()

def is_filtered(self):
Expand Down Expand Up @@ -366,7 +366,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
locations = np.asarray(locations)[channel_indices]
return select_axes(locations, axes)

def has_3d_locations(self):
def has_3d_locations(self) -> bool:
return self.get_property("location").shape[1] == 3

def clear_channel_locations(self, channel_ids=None):
Expand Down
10 changes: 4 additions & 6 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np

Expand Down Expand Up @@ -73,7 +73,7 @@ def unit_ids(self):
def sampling_frequency(self):
return self._sampling_frequency

def get_unit_ids(self) -> List:
def get_unit_ids(self) -> list:
return self._main_ids

def get_num_units(self) -> int:
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_total_samples(self) -> int:
s += self.get_num_samples(segment_index)
return s

def get_total_duration(self):
def get_total_duration(self) -> float:
"""Returns the total duration in s of the associated recording.
Returns
Expand Down Expand Up @@ -219,7 +219,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict):
def has_recording(self):
return self._recording is not None

def has_time_vector(self, segment_index=None):
def has_time_vector(self, segment_index=None) -> bool:
"""
Check if the segment of the registered recording has a time vector.
"""
Expand Down Expand Up @@ -515,8 +515,6 @@ def precompute_spike_trains(self, from_spike_vector=None):
"""
Pre-computes and caches all spike trains for this sorting
Parameters
----------
from_spike_vector : None | bool, default: None
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, folder_path):
assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json"
self._bin_kwargs["num_channels"] = self._bin_kwargs["num_chan"]

def is_binary_compatible(self):
def is_binary_compatible(self) -> bool:
return True

def get_binary_description(self):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def write_recording(recording, file_paths, dtype=None, **job_kwargs):
"""
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)

def is_binary_compatible(self):
def is_binary_compatible(self) -> bool:
return True

def get_binary_description(self):
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,14 @@ def make_shared_array(shape, dtype):
return arr, shm


def is_dict_extractor(d):
def is_dict_extractor(d: dict) -> bool:
"""
Check if a dict describe an extractor.
Check if a dict describes an extractor.
Returns
-------
is_extractor : bool
Whether the dict describes an extractor
"""
if not isinstance(d, dict):
return False
Expand Down Expand Up @@ -283,6 +288,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool:
Returns
-------
relative_possible: bool
Whether the given input can be made relative to the relative_folder
"""
path_list = _get_paths_list(input_dict)
relative_folder = Path(relative_folder).resolve().absolute()
Expand Down Expand Up @@ -513,7 +519,8 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0):

def retrieve_importing_provenance(a_class):
"""
Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version.
Retrieve the import provenance of a class, including its import name (that consists of the class name and the module),
the top-level module, and the module version.
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, parent_recording_segment, start_frame, end_frame):
self.start_frame = start_frame
self.end_frame = end_frame

def get_num_samples(self):
def get_num_samples(self) -> int:
return self.end_frame - self.start_frame

def get_traces(self, start_frame, end_frame, channel_indices):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def __init__(
elif self.strategy == "on_the_fly":
pass

def get_num_samples(self):
def get_num_samples(self) -> int:
return self.num_samples

def get_traces(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, traces, sampling_frequency, t_start):
self._traces = traces
self.num_samples = traces.shape[0]

def get_num_samples(self):
def get_num_samples(self) -> int:
return self.num_samples

def get_traces(self, start_frame, end_frame, channel_indices):
Expand Down
12 changes: 11 additions & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,17 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"),
def check_probe_do_not_overlap(probes):
"""
When several probes this check that that they do not overlap in space
and so channel positions can be safly concatenated.
and so channel positions can be safely concatenated.
Raises
------
Exception :
If probes are overlapping
Returns
-------
None : None
If the check is successful
"""
for i in range(len(probes)):
probe_i = probes[i]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ def get_computable_extensions(self):
"""
return get_available_analyzer_extensions()

def get_default_extension_params(self, extension_name: str):
def get_default_extension_params(self, extension_name: str) -> dict:
"""
Get the default params for an extension.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None
return waveforms_by_units


def has_exceeding_spikes(recording, sorting):
def has_exceeding_spikes(recording, sorting) -> bool:
"""
Check if the sorting objects has spikes exceeding the recording number of samples, for all segments
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_
return sorting

@classmethod
def check_compiled(cls):
def check_compiled(cls) -> bool:
"""
Checks if the sorter is running inside an image with matlab-compiled version
Expand All @@ -370,7 +370,7 @@ def check_compiled(cls):
return True

@classmethod
def use_gpu(cls, params):
def use_gpu(cls, params) -> bool:
return cls.gpu_capability != "not-supported"

#############################################
Expand Down Expand Up @@ -436,7 +436,7 @@ def get_job_kwargs(params, verbose):
return job_kwargs


def is_log_ok(output_folder):
def is_log_ok(output_folder) -> bool:
# log is OK when run_time is not None
if (output_folder / "spikeinterface_log.json").is_file():
with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile:
Expand Down

0 comments on commit f366960

Please sign in to comment.