Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend probe constructor #206

Merged
merged 7 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def read_maxwell(file: Union[str, Path], well_name: str = "well000", rec_name: s
prb["channel_groups"][1]["geometry"] = geometry
prb["channel_groups"][1]["channels"] = channels

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", manufacturer="Maxwell Biosystems")

chans = np.array(prb["channel_groups"][1]["channels"], dtype="int64")
positions = np.array([prb["channel_groups"][1]["geometry"][c] for c in chans], dtype="float64")
Expand Down Expand Up @@ -567,7 +567,7 @@ def read_3brain(file: Union[str, Path], mea_pitch: float = 42, electrode_width:
cols = channels["Col"] - 1
positions = np.vstack((rows, cols)).T * mea_pitch

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", manufacturer="3Brain")
probe.set_contacts(positions=positions, shapes="square", shape_params={"width": electrode_width})
probe.annotate_contacts(row=rows)
probe.annotate_contacts(col=cols)
Expand Down Expand Up @@ -600,7 +600,7 @@ def write_prb(file, probegroup, total_nb_channels=None, radius=None, group_mode=
assert group_mode in ("by_probe", "by_shank")

if len(probegroup.probes) == 0:
raise ValueError("Bad boy")
raise ValueError("The probe group must have at least one probe")

for probe in probegroup.probes:
if probe.device_channel_indices is None:
Expand Down Expand Up @@ -953,6 +953,8 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
imDatPrb_type = probe_number_to_probe_type[imDatPrb_pn]

probe_description = npx_probe[imDatPrb_type]
probe_name = probe_description["probe_name"]

fields = probe_description["fields_in_imro_table"]
contact_info = {k: [] for k in fields}
for field_values_str in imro_table_values_list: # Imro table values look like '(value, value, value, ... '
Expand Down Expand Up @@ -986,7 +988,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
contact_ids = [f"e{elec_id}" for elec_id in elec_ids]

# construct Probe object
probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", model_name=probe_name, manufacturer="IMEC")
probe.set_contacts(
positions=positions,
shapes="square",
Expand All @@ -1009,10 +1011,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
probe.set_planar_contour(contour)

# this is scalar annotations
probe_name = probe_description["probe_name"]
probe.annotate(
name=probe_name,
manufacturer="IMEC",
probe_type=imDatPrb_type,
)

Expand Down Expand Up @@ -1102,10 +1101,27 @@ def read_spikeglx(file: Union[str, Path]) -> Probe:

assert "imroTbl" in meta, "Could not find imroTbl field in meta file!"
imro_table = meta["imroTbl"]

# read serial number
imDatPrb_serial_number = meta.get("imDatPrb_sn", None)
if imDatPrb_serial_number is None: # this is for Phase3A
imDatPrb_serial_number = meta.get("imProbeSN", None)

# read other metadata
imDatPrb_pn = meta.get("imDatPrb_pn", None)
imDatPrb_port = meta.get("imDatPrb_port", None)
imDatPrb_slot = meta.get("imDatPrb_slot", None)
imDatPrb_part_number = meta.get("imDatPrb_pn", None)

probe = _read_imro_string(imro_str=imro_table, imDatPrb_pn=imDatPrb_pn)

# add serial number and other annotations
probe.annotate(serial_number=imDatPrb_serial_number)
probe.annotate(part_number=imDatPrb_part_number)
probe.annotate(port=imDatPrb_port)
probe.annotate(slot=imDatPrb_slot)
probe.annotate(serial_number=imDatPrb_serial_number)

# sometimes we need to slice the probe when not all channels are saved
saved_chans = get_saved_channel_indices_from_spikeglx_meta(meta_file)
# remove the SYS chans
Expand Down Expand Up @@ -1295,7 +1311,8 @@ def read_openephys(
slot = np_probe.attrib["slot"]
port = np_probe.attrib["port"]
dock = np_probe.attrib["dock"]
np_serial_number = np_probe.attrib["probe_serial_number"]
probe_part_number = np_probe.attrib["probe_part_number"]
probe_serial_number = np_probe.attrib["probe_serial_number"]
# read channels
channels = np_probe.find("CHANNELS")
channel_names = np.array(list(channels.attrib.keys()))
Expand Down Expand Up @@ -1368,14 +1385,15 @@ def read_openephys(
contact_ids.append(f"e{contact_id}")

np_probe_dict = {
"channel_names": channel_names,
"model_name": pname,
"shank_ids": shank_ids,
"contact_ids": contact_ids,
"positions": positions,
"slot": slot,
"port": port,
"dock": dock,
"serial_number": np_serial_number,
"serial_number": probe_serial_number,
"part_number": probe_part_number,
"ptype": ptype,
}
# Sequentially assign probe names
Expand Down Expand Up @@ -1468,7 +1486,7 @@ def read_openephys(
np_probe = np_probes[probe_idx]
positions = np_probe_info["positions"]
shank_ids = np_probe_info["shank_ids"]
pname = np_probe.attrib["probe_name"]
pname = np_probe_info["name"]

ptype = np_probe_info["ptype"]
if ptype in npx_probe:
Expand All @@ -1491,19 +1509,25 @@ def read_openephys(
if contact_ids is not None:
contact_ids = np.array(contact_ids)[chans_saved]

probe = Probe(ndim=2, si_units="um")
probe = Probe(
ndim=2,
si_units="um",
name=np_probe_info["name"],
serial_number=np_probe_info["serial_number"],
manufacturer="IMEC",
model_name=np_probe_info["model_name"],
)
probe.set_contacts(
positions=positions,
shapes="square",
shank_ids=shank_ids,
shape_params={"width": contact_width},
)
probe.annotate(
name=pname,
manufacturer="IMEC",
probe_name=pname,
probe_part_number=np_probe.attrib["probe_part_number"],
probe_serial_number=np_probe.attrib["probe_serial_number"],
part_number=np_probe_info["part_number"],
slot=np_probe_info["slot"],
dock=np_probe_info["dock"],
port=np_probe_info["port"],
)

if contact_ids is not None:
Expand Down Expand Up @@ -1631,7 +1655,7 @@ def read_mearec(file: Union[str, Path]) -> Probe:
description = electrodes_info["description"][()]
mearec_description = description.decode("utf-8") if isinstance(description, bytes) else description

probe = Probe(ndim=2, si_units="um")
probe = Probe(ndim=2, si_units="um", model_name=mearec_name)

plane = "yz" # default
if "plane" in electrodes_info_keys:
Expand Down
8 changes: 7 additions & 1 deletion src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_from_cache(manufacturer, probe_name):
return probe


def get_probe(manufacturer, probe_name):
def get_probe(manufacturer, probe_name, name=None):
"""
Get probe from ProbeInterface library

Expand All @@ -81,6 +81,8 @@ def get_probe(manufacturer, probe_name):
The probe manufacturer (e.g. 'cambridgeneurotech')
probe_name : str
The probe name
name : str or None
Optional name for the probe

Returns
----------
Expand All @@ -93,5 +95,9 @@ def get_probe(manufacturer, probe_name):
if probe is None:
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
if probe.annotations["manufacturer"] == "":
probe.annotations["manufacturer"] = manufacturer
if name is not None:
probe.annotations["name"] = name

return probe
61 changes: 53 additions & 8 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Probe:

"""

def __init__(self, ndim=2, si_units="um"):
def __init__(self, ndim=2, si_units="um", name=None, serial_number=None, model_name=None, manufacturer=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure to like the idea of overloading the constructor with fixed annotation even if they are important.

When creating a Probe from scratch (wich a end user normally do not do) I think I prefer this:

probe = Probe(ndim=2)
probe.annotate(name=name)
probe.annotate(serial_number=serial_number)
probe.annotate(model_name=model_name)
probe.annotate(manufacturer=manufacturer)

The discussion of which annotations are important are endless.
So lets put annotations outside the constructor.

The other solution would be to put all of then (but I do not like)

probe = Probe(ndim=2, **any_annotations)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or another solution would be to make set_* for these 4 key annotations, so that they are still annotations (i.e., saved to file too), but the have an easier API. Whaat do you think?

"""
Some attributes are protected and have to be set with setters:
* set_contacts(...)
Expand All @@ -27,7 +27,18 @@ def __init__(self, ndim=2, si_units="um"):
Handles 2D or 3D probe
si_units: str
'um', 'mm', 'm'
name: str
The name of the probe
serial_number: str
The serial number of the probe
model_name: str
The model of the probe
manufacturer: str
The manufacturer of the probe

Returns
-------
Probe: instance of Probe
"""

assert ndim in (2, 3)
Expand Down Expand Up @@ -58,7 +69,13 @@ def __init__(self, ndim=2, si_units="um"):

# annotation: a dict that contains all meta information about
# the probe (name, manufacturor, date of production, ...)
self.annotations = dict(name="")
self.annotations = dict()
self.annotate(
name=name if name is not None else "",
serial_number=serial_number if serial_number is not None else "",
model_name=model_name if model_name is not None else "",
manufacturer=manufacturer if manufacturer is not None else "",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer nothing instead of empty string when None.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, fair enough ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

# same idea but handle in vector way for contacts
self.contact_annotations = dict()

Expand Down Expand Up @@ -90,17 +107,43 @@ def contact_ids(self):
def shank_ids(self):
return self._shank_ids

@property
def name(self):
return self.annotations.get("name", "")

@property
def serial_number(self):
return self.annotations.get("serial_number", "")

@property
def model_name(self):
return self.annotations.get("model_name", "")

@property
def manufacturer(self):
return self.annotations.get("manufacturer", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK for me.


def get_title(self):
if self.contact_positions is None:
txt = "Undefined probe"
else:
n = self.get_contact_count()
name = self.annotations.get("name", "")
manufacturer = self.annotations.get("manufacturer", "")
if len(name) > 0 or len(manufacturer):
txt = f"{manufacturer} - {name} - {n}ch"
name = self.name
serial_number = self.serial_number
model_name = self.model_name
manufacturer = self.manufacturer
txt = ""
if len(name) > 0:
txt += f"{name}"
else:
txt = f"Probe - {n}ch"
txt += f"Probe"
if len(manufacturer) > 0:
txt += f" - {manufacturer}"
if len(model_name) > 0:
txt += f" - {model_name}"
if len(serial_number) > 0:
txt += f" - {serial_number}"
txt += f" - {n}ch"
if self.shank_ids is not None:
num_shank = self.get_shank_count()
txt += f" - {num_shank}shanks"
Expand Down Expand Up @@ -904,7 +947,9 @@ def to_image(self, values, pixel_size=0.5, num_pixel=None, method="linear", xlim
except ImportError:
raise ImportError("to_image() requires the scipy package")
assert self.ndim == 2
assert values.shape == (self.get_contact_count(),), "Bad boy: values must have size equal contact count"
assert values.shape == (
self.get_contact_count(),
), "Shape mismatch: values must have the same size as contact count"

if xlims is None:
x0 = np.min(self.contact_positions[:, 0])
Expand Down
23 changes: 12 additions & 11 deletions tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_NP2():
# NP2
probe = read_openephys(data_path / "OE_Neuropix-PXI" / "settings.xml")
assert probe.get_shank_count() == 1
assert "2.0 - Single Shank" in probe.annotations["name"]
assert "2.0 - Single Shank" in probe.model_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me name of a probe is not the model name.

A name can be something which is mainingfull for a end user : like "probeA"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact this is now "model_name". The model name is the manufacturer name of the probe. Name can be anything (like ProbeA)



def test_NP1_subset():
Expand All @@ -23,15 +23,15 @@ def test_NP1_subset():
)

assert probe_ap.get_shank_count() == 1
assert "1.0" in probe_ap.annotations["name"]
assert "1.0" in probe_ap.model_name
assert len(probe_ap.contact_positions) == 200

probe_lf = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP"
)

assert probe_lf.get_shank_count() == 1
assert "1.0" in probe_lf.annotations["name"]
assert "1.0" in probe_lf.model_name
assert len(probe_lf.contact_positions) == 200

# Not specifying the stream_name should raise an Exception, because both the ProbeA-AP and
Expand All @@ -47,7 +47,7 @@ def test_multiple_probes():
)

assert probeA.get_shank_count() == 1
assert "1.0" in probeA.annotations["name"]
assert "1.0" in probeA.model_name

probeB = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml",
Expand All @@ -69,10 +69,10 @@ def test_multiple_probes():

assert probeD.get_shank_count() == 1

assert probeA.annotations["probe_serial_number"] == "17131307831"
assert probeB.annotations["probe_serial_number"] == "20403311724"
assert probeC.annotations["probe_serial_number"] == "20403311714"
assert probeD.annotations["probe_serial_number"] == "21144108671"
assert probeA.serial_number == "17131307831"
assert probeB.serial_number == "20403311724"
assert probeC.serial_number == "20403311714"
assert probeD.serial_number == "21144108671"

probeA2 = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings_2.xml",
Expand All @@ -89,7 +89,7 @@ def test_multiple_probes():
)

assert probeB2.get_shank_count() == 1
assert "2.0 - Multishank" in probeB2.annotations["name"]
assert "2.0 - Multishank" in probeB2.model_name

ypos = probeB2.contact_positions[:, 1]
assert np.min(ypos) >= 0
Expand All @@ -103,10 +103,11 @@ def test_older_than_06_format():
)

assert probe.get_shank_count() == 4
assert "2.0 - Multishank" in probe.annotations["name"]
assert "2.0 - Multishank" in probe.model_name
ypos = probe.contact_positions[:, 1]
assert np.min(ypos) >= 0


if __name__ == "__main__":
test_NP1_subset()
test_multiple_probes()
test_older_than_06_format()
Loading