Skip to content

Commit

Permalink
switch to | notation
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Sep 15, 2023
1 parent d45bfa3 commit ce3250a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 51 deletions.
32 changes: 16 additions & 16 deletions src/probeinterface/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
This module contains useful helper functions for generating probes.
"""

from __future__ import annotations
import numpy as np

from typing import Optional

from .probe import Probe
from .probegroup import ProbeGroup
from .utils import combine_probes


def generate_dummy_probe(elec_shapes: str = "circle"):
def generate_dummy_probe(elec_shapes:str ="circle") -> Probe:
"""
Generate a dummy probe with 3 columns and 32 contacts.
Mainly used for testing and examples.
Expand Down Expand Up @@ -49,7 +51,7 @@ def generate_dummy_probe(elec_shapes: str = "circle"):
return probe


def generate_dummy_probe_group():
def generate_dummy_probe_group() -> ProbeGroup:
"""
Generate a ProbeGroup with 2 probes.
Mainly used for testing and examples.
Expand All @@ -72,7 +74,7 @@ def generate_dummy_probe_group():
return probegroup


def generate_tetrode(r: float = 10):
def generate_tetrode(r: float = 10)-> Probe:
"""
Generate a tetrode Probe.
Parameters
Expand All @@ -92,14 +94,14 @@ def generate_tetrode(r: float = 10):


def generate_multi_columns_probe(
num_columns: int = 3,
num_contact_per_column: int = 10,
xpitch: float = 20,
ypitch: float = 20,
y_shift_per_column=None,
contact_shapes: str = "circle",
contact_shape_params: dict = {"radius": 6},
):
num_columns:int=3,
num_contact_per_column:int=10,
xpitch:float=20,
ypitch:float=20,
y_shift_per_column: Optional[np.array | list]=None,
contact_shapes:str="circle",
contact_shape_params:dict={"radius": 6},
) -> Probe:
"""Generate a Probe with several columns.
Parameters
Expand Down Expand Up @@ -150,9 +152,7 @@ def generate_multi_columns_probe(
return probe


def generate_linear_probe(
num_elec: int = 16, ypitch: float = 20, contact_shapes: str = "circle", contact_shape_params: dict = {"radius": 6}
):
def generate_linear_probe(num_elec: int =16, ypitch:float=20, contact_shapes:str="circle", contact_shape_params:dict={"radius": 6}) -> Probe:
"""Generate a one-column linear probe.
Parameters
Expand Down Expand Up @@ -186,7 +186,7 @@ def generate_linear_probe(
return probe


def generate_multi_shank(num_shank: int = 2, shank_pitch: list = [150, 0], **kargs):
def generate_multi_shank(num_shank:int=2, shank_pitch:list=[150, 0], **kargs) -> Probe:
"""Generate a multi-shank probe.
Internally, calls generate_multi_columns_probe and combine_probes.
Expand Down
33 changes: 18 additions & 15 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Neurodata Without Borders (.nwb)
"""
from __future__ import annotations
from pathlib import Path
from typing import Union, Optional
import re
Expand All @@ -32,7 +33,7 @@ def _probeinterface_format_check_version(d):
pass


def read_probeinterface(file: Union[str, Path]) -> ProbeGroup:
def read_probeinterface(file: str | Path)-> ProbeGroup:
"""
Read probeinterface JSON-based format.
Expand All @@ -58,7 +59,7 @@ def read_probeinterface(file: Union[str, Path]) -> ProbeGroup:
return ProbeGroup.from_dict(d)


def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup]):
def write_probeinterface(file: str | Path, probe_or_probegroup: Probe | ProbeGroup):
"""
Write a probeinterface JSON file.
Expand Down Expand Up @@ -103,7 +104,7 @@ def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Prob
tsv_label_map_to_probeinterface = {v: k for k, v in tsv_label_map_to_BIDS.items()}


def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> ProbeGroup:
def read_BIDS_probe(folder: str | Path, prefix: Optional[str] = None) -> ProbeGroup:
"""
Read to BIDS probe format.
Expand Down Expand Up @@ -294,7 +295,7 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P
return probegroup


def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup], prefix: str = ""):
def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup, prefix:str=""):
"""
Write to probe and contact formats as proposed
for ephy BIDS extension (tsv & json based).
Expand Down Expand Up @@ -423,7 +424,7 @@ def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe,
json.dump({"ContactId": contacts_dict}, f, indent=4)


def read_prb(file: Union[str, Path]) -> ProbeGroup:
def read_prb(file: str | Path) -> ProbeGroup:
"""
Read a PRB file and return a ProbeGroup object.
Expand Down Expand Up @@ -471,7 +472,7 @@ def read_prb(file: Union[str, Path]) -> ProbeGroup:
return probegroup


def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: str = "rec0000") -> Probe:
def read_maxwell(file: str | Path, well_name: str = "well000", rec_name: str = "rec0000") -> Probe:
"""
Read a maxwell file and return a Probe object. The Maxwell file format can be
either Maxone (and thus just the file name is needed), or MaxTwo. In case
Expand Down Expand Up @@ -538,7 +539,7 @@ def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: s
return probe


def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width: float = 21) -> Probe:
def read_3brain(file: str| Path, mea_pitch: float = 42, electrode_width: float = 21) -> Probe:
"""
Read a 3brain file and return a Probe object. The 3brain file format can be
either an .h5 file or a .brw
Expand Down Expand Up @@ -666,7 +667,7 @@ def write_prb(
f.write("}\n")


def read_csv(file: Union[str, Path]):
def read_csv(file: str | Path):
"""
Return a 2 or 3 columns csv file with contact positions
"""
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
return probe


def write_imro(file, probe):
def write_imro(file: str | Path, probe: Probe):
"""
save imro file (`.imrc`, imec readout) in a file.
https://github.com/open-ephys-plugins/neuropixels-pxi/blob/master/Source/Formats/IMRO.h
Expand All @@ -1063,6 +1064,7 @@ def write_imro(file, probe):
file : Path or str
The file path
probe : Probe object
"""
probe_type = probe.annotations["probe_type"]
data = probe.to_dataframe(complete=True).sort_values("device_channel_indices")
Expand Down Expand Up @@ -1095,7 +1097,7 @@ def write_imro(file, probe):
f.write("".join(ret))


def read_spikeglx(file: Union[str, Path]) -> Probe:
def read_spikeglx(file: str | Path) -> Probe:
"""
Read probe position for the meta file generated by SpikeGLX
Expand Down Expand Up @@ -1143,7 +1145,7 @@ def read_spikeglx(file: Union[str, Path]) -> Probe:
return probe


def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict:
def parse_spikeglx_meta(meta_file: str | Path) -> dict:
"""
Parse the "meta" file from spikeglx into a dict.
All fiields are kept in txt format and must also parsed themself.
Expand All @@ -1164,7 +1166,7 @@ def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict:
return meta


def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) -> np.array:
def get_saved_channel_indices_from_spikeglx_meta(meta_file: str | Path) -> np.array:
"""
Utils function to get the saved channels.
Expand All @@ -1174,6 +1176,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) ->
This function come from here Jennifer Colonell
https://github.com/jenniferColonell/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/common/SGLXMetaToCoords.py#L65
"""
meta_file = Path(meta_file)
meta = parse_spikeglx_meta(meta_file)
Expand All @@ -1195,7 +1198,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) ->


def read_openephys(
settings_file: Union[str, Path],
settings_file: str | Path,
stream_name: Optional[str] = None,
probe_name: Optional[str] = None,
serial_number: Optional[str] = None,
Expand Down Expand Up @@ -1553,7 +1556,7 @@ def read_openephys(
return probe


def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name):
def get_saved_channel_indices_from_openephys_settings(settings_file: str | Path, stream_name: str) -> Optional[np.array]:
"""
Returns an array with the subset of saved channels indices (if used)
Expand Down Expand Up @@ -1620,7 +1623,7 @@ def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name
return chans_saved


def read_mearec(file: Union[str, Path]) -> Probe:
def read_mearec(file: str | Path) -> Probe:
"""
Read probe position, and contact shape from a MEArec file.
Expand Down
Loading

0 comments on commit ce3250a

Please sign in to comment.