Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Jan 30, 2024
1 parent 3d9c2e3 commit 3838cbb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
86 changes: 85 additions & 1 deletion src/probeinterface/probe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
import numpy as np
from typing import Optional

from pathlib import Path
import json

from .shank import Shank

Expand Down Expand Up @@ -959,6 +960,87 @@ def from_numpy(arr: np.ndarray) -> "Probe":
probe.set_contact_ids(arr["contact_ids"])

return probe

def to_zarr(self, folder_path: str | Path):
"""Serialize the Probe object to a Zarr file.
Parameters
----------
folder_path : str, Path
The path to the folder where the serialized data will be stored.
"""
import zarr

# Create a Zarr group
zarr_group = zarr.open(folder_path, mode='w')

# Top-level attributes
zarr_group.attrs['ndim'] = self.ndim
zarr_group.attrs['si_units'] = self.si_units

# Annotations as a group
annotations_group = zarr_group.create_group('annotations')
for key, value in self.annotations.items():
annotations_group.attrs[key] = value

# Contact annotations as a group
contact_annotations_group = zarr_group.create_group('contact_annotations')
for key, value in self.contact_annotations.items():
contact_annotations_group.create_dataset(key, data=value, chunks=True)

# Datasets
dataset_attrs = [
'_contact_positions', '_contact_plane_axes', '_contact_shapes',
'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour'
]
for attr in dataset_attrs:
value = getattr(self, attr)
if value is not None:
zarr_group.create_dataset(attr, data=value, chunks=True)

# Handling contact_shape_params
if self._contact_shape_params is not None:
shape_params_json = [json.dumps(d) for d in self._contact_shape_params]
zarr_group.create_dataset('contact_shape_params', data=shape_params_json, dtype=object)

def from_zarr(folder_path: str | Path):
# Open the Zarr group
import zarr
zarr_group = zarr.open(folder_path, mode='r')

# Initialize a new Probe instance
probe = Probe(
ndim=zarr_group.attrs['ndim'],
si_units=zarr_group.attrs['si_units']
)

# Load annotations
if 'annotations' in zarr_group:
annotations_group = zarr_group['annotations']
for key in annotations_group.attrs.keys():
probe.annotations[key] = annotations_group.attrs[key]

# Load contact annotations
if 'contact_annotations' in zarr_group:
contact_annotations_group = zarr_group['contact_annotations']
for key in contact_annotations_group:
probe.contact_annotations[key] = contact_annotations_group[key][:]

# Load datasets
dataset_attrs = [
'_contact_positions', '_contact_plane_axes', '_contact_shapes',
'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour'
]
for attr in dataset_attrs:
if attr in zarr_group:
setattr(probe, attr, zarr_group[attr][:])

# Handling contact_shape_params
if 'contact_shape_params' in zarr_group:
shape_params_json = zarr_group['contact_shape_params'][:]
probe._contact_shape_params = [json.loads(d) for d in shape_params_json]

return probe

def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
"""
Expand Down Expand Up @@ -1004,6 +1086,8 @@ def from_dataframe(df: "pandas.DataFrame") -> "Probe":
arr = df.to_records(index=False)
return Probe.from_numpy(arr)



def to_image(
self,
values: np.array | list,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from probeinterface import Probe
from probeinterface.generator import generate_dummy_probe
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -161,6 +162,22 @@ def test_set_shanks():
assert all(probe.shank_ids == shank_ids.astype(str))


def test_save_to_zarr(tmp_path):
# Generate a dummy probe instance
probe = generate_dummy_probe()

# Define file path in the temporary directory
folder_path = Path(tmp_path) / "probe.zarr"

# Save the probe object to Zarr format
probe.to_zarr(folder_path=folder_path)

# Reload the probe object from the saved Zarr file
reloaded_probe = Probe.from_zarr(folder_path=folder_path)

# Assert that the reloaded probe is equal to the original
assert probe == reloaded_probe, "Reloaded Probe object does not match the original"

if __name__ == "__main__":
test_probe()

Expand Down

0 comments on commit 3838cbb

Please sign in to comment.