Skip to content

Commit

Permalink
Merge pull request #249 from h-mayorquin/black_test
Browse files Browse the repository at this point in the history
Black format tests
  • Loading branch information
alejoe91 authored Jan 31, 2024
2 parents e4f36a1 + bc23643 commit ee744f8
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 146 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ repos:
rev: 24.1.1
hooks:
- id: black
files: ^src/
files: ^src/|^tests/
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test = [
"scipy",
"pandas",
"h5py",
]
]

docs = [
"pillow",
Expand Down
32 changes: 18 additions & 14 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from probeinterface import (generate_dummy_probe, generate_dummy_probe_group,
generate_tetrode, generate_linear_probe, generate_multi_columns_probe,
generate_multi_shank)
from probeinterface import (
generate_dummy_probe,
generate_dummy_probe_group,
generate_tetrode,
generate_linear_probe,
generate_multi_columns_probe,
generate_multi_shank,
)


from pathlib import Path
Expand All @@ -15,20 +20,19 @@ def test_generate():

tetrode = generate_tetrode()

multi_columns = generate_multi_columns_probe(num_columns=3,
num_contact_per_column=[10, 12, 10],
xpitch=22, ypitch=20,
y_shift_per_column=[0, -10, 0])
multi_columns = generate_multi_columns_probe(
num_columns=3, num_contact_per_column=[10, 12, 10], xpitch=22, ypitch=20, y_shift_per_column=[0, -10, 0]
)

linear = generate_linear_probe(num_elec=16, ypitch=20,
contact_shapes='square', contact_shape_params={'width': 15})
linear = generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="square", contact_shape_params={"width": 15})

multi_shank = generate_multi_shank()

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(multi_shank, with_contact_id=True,)
#~ plt.show()
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe(multi_shank, with_contact_id=True,)
# ~ plt.show()

if __name__ == '__main__':

if __name__ == "__main__":
test_generate()
27 changes: 7 additions & 20 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_probeinterface_format(tmp_path):
# ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False)
# ~ plt.show()


def test_writeprobeinterface(tmp_path):
probe = generate_dummy_probe()
file_path = tmp_path / "test.prb"
Expand All @@ -61,7 +62,6 @@ def test_writeprobeinterface_raises_error_with_bad_input(tmp_path):
write_probeinterface(file_path, probe)



def test_BIDS_format(tmp_path):
folder_path = tmp_path / "test_BIDS"
folder_path.mkdir()
Expand All @@ -77,9 +77,7 @@ def test_BIDS_format(tmp_path):
# with BIDS specifications
n_els = sum([p.get_contact_count() for p in probegroup.probes])
# using np.random.choice to ensure uniqueness of contact ids
el_ids = np.random.choice(
np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els
).astype(str)
el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str)
for probe in probegroup.probes:
probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()])
probe.set_contact_ids(probe_el_ids)
Expand All @@ -102,12 +100,7 @@ def test_BIDS_format(tmp_path):
assert all(np.isin(probe_orig.contact_ids, probe_read.contact_ids))

# the transformation of contact order between the two probes
t = np.array(
[
list(probe_read.contact_ids).index(elid)
for elid in probe_orig.contact_ids
]
)
t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids])

assert all(probe_orig.contact_ids == probe_read.contact_ids[t])
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
Expand All @@ -116,21 +109,14 @@ def test_BIDS_format(tmp_path):
assert probe_orig.si_units == probe_read.si_units

for i in range(len(probe_orig.probe_planar_contour)):
assert all(
probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i]
)
assert all(probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i])
for sid, shape_params in enumerate(probe_orig.contact_shape_params):
assert shape_params == probe_read.contact_shape_params[t][sid]
for i in range(len(probe_orig.contact_positions)):
assert all(
probe_orig.contact_positions[i] == probe_read.contact_positions[t][i]
)
assert all(probe_orig.contact_positions[i] == probe_read.contact_positions[t][i])
for i in range(len(probe.contact_plane_axes)):
for dim in range(len(probe.contact_plane_axes[i])):
assert all(
probe_orig.contact_plane_axes[i][dim]
== probe_read.contact_plane_axes[t][i][dim]
)
assert all(probe_orig.contact_plane_axes[i][dim] == probe_read.contact_plane_axes[t][i][dim])


def test_BIDS_format_empty(tmp_path):
Expand Down Expand Up @@ -218,6 +204,7 @@ def test_prb(tmp_path):
# plot_probe(probe)
# plt.show()


if __name__ == "__main__":
# test_probeinterface_format()
# test_BIDS_format()
Expand Down
22 changes: 5 additions & 17 deletions tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,15 @@ def test_NP_Ultra():
assert len(np.unique(probeD.contact_positions[:, 0])) == 1



def test_NP1_subset():
# NP1 - 200 channels selected by recording_state in Record Node
probe_ap = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP"
)
probe_ap = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP")

assert probe_ap.get_shank_count() == 1
assert "1.0" in probe_ap.model_name
assert probe_ap.get_contact_count() == 200

probe_lf = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP"
)
probe_lf = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP")

assert probe_lf.get_shank_count() == 1
assert "1.0" in probe_lf.model_name
Expand All @@ -88,9 +83,7 @@ def test_NP1_subset():

def test_multiple_probes():
# multiple probes
probeA = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA"
)
probeA = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA")

assert probeA.get_shank_count() == 1
assert "1.0" in probeA.model_name
Expand All @@ -109,9 +102,7 @@ def test_multiple_probes():

assert probeC.get_shank_count() == 1

probeD = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD"
)
probeD = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD")

assert probeD.get_shank_count() == 1

Expand Down Expand Up @@ -148,13 +139,10 @@ def test_np_opto_with_sync():
assert probe.get_contact_count() == 384



def test_older_than_06_format():
## Test with the open ephys < 0.6 format

probe = read_openephys(
data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0"
)
probe = read_openephys(data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0")

assert probe.get_shank_count() == 4
assert "2.0 - Four Shank" in probe.model_name
Expand Down
15 changes: 8 additions & 7 deletions tests/test_io/test_spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx"


def test_parse_meta():
for meta_file in [
"doppio-checkerboard_t0.imec0.ap.meta",
Expand All @@ -19,19 +20,17 @@ def test_parse_meta():
]:
meta = parse_spikeglx_meta(data_path / meta_file)


def test_get_saved_channel_indices_from_spikeglx_meta():
# all channel saved + 1 synchro
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
data_path / "Noise_g0_t0.imec0.ap.meta"
)
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "Noise_g0_t0.imec0.ap.meta")
assert chan_inds.size == 385

# example by Pierre Yger NP1.0 with 384 but only 151 channels are saved + 1 synchro
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
data_path / "NP1_saved_only_subset_of_channels.meta"
)
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "NP1_saved_only_subset_of_channels.meta")
assert chan_inds.size == 152


def test_NP1():
probe = read_spikeglx(data_path / "Noise_g0_t0.imec0.ap.meta")
assert "1.0" in probe.model_name
Expand Down Expand Up @@ -187,6 +186,7 @@ def tes_NP1_384_channels():
assert probe.get_contact_count() == 151
assert 152 not in probe.contact_annotations["channel_ids"]


def test_NPH_long_staggered():
# Data provided by Nate Dolensek
probe = read_spikeglx(data_path / "non_human_primate_long_staggered.imec0.ap.meta")
Expand Down Expand Up @@ -242,6 +242,7 @@ def test_NPH_long_staggered():
assert np.allclose(references, 0)
assert np.allclose(filters, 1)


def test_NPH_short_linear_probe_type_0():
# Data provided by Jonathan A Michaels
probe = read_spikeglx(data_path / "non_human_primate_short_linear_probe_type_0.meta")
Expand All @@ -254,7 +255,6 @@ def test_NPH_short_linear_probe_type_0():
assert probe.get_shank_count() == 1
assert probe.get_contact_count() == 384


# Test contact geometry
x_pitch = 56.0
y_pitch = 20.0
Expand Down Expand Up @@ -320,6 +320,7 @@ def test_ultra_probe():
unique_y_values = np.unique(y)
assert unique_y_values.size == expected_electode_rows


def test_CatGT_NP1():
probe = read_spikeglx(data_path / "catgt.meta")
assert "1.0" in probe.model_name
Expand Down
12 changes: 6 additions & 6 deletions tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from probeinterface import Probe
from probeinterface.library import (download_probeinterface_file,
get_from_cache, get_probe)
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe


from pathlib import Path
Expand All @@ -9,19 +8,20 @@
import pytest


manufacturer = 'neuronexus'
probe_name = 'A1x32-Poly3-10mm-50-177'
manufacturer = "neuronexus"
probe_name = "A1x32-Poly3-10mm-50-177"


def test_download_probeinterface_file():
download_probeinterface_file(manufacturer, probe_name)


def test_get_from_cache():
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
assert isinstance(probe, Probe)

probe = get_from_cache('yep', 'yop')
probe = get_from_cache("yep", "yop")
assert probe is None


Expand All @@ -31,7 +31,7 @@ def test_get_probe():
assert probe.get_contact_count() == 32


if __name__ == '__main__':
if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
6 changes: 3 additions & 3 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_plot_probe():
plot_probe(probe)
plot_probe(probe, with_contact_id=True)
plot_probe(probe, with_device_index=True)
plot_probe(probe, text_on_contact=['abcde'[i%5] for i in range(probe.get_contact_count())])
plot_probe(probe, text_on_contact=["abcde"[i % 5] for i in range(probe.get_contact_count())])

# with color
n = probe.get_contact_count()
contacts_colors = np.random.rand(n, 3)
plot_probe(probe, contacts_colors=contacts_colors)

# 3d
probe_3d = probe.to_3d(axes='xz')
probe_3d = probe.to_3d(axes="xz")
plot_probe(probe_3d)

# on click
Expand All @@ -43,7 +43,7 @@ def test_plot_probe_group():
plot_probe_group(probegroup_3d, same_axes=True)


if __name__ == '__main__':
if __name__ == "__main__":
test_plot_probe()
# test_plot_probe_group()
plt.show()
Loading

0 comments on commit ee744f8

Please sign in to comment.