Skip to content

Commit

Permalink
add equality dunder method and test
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Jan 30, 2024
1 parent 2cdf024 commit f139bd6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,38 @@ def get_shanks(self):
shanks.append(shank)
return shanks

def __eq__(self, other):
if not isinstance(other, Probe):
return False

if not (
self.ndim == other.ndim
and self.si_units == other.si_units
and self.name == other.name
and self.serial_number == other.serial_number
and self.model_name == other.model_name
and self.manufacturer == other.manufacturer
and np.array_equal(self._contact_positions, other._contact_positions)
and np.array_equal(self._contact_plane_axes, other._contact_plane_axes)
and np.array_equal(self._contact_shapes, other._contact_shapes)
and np.array_equal(self._contact_shape_params, other._contact_shape_params)
and np.array_equal(self.probe_planar_contour, other.probe_planar_contour)
and np.array_equal(self._shank_ids, other._shank_ids)
and np.array_equal(self.device_channel_indices, other.device_channel_indices)
and np.array_equal(self._contact_ids, other._contact_ids)
and self.annotations == other.annotations
):
return False

# Compare contact_annotations dictionaries
if self.contact_annotations.keys() != other.contact_annotations.keys():
return False
for key in self.contact_annotations:
if not np.array_equal(self.contact_annotations[key], other.contact_annotations[key]):
return False

return True

def copy(self):
"""
Copy to another Probe instance.
Expand Down
15 changes: 15 additions & 0 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from probeinterface import Probe
from probeinterface.generator import generate_dummy_probe

import numpy as np

Expand Down Expand Up @@ -137,6 +138,20 @@ def test_probe():
#~ plt.show()


def test_probe_equality_dunder():

probe1 = generate_dummy_probe()
probe2 = generate_dummy_probe()

assert probe1 == probe1
assert probe2 == probe2
assert probe1 == probe2

# Modify probe2
probe2.move([1, 1])



def test_set_shanks():
probe = Probe(ndim=2, si_units='um')
probe.set_contacts(
Expand Down

0 comments on commit f139bd6

Please sign in to comment.