From 5970894070e8f7e6f3a3b138fafe3c8b80890aea Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:31:26 -0400 Subject: [PATCH] first draft of typing, typo-fixes, docstring updates --- src/probeinterface/generator.py | 77 ++++++++++-------- src/probeinterface/io.py | 31 +++++--- src/probeinterface/library.py | 11 +-- src/probeinterface/probe.py | 133 ++++++++++++++++++-------------- src/probeinterface/utils.py | 10 +-- src/probeinterface/wiring.py | 10 +-- 6 files changed, 158 insertions(+), 114 deletions(-) diff --git a/src/probeinterface/generator.py b/src/probeinterface/generator.py index 709402f..f690b1d 100644 --- a/src/probeinterface/generator.py +++ b/src/probeinterface/generator.py @@ -10,15 +10,15 @@ from .utils import combine_probes -def generate_dummy_probe(elec_shapes="circle"): +def generate_dummy_probe(elec_shapes:str ="circle"): """ Generate a dummy probe with 3 columns and 32 contacts. Mainly used for testing and examples. Parameters ---------- - elec_shapes : str, optional - Shape of the electrodes, by default 'circle' + elec_shapes : str, , by default 'circle' + Shape of the electrodes with possibilities of ('circle', 'square', 'rect') Returns ------- @@ -72,10 +72,13 @@ def generate_dummy_probe_group(): return probegroup -def generate_tetrode(r=10): +def generate_tetrode(r:float=10): """ Generate a tetrode Probe. - + Parameters + ---------- + r: float + The distance multiplier for the positions Returns ------- probe : Probe @@ -89,32 +92,35 @@ def generate_tetrode(r=10): def generate_multi_columns_probe( - num_columns=3, - num_contact_per_column=10, - xpitch=20, - ypitch=20, + num_columns:int=3, + num_contact_per_column:int =10, + xpitch:float=20, + ypitch:float=20, y_shift_per_column=None, - contact_shapes="circle", - contact_shape_params={"radius": 6}, + contact_shapes:str="circle", + contact_shape_params:dict={"radius": 6}, ): """Generate a Probe with several columns. Parameters ---------- - num_columns : int, optional - Number of columns, by default 3 - num_contact_per_column : int, optional - Number of contacts per column, by default 10 - xpitch : float, optional - Pitch in x direction, by default 20 - ypitch : float, optional - Pitch in y direction, by default 20 + num_columns : int, by default 3 + Number of columns + num_contact_per_column : int, by default 10 + Number of contacts per column + xpitch : float, by default 20 + Pitch in x direction + ypitch : float, by default 20 + Pitch in y direction y_shift_per_column : array-like, optional Shift in y direction per column. It needs to have the same length as num_columns, by default None - contact_shapes : str, optional - Shape of the contacts ('circle', 'rect', 'square'), by default 'circle' - contact_shape_params : dict, optional - Parameters for the shape, by default {'radius': 6} + contact_shapes : str, by default 'circle' + Shape of the contacts ('circle', 'rect', 'square') + contact_shape_params : dict, default {'radius': 6} + Parameters for the shape. + For circle: {"radius": float} + For square: {"width": float} + For rectangle: {"width": float, "height": float} Returns ------- @@ -144,19 +150,22 @@ def generate_multi_columns_probe( return probe -def generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="circle", contact_shape_params={"radius": 6}): +def generate_linear_probe(num_elec: int =16, ypitch:float=20, contact_shapes:str="circle", contact_shape_params:dict={"radius": 6}): """Generate a one-column linear probe. Parameters ---------- - num_elec : int, optional + num_elec : int Number of electrodes, by default 16 - ypitch : float, optional + ypitch : float Pitch in y direction, by default 20 - contact_shapes : str, optional - Shape of the contacts ('circle', 'rect', 'square'), by default 'circle' - contact_shape_params : dict, optional - Parameters for the shape, by default {'radius': 6} + contact_shapes : str, default 'circle' + Shape of the contacts ('circle', 'rect', 'square') + contact_shape_params : dict, default {'radius': 6} + Parameters for the shape. + For circle: {"radius": float} + For square: {"width": float} + For rectangle: {"width": float, "height": float} Returns ------- @@ -175,15 +184,15 @@ def generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="circle", conta return probe -def generate_multi_shank(num_shank=2, shank_pitch=[150, 0], **kargs): +def generate_multi_shank(num_shank:int=2, shank_pitch:list=[150, 0], **kargs): """Generate a multi-shank probe. Internally, calls generate_multi_columns_probe and combine_probes. Parameters ---------- - num_shank : int, optional - Number of shanks, by default 2 - shank_pitch : list, optional + num_shank : int, default 2 + Number of shanks + shank_pitch : list, default [150,0] Distance between shanks, by default [150, 0] Returns diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index cd855af..f60dca4 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -32,7 +32,7 @@ def _probeinterface_format_check_version(d): pass -def read_probeinterface(file): +def read_probeinterface(file:Union[str, Path])-> ProbeGroup: """ Read probeinterface JSON-based format. @@ -80,7 +80,7 @@ def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Prob elif isinstance(probe_or_probegroup, ProbeGroup): probegroup = probe_or_probegroup else: - raise ValueError("write_probeinterface : need probe or probegroup") + raise ValueError("write_probeinterface : needs a probe or probegroup") file = Path(file) @@ -175,23 +175,23 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P if "contact_shapes" not in df_probe: df_probe["contact_shapes"] = "circle" df_probe["radius"] = 1 - print(f"There is no contact shape provided for probe {probe_id}, a " f"dummy circle with 1um is created") + print(f"There is no contact shape provided for probe {probe_id}, a " f"dummy circle with 1um radius will be used.") if "x" not in df_probe: df_probe["x"] = np.arange(len(df_probe.index), dtype=float) print( - f"There is no x coordinate provided for probe {probe_id}, a " f"dummy linear x coordinate is created." + f"There is no x coordinate provided for probe {probe_id}, a " f"dummy linear x coordinate will be used." ) if "y" not in df_probe: df_probe["y"] = 0.0 print( - f"There is no y coordinate provided for probe {probe_id}, a " f"dummy constant y coordinate is created." + f"There is no y coordinate provided for probe {probe_id}, a " f"dummy constant y coordinate will be used." ) if "si_units" not in df_probe: df_probe["si_units"] = "um" - print(f"There is no SI units provided for probe {probe_id}, a " f"dummy SI unit (um) is created.") + print(f"There is no SI unit provided for probe {probe_id}, a " f"dummy SI unit (um) will be used") # create probe object and register with probegroup probe = Probe.from_dataframe(df=df_probe) @@ -290,7 +290,7 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P return probegroup -def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup], prefix=""): +def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup], prefix:str=""): """ Write to probe and contact formats as proposed for ephy BIDS extension (tsv & json based). @@ -577,7 +577,7 @@ def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width: return probe -def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode="by_probe"): +def write_prb(file:str, probegroup:ProbeGroup, total_nb_channels:Optional[int]=None, radius:Optional[float]=None, group_mode:str="by_probe"): """ Write ProbeGroup into a prb file. @@ -596,6 +596,19 @@ def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode= * "radius" is needed by spyking-circus * "graph" is not handled + Parameters + ---------- + file: str + The name of the file to be written + probegroup: ProbeGroup + The Probegroup to be used for writing + total_nb_channels: Optional[int], default None + ***to do + radius: Optional[float], default None + *** to do + group_mode: str + One of "by_probe" or "by_shank + """ assert group_mode in ("by_probe", "by_shank") @@ -643,7 +656,7 @@ def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode= f.write("}\n") -def read_csv(file): +def read_csv(file:Union[str, Path]): """ Return a 2 or 3 columns csv file with contact positions """ diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index e5686d6..20c0b15 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -11,6 +11,7 @@ import os from pathlib import Path from urllib.request import urlopen +from typing import Optional from .io import read_probeinterface @@ -24,7 +25,7 @@ cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" -def download_probeinterface_file(manufacturer, probe_name): +def download_probeinterface_file(manufacturer:str, probe_name:str): """Download the probeinterface file to the cache directory. Note that the file is itself a ProbeGroup but on the repo each file represents one probe. @@ -44,14 +45,14 @@ def download_probeinterface_file(manufacturer, probe_name): f.write(dist.read()) -def get_from_cache(manufacturer, probe_name): +def get_from_cache(manufacturer:str, probe_name:str)-> Optional["Probe"]: """ Get Probe from local cache Parameters ---------- manufacturer : str - The probe manufacturer (e.g. 'cambridgeneurotech') + The probe manufacturer (e.g. 'cambridgeneurotech', 'neuronexus') probe_name : str The probe name @@ -71,14 +72,14 @@ def get_from_cache(manufacturer, probe_name): return probe -def get_probe(manufacturer, probe_name): +def get_probe(manufacturer:str, probe_name:str)-> "Probe": """ Get probe from ProbeInterface library Parameters ---------- manufacturer : str - The probe manufacturer (e.g. 'cambridgeneurotech') + The probe manufacturer (e.g. 'cambridgeneurotech', 'neuronexus') probe_name : str The probe name diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index ba2548c..ff11258 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Sequence, Union, Optional from .shank import Shank @@ -15,7 +16,7 @@ class Probe: """ - def __init__(self, ndim=2, si_units="um"): + def __init__(self, ndim:int=2, si_units:int="um"): """ Some attributes are protected and have to be set with setters: * set_contacts(...) @@ -90,7 +91,7 @@ def contact_ids(self): def shank_ids(self): return self._shank_ids - def get_title(self): + def get_title(self)->str: if self.contact_positions is None: txt = "Undefined probe" else: @@ -131,14 +132,14 @@ def check_annotations(self): if "first_index" in d: assert d["first_index"] in (0, 1) - def get_contact_count(self): + def get_contact_count(self)->int: """ Return the number of contacts on the probe. """ assert self.contact_positions is not None return len(self.contact_positions) - def get_shank_count(self): + def get_shank_count(self)->int: """ Return the number of shanks for this probe. """ @@ -194,7 +195,7 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, else: self._shank_ids = np.asarray(shank_ids).astype(str) if self.shank_ids.size != n: - raise ValueError("shan_ids have wring size") + raise ValueError("shank_ids have wrong size") # shape if isinstance(shapes, str): @@ -203,7 +204,7 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, if not np.all(np.in1d(shapes, _possible_contact_shapes)): raise ValueError(f"contacts shape must be in {_possible_contact_shapes}") if shapes.shape[0] != n: - raise ValueError("contacts shape must have same length as posistions") + raise ValueError("contacts shape must have same length as positions") self._contact_shapes = np.array(shapes) # shape params @@ -211,7 +212,7 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, shape_params = [shape_params] * n self._contact_shape_params = np.array(shape_params) - def set_planar_contour(self, contour_polygon): + def set_planar_contour(self, contour_polygon:list): """Set the planar countour (the shape) of the probe. Parameters @@ -224,15 +225,15 @@ def set_planar_contour(self, contour_polygon): raise ValueError("contour_polygon.shape[1] and ndim do not match!") self.probe_planar_contour = contour_polygon - def create_auto_shape(self, probe_type="tip", margin=20.0): + def create_auto_shape(self, probe_type:str="tip", margin:float=20.0): """Create planar contour automatically based on probe contact positions. Parameters ---------- - probe_type : str, optional - The probe type ('tip' or 'rect'), by default 'tip' - margin : float, optional - The margin to add to the contact positions, by default 20 + probe_type : str, by default 'tip' + The probe type ('tip' or 'rect') + margin : float, by default 20.0 + The margin to add to the contact positions """ if self.ndim != 2: @@ -279,7 +280,7 @@ def create_auto_shape(self, probe_type="tip", margin=20.0): self.set_planar_contour(polygon) - def set_device_channel_indices(self, channel_indices): + def set_device_channel_indices(self, channel_indices:Sequence[int]): """ Manually set the device channel indices. @@ -297,7 +298,7 @@ def set_device_channel_indices(self, channel_indices): if self._probe_group is not None: self._probe_group.check_global_device_wiring_and_ids() - def wiring_to_device(self, pathway, channel_offset=0): + def wiring_to_device(self, pathway:str, channel_offset:int=0): """ Automatically set device_channel_indices based on a pathway. @@ -305,15 +306,16 @@ def wiring_to_device(self, pathway, channel_offset=0): Parameters ---------- - pathway : str The pathway. E.g. 'H32>RHD' + channel_offset: int, default 0 + An optional offset to add to the device_channel_indices """ from .wiring import wire_probe wire_probe(self, pathway, channel_offset=channel_offset) - def set_contact_ids(self, contact_ids): + def set_contact_ids(self, contact_ids:Sequence[Union[int, float, str]]): """ Set contact ids. Channel ids are converted to strings. Contact ids must be **unique** for the **Probe** @@ -328,7 +330,7 @@ def set_contact_ids(self, contact_ids): contact_ids = np.asarray(contact_ids) if contact_ids.size != self.get_contact_count(): - ValueError("channel_indices have not the same size as contact") + ValueError("channel_indices do not have the same size as contact") if contact_ids.dtype.kind != "U": contact_ids = contact_ids.astype("U") @@ -337,14 +339,14 @@ def set_contact_ids(self, contact_ids): if self._probe_group is not None: self._probe_group.check_global_device_wiring_and_ids() - def set_shank_ids(self, shank_ids): + def set_shank_ids(self, shank_ids:Sequence[Union[int, float, str]]): """ Set shank ids. Parameters ---------- shank_ids : list or array - Array with shank ids + Array with shank ids, if int or float converted to strings """ shank_ids = np.asarray(shank_ids).astype(str) if shank_ids.size != self.get_contact_count(): @@ -383,7 +385,7 @@ def copy(self): # channel_indices are not copied return other - def to_3d(self, axes="xz"): + def to_3d(self, axes:str="xz"): """ Transform 2d probe to 3d probe. @@ -391,7 +393,7 @@ def to_3d(self, axes="xz"): Parameters ---------- - axes : str + axes : str, default "xz" The axes that define the plane on which the 2D probe is defined. 'xy', 'yz' ', xz' """ assert self.ndim == 2 @@ -420,7 +422,7 @@ def to_3d(self, axes="xz"): return probe3d - def to_2d(self, axes="xy"): + def to_2d(self, axes:str="xy"): """ Transform 3d probe to 2d probe. @@ -451,7 +453,7 @@ def to_2d(self, axes="xy"): return probe2d - def get_contact_vertices(self): + def get_contact_vertices(self)->list: """ Return a list of contact vertices. """ @@ -491,7 +493,7 @@ def get_contact_vertices(self): vertices.append(one_vertice) return vertices - def move(self, translation_vector): + def move(self, translation_vector:Sequence[int]): """ Translate the probe in one direction. @@ -509,7 +511,7 @@ def move(self, translation_vector): if self.probe_planar_contour is not None: self.probe_planar_contour += translation_vector - def rotate(self, theta, center=None, axis=None): + def rotate(self, theta:float, center=None, axis=None): """ Rotate the probe around a specified axis. @@ -557,15 +559,15 @@ def rotate(self, theta, center=None, axis=None): new_vertices = (self.probe_planar_contour - center) @ R + center self.probe_planar_contour = new_vertices - def rotate_contacts(self, thetas): + def rotate_contacts(self, thetas:Union[float, Sequence[float]]): """ Rotate each contact of the probe. - Internaly, it modifies the contact_plane_axes. + Internally, it modifies the contact_plane_axes. Parameters ---------- thetas : array of float - Rotation angle in degree. + Rotation angle in degrees. If scalar, then it is applied to all contacts. """ @@ -600,14 +602,14 @@ def rotate_contacts(self, thetas): "_shank_ids", ] - def to_dict(self, array_as_list=False): + def to_dict(self, array_as_list:bool=False)->dict: """Create a dictionary of all necessary attributes. Useful for dumping and saving to json. Parameters ---------- - array_as_list : bool, optional - If True, arrays are converted to lists, by default False + array_as_list : bool, default False + If True, arrays are converted to lists Returns ------- @@ -631,7 +633,7 @@ def to_dict(self, array_as_list=False): return d @staticmethod - def from_dict(d): + def from_dict(d:dict): """Instantiate a Probe from a dictionary Parameters @@ -676,7 +678,7 @@ def from_dict(d): return probe - def to_numpy(self, complete=False): + def to_numpy(self, complete:bool=False)->np.array: """ Export to a numpy vector (structured array). This vector handles all contact attributes. @@ -687,9 +689,9 @@ def to_numpy(self, complete=False): Parameters ---------- - complete : bool + complete : bool, default False If True, export complete information about the probe, - including contact_plane_axes/si_units/device_channel_indices (default False) + including contact_plane_axes/si_units/device_channel_indices returns --------- @@ -757,7 +759,7 @@ def to_numpy(self, complete=False): return arr @staticmethod - def from_numpy(arr): + def from_numpy(arr:np.ndarray): """ Create Probe from a complex numpy array see Probe.to_numpy() @@ -780,8 +782,8 @@ def from_numpy(arr): else: ndim = 2 - assert "x" in fields - assert "y" in fields + assert "x" in fields, "arr must contain a .dtype.fields of x" + assert "y" in fields, "arr must contain a .dtype.fields of y" if "si_units" in fields: assert np.unique(arr["si_units"]).size == 1 si_units = np.unique(arr["si_units"])[0] @@ -828,13 +830,13 @@ def from_numpy(arr): return probe - def to_dataframe(self, complete=False): + def to_dataframe(self, complete:bool=False)-> "pandas.DataFrame": """ Export the probe to a pandas dataframe Parameters ---------- - complete : bool + complete : bool, default False If True, export complete information about the probe, including the probe plane axis. @@ -852,7 +854,7 @@ def to_dataframe(self, complete=False): return df @staticmethod - def from_dataframe(df): + def from_dataframe(df:"pandas.DataFrame"): """ Create Probe from a pandas.DataFrame see Probe.to_dataframe() @@ -870,9 +872,9 @@ def from_dataframe(df): arr = df.to_records(index=False) return Probe.from_numpy(arr) - def to_image(self, values, pixel_size=0.5, num_pixel=None, method="linear", xlims=None, ylims=None): + def to_image(self, values:Sequence, pixel_size:float=0.5, num_pixel:Optional[int]=None, method:str="linear", xlims:Optional[tuple]=None, ylims:Optional[tuple]=None)-> tuple[np.ndarray, tuple, tuple]: """ - Generated a 2d (image) from a values vector which an interpolation + Generated a 2d (image) from a values vector with an interpolation into a grid mesh. Parameters @@ -883,10 +885,11 @@ def to_image(self, values, pixel_size=0.5, num_pixel=None, method="linear", xlim size of one pixel in micrometers num_pixel : alternative to pixel_size give pixel number of the image width - method : 'linear' or 'nearest' or 'cubic' - xlims : tuple or None + method : str, default 'linear' + One of the options: 'linear' or 'nearest' or 'cubic' + xlims : Optional[tuple], default None Force image xlims - ylims : tuple or None + ylims : Optional[tuple], default None Force image ylims Returns @@ -932,7 +935,7 @@ def to_image(self, values, pixel_size=0.5, num_pixel=None, method="linear", xlim return image, xlims, ylims - def get_slice(self, selection): + def get_slice(self, selection:np.ndarray[Union[bool, int]]): """ Get a copy of the Probe with a sub selection of contacts. @@ -941,6 +944,13 @@ def get_slice(self, selection): Parameters ---------- selection : np.array of bool or int (for index) + Either an np.array of bool for desired selection of contacts + or the indices of the desired contacts + + Returns + ------- + sliced_probe: Probe + The sliced probe """ @@ -975,7 +985,7 @@ def get_slice(self, selection): return sliced_probe -def _2d_to_3d(data2d, axes): +def _2d_to_3d(data2d:np.ndarray, axes:str)-> np.ndarray: """ Add a third dimension @@ -997,7 +1007,7 @@ def _2d_to_3d(data2d, axes): return data3d -def select_axes(data, axes="xy"): +def select_axes(data:np.ndarray, axes:str="xy")->np.ndarray: """ Select axes in a 3d or 2d array. @@ -1005,7 +1015,7 @@ def select_axes(data, axes="xy"): ---------- data: np.array shape (n, 2) or (n, 3) - axes: str + axes: str, default 'xy' 'xy', 'yz' 'xz' or 'xyz' Returns ------- @@ -1018,24 +1028,35 @@ def select_axes(data, axes="xy"): return data[:, dims] -def _3d_to_2d(data3d, axes="xy"): +def _3d_to_2d(data3d:np.ndarray, axes:str="xy")-> np.ndarray: """ Reduce 3d array to 2d array on given axes. + + Parameters + ---------- + data: np.ndarray + The data with shape (n,3) + axes: str, default 'xy' + The axes over which to reduce the 2d array + + Returns + ------- + reduced_data: np.ndarray + The reduced data array """ assert data3d.shape[1] == 3 assert len(axes) == 2 return select_axes(data3d, axes=axes) -def _rotation_matrix_2d(theta): +def _rotation_matrix_2d(theta:float)->np.ndarray: """ Returns 2D rotation matrix Parameters ---------- theta : float - Angle in radians for rotation (anti-clockwise) - + Angle in radians for rotation (anti-clockwise/counterclockwise) Returns ------- R : np.array @@ -1046,7 +1067,7 @@ def _rotation_matrix_2d(theta): return R -def _rotation_matrix_3d(axis, theta): +def _rotation_matrix_3d(axis:Sequence, theta:float)->np.ndarray: """ Returns 3D rotation matrix @@ -1057,7 +1078,7 @@ def _rotation_matrix_3d(axis, theta): axis : np.array or list 3D axis of rotation theta : float - Angle in radians for rotation anti-clockwise + Angle in radians for rotation anti-clockwise/counterclockwise Returns ------- diff --git a/src/probeinterface/utils.py b/src/probeinterface/utils.py index 89bad84..d484580 100644 --- a/src/probeinterface/utils.py +++ b/src/probeinterface/utils.py @@ -47,7 +47,7 @@ def import_safely(module: str) -> ModuleType: return module_obj -def combine_probes(probes, connect_shape=True): +def combine_probes(probes:Probe, connect_shape:bool=True): """ Combine several Probe objects into a unique multi-shank Probe object. @@ -63,8 +63,8 @@ def combine_probes(probes, connect_shape=True): ---------- probes : list List of Probe objects - connect_shape : bool (default True) - Connect all shapes togother. + connect_shape : bool, default True + Connect all shapes together. Be careful, as this can lead to strange probe shape.... Return @@ -102,7 +102,7 @@ def combine_probes(probes, connect_shape=True): return multi_shank -def generate_unique_ids(min, max, n, trials=20): +def generate_unique_ids(min:int, max:int, n:int, trials:int=20)-> np.array: """ Create n unique identifiers. Creates `n` unique integer identifiers between `min` and `max` within a @@ -116,7 +116,7 @@ def generate_unique_ids(min, max, n, trials=20): Maximum value permitted for an identifier n : int Number of identifiers to create - trials : int + trials : int, default 20 Maximal number of attempts for generating the set of identifiers diff --git a/src/probeinterface/wiring.py b/src/probeinterface/wiring.py index e8397b6..105c8dc 100644 --- a/src/probeinterface/wiring.py +++ b/src/probeinterface/wiring.py @@ -50,7 +50,7 @@ # fmt: on -def get_available_pathways(): +def get_available_pathways()->list: """Return available pathways Returns @@ -61,7 +61,7 @@ def get_available_pathways(): return list(pathways.keys()) -def wire_probe(probe, pathway, channel_offset=0): +def wire_probe(probe:"Probe", pathway:str, channel_offset:int=0): """Inplace wiring for a Probe using a pathway Parameters @@ -70,10 +70,10 @@ def wire_probe(probe, pathway, channel_offset=0): The probe to wire pathway : str The pathway to use - channel_offset : int, optional - An optional offset to add to the device_channel_indices, by default 0 + channel_offset : int, default 0 + An optional offset to add to the device_channel_indices """ - assert pathway in pathways + assert pathway in pathways, f"{pathway} is not a currently supported pathway" chan_indices = np.array(pathways[pathway], dtype="int64") + channel_offset assert chan_indices.size == probe.get_contact_count() probe.set_device_channel_indices(chan_indices)