Skip to content

Commit

Permalink
Add validation/function to Delay and Duration tags.
Browse files Browse the repository at this point in the history
Related minor cleanup/bug fixes/reorg
  • Loading branch information
IanCa committed Mar 19, 2024
1 parent 7286dd6 commit d4bea7c
Show file tree
Hide file tree
Showing 21 changed files with 425 additions and 243 deletions.
46 changes: 9 additions & 37 deletions hed/models/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import os

import openpyxl
import pandas
import pandas as pd

from hed.models.column_mapper import ColumnMapper
from hed.errors.exceptions import HedFileError, HedExceptions
import pandas as pd

from hed.models.df_util import _handle_curly_braces_refs
from hed.models.df_util import _handle_curly_braces_refs, filter_series_by_onset


class BaseInput:
Expand Down Expand Up @@ -118,37 +117,10 @@ def series_filtered(self):
"""Return the assembled dataframe as a series, with rows that have the same onset combined.
Returns:
Series: the assembled dataframe with columns merged, and the rows filtered together.
Series or None: the assembled dataframe with columns merged, and the rows filtered together.
"""
if self.onsets is not None:
indexed_dict = self._indexed_dict_from_onsets(self.onsets.astype(float))
return self._filter_by_index_list(self.series_a, indexed_dict=indexed_dict)

@staticmethod
def _indexed_dict_from_onsets(onsets):
current_onset = -1000000.0
tol = 1e-9
from collections import defaultdict
indexed_dict = defaultdict(list)
for i, onset in enumerate(onsets):
if abs(onset - current_onset) > tol:
current_onset = onset
indexed_dict[current_onset].append(i)

return indexed_dict

# This would need to store the index list -> So it can optionally apply to other columns on request.
@staticmethod
def _filter_by_index_list(original_series, indexed_dict):
new_series = pd.Series([""] * len(original_series), dtype=str)

for onset, indices in indexed_dict.items():
if indices:
first_index = indices[0] # Take the first index of each onset group
# Join the corresponding original series entries and place them at the first index
new_series[first_index] = ",".join([str(original_series[i]) for i in indices])

return new_series
return filter_series_by_onset(self.series_a, self.onsets)

@property
def onsets(self):
Expand All @@ -161,7 +133,7 @@ def needs_sorting(self):
"""Return True if this both has an onset column, and it needs sorting."""
onsets = self.onsets
if onsets is not None:
onsets = onsets.astype(float)
onsets = pd.to_numeric(self.dataframe['onset'], errors='coerce')
return not onsets.is_monotonic_increasing

@property
Expand Down Expand Up @@ -369,9 +341,9 @@ def _get_dataframe_from_worksheet(worksheet, has_headers):
# first row is columns
cols = next(data)
data = list(data)
return pandas.DataFrame(data, columns=cols, dtype=str)
return pd.DataFrame(data, columns=cols, dtype=str)
else:
return pandas.DataFrame(worksheet.values, dtype=str)
return pd.DataFrame(worksheet.values, dtype=str)

def validate(self, hed_schema, extra_def_dicts=None, name=None, error_handler=None):
"""Creates a SpreadsheetValidator and returns all issues with this file.
Expand Down Expand Up @@ -483,14 +455,14 @@ def _open_dataframe_file(self, file, has_column_names, input_type):
if not has_column_names:
pandas_header = None

if isinstance(file, pandas.DataFrame):
if isinstance(file, pd.DataFrame):
self._dataframe = file.astype(str)
self._has_column_names = self._dataframe_has_names(self._dataframe)
elif not file:
raise HedFileError(HedExceptions.FILE_NOT_FOUND, "Empty file passed to BaseInput.", file)
elif input_type in self.TEXT_EXTENSION:
try:
self._dataframe = pandas.read_csv(file, delimiter='\t', header=pandas_header,
self._dataframe = pd.read_csv(file, delimiter='\t', header=pandas_header,
dtype=str, keep_default_na=True, na_values=("", "null"))
except Exception as e:
raise HedFileError(HedExceptions.INVALID_FILE_FORMAT, str(e), self.name) from e
Expand Down
1 change: 0 additions & 1 deletion hed/models/definition_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, def_dicts=None, hed_schema=None):
"""

self.defs = {}
self._label_tag_name = DefTagNames.DEF_KEY
self._issues = []
if def_dicts:
self.add_definitions(def_dicts, hed_schema)
Expand Down
104 changes: 99 additions & 5 deletions hed/models/df_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from functools import partial
import pandas as pd
from hed.models.hed_string import HedString
from hed.models.model_constants import DefTagNames


def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True, return_filtered=False):
def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True):
""" Create an array of assembled HedString objects (or list of these) of the same length as tabular file input.
Parameters:
Expand All @@ -14,16 +15,14 @@ def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=
extra_def_dicts: list of DefinitionDict, optional
Any extra DefinitionDict objects to use when parsing the HED tags.
defs_expanded (bool): (Default True) Expands definitions if True, otherwise shrinks them.
return_filtered (bool): If true, combines lines with the same onset.
Further lines with that onset are marked n/a
Returns:
tuple:
hed_strings(list of HedStrings): A list of HedStrings
def_dict(DefinitionDict): The definitions from this Sidecar.
"""

def_dict = tabular_file.get_def_dict(hed_schema, extra_def_dicts=extra_def_dicts)
series_a = tabular_file.series_a if not return_filtered else tabular_file.series_filtered
series_a = tabular_file.series_a
if defs_expanded:
return [HedString(x, hed_schema, def_dict).expand_defs() for x in series_a], def_dict
else:
Expand Down Expand Up @@ -217,7 +216,102 @@ def _handle_curly_braces_refs(df, refs, column_names):
# df[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y
# in zip(df[column_name], saved_columns[replacing_name]))
new_df[column_name] = pd.Series(replace_ref(x, y, replacing_name) for x, y
in zip(new_df[column_name], saved_columns[replacing_name]))
in zip(new_df[column_name], saved_columns[replacing_name]))
new_df = new_df[remaining_columns]

return new_df


# todo: Consider updating this to be a pure string function(or at least, only instantiating the Duration tags)
def split_delay_tags(series, hed_schema, onsets):
"""Sorts the series based on Delay tags, so that the onsets are in order after delay is applied.
Parameters:
series(pd.Series or None): the series of tags to split/sort
hed_schema(HedSchema): The schema to use to identify tags
onsets(pd.Series or None)
Returns:
sorted_df(pd.Dataframe or None): If we had onsets, a dataframe with 3 columns
"HED": The hed strings(still str)
"onset": the updated onsets
"original_index": the original source line. Multiple lines can have the same original source line.
Note: This dataframe may be longer than the original series, but it will never be shorter.
"""
if series is None or onsets is None:
return
split_df = pd.DataFrame({"onset": onsets, "HED": series, "original_index": series.index})
delay_strings = [(i, HedString(hed_string, hed_schema)) for (i, hed_string) in series.items() if
"delay/" in hed_string.lower()]
delay_groups = []
for i, delay_string in delay_strings:
duration_tags = delay_string.find_top_level_tags({DefTagNames.DELAY_KEY})
to_remove = []
for tag, group in duration_tags:
onset_mod = tag.value_as_default_unit() + float(onsets[i])
to_remove.append(group)
insert_index = split_df['original_index'].index.max() + 1
split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i}
delay_string.remove(to_remove)
# update the old string with the removals done
split_df.at[i, "HED"] = str(delay_string)

for i, onset_mod, group in delay_groups:
insert_index = split_df['original_index'].index.max() + 1
split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i}
split_df = sort_dataframe_by_onsets(split_df)
split_df.reset_index(drop=True, inplace=True)

split_df = filter_series_by_onset(split_df, split_df.onset)
return split_df


def filter_series_by_onset(series, onsets):
"""Return the series, with rows that have the same onset combined.
Parameters:
series(pd.Series or pd.Dataframe): the series to filter. If dataframe, it filters the "HED" column
onsets(pd.Series): the onset column to filter by
Returns:
Series or Dataframe: the series with rows filtered together.
"""
indexed_dict = _indexed_dict_from_onsets(onsets.astype(float))
return _filter_by_index_list(series, indexed_dict=indexed_dict)


def _indexed_dict_from_onsets(onsets):
"""Finds series of consecutive lines with the same(or close enough) onset"""
current_onset = -1000000.0
tol = 1e-9
from collections import defaultdict
indexed_dict = defaultdict(list)
for i, onset in enumerate(onsets):
if abs(onset - current_onset) > tol:
current_onset = onset
indexed_dict[current_onset].append(i)

return indexed_dict


def _filter_by_index_list(original_data, indexed_dict):
"""Filters a series or dataframe by the indexed_dict, joining lines as indicated"""
if isinstance(original_data, pd.Series):
data_series = original_data
elif isinstance(original_data, pd.DataFrame):
data_series = original_data["HED"]
else:
raise TypeError("Input must be a pandas Series or DataFrame")

new_series = pd.Series([""] * len(data_series), dtype=str)
for onset, indices in indexed_dict.items():
if indices:
first_index = indices[0]
new_series[first_index] = ",".join([str(data_series[i]) for i in indices])

if isinstance(original_data, pd.Series):
return new_series
else:
result_df = original_data.copy()
result_df["HED"] = new_series
return result_df
12 changes: 7 additions & 5 deletions hed/models/hed_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" A single parenthesized HED string. """
from hed.models.hed_tag import HedTag
from hed.models.model_constants import DefTagNames
import copy
from typing import Iterable, Union

Expand Down Expand Up @@ -441,7 +442,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
tags = self.get_all_tags()
else:
tags = self.tags()

search_tags = {tag.lower() for tag in search_tags}
for tag in tags:
if tag.short_base_tag.lower() in search_tags:
found_tags.append((tag, tag._parent))
Expand All @@ -453,7 +454,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
""" Find the tags and their containing groups.
This searches tag.short_tag, with an implicit wildcard on the end.
This searches tag.short_tag.lower(), with an implicit wildcard on the end.
e.g. "Eve" will find Event, but not Sensory-event.
Expand All @@ -474,6 +475,8 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
else:
tags = self.tags()

search_tags = {search_tag.lower() for search_tag in search_tags}

for tag in tags:
for search_tag in search_tags:
if tag.short_tag.lower().startswith(search_tag):
Expand Down Expand Up @@ -539,15 +542,14 @@ def find_def_tags(self, recursive=False, include_groups=3):

@staticmethod
def _get_def_tags_from_group(group):
from hed.models.definition_dict import DefTagNames
def_tags = []
for child in group.children:
if isinstance(child, HedTag):
if child.short_base_tag == DefTagNames.DEF_ORG_KEY:
if child.short_base_tag == DefTagNames.DEF_KEY:
def_tags.append((child, child, group))
else:
for tag in child.tags():
if tag.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY:
if tag.short_base_tag == DefTagNames.DEF_EXPAND_KEY:
def_tags.append((tag, child, group))
return def_tags

Expand Down
26 changes: 2 additions & 24 deletions hed/models/hed_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def shrink_defs(self):
for def_expand_tag, def_expand_group in self.find_tags({DefTagNames.DEF_EXPAND_KEY}, recursive=True):
expanded_parent = def_expand_group._parent
if expanded_parent:
def_expand_tag.short_base_tag = DefTagNames.DEF_ORG_KEY
def_expand_tag.short_base_tag = DefTagNames.DEF_KEY
def_expand_tag._parent = expanded_parent
expanded_parent.replace(def_expand_group, def_expand_tag)

Expand Down Expand Up @@ -353,6 +353,7 @@ def find_top_level_tags(self, anchor_tags, include_groups=2):
Returns:
list: The returned result depends on include_groups.
"""
anchor_tags = {tag.lower() for tag in anchor_tags}
top_level_tags = []
for group in self.groups():
for tag in group.tags():
Expand All @@ -365,29 +366,6 @@ def find_top_level_tags(self, anchor_tags, include_groups=2):
return [tag[include_groups] for tag in top_level_tags]
return top_level_tags

def find_top_level_tags_grouped(self, anchor_tags):
""" Find top level groups with an anchor tag.
This is an alternate one designed to be easy to use with Delay/Duration tag.
Parameters:
anchor_tags (container): A list/set/etc. of short_base_tags to find groups by.
Returns:
list of tuples:
list of tags: the tags in the same subgroup
group: the subgroup containing the tags
"""
top_level_tags = []
for group in self.groups():
tags = []
for tag in group.tags():
if tag.short_base_tag.lower() in anchor_tags:
tags.append(tag)
if tags:
top_level_tags.append((tags, group))

return top_level_tags

def remove_refs(self):
""" Remove any refs(tags contained entirely inside curly braces) from the string.
Expand Down
8 changes: 4 additions & 4 deletions hed/models/hed_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, hed_string, hed_schema, span=None, def_dict=None):

self._def_entry = None
if def_dict:
if self.short_base_tag in {DefTagNames.DEF_ORG_KEY, DefTagNames.DEF_EXPAND_ORG_KEY}:
if self.short_base_tag in {DefTagNames.DEF_KEY, DefTagNames.DEF_EXPAND_KEY}:
self._def_entry = def_dict.get_definition_entry(self)

def copy(self):
Expand Down Expand Up @@ -277,7 +277,7 @@ def expandable(self):
self._parent = save_parent
if def_contents is not None:
self._expandable = def_contents
self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY
self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_KEY
return self._expandable

def is_column_ref(self):
Expand Down Expand Up @@ -621,12 +621,12 @@ def __eq__(self, other):
return True

if isinstance(other, str):
return self.lower() == other
return self.lower() == other.lower()

if not isinstance(other, HedTag):
return False

if self.short_tag.lower() == other.short_tag.lower():
if self.short_tag == other.short_tag:
return True

if self.org_tag.lower() == other.org_tag.lower():
Expand Down
Loading

0 comments on commit d4bea7c

Please sign in to comment.