diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 3d523d5..dd44443 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -1,7 +1,8 @@ from __future__ import annotations import numpy as np from typing import Optional - +from pathlib import Path +import json from .shank import Shank @@ -959,6 +960,87 @@ def from_numpy(arr: np.ndarray) -> "Probe": probe.set_contact_ids(arr["contact_ids"]) return probe + + def to_zarr(self, folder_path: str | Path): + """Serialize the Probe object to a Zarr file. + + Parameters + ---------- + folder_path : str, Path + The path to the folder where the serialized data will be stored. + """ + import zarr + + # Create a Zarr group + zarr_group = zarr.open(folder_path, mode='w') + + # Top-level attributes + zarr_group.attrs['ndim'] = self.ndim + zarr_group.attrs['si_units'] = self.si_units + + # Annotations as a group + annotations_group = zarr_group.create_group('annotations') + for key, value in self.annotations.items(): + annotations_group.attrs[key] = value + + # Contact annotations as a group + contact_annotations_group = zarr_group.create_group('contact_annotations') + for key, value in self.contact_annotations.items(): + contact_annotations_group.create_dataset(key, data=value, chunks=True) + + # Datasets + dataset_attrs = [ + '_contact_positions', '_contact_plane_axes', '_contact_shapes', + 'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour' + ] + for attr in dataset_attrs: + value = getattr(self, attr) + if value is not None: + zarr_group.create_dataset(attr, data=value, chunks=True) + + # Handling contact_shape_params + if self._contact_shape_params is not None: + shape_params_json = [json.dumps(d) for d in self._contact_shape_params] + zarr_group.create_dataset('contact_shape_params', data=shape_params_json, dtype=object) + + def from_zarr(folder_path: str | Path): + # Open the Zarr group + import zarr + zarr_group = zarr.open(folder_path, mode='r') + + # Initialize a new Probe instance + probe = Probe( + ndim=zarr_group.attrs['ndim'], + si_units=zarr_group.attrs['si_units'] + ) + + # Load annotations + if 'annotations' in zarr_group: + annotations_group = zarr_group['annotations'] + for key in annotations_group.attrs.keys(): + probe.annotations[key] = annotations_group.attrs[key] + + # Load contact annotations + if 'contact_annotations' in zarr_group: + contact_annotations_group = zarr_group['contact_annotations'] + for key in contact_annotations_group: + probe.contact_annotations[key] = contact_annotations_group[key][:] + + # Load datasets + dataset_attrs = [ + '_contact_positions', '_contact_plane_axes', '_contact_shapes', + 'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour' + ] + for attr in dataset_attrs: + if attr in zarr_group: + setattr(probe, attr, zarr_group[attr][:]) + + # Handling contact_shape_params + if 'contact_shape_params' in zarr_group: + shape_params_json = zarr_group['contact_shape_params'][:] + probe._contact_shape_params = [json.loads(d) for d in shape_params_json] + + return probe def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": """ @@ -1004,6 +1086,8 @@ def from_dataframe(df: "pandas.DataFrame") -> "Probe": arr = df.to_records(index=False) return Probe.from_numpy(arr) + + def to_image( self, values: np.array | list, diff --git a/tests/test_probe.py b/tests/test_probe.py index ef859d7..135eb92 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -1,5 +1,6 @@ from probeinterface import Probe from probeinterface.generator import generate_dummy_probe +from pathlib import Path import numpy as np @@ -161,6 +162,22 @@ def test_set_shanks(): assert all(probe.shank_ids == shank_ids.astype(str)) +def test_save_to_zarr(tmp_path): + # Generate a dummy probe instance + probe = generate_dummy_probe() + + # Define file path in the temporary directory + folder_path = Path(tmp_path) / "probe.zarr" + + # Save the probe object to Zarr format + probe.to_zarr(folder_path=folder_path) + + # Reload the probe object from the saved Zarr file + reloaded_probe = Probe.from_zarr(folder_path=folder_path) + + # Assert that the reloaded probe is equal to the original + assert probe == reloaded_probe, "Reloaded Probe object does not match the original" + if __name__ == "__main__": test_probe()