diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index dd44443..5fea947 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -144,7 +144,7 @@ def model_name(self): @model_name.setter def model_name(self, value): - if value is not None: + if value is not None: # Alessio, why propagate to annotations and not just set the attribute? self.annotate(model_name=value) @property @@ -960,88 +960,188 @@ def from_numpy(arr: np.ndarray) -> "Probe": probe.set_contact_ids(arr["contact_ids"]) return probe - - def to_zarr(self, folder_path: str | Path): - """Serialize the Probe object to a Zarr file. - Parameters - ---------- - folder_path : str, Path - The path to the folder where the serialized data will be stored. + def add_probe_to_zarr_group(self, group: "zarr.Group") -> None: """ - import zarr + Serialize the probe's data and structure to a specified Zarr group. - # Create a Zarr group - zarr_group = zarr.open(folder_path, mode='w') + This method is used to save the probe's attributes, annotations, and other + related data into a Zarr group, facilitating integration into larger Zarr + structures. - # Top-level attributes - zarr_group.attrs['ndim'] = self.ndim - zarr_group.attrs['si_units'] = self.si_units + Parameters + ---------- + group : zarr.Group + The target Zarr group where the probe's data will be stored. + """ + # Top-level attributes used to initialize a new Probe instance + group.attrs["ndim"] = self.ndim + group.attrs["si_units"] = self.si_units + + # Need this behavior because the following attributes are "" (the empty string) by default + if self.name is not "": + group.attrs["name"] = self.name + if self.manufacturer is not "": + group.attrs["manufacturer"] = self.manufacturer + if self.model_name is not "": + group.attrs["model_name"] = self.model_name + if self.serial_number is not "": + group.attrs["serial_number"] = self.serial_number + + # Annotations as a group - annotations_group = zarr_group.create_group('annotations') + annotations_group = group.create_group("annotations") for key, value in self.annotations.items(): annotations_group.attrs[key] = value # Contact annotations as a group - contact_annotations_group = zarr_group.create_group('contact_annotations') + contact_annotations_group = group.create_group("contact_annotations") for key, value in self.contact_annotations.items(): - contact_annotations_group.create_dataset(key, data=value, chunks=True) + contact_annotations_group.create_dataset(key, data=value, chunks=False) - # Datasets + # Save the following fields of the probe as top-level datasets dataset_attrs = [ - '_contact_positions', '_contact_plane_axes', '_contact_shapes', - 'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour' + "_contact_positions", + "_contact_plane_axes", + "_contact_shapes", + "device_channel_indices", + "_contact_ids", + "_shank_ids", + "probe_planar_contour", ] for attr in dataset_attrs: value = getattr(self, attr) if value is not None: - zarr_group.create_dataset(attr, data=value, chunks=True) + group.create_dataset(attr, data=value, chunks=False) # Handling contact_shape_params if self._contact_shape_params is not None: shape_params_json = [json.dumps(d) for d in self._contact_shape_params] - zarr_group.create_dataset('contact_shape_params', data=shape_params_json, dtype=object) - - def from_zarr(folder_path: str | Path): - # Open the Zarr group + shape_params_json_np = np.array(shape_params_json, dtype=object) + group.create_dataset("contact_shape_params", data=shape_params_json_np, dtype=str) + + def to_zarr(self, folder_path: str | Path) -> None: + """ + Serialize the Probe object to a Zarr file located at the specified folder path. + + This method initializes a new Zarr group at the given folder path and calls + `add_probe_to_zarr_group` to serialize the Probe's data into this group, effectively + storing the entire Probe's state in a Zarr archive. + + Parameters + ---------- + folder_path : str | Path + The path to the folder where the Zarr data structure will be created and + where the serialized data will be stored. If the folder does not exist, + it will be created. + """ import zarr - zarr_group = zarr.open(folder_path, mode='r') - # Initialize a new Probe instance + # Create or open a Zarr group for writing + zarr_group = zarr.open_group(folder_path, mode="w") + + # Serialize this Probe object into the Zarr group + self.add_probe_to_zarr_group(zarr_group) + + @staticmethod + def from_zarr_group(group: zarr.Group) -> "Probe": + """ + Load a probe instance from a given Zarr group. + + Parameters + ---------- + group : zarr.Group + The Zarr group from which to load the probe. + + Returns + ------- + Probe + An instance of the Probe class initialized with data from the Zarr group. + """ + # Initialize a new Probe instance with basic attributes probe = Probe( - ndim=zarr_group.attrs['ndim'], - si_units=zarr_group.attrs['si_units'] + ndim=group.attrs["ndim"], + si_units=group.attrs["si_units"], + name=group.attrs.get("name", None), + manufacturer=group.attrs.get("manufacturer", None), + model_name=group.attrs.get("model_name", None), + serial_number=group.attrs.get("serial_number", None), ) # Load annotations - if 'annotations' in zarr_group: - annotations_group = zarr_group['annotations'] - for key in annotations_group.attrs.keys(): - probe.annotations[key] = annotations_group.attrs[key] + annotations_group = group.get("annotations", None) + for key in annotations_group.attrs.keys(): + # Use the annotate method for each key-value pair + probe.annotate(**{key: annotations_group.attrs[key]}) + + # Initialize contacts because it is possible to have a probe without contacts (undefined) + if "_contact_positions" in group: + positions = group["_contact_positions"][:] + shapes = group["_contact_shapes"][:] + + plane_axes = group.get("_contact_plane_axes", None) + if plane_axes is not None: + plane_axes = plane_axes[:] + + contact_ids = group.get("_contact_ids", None) + if contact_ids is not None: + contact_ids = contact_ids[:] + + shank_ids = group.get("_shank_ids", None) + if shank_ids is not None: + shank_ids = shank_ids[:] + + shape_params = group.get("contact_shape_params", None) + if shape_params is not None: + shape_params = np.array([json.loads(d) for d in shape_params[:]]) + + probe.set_contacts( + positions=positions, + plane_axes=plane_axes, + shapes=shapes, + shape_params=shape_params, + contact_ids=contact_ids, + shank_ids=shank_ids, + ) # Load contact annotations - if 'contact_annotations' in zarr_group: - contact_annotations_group = zarr_group['contact_annotations'] - for key in contact_annotations_group: - probe.contact_annotations[key] = contact_annotations_group[key][:] + contact_annotations_group = group.get("contact_annotations", None) + if contact_annotations_group: + contact_annotations_dict = {key: contact_annotations_group[key][:] for key in contact_annotations_group} + # Use the annotate_contacts method + probe.annotate_contacts(**contact_annotations_dict) - # Load datasets - dataset_attrs = [ - '_contact_positions', '_contact_plane_axes', '_contact_shapes', - 'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour' - ] - for attr in dataset_attrs: - if attr in zarr_group: - setattr(probe, attr, zarr_group[attr][:]) + if "probe_planar_contour" in group: + # Directly assign since there's no specific setter for probe_planar_contour + probe.probe_planar_contour = group["probe_planar_contour"][:] - # Handling contact_shape_params - if 'contact_shape_params' in zarr_group: - shape_params_json = zarr_group['contact_shape_params'][:] - probe._contact_shape_params = [json.loads(d) for d in shape_params_json] + if "device_channel_indices" in group: + # Use set_device_channel_indices for device_channel_indices + probe.set_device_channel_indices(group["device_channel_indices"][:]) return probe + @staticmethod + def from_zarr(folder_path: str | Path) -> "Probe": + """ + Deserialize the Probe object from a Zarr file located at the given folder path. + + Parameters + ---------- + folder_path : str | Path + The path to the folder where the Zarr file is located. + + Returns + ------- + Probe + An instance of the Probe class initialized with data from the Zarr file. + """ + import zarr + + zarr_group = zarr.open(folder_path, mode="r") + return Probe.from_zarr_group(zarr_group) + def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": """ Export the probe to a pandas dataframe @@ -1086,8 +1186,6 @@ def from_dataframe(df: "pandas.DataFrame") -> "Probe": arr = df.to_records(index=False) return Probe.from_numpy(arr) - - def to_image( self, values: np.array | list, diff --git a/tests/test_probe.py b/tests/test_probe.py index 135eb92..7922525 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -149,6 +149,7 @@ def test_probe_equality_dunder(): # Modify probe2 probe2.move([1, 1]) + assert probe1 != probe2 def test_set_shanks():