Skip to content

Commit

Permalink
Merge pull request #206 from SpikeInterface/extend-probe-constructor
Browse files Browse the repository at this point in the history
Extend probe constructor
  • Loading branch information
alejoe91 authored Oct 16, 2023
2 parents f00902e + c7a83de commit 86cdc9e
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 55 deletions.
60 changes: 42 additions & 18 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def read_maxwell(file: str | Path, well_name: str = "well000", rec_name: str = "
prb["channel_groups"][1]["geometry"] = geometry
prb["channel_groups"][1]["channels"] = channels

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", manufacturer="Maxwell Biosystems")

chans = np.array(prb["channel_groups"][1]["channels"], dtype="int64")
positions = np.array([prb["channel_groups"][1]["geometry"][c] for c in chans], dtype="float64")
Expand Down Expand Up @@ -572,7 +572,7 @@ def read_3brain(file: str | Path, mea_pitch: float = 42, electrode_width: float
cols = channels["Col"] - 1
positions = np.vstack((rows, cols)).T * mea_pitch

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", manufacturer="3Brain")
probe.set_contacts(positions=positions, shapes="square", shape_params={"width": electrode_width})
probe.annotate_contacts(row=rows)
probe.annotate_contacts(col=cols)
Expand Down Expand Up @@ -624,7 +624,7 @@ def write_prb(
assert group_mode in ("by_probe", "by_shank")

if len(probegroup.probes) == 0:
raise ValueError("Bad boy")
raise ValueError("The probe group must have at least one probe")

for probe in probegroup.probes:
if probe.device_channel_indices is None:
Expand Down Expand Up @@ -977,6 +977,8 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
imDatPrb_type = probe_number_to_probe_type[imDatPrb_pn]

probe_description = npx_probe[imDatPrb_type]
probe_name = probe_description["probe_name"]

fields = probe_description["fields_in_imro_table"]
contact_info = {k: [] for k in fields}
for field_values_str in imro_table_values_list: # Imro table values look like '(value, value, value, ... '
Expand Down Expand Up @@ -1013,7 +1015,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
positions = np.stack((x_pos, y_pos), axis=1)

# construct Probe object
probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", model_name=probe_name, manufacturer="IMEC")
probe.set_contacts(
positions=positions,
shapes="square",
Expand All @@ -1036,10 +1038,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
probe.set_planar_contour(contour)

# this is scalar annotations
probe_name = probe_description["probe_name"]
probe.annotate(
name=probe_name,
manufacturer="IMEC",
probe_type=imDatPrb_type,
)

Expand Down Expand Up @@ -1130,10 +1129,27 @@ def read_spikeglx(file: str | Path) -> Probe:

assert "imroTbl" in meta, "Could not find imroTbl field in meta file!"
imro_table = meta["imroTbl"]

# read serial number
imDatPrb_serial_number = meta.get("imDatPrb_sn", None)
if imDatPrb_serial_number is None: # this is for Phase3A
imDatPrb_serial_number = meta.get("imProbeSN", None)

# read other metadata
imDatPrb_pn = meta.get("imDatPrb_pn", None)
imDatPrb_port = meta.get("imDatPrb_port", None)
imDatPrb_slot = meta.get("imDatPrb_slot", None)
imDatPrb_part_number = meta.get("imDatPrb_pn", None)

probe = _read_imro_string(imro_str=imro_table, imDatPrb_pn=imDatPrb_pn)

# add serial number and other annotations
probe.annotate(serial_number=imDatPrb_serial_number)
probe.annotate(part_number=imDatPrb_part_number)
probe.annotate(port=imDatPrb_port)
probe.annotate(slot=imDatPrb_slot)
probe.annotate(serial_number=imDatPrb_serial_number)

# sometimes we need to slice the probe when not all channels are saved
saved_chans = get_saved_channel_indices_from_spikeglx_meta(meta_file)
# remove the SYS chans
Expand Down Expand Up @@ -1324,7 +1340,8 @@ def read_openephys(
slot = np_probe.attrib["slot"]
port = np_probe.attrib["port"]
dock = np_probe.attrib["dock"]
np_serial_number = np_probe.attrib["probe_serial_number"]
probe_part_number = np_probe.attrib["probe_part_number"]
probe_serial_number = np_probe.attrib["probe_serial_number"]
# read channels
channels = np_probe.find("CHANNELS")
channel_names = np.array(list(channels.attrib.keys()))
Expand Down Expand Up @@ -1397,14 +1414,15 @@ def read_openephys(
contact_ids.append(f"e{contact_id}")

np_probe_dict = {
"channel_names": channel_names,
"model_name": pname,
"shank_ids": shank_ids,
"contact_ids": contact_ids,
"positions": positions,
"slot": slot,
"port": port,
"dock": dock,
"serial_number": np_serial_number,
"serial_number": probe_serial_number,
"part_number": probe_part_number,
"ptype": ptype,
}
# Sequentially assign probe names
Expand Down Expand Up @@ -1497,7 +1515,7 @@ def read_openephys(
np_probe = np_probes[probe_idx]
positions = np_probe_info["positions"]
shank_ids = np_probe_info["shank_ids"]
pname = np_probe.attrib["probe_name"]
pname = np_probe_info["name"]

ptype = np_probe_info["ptype"]
if ptype in npx_probe:
Expand All @@ -1522,19 +1540,25 @@ def read_openephys(
if contact_ids is not None:
contact_ids = np.array(contact_ids)[chans_saved]

probe = Probe(ndim=2, si_units="um")
probe = Probe(
ndim=2,
si_units="um",
name=np_probe_info["name"],
serial_number=np_probe_info["serial_number"],
manufacturer="IMEC",
model_name=np_probe_info["model_name"],
)
probe.set_contacts(
positions=positions,
shapes="square",
shank_ids=shank_ids,
shape_params={"width": contact_width},
)
probe.annotate(
name=pname,
manufacturer="IMEC",
probe_name=pname,
probe_part_number=np_probe.attrib["probe_part_number"],
probe_serial_number=np_probe.attrib["probe_serial_number"],
part_number=np_probe_info["part_number"],
slot=np_probe_info["slot"],
dock=np_probe_info["dock"],
port=np_probe_info["port"],
)

if contact_ids is not None:
Expand Down Expand Up @@ -1664,7 +1688,7 @@ def read_mearec(file: str | Path) -> Probe:
description = electrodes_info["description"][()]
mearec_description = description.decode("utf-8") if isinstance(description, bytes) else description

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", model_name=mearec_name)

plane = "yz" # default
if "plane" in electrodes_info_keys:
Expand Down
8 changes: 7 additions & 1 deletion src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:
return probe


def get_probe(manufacturer: str, probe_name: str) -> "Probe":
def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe":
"""
Get probe from ProbeInterface library
Expand All @@ -82,6 +82,8 @@ def get_probe(manufacturer: str, probe_name: str) -> "Probe":
The probe manufacturer (e.g. 'cambridgeneurotech', 'neuronexus')
probe_name : str
The probe name
name : str or None
Optional name for the probe
Returns
----------
Expand All @@ -94,5 +96,9 @@ def get_probe(manufacturer: str, probe_name: str) -> "Probe":
if probe is None:
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
if probe.manufacturer == "":
probe.manufacturer = manufacturer
if name is not None:
probe.name = name

return probe
90 changes: 82 additions & 8 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ class Probe:
"""

def __init__(self, ndim: int = 2, si_units: int = "um"):
def __init__(
self,
ndim: int = 2,
si_units: str = "um",
name: Optional[str] = None,
serial_number: Optional[str] = None,
model_name: Optional[str] = None,
manufacturer: Optional[str] = None,
):
"""
Some attributes are protected and have to be set with setters:
* set_contacts(...)
Expand All @@ -30,7 +38,18 @@ def __init__(self, ndim: int = 2, si_units: int = "um"):
Handles 2D or 3D probe
si_units: str
'um', 'mm', 'm'
name: str
The name of the probe
serial_number: str
The serial number of the probe
model_name: str
The model of the probe
manufacturer: str
The manufacturer of the probe
Returns
-------
Probe: instance of Probe
"""

assert ndim in (2, 3)
Expand Down Expand Up @@ -61,7 +80,14 @@ def __init__(self, ndim: int = 2, si_units: int = "um"):

# annotation: a dict that contains all meta information about
# the probe (name, manufacturor, date of production, ...)
self.annotations = dict(name="")
self.annotations = dict()

# set key properties
self.name = name
self.serial_number = serial_number
self.model_name = model_name
self.manufacturer = manufacturer

# same idea but handle in vector way for contacts
self.contact_annotations = dict()

Expand Down Expand Up @@ -93,17 +119,63 @@ def contact_ids(self):
def shank_ids(self):
return self._shank_ids

@property
def name(self):
return self.annotations.get("name", "")

@name.setter
def name(self, value):
if value is not None:
self.annotate(name=value)

@property
def serial_number(self):
return self.annotations.get("serial_number", "")

@serial_number.setter
def serial_number(self, value):
if value is not None:
self.annotate(serial_number=value)

@property
def model_name(self):
return self.annotations.get("model_name", "")

@model_name.setter
def model_name(self, value):
if value is not None:
self.annotate(model_name=value)

@property
def manufacturer(self):
return self.annotations.get("manufacturer", "")

@manufacturer.setter
def manufacturer(self, value):
if value is not None:
self.annotate(manufacturer=value)

def get_title(self) -> str:
if self.contact_positions is None:
txt = "Undefined probe"
else:
n = self.get_contact_count()
name = self.annotations.get("name", "")
manufacturer = self.annotations.get("manufacturer", "")
if len(name) > 0 or len(manufacturer):
txt = f"{manufacturer} - {name} - {n}ch"
name = self.name
serial_number = self.serial_number
model_name = self.model_name
manufacturer = self.manufacturer
txt = ""
if len(name) > 0:
txt += f"{name}"
else:
txt = f"Probe - {n}ch"
txt += f"Probe"
if len(manufacturer) > 0:
txt += f" - {manufacturer}"
if len(model_name) > 0:
txt += f" - {model_name}"
if len(serial_number) > 0:
txt += f" - {serial_number}"
txt += f" - {n}ch"
if self.shank_ids is not None:
num_shank = self.get_shank_count()
txt += f" - {num_shank}shanks"
Expand Down Expand Up @@ -919,7 +991,9 @@ def to_image(
except ImportError:
raise ImportError("to_image() requires the scipy package")
assert self.ndim == 2
assert values.shape == (self.get_contact_count(),), "Bad boy: values must have size equal contact count"
assert values.shape == (
self.get_contact_count(),
), "Shape mismatch: values must have the same size as contact count"

if xlims is None:
x0 = np.min(self.contact_positions[:, 0])
Expand Down
23 changes: 12 additions & 11 deletions tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_NP2():
# NP2
probe = read_openephys(data_path / "OE_Neuropix-PXI" / "settings.xml")
assert probe.get_shank_count() == 1
assert "2.0 - Single Shank" in probe.annotations["name"]
assert "2.0 - Single Shank" in probe.model_name


def test_NP1_subset():
Expand All @@ -23,15 +23,15 @@ def test_NP1_subset():
)

assert probe_ap.get_shank_count() == 1
assert "1.0" in probe_ap.annotations["name"]
assert "1.0" in probe_ap.model_name
assert len(probe_ap.contact_positions) == 200

probe_lf = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP"
)

assert probe_lf.get_shank_count() == 1
assert "1.0" in probe_lf.annotations["name"]
assert "1.0" in probe_lf.model_name
assert len(probe_lf.contact_positions) == 200

# Not specifying the stream_name should raise an Exception, because both the ProbeA-AP and
Expand All @@ -47,7 +47,7 @@ def test_multiple_probes():
)

assert probeA.get_shank_count() == 1
assert "1.0" in probeA.annotations["name"]
assert "1.0" in probeA.model_name

probeB = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml",
Expand All @@ -69,10 +69,10 @@ def test_multiple_probes():

assert probeD.get_shank_count() == 1

assert probeA.annotations["probe_serial_number"] == "17131307831"
assert probeB.annotations["probe_serial_number"] == "20403311724"
assert probeC.annotations["probe_serial_number"] == "20403311714"
assert probeD.annotations["probe_serial_number"] == "21144108671"
assert probeA.serial_number == "17131307831"
assert probeB.serial_number == "20403311724"
assert probeC.serial_number == "20403311714"
assert probeD.serial_number == "21144108671"

probeA2 = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings_2.xml",
Expand All @@ -89,7 +89,7 @@ def test_multiple_probes():
)

assert probeB2.get_shank_count() == 1
assert "2.0 - Multishank" in probeB2.annotations["name"]
assert "2.0 - Multishank" in probeB2.model_name

ypos = probeB2.contact_positions[:, 1]
assert np.min(ypos) >= 0
Expand All @@ -103,10 +103,11 @@ def test_older_than_06_format():
)

assert probe.get_shank_count() == 4
assert "2.0 - Multishank" in probe.annotations["name"]
assert "2.0 - Multishank" in probe.model_name
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) >= 0


if __name__ == "__main__":
test_NP1_subset()
test_multiple_probes()
test_older_than_06_format()
Loading

0 comments on commit 86cdc9e

Please sign in to comment.