From ce3250a04e2fda733ea307544f5a38e9e760f4e6 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:36:43 -0400 Subject: [PATCH] switch to | notation --- src/probeinterface/generator.py | 32 +++++++++++++------------- src/probeinterface/io.py | 33 ++++++++++++++------------- src/probeinterface/probe.py | 40 +++++++++++++++++---------------- src/probeinterface/utils.py | 2 +- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/src/probeinterface/generator.py b/src/probeinterface/generator.py index cd3f8e6..3144963 100644 --- a/src/probeinterface/generator.py +++ b/src/probeinterface/generator.py @@ -2,15 +2,17 @@ This module contains useful helper functions for generating probes. """ - +from __future__ import annotations import numpy as np +from typing import Optional + from .probe import Probe from .probegroup import ProbeGroup from .utils import combine_probes -def generate_dummy_probe(elec_shapes: str = "circle"): +def generate_dummy_probe(elec_shapes:str ="circle") -> Probe: """ Generate a dummy probe with 3 columns and 32 contacts. Mainly used for testing and examples. @@ -49,7 +51,7 @@ def generate_dummy_probe(elec_shapes: str = "circle"): return probe -def generate_dummy_probe_group(): +def generate_dummy_probe_group() -> ProbeGroup: """ Generate a ProbeGroup with 2 probes. Mainly used for testing and examples. @@ -72,7 +74,7 @@ def generate_dummy_probe_group(): return probegroup -def generate_tetrode(r: float = 10): +def generate_tetrode(r: float = 10)-> Probe: """ Generate a tetrode Probe. Parameters @@ -92,14 +94,14 @@ def generate_tetrode(r: float = 10): def generate_multi_columns_probe( - num_columns: int = 3, - num_contact_per_column: int = 10, - xpitch: float = 20, - ypitch: float = 20, - y_shift_per_column=None, - contact_shapes: str = "circle", - contact_shape_params: dict = {"radius": 6}, -): + num_columns:int=3, + num_contact_per_column:int=10, + xpitch:float=20, + ypitch:float=20, + y_shift_per_column: Optional[np.array | list]=None, + contact_shapes:str="circle", + contact_shape_params:dict={"radius": 6}, +) -> Probe: """Generate a Probe with several columns. Parameters @@ -150,9 +152,7 @@ def generate_multi_columns_probe( return probe -def generate_linear_probe( - num_elec: int = 16, ypitch: float = 20, contact_shapes: str = "circle", contact_shape_params: dict = {"radius": 6} -): +def generate_linear_probe(num_elec: int =16, ypitch:float=20, contact_shapes:str="circle", contact_shape_params:dict={"radius": 6}) -> Probe: """Generate a one-column linear probe. Parameters @@ -186,7 +186,7 @@ def generate_linear_probe( return probe -def generate_multi_shank(num_shank: int = 2, shank_pitch: list = [150, 0], **kargs): +def generate_multi_shank(num_shank:int=2, shank_pitch:list=[150, 0], **kargs) -> Probe: """Generate a multi-shank probe. Internally, calls generate_multi_columns_probe and combine_probes. diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 5499af9..c5b3d5c 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -9,6 +9,7 @@ * Neurodata Without Borders (.nwb) """ +from __future__ import annotations from pathlib import Path from typing import Union, Optional import re @@ -32,7 +33,7 @@ def _probeinterface_format_check_version(d): pass -def read_probeinterface(file: Union[str, Path]) -> ProbeGroup: +def read_probeinterface(file: str | Path)-> ProbeGroup: """ Read probeinterface JSON-based format. @@ -58,7 +59,7 @@ def read_probeinterface(file: Union[str, Path]) -> ProbeGroup: return ProbeGroup.from_dict(d) -def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup]): +def write_probeinterface(file: str | Path, probe_or_probegroup: Probe | ProbeGroup): """ Write a probeinterface JSON file. @@ -103,7 +104,7 @@ def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Prob tsv_label_map_to_probeinterface = {v: k for k, v in tsv_label_map_to_BIDS.items()} -def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> ProbeGroup: +def read_BIDS_probe(folder: str | Path, prefix: Optional[str] = None) -> ProbeGroup: """ Read to BIDS probe format. @@ -294,7 +295,7 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P return probegroup -def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup], prefix: str = ""): +def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup, prefix:str=""): """ Write to probe and contact formats as proposed for ephy BIDS extension (tsv & json based). @@ -423,7 +424,7 @@ def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, json.dump({"ContactId": contacts_dict}, f, indent=4) -def read_prb(file: Union[str, Path]) -> ProbeGroup: +def read_prb(file: str | Path) -> ProbeGroup: """ Read a PRB file and return a ProbeGroup object. @@ -471,7 +472,7 @@ def read_prb(file: Union[str, Path]) -> ProbeGroup: return probegroup -def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: str = "rec0000") -> Probe: +def read_maxwell(file: str | Path, well_name: str = "well000", rec_name: str = "rec0000") -> Probe: """ Read a maxwell file and return a Probe object. The Maxwell file format can be either Maxone (and thus just the file name is needed), or MaxTwo. In case @@ -538,7 +539,7 @@ def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: s return probe -def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width: float = 21) -> Probe: +def read_3brain(file: str| Path, mea_pitch: float = 42, electrode_width: float = 21) -> Probe: """ Read a 3brain file and return a Probe object. The 3brain file format can be either an .h5 file or a .brw @@ -666,7 +667,7 @@ def write_prb( f.write("}\n") -def read_csv(file: Union[str, Path]): +def read_csv(file: str | Path): """ Return a 2 or 3 columns csv file with contact positions """ @@ -1053,7 +1054,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe return probe -def write_imro(file, probe): +def write_imro(file: str | Path, probe: Probe): """ save imro file (`.imrc`, imec readout) in a file. https://github.com/open-ephys-plugins/neuropixels-pxi/blob/master/Source/Formats/IMRO.h @@ -1063,6 +1064,7 @@ def write_imro(file, probe): file : Path or str The file path probe : Probe object + """ probe_type = probe.annotations["probe_type"] data = probe.to_dataframe(complete=True).sort_values("device_channel_indices") @@ -1095,7 +1097,7 @@ def write_imro(file, probe): f.write("".join(ret)) -def read_spikeglx(file: Union[str, Path]) -> Probe: +def read_spikeglx(file: str | Path) -> Probe: """ Read probe position for the meta file generated by SpikeGLX @@ -1143,7 +1145,7 @@ def read_spikeglx(file: Union[str, Path]) -> Probe: return probe -def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict: +def parse_spikeglx_meta(meta_file: str | Path) -> dict: """ Parse the "meta" file from spikeglx into a dict. All fiields are kept in txt format and must also parsed themself. @@ -1164,7 +1166,7 @@ def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict: return meta -def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) -> np.array: +def get_saved_channel_indices_from_spikeglx_meta(meta_file: str | Path) -> np.array: """ Utils function to get the saved channels. @@ -1174,6 +1176,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) -> This function come from here Jennifer Colonell https://github.com/jenniferColonell/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/common/SGLXMetaToCoords.py#L65 + """ meta_file = Path(meta_file) meta = parse_spikeglx_meta(meta_file) @@ -1195,7 +1198,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) -> def read_openephys( - settings_file: Union[str, Path], + settings_file: str | Path, stream_name: Optional[str] = None, probe_name: Optional[str] = None, serial_number: Optional[str] = None, @@ -1553,7 +1556,7 @@ def read_openephys( return probe -def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name): +def get_saved_channel_indices_from_openephys_settings(settings_file: str | Path, stream_name: str) -> Optional[np.array]: """ Returns an array with the subset of saved channels indices (if used) @@ -1620,7 +1623,7 @@ def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name return chans_saved -def read_mearec(file: Union[str, Path]) -> Probe: +def read_mearec(file: str | Path) -> Probe: """ Read probe position, and contact shape from a MEArec file. diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index a32f1d7..81d8ba9 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -1,5 +1,7 @@ +from __future__ import annotations import numpy as np -from typing import Sequence, Union, Optional +from typing import Optional + from .shank import Shank @@ -280,7 +282,7 @@ def create_auto_shape(self, probe_type: str = "tip", margin: float = 20.0): self.set_planar_contour(polygon) - def set_device_channel_indices(self, channel_indices: Sequence[int]): + def set_device_channel_indices(self, channel_indices: np.array | list): """ Manually set the device channel indices. @@ -315,7 +317,7 @@ def wiring_to_device(self, pathway: str, channel_offset: int = 0): wire_probe(self, pathway, channel_offset=channel_offset) - def set_contact_ids(self, contact_ids: Sequence[Union[int, float, str]]): + def set_contact_ids(self, contact_ids: np.array | list): """ Set contact ids. Channel ids are converted to strings. Contact ids must be **unique** for the **Probe** @@ -330,7 +332,7 @@ def set_contact_ids(self, contact_ids: Sequence[Union[int, float, str]]): contact_ids = np.asarray(contact_ids) if contact_ids.size != self.get_contact_count(): - ValueError("channel_indices do not have the same size as contact") + ValueError(f"channel_indices do not have the same size as number of contacts") if contact_ids.dtype.kind != "U": contact_ids = contact_ids.astype("U") @@ -339,7 +341,7 @@ def set_contact_ids(self, contact_ids: Sequence[Union[int, float, str]]): if self._probe_group is not None: self._probe_group.check_global_device_wiring_and_ids() - def set_shank_ids(self, shank_ids: Sequence[Union[int, float, str]]): + def set_shank_ids(self, shank_ids: np.array | list): """ Set shank ids. @@ -493,7 +495,7 @@ def get_contact_vertices(self) -> list: vertices.append(one_vertice) return vertices - def move(self, translation_vector: Sequence[int]): + def move(self, translation_vector: np.array | list): """ Translate the probe in one direction. @@ -559,7 +561,7 @@ def rotate(self, theta: float, center=None, axis=None): new_vertices = (self.probe_planar_contour - center) @ R + center self.probe_planar_contour = new_vertices - def rotate_contacts(self, thetas: Union[float, Sequence[float]]): + def rotate_contacts(self, thetas: float | np.array[float] | list[float]): """ Rotate each contact of the probe. Internally, it modifies the contact_plane_axes. @@ -633,7 +635,7 @@ def to_dict(self, array_as_list: bool = False) -> dict: return d @staticmethod - def from_dict(d: dict): + def from_dict(d:dict) -> "Probe": """Instantiate a Probe from a dictionary Parameters @@ -844,6 +846,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": ------- df : pandas.DataFrame The dataframe representation of the probe + """ import pandas as pd @@ -868,19 +871,12 @@ def from_dataframe(df: "pandas.DataFrame"): ------- probe : Probe The instantiated Probe object + """ arr = df.to_records(index=False) return Probe.from_numpy(arr) - def to_image( - self, - values: Sequence, - pixel_size: float = 0.5, - num_pixel: Optional[int] = None, - method: str = "linear", - xlims: Optional[tuple] = None, - ylims: Optional[tuple] = None, - ) -> tuple[np.ndarray, tuple, tuple]: + def to_image(self, values: np.array | list, pixel_size:float=0.5, num_pixel:Optional[int]=None, method:str="linear", xlims:Optional[tuple]=None, ylims:Optional[tuple]=None)-> tuple[np.ndarray, tuple, tuple]: """ Generated a 2d (image) from a values vector with an interpolation into a grid mesh. @@ -943,7 +939,7 @@ def to_image( return image, xlims, ylims - def get_slice(self, selection: np.ndarray[Union[bool, int]]): + def get_slice(self, selection:np.ndarray[bool|int]): """ Get a copy of the Probe with a sub selection of contacts. @@ -1003,10 +999,12 @@ def _2d_to_3d(data2d: np.ndarray, axes: str) -> np.ndarray: shape (n, 2) axes: str The axes that define the plane where electrodes lie on. E.g. 'xy', 'yz' or 'xz' + Returns ------- data3d shape (n, 3) + """ data3d = np.zeros((data2d.shape[0], 3), dtype=data2d.dtype) dims = np.array(["xyz".index(axis) for axis in axes]) @@ -1025,10 +1023,12 @@ def select_axes(data: np.ndarray, axes: str = "xy") -> np.ndarray: shape (n, 2) or (n, 3) axes: str, default 'xy' 'xy', 'yz' 'xz' or 'xyz' + Returns ------- data3d shape (n, 3) + """ assert np.all([axes.count(axis) == 1 for axis in axes]), "select_axes : axes must be unique." dims = np.array(["xyz".index(axis) for axis in axes]) @@ -1051,6 +1051,7 @@ def _3d_to_2d(data3d: np.ndarray, axes: str = "xy") -> np.ndarray: ------- reduced_data: np.ndarray The reduced data array + """ assert data3d.shape[1] == 3 assert len(axes) == 2 @@ -1065,6 +1066,7 @@ def _rotation_matrix_2d(theta: float) -> np.ndarray: ---------- theta : float Angle in radians for rotation (anti-clockwise/counterclockwise) + Returns ------- R : np.array @@ -1075,7 +1077,7 @@ def _rotation_matrix_2d(theta: float) -> np.ndarray: return R -def _rotation_matrix_3d(axis: Sequence, theta: float) -> np.ndarray: +def _rotation_matrix_3d(axis: np.array | list, theta:float)->np.ndarray: """ Returns 3D rotation matrix diff --git a/src/probeinterface/utils.py b/src/probeinterface/utils.py index 69e23d9..11bdc8a 100644 --- a/src/probeinterface/utils.py +++ b/src/probeinterface/utils.py @@ -47,7 +47,7 @@ def import_safely(module: str) -> ModuleType: return module_obj -def combine_probes(probes: Probe, connect_shape: bool = True): +def combine_probes(probes:Probe, connect_shape:bool=True) -> Probe: """ Combine several Probe objects into a unique multi-shank Probe object.