Skip to content

Commit

Permalink
[Bug] Pass customized backend to run_conversion (#885)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CodyCBakerPhD and pre-commit-ci[bot] authored Jun 4, 2024
1 parent 09b1d81 commit 57c9dad
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
* Fixed a bug when adding ragged arrays to the electrode and units table. [PR #870](https://github.com/catalystneuro/neuroconv/pull/870)
* Fixed a bug where `write_recording` will call an empty nwbfile when passing a path. [PR #877](https://github.com/catalystneuro/neuroconv/pull/877)
* Fixed a bug that failed to properly include time alignment information in the output NWB file for objects added from any `RecordingInterface` in combination with `stub_test=True`. [PR #884](https://github.com/catalystneuro/neuroconv/pull/884)
* Fixed a bug that prevented passing `nwbfile=None` and a `backend_configuration` to `NWBConverter.run_conversion`. [PR #885](https://github.com/catalystneuro/neuroconv/pull/885)

### Improvements
* Fixed docstrings related to backend configurations for various methods. [PR #822](https://github.com/catalystneuro/neuroconv/pull/822)
Expand Down
5 changes: 3 additions & 2 deletions src/neuroconv/basedatainterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def run_conversion(
" use DataInterface.add_to_nwbfile."
)
backend = _resolve_backend(backend, backend_configuration)
no_nwbfile_provided = nwbfile is None # Otherwise, variable reference may mutate later on inside the context

if metadata is None:
metadata = self.get_metadata()
Expand All @@ -165,10 +166,10 @@ def run_conversion(
backend=backend,
verbose=getattr(self, "verbose", False),
) as nwbfile_out:
if backend_configuration is None:
# In this case, assume the relevant data has not been added to the NWBFile
if no_nwbfile_provided:
self.add_to_nwbfile(nwbfile=nwbfile_out, metadata=metadata, **conversion_options)

if backend_configuration is None:
backend_configuration = self.get_default_backend_configuration(nwbfile=nwbfile_out, backend=backend)

configure_backend(nwbfile=nwbfile_out, backend_configuration=backend_configuration)
Expand Down
7 changes: 4 additions & 3 deletions src/neuroconv/nwbconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def run_conversion(
conversion specification is requested.
"""
backend = _resolve_backend(backend, backend_configuration)
no_nwbfile_provided = nwbfile is None # Otherwise, variable reference may mutate later on inside the context

if metadata is None:
metadata = self.get_metadata()
Expand All @@ -220,10 +221,10 @@ def run_conversion(
backend=backend,
verbose=getattr(self, "verbose", False),
) as nwbfile_out:
if backend_configuration is None:
# Otherwise assume the data has already been added to the NWBFile
self.add_to_nwbfile(nwbfile_out, metadata=metadata, conversion_options=conversion_options)
if no_nwbfile_provided:
self.add_to_nwbfile(nwbfile=nwbfile_out, metadata=metadata, conversion_options=conversion_options)

if backend_configuration is None:
backend_configuration = self.get_default_backend_configuration(nwbfile=nwbfile_out, backend=backend)

configure_backend(nwbfile=nwbfile_out, backend_configuration=backend_configuration)
Expand Down
146 changes: 77 additions & 69 deletions tests/test_ecephys/test_mock_nidq_interface.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,96 @@
import pathlib
from datetime import datetime

from hdmf.testing import TestCase
from numpy.testing import assert_array_almost_equal
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile

from neuroconv.tools.testing import MockSpikeGLXNIDQInterface


class TestMockSpikeGLXNIDQInterface(TestCase):
maxDiff = None

def test_current_default_inferred_ttl_times(self):
interface = MockSpikeGLXNIDQInterface()

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3", "nidq#XA4", "nidq#XA5", "nidq#XA6", "nidq#XA7"]
inferred_starting_times = list()
for channel_index, channel_name in enumerate(channel_names):
inferred_starting_times.append(interface.get_event_times_from_ttl(channel_name=channel_name))

expected_ttl_times = [[1.0 * (1 + 2 * period) + 0.1 * channel for period in range(3)] for channel in range(8)]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=expected_ttl_times[channel_index], decimal=4)

def test_explicit_original_default_inferred_ttl_times(self):
interface = MockSpikeGLXNIDQInterface(signal_duration=7.0, ttl_times=None, ttl_duration=1.0)

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3", "nidq#XA4", "nidq#XA5", "nidq#XA6", "nidq#XA7"]
expected_ttl_times = [[1.0 * (1 + 2 * period) + 0.1 * channel for period in range(3)] for channel in range(8)]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=expected_ttl_times[channel_index], decimal=4)

def test_custom_inferred_ttl_times(self):
custom_ttl_times = [[1.2], [3.6], [0.7, 4.5], [5.1]]
interface = MockSpikeGLXNIDQInterface(ttl_times=custom_ttl_times)

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3"]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=custom_ttl_times[channel_index], decimal=4)

def test_mock_metadata(self):
interface = MockSpikeGLXNIDQInterface()

metadata = interface.get_metadata()

expected_ecephys_metadata = {
"Ecephys": {
"Device": [
{"description": "no description", "manufacturer": "Imec", "name": "Neuropixel-Imec"},
],
"ElectrodeGroup": [
{
"name": "NIDQChannelGroup",
"description": "A group representing the NIDQ channels.",
"device": "Neuropixel-Imec",
"location": "unknown",
},
],
"Electrodes": [
{"name": "group_name", "description": "Name of the ElectrodeGroup this electrode is a part of."}
],
"ElectricalSeriesNIDQ": {
"name": "ElectricalSeriesNIDQ",
"description": "Raw acquisition traces from the NIDQ (.nidq.bin) channels.",
def test_current_default_inferred_ttl_times():
interface = MockSpikeGLXNIDQInterface()

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3", "nidq#XA4", "nidq#XA5", "nidq#XA6", "nidq#XA7"]
inferred_starting_times = list()
for channel_index, channel_name in enumerate(channel_names):
inferred_starting_times.append(interface.get_event_times_from_ttl(channel_name=channel_name))

expected_ttl_times = [[1.0 * (1 + 2 * period) + 0.1 * channel for period in range(3)] for channel in range(8)]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=expected_ttl_times[channel_index], decimal=4)


def test_explicit_original_default_inferred_ttl_times():
interface = MockSpikeGLXNIDQInterface(signal_duration=7.0, ttl_times=None, ttl_duration=1.0)

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3", "nidq#XA4", "nidq#XA5", "nidq#XA6", "nidq#XA7"]
expected_ttl_times = [[1.0 * (1 + 2 * period) + 0.1 * channel for period in range(3)] for channel in range(8)]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=expected_ttl_times[channel_index], decimal=4)


def test_custom_inferred_ttl_times():
custom_ttl_times = [[1.2], [3.6], [0.7, 4.5], [5.1]]
interface = MockSpikeGLXNIDQInterface(ttl_times=custom_ttl_times)

channel_names = ["nidq#XA0", "nidq#XA1", "nidq#XA2", "nidq#XA3"]
for channel_index, channel_name in enumerate(channel_names):
inferred_ttl_times = interface.get_event_times_from_ttl(channel_name=channel_name)
assert_array_almost_equal(x=inferred_ttl_times, y=custom_ttl_times[channel_index], decimal=4)


def test_mock_metadata():
interface = MockSpikeGLXNIDQInterface()

metadata = interface.get_metadata()

expected_ecephys_metadata = {
"Ecephys": {
"Device": [
{"description": "no description", "manufacturer": "Imec", "name": "Neuropixel-Imec"},
],
"ElectrodeGroup": [
{
"name": "NIDQChannelGroup",
"description": "A group representing the NIDQ channels.",
"device": "Neuropixel-Imec",
"location": "unknown",
},
}
],
"Electrodes": [
{"name": "group_name", "description": "Name of the ElectrodeGroup this electrode is a part of."}
],
"ElectricalSeriesNIDQ": {
"name": "ElectricalSeriesNIDQ",
"description": "Raw acquisition traces from the NIDQ (.nidq.bin) channels.",
},
}
print(metadata["Ecephys"])
self.assertDictEqual(d1=metadata["Ecephys"], d2=expected_ecephys_metadata["Ecephys"])
}
print(metadata["Ecephys"])
assert metadata["Ecephys"] == expected_ecephys_metadata["Ecephys"]

expected_start_time = datetime(2020, 11, 3, 10, 35, 10)
assert metadata["NWBFile"]["session_start_time"] == expected_start_time


expected_start_time = datetime(2020, 11, 3, 10, 35, 10)
assert metadata["NWBFile"]["session_start_time"] == expected_start_time
def test_mock_run_conversion(tmpdir: pathlib.Path):
interface = MockSpikeGLXNIDQInterface()

def test_mock_run_conversion(self):
interface = MockSpikeGLXNIDQInterface()
metadata = interface.get_metadata()

metadata = interface.get_metadata()
test_directory = pathlib.Path(tmpdir) / "TestMockSpikeGLXNIDQInterface"
test_directory.mkdir(exist_ok=True)
nwbfile_path = test_directory / "test_mock_run_conversion.nwb"
interface.run_conversion(nwbfile_path=nwbfile_path, metadata=metadata, overwrite=True)

nwbfile = mock_NWBFile()
interface.run_conversion(nwbfile=nwbfile, metadata=metadata)
with NWBHDF5IO(path=nwbfile_path, mode="r") as io:
nwbfile = io.read()

assert "Neuropixel-Imec" in nwbfile.devices
assert "NIDQChannelGroup" in nwbfile.electrode_groups
assert nwbfile.electrodes.id[:] == [0, 1, 2, 3, 4, 5, 6, 7]
assert list(nwbfile.electrodes.id[:]) == [0, 1, 2, 3, 4, 5, 6, 7]
assert "ElectricalSeriesNIDQ" in nwbfile.acquisition

0 comments on commit 57c9dad

Please sign in to comment.