Skip to content

Commit

Permalink
add to zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Jan 30, 2024
1 parent 3838cbb commit 3c6a10f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 51 deletions.
200 changes: 149 additions & 51 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def model_name(self):

@model_name.setter
def model_name(self, value):
if value is not None:
if value is not None: # Alessio, why propagate to annotations and not just set the attribute?
self.annotate(model_name=value)

@property
Expand Down Expand Up @@ -960,88 +960,188 @@ def from_numpy(arr: np.ndarray) -> "Probe":
probe.set_contact_ids(arr["contact_ids"])

return probe

def to_zarr(self, folder_path: str | Path):
"""Serialize the Probe object to a Zarr file.

Parameters
----------
folder_path : str, Path
The path to the folder where the serialized data will be stored.
def add_probe_to_zarr_group(self, group: "zarr.Group") -> None:
"""
import zarr
Serialize the probe's data and structure to a specified Zarr group.
# Create a Zarr group
zarr_group = zarr.open(folder_path, mode='w')
This method is used to save the probe's attributes, annotations, and other
related data into a Zarr group, facilitating integration into larger Zarr
structures.
# Top-level attributes
zarr_group.attrs['ndim'] = self.ndim
zarr_group.attrs['si_units'] = self.si_units
Parameters
----------
group : zarr.Group
The target Zarr group where the probe's data will be stored.
"""

# Top-level attributes used to initialize a new Probe instance
group.attrs["ndim"] = self.ndim
group.attrs["si_units"] = self.si_units

# Need this behavior because the following attributes are "" (the empty string) by default
if self.name is not "":
group.attrs["name"] = self.name
if self.manufacturer is not "":
group.attrs["manufacturer"] = self.manufacturer
if self.model_name is not "":
group.attrs["model_name"] = self.model_name
if self.serial_number is not "":
group.attrs["serial_number"] = self.serial_number


# Annotations as a group
annotations_group = zarr_group.create_group('annotations')
annotations_group = group.create_group("annotations")
for key, value in self.annotations.items():
annotations_group.attrs[key] = value

# Contact annotations as a group
contact_annotations_group = zarr_group.create_group('contact_annotations')
contact_annotations_group = group.create_group("contact_annotations")
for key, value in self.contact_annotations.items():
contact_annotations_group.create_dataset(key, data=value, chunks=True)
contact_annotations_group.create_dataset(key, data=value, chunks=False)

# Datasets
# Save the following fields of the probe as top-level datasets
dataset_attrs = [
'_contact_positions', '_contact_plane_axes', '_contact_shapes',
'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour'
"_contact_positions",
"_contact_plane_axes",
"_contact_shapes",
"device_channel_indices",
"_contact_ids",
"_shank_ids",
"probe_planar_contour",
]
for attr in dataset_attrs:
value = getattr(self, attr)
if value is not None:
zarr_group.create_dataset(attr, data=value, chunks=True)
group.create_dataset(attr, data=value, chunks=False)

# Handling contact_shape_params
if self._contact_shape_params is not None:
shape_params_json = [json.dumps(d) for d in self._contact_shape_params]
zarr_group.create_dataset('contact_shape_params', data=shape_params_json, dtype=object)

def from_zarr(folder_path: str | Path):
# Open the Zarr group
shape_params_json_np = np.array(shape_params_json, dtype=object)
group.create_dataset("contact_shape_params", data=shape_params_json_np, dtype=str)

def to_zarr(self, folder_path: str | Path) -> None:
"""
Serialize the Probe object to a Zarr file located at the specified folder path.
This method initializes a new Zarr group at the given folder path and calls
`add_probe_to_zarr_group` to serialize the Probe's data into this group, effectively
storing the entire Probe's state in a Zarr archive.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr data structure will be created and
where the serialized data will be stored. If the folder does not exist,
it will be created.
"""
import zarr
zarr_group = zarr.open(folder_path, mode='r')

# Initialize a new Probe instance
# Create or open a Zarr group for writing
zarr_group = zarr.open_group(folder_path, mode="w")

# Serialize this Probe object into the Zarr group
self.add_probe_to_zarr_group(zarr_group)

@staticmethod
def from_zarr_group(group: zarr.Group) -> "Probe":
"""
Load a probe instance from a given Zarr group.
Parameters
----------
group : zarr.Group
The Zarr group from which to load the probe.
Returns
-------
Probe
An instance of the Probe class initialized with data from the Zarr group.
"""
# Initialize a new Probe instance with basic attributes
probe = Probe(
ndim=zarr_group.attrs['ndim'],
si_units=zarr_group.attrs['si_units']
ndim=group.attrs["ndim"],
si_units=group.attrs["si_units"],
name=group.attrs.get("name", None),
manufacturer=group.attrs.get("manufacturer", None),
model_name=group.attrs.get("model_name", None),
serial_number=group.attrs.get("serial_number", None),
)

# Load annotations
if 'annotations' in zarr_group:
annotations_group = zarr_group['annotations']
for key in annotations_group.attrs.keys():
probe.annotations[key] = annotations_group.attrs[key]
annotations_group = group.get("annotations", None)
for key in annotations_group.attrs.keys():
# Use the annotate method for each key-value pair
probe.annotate(**{key: annotations_group.attrs[key]})

# Initialize contacts because it is possible to have a probe without contacts (undefined)
if "_contact_positions" in group:
positions = group["_contact_positions"][:]
shapes = group["_contact_shapes"][:]

plane_axes = group.get("_contact_plane_axes", None)
if plane_axes is not None:
plane_axes = plane_axes[:]

contact_ids = group.get("_contact_ids", None)
if contact_ids is not None:
contact_ids = contact_ids[:]

shank_ids = group.get("_shank_ids", None)
if shank_ids is not None:
shank_ids = shank_ids[:]

shape_params = group.get("contact_shape_params", None)
if shape_params is not None:
shape_params = np.array([json.loads(d) for d in shape_params[:]])

probe.set_contacts(
positions=positions,
plane_axes=plane_axes,
shapes=shapes,
shape_params=shape_params,
contact_ids=contact_ids,
shank_ids=shank_ids,
)

# Load contact annotations
if 'contact_annotations' in zarr_group:
contact_annotations_group = zarr_group['contact_annotations']
for key in contact_annotations_group:
probe.contact_annotations[key] = contact_annotations_group[key][:]
contact_annotations_group = group.get("contact_annotations", None)
if contact_annotations_group:
contact_annotations_dict = {key: contact_annotations_group[key][:] for key in contact_annotations_group}
# Use the annotate_contacts method
probe.annotate_contacts(**contact_annotations_dict)

# Load datasets
dataset_attrs = [
'_contact_positions', '_contact_plane_axes', '_contact_shapes',
'device_channel_indices', '_contact_ids', '_shank_ids', 'probe_planar_contour'
]
for attr in dataset_attrs:
if attr in zarr_group:
setattr(probe, attr, zarr_group[attr][:])
if "probe_planar_contour" in group:
# Directly assign since there's no specific setter for probe_planar_contour
probe.probe_planar_contour = group["probe_planar_contour"][:]

# Handling contact_shape_params
if 'contact_shape_params' in zarr_group:
shape_params_json = zarr_group['contact_shape_params'][:]
probe._contact_shape_params = [json.loads(d) for d in shape_params_json]
if "device_channel_indices" in group:
# Use set_device_channel_indices for device_channel_indices
probe.set_device_channel_indices(group["device_channel_indices"][:])

return probe

@staticmethod
def from_zarr(folder_path: str | Path) -> "Probe":
"""
Deserialize the Probe object from a Zarr file located at the given folder path.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr file is located.
Returns
-------
Probe
An instance of the Probe class initialized with data from the Zarr file.
"""
import zarr

zarr_group = zarr.open(folder_path, mode="r")
return Probe.from_zarr_group(zarr_group)

def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
"""
Export the probe to a pandas dataframe
Expand Down Expand Up @@ -1086,8 +1186,6 @@ def from_dataframe(df: "pandas.DataFrame") -> "Probe":
arr = df.to_records(index=False)
return Probe.from_numpy(arr)



def to_image(
self,
values: np.array | list,
Expand Down
1 change: 1 addition & 0 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_probe_equality_dunder():

# Modify probe2
probe2.move([1, 1])
assert probe1 != probe2


def test_set_shanks():
Expand Down

0 comments on commit 3c6a10f

Please sign in to comment.