diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 58003fb..6f45669 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -925,8 +925,27 @@ def from_numpy(arr: np.ndarray) -> "Probe": probe : Probe The instantiated Probe object """ - fields = list(arr.dtype.fields) + main_fields = [ + "x", + "y", + "z", + "contact_shapes", + "shank_ids", + "contact_ids", + "device_channel_indices", + "radius", + "width", + "height", + "plane_axis_x_0", + "plane_axis_x_1", + "plane_axis_y_0", + "plane_axis_y_1", + "plane_axis_z_0", + "plane_axis_z_1", + "si_units", + ] + contact_annotation_fields = [f for f in fields if f not in main_fields] if "z" in fields: ndim = 3 @@ -973,12 +992,16 @@ def from_numpy(arr: np.ndarray) -> "Probe": if "device_channel_indices" in fields: dev_channel_indices = arr["device_channel_indices"] - probe.set_device_channel_indices(dev_channel_indices) + if not np.all(dev_channel_indices == -1): + probe.set_device_channel_indices(dev_channel_indices) if "shank_ids" in fields: probe.set_shank_ids(arr["shank_ids"]) if "contact_ids" in fields: probe.set_contact_ids(arr["contact_ids"]) + # contact annotations + for k in contact_annotation_fields: + probe.annotate_contacts(**{k: arr[k]}) return probe def add_probe_to_zarr_group(self, group: "zarr.Group") -> None: diff --git a/tests/test_probe.py b/tests/test_probe.py index 8dd2405..8f05bda 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -183,4 +183,6 @@ def test_save_to_zarr(tmp_path): if __name__ == "__main__": test_probe() - test_set_shanks() + tmp_path = Path("tmp") + tmp_path.mkdir(exist_ok=True) + test_save_to_zarr(tmp_path)