Skip to content

Commit

Permalink
Drop casting to int in units table and introduce fine control over nu…
Browse files Browse the repository at this point in the history
…ll values in `units_table` (#989)

Co-authored-by: Cody Baker <[email protected]>
Co-authored-by: Ben Dichter <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent b898587 commit 40d786a
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
* Add Plexon2 support [PR #918](https://github.com/catalystneuro/neuroconv/pull/918)
* Converter working with multiple VideoInterface instances [PR #914](https://github.com/catalystneuro/neuroconv/pull/914)
* Added helper function `neuroconv.tools.data_transfers.submit_aws_batch_job` for basic automated submission of AWS batch jobs. [PR #384](https://github.com/catalystneuro/neuroconv/pull/384)
* Introduced `null_values_for_properties` to `add_units_table` to give user control over null values behavior [PR #989](https://github.com/catalystneuro/neuroconv/pull/989)


### Bug fixes
* Fixed the default naming of multiple electrical series in the `SpikeGLXConverterPipe`. [PR #957](https://github.com/catalystneuro/neuroconv/pull/957)
* Write new properties to the electrode table use the global identifier channel_name, group [PR #984](https://github.com/catalystneuro/neuroconv/pull/984)
* Removed a bug where int64 was casted lossy to float [PR #989](https://github.com/catalystneuro/neuroconv/pull/989)


### Improvements
* The `OpenEphysBinaryRecordingInterface` now uses `lxml` for extracting the session start time from the settings.xml file and does not depend on `pyopenephys` anymore. [PR #971](https://github.com/catalystneuro/neuroconv/pull/971)
Expand Down
3 changes: 3 additions & 0 deletions src/neuroconv/tools/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
add_recording,
add_sorting,
add_units_table,
add_sorting_analyzer,
check_if_recording_traces_fit_into_memory,
get_nwb_metadata,
write_recording,
write_sorting,
write_sorting_analyzer,
add_waveforms,
write_waveforms,
)
121 changes: 69 additions & 52 deletions src/neuroconv/tools/spikeinterface/spikeinterface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid
import warnings
from collections import defaultdict
from numbers import Real
from typing import Any, List, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -399,8 +398,6 @@ def add_electrodes(
exclude: tuple
An iterable containing the string names of channel properties in the RecordingExtractor
object to ignore when writing to the NWBFile.
null_values_for_properties: dict
A dictionary mapping channel properties to null values to use when the property is not present
"""
assert isinstance(
nwbfile, pynwb.NWBFile
Expand Down Expand Up @@ -807,7 +804,6 @@ def add_electrical_series(

# Create a region for the electrodes table
electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile)
electrode_table_indices = _get_electrode_table_indices_for_recording(recording=recording, nwbfile=nwbfile)
electrode_table_region = nwbfile.create_electrode_table_region(
region=electrode_table_indices,
description="electrode_table_region",
Expand Down Expand Up @@ -1107,6 +1103,7 @@ def add_units_table(
waveform_means: Optional[np.ndarray] = None,
waveform_sds: Optional[np.ndarray] = None,
unit_electrode_indices=None,
null_values_for_properties: Optional[dict] = None,
):
"""
Add sorting data to a NWBFile object as a Units table.
Expand Down Expand Up @@ -1142,12 +1139,16 @@ def add_units_table(
Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels).
unit_electrode_indices : list of lists of int, optional
For each unit, a list of electrode indices corresponding to waveform data.
null_values_for_properties: dict, optional
A dictionary mapping properties to null values to use when the property is not present
"""

assert isinstance(
nwbfile, pynwb.NWBFile
), f"'nwbfile' should be of type pynwb.NWBFile but is of type {type(nwbfile)}"

null_values_for_properties = dict() if null_values_for_properties is None else null_values_for_properties

if not write_in_processing_module and units_table_name != "units":
raise ValueError("When writing to the nwbfile.units table, the name of the table must be 'units'!")

Expand Down Expand Up @@ -1203,8 +1204,7 @@ def add_units_table(
# Extract properties
for property in properties_to_extract:
data = sorting.get_property(property)
if isinstance(data[0], (bool, np.bool_)):
data = data.astype(str)

index = isinstance(data[0], (list, np.ndarray, tuple))
if index and isinstance(data[0], np.ndarray):
index = data[0].ndim
Expand All @@ -1222,20 +1222,25 @@ def add_units_table(
unit_name_array = unit_ids.astype("str", copy=False)
data_to_add["unit_name"].update(description="Unique reference for each unit.", data=unit_name_array)

units_table_previous_properties = set(units_table.colnames) - {"spike_times"}
extracted_properties = set(data_to_add)
properties_to_add_by_rows = units_table_previous_properties | {"id"}
properties_to_add_by_columns = extracted_properties - properties_to_add_by_rows

# Find default values for properties / columns already in the table
type_to_default_value = {list: [], np.ndarray: np.array(np.nan), str: "", Real: np.nan}
property_to_default_values = {"id": None}
for property in units_table_previous_properties:
# Find a matching data type and get the default value
sample_data = units_table[property].data[0]
matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type))
default_value = type_to_default_value[matching_type]
property_to_default_values.update({property: default_value})
units_table_previous_properties = set(units_table.colnames).difference({"spike_times"})
properties_to_add = set(data_to_add)
properties_to_add_by_rows = units_table_previous_properties.union({"id"})
properties_to_add_by_columns = properties_to_add - properties_to_add_by_rows

# Properties that were added before require null values to add by rows if data is missing
properties_requiring_null_values = units_table_previous_properties.difference(properties_to_add)
null_values_for_row = {}
for property in properties_requiring_null_values - {"electrodes"}: # TODO, fix electrodes
sample_data = units_table[property][:][0]
null_value = _get_null_value_for_property(
property=property,
sample_data=sample_data,
null_values_for_properties=null_values_for_properties,
)
null_values_for_row[property] = null_value

# Special case
null_values_for_row["id"] = None

# Add data by rows excluding the rows with previously added unit names
unit_names_used_previously = []
Expand All @@ -1259,7 +1264,7 @@ def add_units_table(
rows_to_add.append(index)

for row in rows_to_add:
unit_kwargs = dict(property_to_default_values)
unit_kwargs = null_values_for_row
for property in properties_with_data:
unit_kwargs[property] = data_to_add[property]["data"][row]
spike_times = []
Expand All @@ -1278,9 +1283,9 @@ def add_units_table(
if unit_electrode_indices is not None:
unit_kwargs["electrodes"] = unit_electrode_indices[row]
units_table.add_unit(spike_times=spike_times, **unit_kwargs, enforce_unique_id=True)
# added_unit_table_ids = units_table.id[-len(rows_to_add) :] # TODO - this line is unused?

# Add unit_name as a column and fill previously existing rows with unit_name equal to str(ids)
unit_table_size = len(units_table.id[:])
previous_table_size = len(units_table.id[:]) - len(unit_name_array)
if "unit_name" in properties_to_add_by_columns:
cols_args = data_to_add["unit_name"]
Expand All @@ -1299,41 +1304,47 @@ def add_units_table(
unit_name: table_df.query(f"unit_name=='{unit_name}'").index[0] for unit_name in unit_name_array
}

indexes_for_new_data = [unit_name_to_electrode_index[unit_name] for unit_name in unit_name_array]
indexes_for_default_values = table_df.index.difference(indexes_for_new_data).values
indices_for_new_data = [unit_name_to_electrode_index[unit_name] for unit_name in unit_name_array]
indices_for_null_values = table_df.index.difference(indices_for_new_data).values
extending_column = len(indices_for_null_values) > 0

# Add properties as columns
for property in properties_to_add_by_columns - {"unit_name"}:
cols_args = data_to_add[property]
data = cols_args["data"]
if np.issubdtype(data.dtype, np.integer):
data = data.astype("float")

# Find first matching data-type
sample_data = data[0]
matching_type = next(type for type in type_to_default_value if isinstance(sample_data, type))
default_value = type_to_default_value[matching_type]
# This is the simple case, early return
if not extending_column:
units_table.add_column(property, **cols_args)
continue

# Extending the columns is done differently for ragged arrays
adding_ragged_array = cols_args["index"]
if not adding_ragged_array:
sample_data = data[0]
dtype = data.dtype
extended_data = np.empty(shape=unit_table_size, dtype=dtype)
extended_data[indices_for_new_data] = data

null_value = _get_null_value_for_property(
property=property,
sample_data=sample_data,
null_values_for_properties=null_values_for_properties,
)
extended_data[indices_for_null_values] = null_value
else:

if "index" in cols_args and cols_args["index"]:
dtype = np.ndarray
extended_data = np.empty(shape=len(units_table.id[:]), dtype=dtype)
extended_data = np.empty(shape=unit_table_size, dtype=dtype)
for index, value in enumerate(data):
index_in_extended_data = indexes_for_new_data[index]
index_in_extended_data = indices_for_new_data[index]
extended_data[index_in_extended_data] = value.tolist()

for index in indexes_for_default_values:
default_value = []
extended_data[index] = default_value

else:
dtype = data.dtype
extended_data = np.empty(shape=len(units_table.id[:]), dtype=dtype)
extended_data[indexes_for_new_data] = data
extended_data[indexes_for_default_values] = default_value

if np.issubdtype(extended_data.dtype, np.object_):
extended_data = extended_data.astype("str", copy=False)
for index in indices_for_null_values:
null_value = []
extended_data[index] = null_value

# Add the data
cols_args["data"] = extended_data
units_table.add_column(property, **cols_args)

Expand Down Expand Up @@ -1561,6 +1572,7 @@ def add_sorting_analyzer(
The name of the units table. If write_as=='units', then units_name must also be 'units'.
units_description : str, default: 'Autogenerated by neuroconv.'
"""

# TODO: move into add_units
assert write_as in [
"units",
Expand Down Expand Up @@ -1669,7 +1681,7 @@ def write_sorting_analyzer(
Controls the unit_ids that will be written to the nwb file. If None (default), all
units are written.
write_electrical_series : bool, default: False
If True, the recording object associated to t is written as an electrical series.
If True, the recording object associated to the analyzer is written as an electrical series.
add_electrical_series_kwargs: dict, optional
Keyword arguments to control the `add_electrical_series()` function in case write_electrical_series=True
property_descriptions: dict, optional
Expand All @@ -1688,16 +1700,17 @@ def write_sorting_analyzer(
"""
metadata = metadata if metadata is not None else dict()

if sorting_analyzer.has_recording():
recording = sorting_analyzer.recording
assert recording is not None, (
"recording not found. To add the electrode table, the sorting_analyzer "
"needs to have a recording attached or the 'recording' argument needs to be used."
)

# try:
with make_or_load_nwbfile(
nwbfile_path=nwbfile_path, nwbfile=nwbfile, metadata=metadata, overwrite=overwrite, verbose=verbose
) as nwbfile_out:
if sorting_analyzer.has_recording():
recording = sorting_analyzer.recording
assert recording is not None, (
"recording not found. To add the electrode table, the sorting_analyzer "
"needs to have a recording attached or the 'recording' argument needs to be used."
)

if write_electrical_series:
add_electrical_series_kwargs = add_electrical_series_kwargs or dict()
Expand All @@ -1719,6 +1732,7 @@ def write_sorting_analyzer(
)


# TODO: Remove February 2025
def write_waveforms(
waveform_extractor,
nwbfile_path: Optional[FilePathType] = None,
Expand Down Expand Up @@ -1746,6 +1760,9 @@ def write_waveforms(
)


# TODO: Remove February 2025


def add_waveforms(
waveform_extractor,
nwbfile: Optional[pynwb.NWBFile] = None,
Expand Down
Loading

0 comments on commit 40d786a

Please sign in to comment.