Skip to content

Commit

Permalink
Merge pull request #214 from zm711/typing
Browse files Browse the repository at this point in the history
Add Typing and Update Docstrings
  • Loading branch information
alejoe91 authored Sep 15, 2023
2 parents 8e1fac4 + bbf25fa commit 8bc5b71
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 129 deletions.
89 changes: 51 additions & 38 deletions src/probeinterface/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
This module contains useful helper functions for generating probes.
"""

from __future__ import annotations
import numpy as np

from typing import Optional

from .probe import Probe
from .probegroup import ProbeGroup
from .utils import combine_probes


def generate_dummy_probe(elec_shapes="circle"):
def generate_dummy_probe(elec_shapes: str = "circle") -> Probe:
"""
Generate a dummy probe with 3 columns and 32 contacts.
Mainly used for testing and examples.
Parameters
----------
elec_shapes : str, optional
Shape of the electrodes, by default 'circle'
elec_shapes : str, , by default 'circle'
Shape of the electrodes with possibilities of ('circle', 'square', 'rect')
Returns
-------
Expand Down Expand Up @@ -49,7 +51,7 @@ def generate_dummy_probe(elec_shapes="circle"):
return probe


def generate_dummy_probe_group():
def generate_dummy_probe_group() -> ProbeGroup:
"""
Generate a ProbeGroup with 2 probes.
Mainly used for testing and examples.
Expand All @@ -72,10 +74,13 @@ def generate_dummy_probe_group():
return probegroup


def generate_tetrode(r=10):
def generate_tetrode(r: float = 10) -> Probe:
"""
Generate a tetrode Probe.
Parameters
----------
r: float
The distance multiplier for the positions
Returns
-------
probe : Probe
Expand All @@ -89,32 +94,35 @@ def generate_tetrode(r=10):


def generate_multi_columns_probe(
num_columns=3,
num_contact_per_column=10,
xpitch=20,
ypitch=20,
y_shift_per_column=None,
contact_shapes="circle",
contact_shape_params={"radius": 6},
):
num_columns: int = 3,
num_contact_per_column: int = 10,
xpitch: float = 20,
ypitch: float = 20,
y_shift_per_column: Optional[np.array | list] = None,
contact_shapes: str = "circle",
contact_shape_params: dict = {"radius": 6},
) -> Probe:
"""Generate a Probe with several columns.
Parameters
----------
num_columns : int, optional
Number of columns, by default 3
num_contact_per_column : int, optional
Number of contacts per column, by default 10
xpitch : float, optional
Pitch in x direction, by default 20
ypitch : float, optional
Pitch in y direction, by default 20
num_columns : int, by default 3
Number of columns
num_contact_per_column : int, by default 10
Number of contacts per column
xpitch : float, by default 20
Pitch in x direction
ypitch : float, by default 20
Pitch in y direction
y_shift_per_column : array-like, optional
Shift in y direction per column. It needs to have the same length as num_columns, by default None
contact_shapes : str, optional
Shape of the contacts ('circle', 'rect', 'square'), by default 'circle'
contact_shape_params : dict, optional
Parameters for the shape, by default {'radius': 6}
contact_shapes : str, by default 'circle'
Shape of the contacts ('circle', 'rect', 'square')
contact_shape_params : dict, default {'radius': 6}
Parameters for the shape.
For circle: {"radius": float}
For square: {"width": float}
For rectangle: {"width": float, "height": float}
Returns
-------
Expand Down Expand Up @@ -144,19 +152,24 @@ def generate_multi_columns_probe(
return probe


def generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="circle", contact_shape_params={"radius": 6}):
def generate_linear_probe(
num_elec: int = 16, ypitch: float = 20, contact_shapes: str = "circle", contact_shape_params: dict = {"radius": 6}
) -> Probe:
"""Generate a one-column linear probe.
Parameters
----------
num_elec : int, optional
num_elec : int
Number of electrodes, by default 16
ypitch : float, optional
ypitch : float
Pitch in y direction, by default 20
contact_shapes : str, optional
Shape of the contacts ('circle', 'rect', 'square'), by default 'circle'
contact_shape_params : dict, optional
Parameters for the shape, by default {'radius': 6}
contact_shapes : str, default 'circle'
Shape of the contacts ('circle', 'rect', 'square')
contact_shape_params : dict, default {'radius': 6}
Parameters for the shape.
For circle: {"radius": float}
For square: {"width": float}
For rectangle: {"width": float, "height": float}
Returns
-------
Expand All @@ -175,15 +188,15 @@ def generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="circle", conta
return probe


def generate_multi_shank(num_shank=2, shank_pitch=[150, 0], **kargs):
def generate_multi_shank(num_shank: int = 2, shank_pitch: list = [150, 0], **kargs) -> Probe:
"""Generate a multi-shank probe.
Internally, calls generate_multi_columns_probe and combine_probes.
Parameters
----------
num_shank : int, optional
Number of shanks, by default 2
shank_pitch : list, optional
num_shank : int, default 2
Number of shanks
shank_pitch : list, default [150,0]
Distance between shanks, by default [150, 0]
Returns
Expand Down
70 changes: 49 additions & 21 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Neurodata Without Borders (.nwb)
"""
from __future__ import annotations
from pathlib import Path
from typing import Union, Optional
import re
Expand All @@ -32,7 +33,7 @@ def _probeinterface_format_check_version(d):
pass


def read_probeinterface(file):
def read_probeinterface(file: str | Path) -> ProbeGroup:
"""
Read probeinterface JSON-based format.
Expand All @@ -58,7 +59,7 @@ def read_probeinterface(file):
return ProbeGroup.from_dict(d)


def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup]):
def write_probeinterface(file: str | Path, probe_or_probegroup: Probe | ProbeGroup):
"""
Write a probeinterface JSON file.
Expand All @@ -80,7 +81,7 @@ def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Prob
elif isinstance(probe_or_probegroup, ProbeGroup):
probegroup = probe_or_probegroup
else:
raise ValueError("write_probeinterface : need probe or probegroup")
raise ValueError("write_probeinterface : needs a probe or probegroup")

file = Path(file)

Expand All @@ -103,7 +104,7 @@ def write_probeinterface(file: Union[str, Path], probe_or_probegroup: Union[Prob
tsv_label_map_to_probeinterface = {v: k for k, v in tsv_label_map_to_BIDS.items()}


def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> ProbeGroup:
def read_BIDS_probe(folder: str | Path, prefix: Optional[str] = None) -> ProbeGroup:
"""
Read to BIDS probe format.
Expand Down Expand Up @@ -175,23 +176,27 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P
if "contact_shapes" not in df_probe:
df_probe["contact_shapes"] = "circle"
df_probe["radius"] = 1
print(f"There is no contact shape provided for probe {probe_id}, a " f"dummy circle with 1um is created")
print(
f"There is no contact shape provided for probe {probe_id}, a "
f"dummy circle with 1um radius will be used."
)

if "x" not in df_probe:
df_probe["x"] = np.arange(len(df_probe.index), dtype=float)
print(
f"There is no x coordinate provided for probe {probe_id}, a " f"dummy linear x coordinate is created."
f"There is no x coordinate provided for probe {probe_id}, a " f"dummy linear x coordinate will be used."
)

if "y" not in df_probe:
df_probe["y"] = 0.0
print(
f"There is no y coordinate provided for probe {probe_id}, a " f"dummy constant y coordinate is created."
f"There is no y coordinate provided for probe {probe_id}, a "
f"dummy constant y coordinate will be used."
)

if "si_units" not in df_probe:
df_probe["si_units"] = "um"
print(f"There is no SI units provided for probe {probe_id}, a " f"dummy SI unit (um) is created.")
print(f"There is no SI unit provided for probe {probe_id}, a " f"dummy SI unit (um) will be used")

# create probe object and register with probegroup
probe = Probe.from_dataframe(df=df_probe)
Expand Down Expand Up @@ -290,7 +295,7 @@ def read_BIDS_probe(folder: Union[str, Path], prefix: Optional[str] = None) -> P
return probegroup


def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe, ProbeGroup], prefix=""):
def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup, prefix: str = ""):
"""
Write to probe and contact formats as proposed
for ephy BIDS extension (tsv & json based).
Expand Down Expand Up @@ -419,7 +424,7 @@ def write_BIDS_probe(folder: Union[str, Path], probe_or_probegroup: Union[Probe,
json.dump({"ContactId": contacts_dict}, f, indent=4)


def read_prb(file: Union[str, Path]) -> ProbeGroup:
def read_prb(file: str | Path) -> ProbeGroup:
"""
Read a PRB file and return a ProbeGroup object.
Expand Down Expand Up @@ -467,7 +472,7 @@ def read_prb(file: Union[str, Path]) -> ProbeGroup:
return probegroup


def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: str = "rec0000") -> Probe:
def read_maxwell(file: str | Path, well_name: str = "well000", rec_name: str = "rec0000") -> Probe:
"""
Read a maxwell file and return a Probe object. The Maxwell file format can be
either Maxone (and thus just the file name is needed), or MaxTwo. In case
Expand Down Expand Up @@ -534,7 +539,7 @@ def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: s
return probe


def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width: float = 21) -> Probe:
def read_3brain(file: str | Path, mea_pitch: float = 42, electrode_width: float = 21) -> Probe:
"""
Read a 3brain file and return a Probe object. The 3brain file format can be
either an .h5 file or a .brw
Expand Down Expand Up @@ -577,7 +582,13 @@ def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width:
return probe


def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode="by_probe"):
def write_prb(
file: str,
probegroup: ProbeGroup,
total_nb_channels: Optional[int] = None,
radius: Optional[float] = None,
group_mode: str = "by_probe",
):
"""
Write ProbeGroup into a prb file.
Expand All @@ -596,6 +607,19 @@ def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode=
* "radius" is needed by spyking-circus
* "graph" is not handled
Parameters
----------
file: str
The name of the file to be written
probegroup: ProbeGroup
The Probegroup to be used for writing
total_nb_channels: Optional[int], default None
***to do
radius: Optional[float], default None
*** to do
group_mode: str
One of "by_probe" or "by_shank
"""
assert group_mode in ("by_probe", "by_shank")

Expand Down Expand Up @@ -643,7 +667,7 @@ def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode=
f.write("}\n")


def read_csv(file):
def read_csv(file: str | Path):
"""
Return a 2 or 3 columns csv file with contact positions
"""
Expand Down Expand Up @@ -1030,7 +1054,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
return probe


def write_imro(file, probe):
def write_imro(file: str | Path, probe: Probe):
"""
save imro file (`.imrc`, imec readout) in a file.
https://github.com/open-ephys-plugins/neuropixels-pxi/blob/master/Source/Formats/IMRO.h
Expand All @@ -1040,6 +1064,7 @@ def write_imro(file, probe):
file : Path or str
The file path
probe : Probe object
"""
probe_type = probe.annotations["probe_type"]
data = probe.to_dataframe(complete=True).sort_values("device_channel_indices")
Expand Down Expand Up @@ -1072,7 +1097,7 @@ def write_imro(file, probe):
f.write("".join(ret))


def read_spikeglx(file: Union[str, Path]) -> Probe:
def read_spikeglx(file: str | Path) -> Probe:
"""
Read probe position for the meta file generated by SpikeGLX
Expand Down Expand Up @@ -1120,7 +1145,7 @@ def read_spikeglx(file: Union[str, Path]) -> Probe:
return probe


def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict:
def parse_spikeglx_meta(meta_file: str | Path) -> dict:
"""
Parse the "meta" file from spikeglx into a dict.
All fiields are kept in txt format and must also parsed themself.
Expand All @@ -1141,7 +1166,7 @@ def parse_spikeglx_meta(meta_file: Union[str, Path]) -> dict:
return meta


def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) -> np.array:
def get_saved_channel_indices_from_spikeglx_meta(meta_file: str | Path) -> np.array:
"""
Utils function to get the saved channels.
Expand All @@ -1151,6 +1176,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) ->
This function come from here Jennifer Colonell
https://github.com/jenniferColonell/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/common/SGLXMetaToCoords.py#L65
"""
meta_file = Path(meta_file)
meta = parse_spikeglx_meta(meta_file)
Expand All @@ -1172,7 +1198,7 @@ def get_saved_channel_indices_from_spikeglx_meta(meta_file: Union[str, Path]) ->


def read_openephys(
settings_file: Union[str, Path],
settings_file: str | Path,
stream_name: Optional[str] = None,
probe_name: Optional[str] = None,
serial_number: Optional[str] = None,
Expand Down Expand Up @@ -1530,7 +1556,9 @@ def read_openephys(
return probe


def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name):
def get_saved_channel_indices_from_openephys_settings(
settings_file: str | Path, stream_name: str
) -> Optional[np.array]:
"""
Returns an array with the subset of saved channels indices (if used)
Expand Down Expand Up @@ -1597,7 +1625,7 @@ def get_saved_channel_indices_from_openephys_settings(settings_file, stream_name
return chans_saved


def read_mearec(file: Union[str, Path]) -> Probe:
def read_mearec(file: str | Path) -> Probe:
"""
Read probe position, and contact shape from a MEArec file.
Expand Down
Loading

0 comments on commit 8bc5b71

Please sign in to comment.