Skip to content

Commit

Permalink
to/from_numpy: propagate contact annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Feb 5, 2024
1 parent 06cb118 commit 96f4eea
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
27 changes: 25 additions & 2 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,27 @@ def from_numpy(arr: np.ndarray) -> "Probe":
probe : Probe
The instantiated Probe object
"""

fields = list(arr.dtype.fields)
main_fields = [
"x",
"y",
"z",
"contact_shapes",
"shank_ids",
"contact_ids",
"device_channel_indices",
"radius",
"width",
"height",
"plane_axis_x_0",
"plane_axis_x_1",
"plane_axis_y_0",
"plane_axis_y_1",
"plane_axis_z_0",
"plane_axis_z_1",
"si_units",
]
contact_annotation_fields = [f for f in fields if f not in main_fields]

if "z" in fields:
ndim = 3
Expand Down Expand Up @@ -973,12 +992,16 @@ def from_numpy(arr: np.ndarray) -> "Probe":

if "device_channel_indices" in fields:
dev_channel_indices = arr["device_channel_indices"]
probe.set_device_channel_indices(dev_channel_indices)
if not np.all(dev_channel_indices == -1):
probe.set_device_channel_indices(dev_channel_indices)
if "shank_ids" in fields:
probe.set_shank_ids(arr["shank_ids"])
if "contact_ids" in fields:
probe.set_contact_ids(arr["contact_ids"])

# contact annotations
for k in contact_annotation_fields:
probe.annotate_contacts(**{k: arr[k]})
return probe

def add_probe_to_zarr_group(self, group: "zarr.Group") -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,6 @@ def test_save_to_zarr(tmp_path):
if __name__ == "__main__":
test_probe()

test_set_shanks()
tmp_path = Path("tmp")
tmp_path.mkdir(exist_ok=True)
test_save_to_zarr(tmp_path)

0 comments on commit 96f4eea

Please sign in to comment.