From d32188b70e007de7d3e00e99dc20e5a7a9ad9716 Mon Sep 17 00:00:00 2001 From: Oliver Ruebel Date: Mon, 14 Oct 2024 13:21:09 -0700 Subject: [PATCH] Fix bad validation check for postion in ElectrodeGroup.__init__ (#1770) Co-authored-by: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> Co-authored-by: Steph Prince <40640337+stephprince@users.noreply.github.com> Co-authored-by: Ryan Ly --- CHANGELOG.md | 3 ++ docs/source/conf.py | 4 --- pyproject.toml | 2 +- requirements-min.txt | 2 +- requirements.txt | 2 +- src/pynwb/ecephys.py | 27 ++++++++++++--- tests/integration/hdf5/test_ecephys.py | 3 +- tests/unit/test_ecephys.py | 48 +++++++++++++++++++++++--- 8 files changed, 75 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5909f577..16e504659 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ ### Performance - Cache global type map to speed import 3X. @sneakers-the-rat [#1931](https://github.com/NeurodataWithoutBorders/pynwb/pull/1931) +### Bug fixes +- Fixed bug in how `ElectrodeGroup.__init__` validates its `position` argument. @oruebel [#1770](https://github.com/NeurodataWithoutBorders/pynwb/pull/1770) + ## PyNWB 2.8.2 (September 9, 2024) ### Enhancements and minor changes diff --git a/docs/source/conf.py b/docs/source/conf.py index 4eaf1a19b..eabca22c7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -244,7 +244,6 @@ def __call__(self, filename): # html_theme = 'default' # html_theme = "sphinxdoc" html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -260,9 +259,6 @@ def __call__(self, filename): 'css/custom.css', ] -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". # html_title = None diff --git a/pyproject.toml b/pyproject.toml index f798f2b5a..14a11f5d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ ] dependencies = [ "h5py>=2.10", - "hdmf>=3.14.3", + "hdmf>=3.14.5", "numpy>=1.18", "pandas>=1.1.5", "python-dateutil>=2.7.3", diff --git a/requirements-min.txt b/requirements-min.txt index eef051b25..feed604bc 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,6 +1,6 @@ # minimum versions of package dependencies for installing PyNWB h5py==2.10 # support for selection of datasets with list of indices added in 2.10 -hdmf==3.14.3 +hdmf==3.14.5 numpy==1.18 pandas==1.1.5 python-dateutil==2.7.3 diff --git a/requirements.txt b/requirements.txt index 6d7a17623..1e7a5e18d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # pinned dependencies to reproduce an entire development environment to use PyNWB h5py==3.11.0 -hdmf==3.14.3 +hdmf==3.14.5 numpy==2.1.1; python_version > "3.9" # numpy 2.1+ is not compatible with py3.9 numpy==2.0.2; python_version == "3.9" pandas==2.2.2 diff --git a/src/pynwb/ecephys.py b/src/pynwb/ecephys.py index 3187cca4a..e6dadbe97 100644 --- a/src/pynwb/ecephys.py +++ b/src/pynwb/ecephys.py @@ -1,4 +1,5 @@ import warnings +import numpy as np from collections.abc import Iterable from hdmf.common import DynamicTableRegion @@ -26,13 +27,31 @@ class ElectrodeGroup(NWBContainer): {'name': 'location', 'type': str, 'doc': 'description of location of this electrode group'}, {'name': 'device', 'type': Device, 'doc': 'the device that was used to record from this electrode group'}, {'name': 'position', 'type': 'array_data', - 'doc': 'stereotaxic position of this electrode group (x, y, z)', 'default': None}) + 'doc': 'Compound dataset with stereotaxic position of this electrode group (x, y, z). ' + 'The data array must have three elements or the dtype of the ' + 'array must be ``(float, float, float)``', 'default': None}) def __init__(self, **kwargs): args_to_set = popargs_to_dict(('description', 'location', 'device', 'position'), kwargs) super().__init__(**kwargs) - if args_to_set['position'] and len(args_to_set['position']) != 3: - raise ValueError('ElectrodeGroup position argument must have three elements: x, y, z, but received: %s' - % str(args_to_set['position'])) + + # position is a compound dataset, i.e., this must be a scalar with a + # compound data type of three floats or a list/tuple of three entries + position = args_to_set['position'] + if position: + # check position argument is valid + position_dtype_invalid = ( + (hasattr(position, 'dtype') and len(position.dtype) != 3) or + (not hasattr(position, 'dtype') and len(position) != 3) or + (len(np.shape(position)) > 1) + ) + if position_dtype_invalid: + raise ValueError(f"ElectrodeGroup position argument must have three elements: x, y, z," + f"but received: {position}") + + # convert position to scalar with compound data type if needed + if not hasattr(position, 'dtype'): + args_to_set['position'] = np.array(tuple(position), dtype=[('x', float), ('y', float), ('z', float)]) + for key, val in args_to_set.items(): setattr(self, key, val) diff --git a/tests/integration/hdf5/test_ecephys.py b/tests/integration/hdf5/test_ecephys.py index ff67d27c9..c44725277 100644 --- a/tests/integration/hdf5/test_ecephys.py +++ b/tests/integration/hdf5/test_ecephys.py @@ -26,7 +26,8 @@ def setUpContainer(self): eg = ElectrodeGroup(name='elec1', description='a test ElectrodeGroup', location='a nonexistent place', - device=self.dev1) + device=self.dev1, + position=(1., 2., 3.)) return eg def addContainer(self, nwbfile): diff --git a/tests/unit/test_ecephys.py b/tests/unit/test_ecephys.py index dc194af2a..1415c3d30 100644 --- a/tests/unit/test_ecephys.py +++ b/tests/unit/test_ecephys.py @@ -178,16 +178,34 @@ class ElectrodeGroupConstructor(TestCase): def test_init(self): dev1 = Device('dev1') - group = ElectrodeGroup('elec1', 'electrode description', 'electrode location', dev1, (1, 2, 3)) + group = ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1, + position=(1, 2, 3)) self.assertEqual(group.name, 'elec1') self.assertEqual(group.description, 'electrode description') self.assertEqual(group.location, 'electrode location') self.assertEqual(group.device, dev1) - self.assertEqual(group.position, (1, 2, 3)) + self.assertEqual(group.position.tolist(), (1, 2, 3)) + + def test_init_position_array(self): + position = np.array((1, 2, 3), dtype=np.dtype([('x', float), ('y', float), ('z', float)])) + dev1 = Device('dev1') + group = ElectrodeGroup('elec1', 'electrode description', 'electrode location', dev1, + position) + self.assertEqual(group.name, 'elec1') + self.assertEqual(group.description, 'electrode description') + self.assertEqual(group.location, 'electrode location') + self.assertEqual(group.device, dev1) + self.assertEqual(group.position, position) def test_init_position_none(self): dev1 = Device('dev1') - group = ElectrodeGroup('elec1', 'electrode description', 'electrode location', dev1) + group = ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1) self.assertEqual(group.name, 'elec1') self.assertEqual(group.description, 'electrode description') self.assertEqual(group.location, 'electrode location') @@ -197,7 +215,29 @@ def test_init_position_none(self): def test_init_position_bad(self): dev1 = Device('dev1') with self.assertRaises(ValueError): - ElectrodeGroup('elec1', 'electrode description', 'electrode location', dev1, (1, 2)) + ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1, + position=(1, 2)) + with self.assertRaises(ValueError): + ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1, + position=[(1, 2), ]) + with self.assertRaises(ValueError): + ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1, + position=np.array([(1., 2.)], dtype=np.dtype([('x', float), ('y', float)]))) + with self.assertRaises(ValueError): + ElectrodeGroup(name='elec1', + description='electrode description', + location='electrode location', + device=dev1, + position=[(1, 2, 3), (4, 5, 6), (7, 8, 9)]) class EventDetectionConstructor(TestCase):