diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eb0d89f..839bb25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,4 +9,4 @@ repos: rev: 24.1.1 hooks: - id: black - files: ^src/ + files: ^src/|^tests/ diff --git a/pyproject.toml b/pyproject.toml index c8a12d0..cf40085 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ test = [ "scipy", "pandas", "h5py", -] + ] docs = [ "pillow", diff --git a/tests/test_generator.py b/tests/test_generator.py index 2836e1c..d14ed9a 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,6 +1,11 @@ -from probeinterface import (generate_dummy_probe, generate_dummy_probe_group, - generate_tetrode, generate_linear_probe, generate_multi_columns_probe, - generate_multi_shank) +from probeinterface import ( + generate_dummy_probe, + generate_dummy_probe_group, + generate_tetrode, + generate_linear_probe, + generate_multi_columns_probe, + generate_multi_shank, +) from pathlib import Path @@ -15,20 +20,19 @@ def test_generate(): tetrode = generate_tetrode() - multi_columns = generate_multi_columns_probe(num_columns=3, - num_contact_per_column=[10, 12, 10], - xpitch=22, ypitch=20, - y_shift_per_column=[0, -10, 0]) + multi_columns = generate_multi_columns_probe( + num_columns=3, num_contact_per_column=[10, 12, 10], xpitch=22, ypitch=20, y_shift_per_column=[0, -10, 0] + ) - linear = generate_linear_probe(num_elec=16, ypitch=20, - contact_shapes='square', contact_shape_params={'width': 15}) + linear = generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="square", contact_shape_params={"width": 15}) multi_shank = generate_multi_shank() - #~ from probeinterface.plotting import plot_probe_group, plot_probe - #~ import matplotlib.pyplot as plt - #~ plot_probe(multi_shank, with_contact_id=True,) - #~ plt.show() + # ~ from probeinterface.plotting import plot_probe_group, plot_probe + # ~ import matplotlib.pyplot as plt + # ~ plot_probe(multi_shank, with_contact_id=True,) + # ~ plt.show() -if __name__ == '__main__': + +if __name__ == "__main__": test_generate() diff --git a/tests/test_io/test_io.py b/tests/test_io/test_io.py index 578b29a..f038c52 100644 --- a/tests/test_io/test_io.py +++ b/tests/test_io/test_io.py @@ -43,6 +43,7 @@ def test_probeinterface_format(tmp_path): # ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False) # ~ plt.show() + def test_writeprobeinterface(tmp_path): probe = generate_dummy_probe() file_path = tmp_path / "test.prb" @@ -61,7 +62,6 @@ def test_writeprobeinterface_raises_error_with_bad_input(tmp_path): write_probeinterface(file_path, probe) - def test_BIDS_format(tmp_path): folder_path = tmp_path / "test_BIDS" folder_path.mkdir() @@ -77,9 +77,7 @@ def test_BIDS_format(tmp_path): # with BIDS specifications n_els = sum([p.get_contact_count() for p in probegroup.probes]) # using np.random.choice to ensure uniqueness of contact ids - el_ids = np.random.choice( - np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els - ).astype(str) + el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str) for probe in probegroup.probes: probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()]) probe.set_contact_ids(probe_el_ids) @@ -102,12 +100,7 @@ def test_BIDS_format(tmp_path): assert all(np.isin(probe_orig.contact_ids, probe_read.contact_ids)) # the transformation of contact order between the two probes - t = np.array( - [ - list(probe_read.contact_ids).index(elid) - for elid in probe_orig.contact_ids - ] - ) + t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids]) assert all(probe_orig.contact_ids == probe_read.contact_ids[t]) assert all(probe_orig.shank_ids == probe_read.shank_ids[t]) @@ -116,21 +109,14 @@ def test_BIDS_format(tmp_path): assert probe_orig.si_units == probe_read.si_units for i in range(len(probe_orig.probe_planar_contour)): - assert all( - probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i] - ) + assert all(probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i]) for sid, shape_params in enumerate(probe_orig.contact_shape_params): assert shape_params == probe_read.contact_shape_params[t][sid] for i in range(len(probe_orig.contact_positions)): - assert all( - probe_orig.contact_positions[i] == probe_read.contact_positions[t][i] - ) + assert all(probe_orig.contact_positions[i] == probe_read.contact_positions[t][i]) for i in range(len(probe.contact_plane_axes)): for dim in range(len(probe.contact_plane_axes[i])): - assert all( - probe_orig.contact_plane_axes[i][dim] - == probe_read.contact_plane_axes[t][i][dim] - ) + assert all(probe_orig.contact_plane_axes[i][dim] == probe_read.contact_plane_axes[t][i][dim]) def test_BIDS_format_empty(tmp_path): @@ -218,6 +204,7 @@ def test_prb(tmp_path): # plot_probe(probe) # plt.show() + if __name__ == "__main__": # test_probeinterface_format() # test_BIDS_format() diff --git a/tests/test_io/test_openephys.py b/tests/test_io/test_openephys.py index 20e58e6..7f390c5 100644 --- a/tests/test_io/test_openephys.py +++ b/tests/test_io/test_openephys.py @@ -61,20 +61,15 @@ def test_NP_Ultra(): assert len(np.unique(probeD.contact_positions[:, 0])) == 1 - def test_NP1_subset(): # NP1 - 200 channels selected by recording_state in Record Node - probe_ap = read_openephys( - data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP" - ) + probe_ap = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP") assert probe_ap.get_shank_count() == 1 assert "1.0" in probe_ap.model_name assert probe_ap.get_contact_count() == 200 - probe_lf = read_openephys( - data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP" - ) + probe_lf = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP") assert probe_lf.get_shank_count() == 1 assert "1.0" in probe_lf.model_name @@ -88,9 +83,7 @@ def test_NP1_subset(): def test_multiple_probes(): # multiple probes - probeA = read_openephys( - data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA" - ) + probeA = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA") assert probeA.get_shank_count() == 1 assert "1.0" in probeA.model_name @@ -109,9 +102,7 @@ def test_multiple_probes(): assert probeC.get_shank_count() == 1 - probeD = read_openephys( - data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD" - ) + probeD = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD") assert probeD.get_shank_count() == 1 @@ -148,13 +139,10 @@ def test_np_opto_with_sync(): assert probe.get_contact_count() == 384 - def test_older_than_06_format(): ## Test with the open ephys < 0.6 format - probe = read_openephys( - data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0" - ) + probe = read_openephys(data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0") assert probe.get_shank_count() == 4 assert "2.0 - Four Shank" in probe.model_name diff --git a/tests/test_io/test_spikeglx.py b/tests/test_io/test_spikeglx.py index 48281ba..f32ee6e 100644 --- a/tests/test_io/test_spikeglx.py +++ b/tests/test_io/test_spikeglx.py @@ -11,6 +11,7 @@ data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx" + def test_parse_meta(): for meta_file in [ "doppio-checkerboard_t0.imec0.ap.meta", @@ -19,19 +20,17 @@ def test_parse_meta(): ]: meta = parse_spikeglx_meta(data_path / meta_file) + def test_get_saved_channel_indices_from_spikeglx_meta(): # all channel saved + 1 synchro - chan_inds = get_saved_channel_indices_from_spikeglx_meta( - data_path / "Noise_g0_t0.imec0.ap.meta" - ) + chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "Noise_g0_t0.imec0.ap.meta") assert chan_inds.size == 385 # example by Pierre Yger NP1.0 with 384 but only 151 channels are saved + 1 synchro - chan_inds = get_saved_channel_indices_from_spikeglx_meta( - data_path / "NP1_saved_only_subset_of_channels.meta" - ) + chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "NP1_saved_only_subset_of_channels.meta") assert chan_inds.size == 152 + def test_NP1(): probe = read_spikeglx(data_path / "Noise_g0_t0.imec0.ap.meta") assert "1.0" in probe.model_name @@ -187,6 +186,7 @@ def tes_NP1_384_channels(): assert probe.get_contact_count() == 151 assert 152 not in probe.contact_annotations["channel_ids"] + def test_NPH_long_staggered(): # Data provided by Nate Dolensek probe = read_spikeglx(data_path / "non_human_primate_long_staggered.imec0.ap.meta") @@ -242,6 +242,7 @@ def test_NPH_long_staggered(): assert np.allclose(references, 0) assert np.allclose(filters, 1) + def test_NPH_short_linear_probe_type_0(): # Data provided by Jonathan A Michaels probe = read_spikeglx(data_path / "non_human_primate_short_linear_probe_type_0.meta") @@ -254,7 +255,6 @@ def test_NPH_short_linear_probe_type_0(): assert probe.get_shank_count() == 1 assert probe.get_contact_count() == 384 - # Test contact geometry x_pitch = 56.0 y_pitch = 20.0 @@ -320,6 +320,7 @@ def test_ultra_probe(): unique_y_values = np.unique(y) assert unique_y_values.size == expected_electode_rows + def test_CatGT_NP1(): probe = read_spikeglx(data_path / "catgt.meta") assert "1.0" in probe.model_name diff --git a/tests/test_library.py b/tests/test_library.py index 33875f6..8d4059d 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -1,6 +1,5 @@ from probeinterface import Probe -from probeinterface.library import (download_probeinterface_file, - get_from_cache, get_probe) +from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe from pathlib import Path @@ -9,19 +8,20 @@ import pytest -manufacturer = 'neuronexus' -probe_name = 'A1x32-Poly3-10mm-50-177' +manufacturer = "neuronexus" +probe_name = "A1x32-Poly3-10mm-50-177" def test_download_probeinterface_file(): download_probeinterface_file(manufacturer, probe_name) + def test_get_from_cache(): download_probeinterface_file(manufacturer, probe_name) probe = get_from_cache(manufacturer, probe_name) assert isinstance(probe, Probe) - probe = get_from_cache('yep', 'yop') + probe = get_from_cache("yep", "yop") assert probe is None @@ -31,7 +31,7 @@ def test_get_probe(): assert probe.get_contact_count() == 32 -if __name__ == '__main__': +if __name__ == "__main__": test_download_probeinterface_file() test_get_from_cache() test_get_probe() diff --git a/tests/test_plotting.py b/tests/test_plotting.py index e79fa42..3496bf2 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -13,7 +13,7 @@ def test_plot_probe(): plot_probe(probe) plot_probe(probe, with_contact_id=True) plot_probe(probe, with_device_index=True) - plot_probe(probe, text_on_contact=['abcde'[i%5] for i in range(probe.get_contact_count())]) + plot_probe(probe, text_on_contact=["abcde"[i % 5] for i in range(probe.get_contact_count())]) # with color n = probe.get_contact_count() @@ -21,7 +21,7 @@ def test_plot_probe(): plot_probe(probe, contacts_colors=contacts_colors) # 3d - probe_3d = probe.to_3d(axes='xz') + probe_3d = probe.to_3d(axes="xz") plot_probe(probe_3d) # on click @@ -43,7 +43,7 @@ def test_plot_probe_group(): plot_probe_group(probegroup_3d, same_axes=True) -if __name__ == '__main__': +if __name__ == "__main__": test_plot_probe() # test_plot_probe_group() plt.show() diff --git a/tests/test_probe.py b/tests/test_probe.py index 94a0483..1d47d6e 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -4,6 +4,7 @@ import pytest + def _dummy_position(): n = 24 positions = np.zeros((n, 2)) @@ -19,10 +20,10 @@ def _dummy_position(): def test_probe(): positions = _dummy_position() - probe = Probe(ndim=2, si_units='um') - probe.set_contacts(positions=positions, shapes='circle', shape_params={'radius': 5}) - probe.set_contacts(positions=positions, shapes='square', shape_params={'width': 5}) - probe.set_contacts(positions=positions, shapes='rect', shape_params={'width': 8, 'height':5 }) + probe = Probe(ndim=2, si_units="um") + probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + probe.set_contacts(positions=positions, shapes="square", shape_params={"width": 5}) + probe.set_contacts(positions=positions, shapes="rect", shape_params={"width": 8, "height": 5}) assert probe.get_contact_count() == 24 @@ -34,20 +35,20 @@ def test_probe(): probe.create_auto_shape() # annotation - probe.annotate(manufacturer='me') - assert 'manufacturer' in probe.annotations - probe.annotate_contacts(impedance=np.random.rand(24)*1000) - assert 'impedance' in probe.contact_annotations + probe.annotate(manufacturer="me") + assert "manufacturer" in probe.annotations + probe.annotate_contacts(impedance=np.random.rand(24) * 1000) + assert "impedance" in probe.contact_annotations # device channel - chans = np.arange(0, 24, dtype='int') + chans = np.arange(0, 24, dtype="int") np.random.shuffle(chans) probe.set_device_channel_indices(chans) # contact_ids int or str elec_ids = np.arange(24) probe.set_contact_ids(elec_ids) - elec_ids = [f'elec #{e}' for e in range(24)] + elec_ids = [f"elec #{e}" for e in range(24)] probe.set_contact_ids(elec_ids) # copy @@ -59,18 +60,17 @@ def test_probe(): # make annimage values = np.random.randn(24) - image, xlims, ylims = probe.to_image(values, method='cubic') - - image2, xlims, ylims = probe.to_image(values, method='cubic', num_pixel=16) + image, xlims, ylims = probe.to_image(values, method="cubic") - #~ from probeinterface.plotting import plot_probe_group, plot_probe - #~ import matplotlib.pyplot as plt - #~ fig, ax = plt.subplots() - #~ plot_probe(probe, ax=ax) - #~ ax.imshow(image, extent=xlims+ylims, origin='lower') - #~ ax.imshow(image2, extent=xlims+ylims, origin='lower') - #~ plt.show() + image2, xlims, ylims = probe.to_image(values, method="cubic", num_pixel=16) + # ~ from probeinterface.plotting import plot_probe_group, plot_probe + # ~ import matplotlib.pyplot as plt + # ~ fig, ax = plt.subplots() + # ~ plot_probe(probe, ax=ax) + # ~ ax.imshow(image, extent=xlims+ylims, origin='lower') + # ~ ax.imshow(image2, extent=xlims+ylims, origin='lower') + # ~ plt.show() # 3d probe_3d = probe.to_3d() @@ -81,10 +81,10 @@ def test_probe(): probe_2d = probe_3d.to_2d(axes="xz") assert np.allclose(probe_2d.contact_positions, probe_3d.contact_positions[:, [0, 2]]) - #~ from probeinterface.plotting import plot_probe_group, plot_probe - #~ import matplotlib.pyplot as plt - #~ plot_probe(probe_3d) - #~ plt.show() + # ~ from probeinterface.plotting import plot_probe_group, plot_probe + # ~ import matplotlib.pyplot as plt + # ~ plot_probe(probe_3d) + # ~ plt.show() # get shanks for shank in probe.get_shanks(): @@ -110,40 +110,36 @@ def test_probe(): df = probe.to_dataframe(complete=False) other2 = Probe.from_dataframe(df) df = probe_3d.to_dataframe(complete=True) - # print(df.index) + # print(df.index) other_3d = Probe.from_dataframe(df) assert other_3d.ndim == 3 # slice handling - selection = np.arange(0,18,2) + selection = np.arange(0, 18, 2) # print(selection.dtype.kind) sliced_probe = probe.get_slice(selection) assert sliced_probe.get_contact_count() == 9 - assert sliced_probe.contact_annotations['impedance'].shape == (9, ) + assert sliced_probe.contact_annotations["impedance"].shape == (9,) - #~ from probeinterface.plotting import plot_probe_group, plot_probe - #~ import matplotlib.pyplot as plt - #~ plot_probe(probe) - #~ plot_probe(sliced_probe) + # ~ from probeinterface.plotting import plot_probe_group, plot_probe + # ~ import matplotlib.pyplot as plt + # ~ plot_probe(probe) + # ~ plot_probe(sliced_probe) - selection = np.ones(24, dtype='bool') + selection = np.ones(24, dtype="bool") selection[::2] = False sliced_probe = probe.get_slice(selection) assert sliced_probe.get_contact_count() == 12 - assert sliced_probe.contact_annotations['impedance'].shape == (12, ) + assert sliced_probe.contact_annotations["impedance"].shape == (12,) - #~ plot_probe(probe) - #~ plot_probe(sliced_probe) - #~ plt.show() + # ~ plot_probe(probe) + # ~ plot_probe(sliced_probe) + # ~ plt.show() def test_set_shanks(): - probe = Probe(ndim=2, si_units='um') - probe.set_contacts( - positions= np.arange(20).reshape(10, 2), - shapes='circle', - shape_params={'radius' : 5}) - + probe = Probe(ndim=2, si_units="um") + probe.set_contacts(positions=np.arange(20).reshape(10, 2), shapes="circle", shape_params={"radius": 5}) # for simplicity each contact is on separate shank shank_ids = np.arange(10) @@ -152,7 +148,7 @@ def test_set_shanks(): assert all(probe.shank_ids == shank_ids.astype(str)) -if __name__ == '__main__': +if __name__ == "__main__": test_probe() test_set_shanks() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index a6a4eb7..6479721 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -12,16 +12,15 @@ def test_probegroup(): nchan = 0 for i in range(3): probe = generate_dummy_probe() - probe.move([i*100, i*80]) + probe.move([i * 100, i * 80]) n = probe.get_contact_count() probe.set_device_channel_indices(np.arange(n)[::-1] + nchan) shank_ids = np.ones(n) - shank_ids[:n//2] *= i * 2 - shank_ids[n//2:] *= i * 2 +1 + shank_ids[: n // 2] *= i * 2 + shank_ids[n // 2 :] *= i * 2 + 1 probe.set_shank_ids(shank_ids) probegroup.add_probe(probe) - nchan += n indices = probegroup.get_global_device_channel_indices() @@ -29,7 +28,7 @@ def test_probegroup(): ids = probegroup.get_global_contact_ids() df = probegroup.to_dataframe() - #~ print(df['global_contact_ids']) + # ~ print(df['global_contact_ids']) arr = probegroup.to_numpy(complete=False) other = ProbeGroup.from_numpy(arr) @@ -39,11 +38,11 @@ def test_probegroup(): d = probegroup.to_dict() other = ProbeGroup.from_dict(d) - #~ from probeinterface.plotting import plot_probe_group, plot_probe - #~ import matplotlib.pyplot as plt - #~ plot_probe_group(probegroup) - #~ plot_probe_group(other) - #~ plt.show() + # ~ from probeinterface.plotting import plot_probe_group, plot_probe + # ~ import matplotlib.pyplot as plt + # ~ plot_probe_group(probegroup) + # ~ plot_probe_group(other) + # ~ plt.show() # checking automatic generation of ids with new dummy probes probegroup.probes = [] @@ -54,19 +53,20 @@ def test_probegroup(): for p in probegroup.probes: assert p.contact_ids is not None - assert 'probe_id' in p.annotations + assert "probe_id" in p.annotations + def test_probegroup_3d(): probegroup = ProbeGroup() for i in range(3): probe = generate_dummy_probe().to_3d() - probe.move([i*100, i*80, i*30]) + probe.move([i * 100, i * 80, i * 30]) probegroup.add_probe(probe) assert probegroup.ndim == 3 -if __name__ == '__main__': +if __name__ == "__main__": test_probegroup() - #~ test_probegroup_3d() + # ~ test_probegroup_3d() diff --git a/tests/test_shank.py b/tests/test_shank.py index ad6914c..295920b 100644 --- a/tests/test_shank.py +++ b/tests/test_shank.py @@ -10,13 +10,13 @@ def testing_shank(): num_columns = 1 num_contact_per_column = 6 contact_shapes = "square" - contact_shape_params = {'width': 6} + contact_shape_params = {"width": 6} multi_shank_probe = generate_multi_shank( num_shank=num_shank, num_columns=num_columns, num_contact_per_column=num_contact_per_column, contact_shapes=contact_shapes, - contact_shape_params=contact_shape_params + contact_shape_params=contact_shape_params, ) shank = multi_shank_probe.get_shanks()[0] @@ -55,12 +55,8 @@ def test_contact_shapes(testing_shank): def test_contact_shape_parameters(testing_shank): probe = testing_shank.probe - expected_contact_shape_params = probe.contact_shape_params[ - testing_shank.get_indices() - ] - assert np.array_equal( - testing_shank.contact_shape_params, expected_contact_shape_params - ) + expected_contact_shape_params = probe.contact_shape_params[testing_shank.get_indices()] + assert np.array_equal(testing_shank.contact_shape_params, expected_contact_shape_params) def test_device_channel_indices(testing_shank): diff --git a/tests/test_utils.py b/tests/test_utils.py index 19e3465..e1580ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,13 +1,13 @@ - import pytest from probeinterface.utils import import_safely + def test_good_import(): + np = import_safely("numpy") + assert np.__name__ == "numpy" - np = import_safely('numpy') - assert np.__name__ == 'numpy' def test_handle_import_error(): with pytest.raises(ImportError): - import_safely('not_a_real_package') + import_safely("not_a_real_package") diff --git a/tests/test_wiring.py b/tests/test_wiring.py index bfcae1e..0a3a6d5 100644 --- a/tests/test_wiring.py +++ b/tests/test_wiring.py @@ -7,26 +7,26 @@ import pytest -def test_wire_probe(): - manufacturer = 'neuronexus' - probe_name = 'A1x32-Poly3-10mm-50-177' +def test_wire_probe(): + manufacturer = "neuronexus" + probe_name = "A1x32-Poly3-10mm-50-177" probe = get_probe(manufacturer, probe_name) - probe.wiring_to_device('H32>RHD2132') + probe.wiring_to_device("H32>RHD2132") plot_probe(probe, with_contact_id=True) - manufacturer = 'cambridgeneurotech' - probe_name = 'ASSY-156-P-1' + manufacturer = "cambridgeneurotech" + probe_name = "ASSY-156-P-1" probe = get_probe(manufacturer, probe_name) - probe.wiring_to_device('ASSY-156>RHD2164') + probe.wiring_to_device("ASSY-156>RHD2164") plot_probe(probe, with_contact_id=True) -if __name__ == '__main__': +if __name__ == "__main__": test_wire_probe() plt.show()