Skip to content

Commit

Permalink
Add nullable values specifiers for add_electrodes (#985)
Browse files Browse the repository at this point in the history
Co-authored-by: Cody Baker <[email protected]>
  • Loading branch information
h-mayorquin and CodyCBakerPhD authored Aug 13, 2024
1 parent b7ae085 commit b9507bf
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 40 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
### Features
* Added MedPCInterface for operant behavioral output files. [PR #883](https://github.com/catalystneuro/neuroconv/pull/883)
* Support `SortingAnalyzer` in the `SpikeGLXConverterPipe`. [PR #821](https://github.com/catalystneuro/neuroconv/pull/821)
* Add argument to `add_electrodes` that grants fine control of what to do with the missing values. As a side effect this drops the implicit casting to int when writing int properties to the electrodes table [PR #985](https://github.com/catalystneuro/neuroconv/pull/985)
* Add Plexon2 support [PR #918](https://github.com/catalystneuro/neuroconv/pull/918)
* Converter working with multiple VideoInterface instances [PR #914](https://github.com/catalystneuro/neuroconv/pull/914)
* Added helper function `neuroconv.tools.data_transfers.submit_aws_batch_job` for basic automated submission of AWS batch jobs. [PR #384](https://github.com/catalystneuro/neuroconv/pull/384)

### Bug fixes
Expand Down
148 changes: 109 additions & 39 deletions src/neuroconv/tools/spikeinterface/spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections import defaultdict
from numbers import Real
from typing import List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union

import numpy as np
import psutil
Expand Down Expand Up @@ -301,11 +301,69 @@ def _get_electrode_table_indices_for_recording(recording: BaseRecording, nwbfile
return electrode_table_indices


def _get_null_value_for_property(property: str, sample_data: Any, null_values_for_properties: dict[str, Any]) -> Any:
"""
Retrieve the null value for a given property based on its data type or a provided mapping.
Also performs type checking to ensure the default value matches the type of the existing data.
Parameters
----------
sample_data : Any
The sample data for which the default value is being determined. This can be of any data type.
null_values_for_properties : dict of str to Any
A dictionary mapping properties to their respective default values. If a property is not found in this
dictionary, a sensible default value based on the type of `sample_data` will be used.
Returns
-------
Any
The default value for the specified property. The type of the default value will match the type of `sample_data`
or the type specified in `null_values_for_properties`.
Raises
------
ValueError
If a sensible default value cannot be determined for the given property and data type, or if the type of the
provided default value does not match the type of the existing data.
"""

type_to_default_value = {list: [], np.ndarray: np.array(np.nan), str: "", float: np.nan, complex: np.nan}

# Check for numpy scalar types
sample_data = sample_data.item() if isinstance(sample_data, np.generic) else sample_data

default_value = null_values_for_properties.get(property, None)

if default_value is None:
sample_data_type = type(sample_data)
default_value = type_to_default_value.get(sample_data_type, None)
if default_value is None:
error_msg = (
f"Could not find a sensible default value for property '{property}' of type {sample_data_type} \n"
"This can be fixed by by modifying the recording property or setting a sensible default value "
"using the `add_electrodes` function argument `null_values_for_properties` as in: \n"
"null_values_for_properties = {{property}': default_value}"
)
raise ValueError(error_msg)
if type(default_value) != sample_data_type:
error_msg = (
f"Default value for property '{property}' in null_values_for_properties dict has a "
f"different type {type(default_value)} than the currently existing data type {sample_data_type}. \n"
"Modify the recording property or the default value to match"
)
raise ValueError(error_msg)

return default_value


def add_electrodes(
recording: BaseRecording,
nwbfile: pynwb.NWBFile,
metadata: dict = None,
metadata: Optional[dict] = None,
exclude: tuple = (),
null_values_for_properties: Optional[dict] = None,
):
"""
Add channels from recording object as electrodes to nwbfile object.
Expand Down Expand Up @@ -341,11 +399,15 @@ def add_electrodes(
exclude: tuple
An iterable containing the string names of channel properties in the RecordingExtractor
object to ignore when writing to the NWBFile.
null_values_for_properties: dict
A dictionary mapping channel properties to null values to use when the property is not present
"""
assert isinstance(
nwbfile, pynwb.NWBFile
), f"'nwbfile' should be of type pynwb.NWBFile but is of type {type(nwbfile)}"

null_values_for_properties = dict() if null_values_for_properties is None else null_values_for_properties

# Test that metadata has the expected structure
electrodes_metadata = list()
if metadata is not None:
Expand Down Expand Up @@ -391,10 +453,6 @@ def add_electrodes(
if index and isinstance(data[0], np.ndarray):
index = data[0].ndim

# booleans are parsed as strings
if isinstance(data[0], (bool, np.bool_)):
data = data.astype(str)

# Fill with provided custom descriptions
description = property_descriptions.get(property, "no description")
data_to_add[property] = dict(description=description, data=data, index=index)
Expand Down Expand Up @@ -436,20 +494,24 @@ def add_electrodes(
data_to_add["group"] = dict(description="the ElectrodeGroup object", data=group_list, index=False)

schema_properties = {"group", "group_name", "location"}
properties_to_add = set(data_to_add)
electrode_table_previous_properties = set(nwbfile.electrodes.colnames) if nwbfile.electrodes else set()
extracted_properties = set(data_to_add)
properties_to_add_by_rows = schema_properties.union(electrode_table_previous_properties)
properties_to_add_by_columns = extracted_properties.difference(properties_to_add_by_rows)

# Properties that were added before but we are not adding now require default values
properties_to_default_value = dict()
type_to_default_value = {list: [], np.ndarray: np.array(np.nan), str: "", Real: np.nan}
for property in electrode_table_previous_properties.difference(data_to_add):
# Find a matching data type and get the default value
sample_data = nwbfile.electrodes[property].data[0]
matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type))
default_value = type_to_default_value[matching_type]
properties_to_default_value[property] = default_value
# The schema properties are always added by rows because they are required
properties_to_add_by_rows = schema_properties.union(electrode_table_previous_properties)
properties_to_add_by_columns = properties_to_add.difference(properties_to_add_by_rows)

# Properties that were added before require null values to add by rows if data is missing
properties_requiring_null_values = electrode_table_previous_properties.difference(properties_to_add)
nul_values_for_rows = dict()
for property in properties_requiring_null_values:
sample_data = nwbfile.electrodes[property][:][0]
null_value = _get_null_value_for_property(
property=property,
sample_data=sample_data,
null_values_for_properties=null_values_for_properties,
)
nul_values_for_rows[property] = null_value

# We only add new electrodes to the table
existing_global_ids = _get_electrodes_table_global_ids(nwbfile=nwbfile)
Expand All @@ -458,12 +520,13 @@ def add_electrodes(

properties_with_data = properties_to_add_by_rows.intersection(data_to_add)
for channel_index in channel_indices_to_add:
electrode_kwargs = dict(properties_to_default_value)
electrode_kwargs = nul_values_for_rows
data_dict = {property: data_to_add[property]["data"][channel_index] for property in properties_with_data}
electrode_kwargs.update(**data_dict)
nwbfile.add_electrode(**electrode_kwargs, enforce_unique_id=True)

# Add channel_name as a column and fill previously existing rows with channel_name equal to str(ids)
# The channel_name column as we use channel_name, group_name as a unique identifier
# We fill previously inexistent values with the electrode table ids
electrode_table_size = len(nwbfile.electrodes.id[:])
previous_table_size = electrode_table_size - recording.get_num_channels()

Expand All @@ -478,40 +541,46 @@ def add_electrodes(
cols_args["data"] = extended_data
nwbfile.add_electrode_column("channel_name", **cols_args)

# To fill the new data, get their indices in the electrode table
all_indices = np.arange(electrode_table_size)
indices_for_new_data = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile)
indices_for_default_values = [index for index in all_indices if index not in indices_for_new_data]
indices_for_null_values = [index for index in all_indices if index not in indices_for_new_data]
extending_column = len(indices_for_null_values) > 0

# Add properties as columns
for property in properties_to_add_by_columns - {"channel_name"}:
cols_args = data_to_add[property]
data = cols_args["data"]
if np.issubdtype(data.dtype, np.integer):
data = data.astype("float")
default_value = np.nan

else: # Find first matching data-type for custom column
# This is the simple case, early return
if not extending_column:
nwbfile.add_electrode_column(property, **cols_args)
continue

adding_ragged_array = cols_args["index"]
if not adding_ragged_array:
sample_data = data[0]
matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type))
default_value = type_to_default_value[matching_type]
dtype = data.dtype
extended_data = np.empty(shape=electrode_table_size, dtype=dtype)
extended_data[indices_for_new_data] = data

null_value = _get_null_value_for_property(
property=property,
sample_data=sample_data,
null_values_for_properties=null_values_for_properties,
)
extended_data[indices_for_null_values] = null_value
else:

if "index" in cols_args and cols_args["index"]:
dtype = np.ndarray
extended_data = np.empty(shape=len(nwbfile.electrodes.id[:]), dtype=dtype)
extended_data = np.empty(shape=electrode_table_size, dtype=dtype)
for index, value in enumerate(data):
index_in_extended_data = indices_for_new_data[index]
index_in_extended_data = indices_for_new_data[index]
extended_data[index_in_extended_data] = value.tolist()

for index in indices_for_default_values:
default_value = []
extended_data[index] = default_value

else:
dtype = data.dtype
extended_data = np.empty(shape=len(nwbfile.electrodes.id[:]), dtype=dtype)
extended_data[indices_for_new_data] = data
extended_data[indices_for_default_values] = default_value
for index in indices_for_null_values:
null_value = []
extended_data[index] = null_value

cols_args["data"] = extended_data
nwbfile.add_electrode_column(property, **cols_args)
Expand Down Expand Up @@ -738,6 +807,7 @@ def add_electrical_series(

# Create a region for the electrodes table
electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile)
electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile)
electrode_table_region = nwbfile.create_electrode_table_region(
region=electrode_table_indices,
description="electrode_table_region",
Expand Down
65 changes: 64 additions & 1 deletion tests/test_ecephys/test_tools_spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def test_write_bool_properties(self):
nwbfile=self.nwbfile,
)
self.assertIn("test_bool", self.nwbfile.electrodes.colnames)
assert all(tb in ["False", "True"] for tb in self.nwbfile.electrodes["test_bool"][:])
assert all(tb in [False, True] for tb in self.nwbfile.electrodes["test_bool"][:])


class TestAddElectrodes(TestCase):
Expand Down Expand Up @@ -948,6 +948,69 @@ def test_property_metadata_mismatch(self):
expected_property_2_values = ["", "", "value_1", "value_2", "value_3", "value_4"]
self.assertListEqual(actual_property_2_values, expected_property_2_values)

def test_missing_int_values(self):

recording1 = generate_recording(num_channels=2, durations=[1.0])
recording1 = recording1.rename_channels(new_channel_ids=["a", "b"])
recording1.set_property(key="complete_int_property", values=[1, 2])
add_electrodes(recording=recording1, nwbfile=self.nwbfile)

expected_property = np.asarray([1, 2])
extracted_property = self.nwbfile.electrodes["complete_int_property"].data
assert np.array_equal(extracted_property, expected_property)

recording2 = generate_recording(num_channels=2, durations=[1.0])
recording2 = recording2.rename_channels(new_channel_ids=["c", "d"])

recording2.set_property(key="incomplete_int_property", values=[10, 11])
with self.assertRaises(ValueError):
add_electrodes(recording=recording2, nwbfile=self.nwbfile)

null_values_for_properties = {"complete_int_property": -1, "incomplete_int_property": -3}
add_electrodes(
recording=recording2, nwbfile=self.nwbfile, null_values_for_properties=null_values_for_properties
)

expected_complete_property = np.asarray([1, 2, -1, -1])
expected_incomplete_property = np.asarray([-3, -3, 10, 11])

extracted_complete_property = self.nwbfile.electrodes["complete_int_property"].data
extracted_incomplete_property = self.nwbfile.electrodes["incomplete_int_property"].data

assert np.array_equal(extracted_complete_property, expected_complete_property)
assert np.array_equal(extracted_incomplete_property, expected_incomplete_property)

def test_missing_bool_values(self):
recording1 = generate_recording(num_channels=2)
recording1 = recording1.rename_channels(new_channel_ids=["a", "b"])
recording1.set_property(key="complete_bool_property", values=[True, False])
add_electrodes(recording=recording1, nwbfile=self.nwbfile)

expected_property = np.asarray([True, False])
extracted_property = self.nwbfile.electrodes["complete_bool_property"].data.astype(bool)
assert np.array_equal(extracted_property, expected_property)

recording2 = generate_recording(num_channels=2)
recording2 = recording2.rename_channels(new_channel_ids=["c", "d"])

recording2.set_property(key="incomplete_bool_property", values=[True, False])
with self.assertRaises(ValueError):
add_electrodes(recording=recording2, nwbfile=self.nwbfile)

null_values_for_properties = {"complete_bool_property": False, "incomplete_bool_property": False}
add_electrodes(
recording=recording2, nwbfile=self.nwbfile, null_values_for_properties=null_values_for_properties
)

expected_complete_property = np.asarray([True, False, False, False])
expected_incomplete_property = np.asarray([False, False, True, False])

extracted_complete_property = self.nwbfile.electrodes["complete_bool_property"].data.astype(bool)
extracted_incomplete_property = self.nwbfile.electrodes["incomplete_bool_property"].data.astype(bool)

assert np.array_equal(extracted_complete_property, expected_complete_property)
assert np.array_equal(extracted_incomplete_property, expected_incomplete_property)


class TestAddUnitsTable(TestCase):
@classmethod
Expand Down

0 comments on commit b9507bf

Please sign in to comment.