Skip to content

Commit

Permalink
Remove channel_index concept.
Browse files Browse the repository at this point in the history
Make contact_ids in generator.
Remove assertion for unique contact_ids in ProbeGroup.
  • Loading branch information
samuelgarcia committed Oct 30, 2023
1 parent 45a6782 commit 98449bb
Show file tree
Hide file tree
Showing 17 changed files with 49 additions and 67 deletions.
2 changes: 1 addition & 1 deletion doc/generate_format_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
print(d.keys())

fig, ax = plt.subplots(figsize=(8, 8))
plot_probe(probe, with_channel_index=True, ax=ax)
plot_probe(probe, ax=ax)
ax.set_xlim(-50, 200)
ax.set_ylim(-150, 120)

Expand Down
2 changes: 1 addition & 1 deletion examples/ex_03_generate_probe_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@
##############################################################################
#  or in separate axes:

plot_probe_group(probegroup, same_axes=False, with_channel_index=True)
plot_probe_group(probegroup, same_axes=False, with_contact_id=True)

plt.show()
6 changes: 3 additions & 3 deletions examples/ex_05_device_channel_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
xpitch=75, ypitch=75, y_shift_per_column=[0, -37.5, 0],
contact_shapes='circle', contact_shape_params={'radius': 12})

plot_probe(probe, with_channel_index=True)
plot_probe(probe, with_contact_id=True)

##############################################################################
# The Probe is not connected to any device yet:
Expand All @@ -51,7 +51,7 @@
# * the prbXX is the contact index ordered from 0 to N
# * the devXX is the channel index on the device (with the second half reversed)

plot_probe(probe, with_channel_index=True, with_device_index=True)
plot_probe(probe, with_contact_id=True, with_device_index=True)

##############################################################################
# Very often we have several probes on the device and this can lead to even
Expand Down Expand Up @@ -85,6 +85,6 @@
# The indices of the probe group can also be plotted:

fig, ax = plt.subplots()
plot_probe_group(probegroup, with_channel_index=True, same_axes=True, ax=ax)
plot_probe_group(probegroup, with_contact_id=True, same_axes=True, ax=ax)

plt.show()
2 changes: 1 addition & 1 deletion examples/ex_06_import_export_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@
f.write(prb_two_tetrodes)

two_tetrode = read_prb('two_tetrodes.prb')
plot_probe_group(two_tetrode, same_axes=False, with_channel_index=True)
plot_probe_group(two_tetrode, same_axes=False, with_contact_id=True)

plt.show()
6 changes: 3 additions & 3 deletions examples/ex_07_probe_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
df = probegroup.to_dataframe()
df

plot_probe_group(probegroup, with_channel_index=True, same_axes=True)
plot_probe_group(probegroup, with_contact_id=True, same_axes=True)

##############################################################################
# Generate a linear probe:
Expand All @@ -44,7 +44,7 @@
from probeinterface import generate_linear_probe

linear_probe = generate_linear_probe(num_elec=16, ypitch=20)
plot_probe(linear_probe, with_channel_index=True)
plot_probe(linear_probe, with_contact_id=True)

##############################################################################
# Generate a multi-column probe:
Expand All @@ -57,7 +57,7 @@
xpitch=22, ypitch=20,
y_shift_per_column=[0, -10, 0],
contact_shapes='square', contact_shape_params={'width': 12})
plot_probe(multi_columns, with_channel_index=True, )
plot_probe(multi_columns, with_contact_id=True, )

##############################################################################
# Generate a square probe:
Expand Down
2 changes: 1 addition & 1 deletion examples/ex_10_get_probe_from_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# When plotting, the channel indices are automatically displayed with
# one-based notation (even if internally everything is still zero based):

plot_probe(probe, with_channel_index=True)
plot_probe(probe, with_contact_id=True)

##############################################################################

Expand Down
2 changes: 1 addition & 1 deletion examples/ex_11_automatic_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# * the lower "devXX" is the channel on the Intan device (zero-based)

fig, ax = plt.subplots(figsize=(5, 15))
plot_probe(probe, with_channel_index=True, with_device_index=True, ax=ax)
plot_probe(probe, with_contact_id=True, with_device_index=True, ax=ax)


plt.show()
Expand Down
18 changes: 8 additions & 10 deletions resources/generate_cambridgeneurotech_libray.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def convert_contact_shape(listCoord):
listCoord = [float(s) for s in listCoord.split(' ')]
return listCoord

def get_channel_index(connector, probe_type):
def get_contact_order(connector, probe_type):
"""
Get the channel index given a connector and a probe_type.
This will help to re-order the probe contact later on.
Expand Down Expand Up @@ -179,7 +179,7 @@ def create_CN_figure(probe_name, probe):
plot_probe(probe, ax=ax,
contacts_colors = ['#5bc5f2'] * n, # made change to default color
probe_shape_kwargs = dict(facecolor='#6f6f6e', edgecolor='k', lw=0.5, alpha=0.3), # made change to default color
with_channel_index=True)
with_contact_id=True)

ax.set_xlabel(u'Width (\u03bcm)') #modif to legend
ax.set_ylabel(u'Height (\u03bcm)') #modif to legend
Expand Down Expand Up @@ -244,18 +244,16 @@ def generate_all_probes():
#~ continue
print(' ', probe_name)

channelIndex = get_channel_index(connector = connector, probe_type = probe_info['part'])
contact_order = get_contact_order(connector = connector, probe_type = probe_info['part'])

order = np.argsort(channelIndex)
probe = probe_unordered.get_slice(order)
sorted_indices = np.argsort(contact_order)
probe = probe_unordered.get_slice(sorted_indices)

probe.annotate(name=probe_name,
manufacturer='cambridgeneurotech',
first_index=1)
probe.annotate(name=probe_name, manufacturer='cambridgeneurotech')

# one based in cambridge neurotech
contact_ids = np.arange(order.size) + 1
contact_ids =contact_ids.astype(str)
contact_ids = np.arange(sorted_indices.size) + 1
contact_ids = contact_ids.astype(str)
probe.set_contact_ids(contact_ids)

export_one_probe(probe_name, probe)
Expand Down
1 change: 1 addition & 0 deletions src/probeinterface/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def generate_multi_columns_probe(
probe = Probe(ndim=2, si_units="um")
probe.set_contacts(positions=positions, shapes=contact_shapes, shape_params=contact_shape_params)
probe.create_auto_shape(probe_type="tip", margin=25)
probe.set_contact_ids(np.arange(positions.shape[0]).astype('str'))

return probe

Expand Down
2 changes: 1 addition & 1 deletion src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
x_pos = x_idx * x_pitch + stagger
y_pos = y_idx * y_pitch

if imDatPrb_type == 24:
if probe_description["shank_number"] > 1:
shank_ids = np.array(contact_info["shank_id"])
shank_pitch = probe_description["shank_pitch"]
contact_ids = [f"s{shank_id}e{elec_id}" for shank_id, elec_id in zip(shank_ids, elec_ids)]
Expand Down
20 changes: 1 addition & 19 deletions src/probeinterface/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ def plot_probe(
probe,
ax=None,
contacts_colors=None,
with_channel_index=False,
with_contact_id=False,
with_device_index=False,
text_on_contact=None,
first_index="auto",
contacts_values=None,
cmap="viridis",
title=True,
Expand All @@ -39,16 +37,12 @@ def plot_probe(
The axis to plot the probe on. If None, an axis is created, by default None
contacts_colors : matplotlib color, optional
The color of the contacts, by default None
with_channel_index : bool, optional
If True, channel indices are displayed on top of the channels, by default False
with_contact_id : bool, optional
If True, channel ids are displayed on top of the channels, by default False
with_device_index : bool, optional
If True, device channel indices are displayed on top of the channels, by default False
text_on_contact: None or list or numpy.array
Addintional text to plot on each contact
first_index : str, optional
The first index of the contacts, by default 'auto' (taken from channel ids)
contacts_values : np.array, optional
Values to color the contacts with, by default None
cmap : str, optional
Expand Down Expand Up @@ -92,16 +86,6 @@ def plot_probe(
else:
fig = ax.get_figure()

if first_index == "auto":
if "first_index" in probe.annotations:
first_index = probe.annotations["first_index"]
elif probe.annotations.get("manufacturer", None) == "neuronexus":
# neuronexus is one based indexing
first_index = 1
else:
first_index = 0
assert first_index in (0, 1)

_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
_probe_shape_kwargs.update(probe_shape_kwargs)

Expand Down Expand Up @@ -154,13 +138,11 @@ def on_press(event):
text_on_contact = np.asarray(text_on_contact)
assert text_on_contact.size == probe.get_contact_count()

if with_channel_index or with_contact_id or with_device_index or text_on_contact is not None:
if with_contact_id or with_device_index or text_on_contact is not None:
if probe.ndim == 3:
raise NotImplementedError("Channel index is 2d only")
for i in range(n):
txt = []
if with_channel_index:
txt.append(f"{i + first_index}")
if with_contact_id and probe.contact_ids is not None:
contact_id = probe.contact_ids[i]
txt.append(f"id{contact_id}")
Expand Down
17 changes: 15 additions & 2 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def get_shank_count(self) -> int:
n = len(np.unique(self.shank_ids))
return n

def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, shank_ids=None):
def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None,
contact_ids=None, shank_ids=None):
"""Sets contacts to a Probe.
This sets four attributes of the probe:
Expand All @@ -241,6 +242,8 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10},
plane_axes : np.array (num_contacts, 2, ndim)
Defines the two axes of the contact plane for each electrode.
The third dimension corresponds to the probe `ndim` (2d or 3d).
contact_ids: None or array of str
Defines the contact ids for the contacts. If None, contact ids are not assigned.
shank_ids : None or array of str
Defines the shank ids for the contacts. If None, then
these are assigned to a unique Shank.
Expand All @@ -264,6 +267,9 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10},
plane_axes = np.array(plane_axes)
self._contact_plane_axes = plane_axes

if contact_ids is not None:
self.set_contact_ids(contact_ids)

if shank_ids is None:
self._shank_ids = np.zeros(n, dtype=str)
else:
Expand All @@ -286,6 +292,8 @@ def set_contacts(self, positions, shapes="circle", shape_params={"radius": 10},
shape_params = [shape_params] * n
self._contact_shape_params = np.array(shape_params)



def set_planar_contour(self, contour_polygon: list):
"""Set the planar countour (the shape) of the probe.
Expand Down Expand Up @@ -402,9 +410,14 @@ def set_contact_ids(self, contact_ids: np.array | list):
"""
contact_ids = np.asarray(contact_ids)
if np.all(contact_ids == ""):
self._contact_ids = None
return

assert np.unique(contact_ids).size == contact_ids.size, "Contact ids have to be unique within a Probe"

if contact_ids.size != self.get_contact_count():
ValueError(f"channel_indices do not have the same size as number of contacts")
ValueError(f"contact_ids do not have the same size as number of contacts")

if contact_ids.dtype.kind != "U":
contact_ids = contact_ids.astype("U")
Expand Down
21 changes: 5 additions & 16 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def add_probe(self, probe):

def _check_compatible(self, probe):
if probe._probe_group is not None:
raise ValueError("This probe is already attached to another ProbeGroup")
raise ValueError("This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup")

if probe.ndim != self.probes[-1].ndim:
raise ValueError("ndim are not compatible")
Expand All @@ -38,7 +38,7 @@ def _check_compatible(self, probe):
def ndim(self):
return self.probes[0].ndim

def get_channel_count(self):
def get_contact_count(self):
"""
Total number of channels.
"""
Expand Down Expand Up @@ -144,7 +144,7 @@ def get_global_device_channel_indices(self):
Note:
channel -1 means not connected
"""
total_chan = self.get_channel_count()
total_chan = self.get_contact_count()
channels = np.zeros(total_chan, dtype=[("probe_index", "int64"), ("device_channel_indices", "int64")])
arr = self.to_numpy(complete=True)
channels["probe_index"] = arr["probe_index"]
Expand All @@ -156,7 +156,7 @@ def set_global_device_channel_indices(self, channels):
Set global indices for all probes
"""
channels = np.asarray(channels)
if channels.size != self.get_channel_count():
if channels.size != self.get_contact_count():
raise ValueError("Wrong channels size")

# first reset previsous indices
Expand Down Expand Up @@ -187,14 +187,6 @@ def check_global_device_wiring_and_ids(self):
if valid_chans.size != np.unique(valid_chans).size:
raise ValueError("channel device index are not unique across probes")

# check unique ids for != ''
all_ids = self.get_global_contact_ids()
keep = [e != "" for e in all_ids]
valid_ids = all_ids[keep]

if valid_ids.size != np.unique(valid_ids).size:
raise ValueError("contact_ids are not unique across probes")

def auto_generate_probe_ids(self, *args, **kwargs):
"""
Annotate all probes with unique probe_id values.
Expand Down Expand Up @@ -230,13 +222,10 @@ def auto_generate_contact_ids(self, *args, **kwargs):
`probeinterface.utils.generate_unique_ids`
"""

if any(p.contact_ids is not None for p in self.probes):
raise ValueError("Some contacts already have contact ids " "assigned.")

if not args:
args = 1e7, 1e8
# 3rd argument has to be the number of probes
args = args[:2] + (self.get_channel_count(),)
args = args[:2] + (self.get_contact_count(),)

contact_ids = generate_unique_ids(*args, **kwargs).astype(str)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_generate():

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(multi_shank, with_channel_index=True,)
#~ plot_probe(multi_shank, with_contact_id=True,)
#~ plt.show()

if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_probeinterface_format(tmp_path):

# ~ from probeinterface.plotting import plot_probe_group
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup2, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup, with_contact_id=True, same_axes=False)
# ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False)
# ~ plt.show()

def test_writeprobeinterface(tmp_path):
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_prb(tmp_path):

# ~ from probeinterface.plotting import plot_probe_group
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup, with_channel_index=True, same_axes=False)
# ~ plot_probe_group(probegroup, with_contact_id=True, same_axes=False)
# ~ plt.show()

# from probeinterface.plotting import plot_probe
Expand Down
3 changes: 1 addition & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def test_plot_probe():
probe = generate_dummy_probe()
plot_probe(probe)
plot_probe(probe, with_channel_index=True)
plot_probe(probe, with_contact_id=True)
plot_probe(probe, with_device_index=True)
plot_probe(probe, text_on_contact=['abcde'[i%5] for i in range(probe.get_contact_count())])
Expand All @@ -33,7 +32,7 @@ def test_plot_probe():
def test_plot_probe_group():
probegroup = generate_dummy_probe_group()

plot_probe_group(probegroup, same_axes=True, with_channel_index=True)
plot_probe_group(probegroup, same_axes=True, with_contact_id=True)
plot_probe_group(probegroup, same_axes=False)

# 3d
Expand Down
Loading

0 comments on commit 98449bb

Please sign in to comment.