From d4bea7c8612208234ac5dcb247a88dde345dd836 Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 18 Mar 2024 20:35:49 -0500 Subject: [PATCH] Add validation/function to Delay and Duration tags. Related minor cleanup/bug fixes/reorg --- hed/models/base_input.py | 46 +---- hed/models/definition_dict.py | 1 - hed/models/df_util.py | 104 +++++++++++- hed/models/hed_group.py | 12 +- hed/models/hed_string.py | 26 +-- hed/models/hed_tag.py | 8 +- hed/models/model_constants.py | 26 +-- hed/models/string_util.py | 2 +- hed/tools/analysis/event_manager.py | 23 +-- hed/tools/analysis/temporal_event.py | 7 +- hed/validator/def_validator.py | 13 +- hed/validator/hed_validator.py | 6 +- hed/validator/onset_validator.py | 34 +--- hed/validator/spreadsheet_validator.py | 62 ++++--- hed/validator/tag_util/group_util.py | 56 +++++- tests/models/test_base_input.py | 59 ------- tests/models/test_df_util.py | 159 +++++++++++++++++- tests/validator/test_onset_validator.py | 5 +- tests/validator/test_spreadsheet_validator.py | 3 +- tests/validator/test_tag_validator.py | 10 +- tests/validator/test_tag_validator_base.py | 6 +- 21 files changed, 425 insertions(+), 243 deletions(-) diff --git a/hed/models/base_input.py b/hed/models/base_input.py index 024f8e27..f77278ae 100644 --- a/hed/models/base_input.py +++ b/hed/models/base_input.py @@ -4,13 +4,12 @@ import os import openpyxl -import pandas +import pandas as pd from hed.models.column_mapper import ColumnMapper from hed.errors.exceptions import HedFileError, HedExceptions -import pandas as pd -from hed.models.df_util import _handle_curly_braces_refs +from hed.models.df_util import _handle_curly_braces_refs, filter_series_by_onset class BaseInput: @@ -118,37 +117,10 @@ def series_filtered(self): """Return the assembled dataframe as a series, with rows that have the same onset combined. Returns: - Series: the assembled dataframe with columns merged, and the rows filtered together. + Series or None: the assembled dataframe with columns merged, and the rows filtered together. """ if self.onsets is not None: - indexed_dict = self._indexed_dict_from_onsets(self.onsets.astype(float)) - return self._filter_by_index_list(self.series_a, indexed_dict=indexed_dict) - - @staticmethod - def _indexed_dict_from_onsets(onsets): - current_onset = -1000000.0 - tol = 1e-9 - from collections import defaultdict - indexed_dict = defaultdict(list) - for i, onset in enumerate(onsets): - if abs(onset - current_onset) > tol: - current_onset = onset - indexed_dict[current_onset].append(i) - - return indexed_dict - - # This would need to store the index list -> So it can optionally apply to other columns on request. - @staticmethod - def _filter_by_index_list(original_series, indexed_dict): - new_series = pd.Series([""] * len(original_series), dtype=str) - - for onset, indices in indexed_dict.items(): - if indices: - first_index = indices[0] # Take the first index of each onset group - # Join the corresponding original series entries and place them at the first index - new_series[first_index] = ",".join([str(original_series[i]) for i in indices]) - - return new_series + return filter_series_by_onset(self.series_a, self.onsets) @property def onsets(self): @@ -161,7 +133,7 @@ def needs_sorting(self): """Return True if this both has an onset column, and it needs sorting.""" onsets = self.onsets if onsets is not None: - onsets = onsets.astype(float) + onsets = pd.to_numeric(self.dataframe['onset'], errors='coerce') return not onsets.is_monotonic_increasing @property @@ -369,9 +341,9 @@ def _get_dataframe_from_worksheet(worksheet, has_headers): # first row is columns cols = next(data) data = list(data) - return pandas.DataFrame(data, columns=cols, dtype=str) + return pd.DataFrame(data, columns=cols, dtype=str) else: - return pandas.DataFrame(worksheet.values, dtype=str) + return pd.DataFrame(worksheet.values, dtype=str) def validate(self, hed_schema, extra_def_dicts=None, name=None, error_handler=None): """Creates a SpreadsheetValidator and returns all issues with this file. @@ -483,14 +455,14 @@ def _open_dataframe_file(self, file, has_column_names, input_type): if not has_column_names: pandas_header = None - if isinstance(file, pandas.DataFrame): + if isinstance(file, pd.DataFrame): self._dataframe = file.astype(str) self._has_column_names = self._dataframe_has_names(self._dataframe) elif not file: raise HedFileError(HedExceptions.FILE_NOT_FOUND, "Empty file passed to BaseInput.", file) elif input_type in self.TEXT_EXTENSION: try: - self._dataframe = pandas.read_csv(file, delimiter='\t', header=pandas_header, + self._dataframe = pd.read_csv(file, delimiter='\t', header=pandas_header, dtype=str, keep_default_na=True, na_values=("", "null")) except Exception as e: raise HedFileError(HedExceptions.INVALID_FILE_FORMAT, str(e), self.name) from e diff --git a/hed/models/definition_dict.py b/hed/models/definition_dict.py index 234cd5b9..60424be9 100644 --- a/hed/models/definition_dict.py +++ b/hed/models/definition_dict.py @@ -23,7 +23,6 @@ def __init__(self, def_dicts=None, hed_schema=None): """ self.defs = {} - self._label_tag_name = DefTagNames.DEF_KEY self._issues = [] if def_dicts: self.add_definitions(def_dicts, hed_schema) diff --git a/hed/models/df_util.py b/hed/models/df_util.py index 0364c539..f68e5791 100644 --- a/hed/models/df_util.py +++ b/hed/models/df_util.py @@ -3,9 +3,10 @@ from functools import partial import pandas as pd from hed.models.hed_string import HedString +from hed.models.model_constants import DefTagNames -def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True, return_filtered=False): +def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded=True): """ Create an array of assembled HedString objects (or list of these) of the same length as tabular file input. Parameters: @@ -14,8 +15,6 @@ def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded= extra_def_dicts: list of DefinitionDict, optional Any extra DefinitionDict objects to use when parsing the HED tags. defs_expanded (bool): (Default True) Expands definitions if True, otherwise shrinks them. - return_filtered (bool): If true, combines lines with the same onset. - Further lines with that onset are marked n/a Returns: tuple: hed_strings(list of HedStrings): A list of HedStrings @@ -23,7 +22,7 @@ def get_assembled(tabular_file, hed_schema, extra_def_dicts=None, defs_expanded= """ def_dict = tabular_file.get_def_dict(hed_schema, extra_def_dicts=extra_def_dicts) - series_a = tabular_file.series_a if not return_filtered else tabular_file.series_filtered + series_a = tabular_file.series_a if defs_expanded: return [HedString(x, hed_schema, def_dict).expand_defs() for x in series_a], def_dict else: @@ -217,7 +216,102 @@ def _handle_curly_braces_refs(df, refs, column_names): # df[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y # in zip(df[column_name], saved_columns[replacing_name])) new_df[column_name] = pd.Series(replace_ref(x, y, replacing_name) for x, y - in zip(new_df[column_name], saved_columns[replacing_name])) + in zip(new_df[column_name], saved_columns[replacing_name])) new_df = new_df[remaining_columns] return new_df + + +# todo: Consider updating this to be a pure string function(or at least, only instantiating the Duration tags) +def split_delay_tags(series, hed_schema, onsets): + """Sorts the series based on Delay tags, so that the onsets are in order after delay is applied. + + Parameters: + series(pd.Series or None): the series of tags to split/sort + hed_schema(HedSchema): The schema to use to identify tags + onsets(pd.Series or None) + + Returns: + sorted_df(pd.Dataframe or None): If we had onsets, a dataframe with 3 columns + "HED": The hed strings(still str) + "onset": the updated onsets + "original_index": the original source line. Multiple lines can have the same original source line. + + Note: This dataframe may be longer than the original series, but it will never be shorter. + """ + if series is None or onsets is None: + return + split_df = pd.DataFrame({"onset": onsets, "HED": series, "original_index": series.index}) + delay_strings = [(i, HedString(hed_string, hed_schema)) for (i, hed_string) in series.items() if + "delay/" in hed_string.lower()] + delay_groups = [] + for i, delay_string in delay_strings: + duration_tags = delay_string.find_top_level_tags({DefTagNames.DELAY_KEY}) + to_remove = [] + for tag, group in duration_tags: + onset_mod = tag.value_as_default_unit() + float(onsets[i]) + to_remove.append(group) + insert_index = split_df['original_index'].index.max() + 1 + split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i} + delay_string.remove(to_remove) + # update the old string with the removals done + split_df.at[i, "HED"] = str(delay_string) + + for i, onset_mod, group in delay_groups: + insert_index = split_df['original_index'].index.max() + 1 + split_df.loc[insert_index] = {'HED': str(group), 'onset': onset_mod, 'original_index': i} + split_df = sort_dataframe_by_onsets(split_df) + split_df.reset_index(drop=True, inplace=True) + + split_df = filter_series_by_onset(split_df, split_df.onset) + return split_df + + +def filter_series_by_onset(series, onsets): + """Return the series, with rows that have the same onset combined. + + Parameters: + series(pd.Series or pd.Dataframe): the series to filter. If dataframe, it filters the "HED" column + onsets(pd.Series): the onset column to filter by + Returns: + Series or Dataframe: the series with rows filtered together. + """ + indexed_dict = _indexed_dict_from_onsets(onsets.astype(float)) + return _filter_by_index_list(series, indexed_dict=indexed_dict) + + +def _indexed_dict_from_onsets(onsets): + """Finds series of consecutive lines with the same(or close enough) onset""" + current_onset = -1000000.0 + tol = 1e-9 + from collections import defaultdict + indexed_dict = defaultdict(list) + for i, onset in enumerate(onsets): + if abs(onset - current_onset) > tol: + current_onset = onset + indexed_dict[current_onset].append(i) + + return indexed_dict + + +def _filter_by_index_list(original_data, indexed_dict): + """Filters a series or dataframe by the indexed_dict, joining lines as indicated""" + if isinstance(original_data, pd.Series): + data_series = original_data + elif isinstance(original_data, pd.DataFrame): + data_series = original_data["HED"] + else: + raise TypeError("Input must be a pandas Series or DataFrame") + + new_series = pd.Series([""] * len(data_series), dtype=str) + for onset, indices in indexed_dict.items(): + if indices: + first_index = indices[0] + new_series[first_index] = ",".join([str(data_series[i]) for i in indices]) + + if isinstance(original_data, pd.Series): + return new_series + else: + result_df = original_data.copy() + result_df["HED"] = new_series + return result_df diff --git a/hed/models/hed_group.py b/hed/models/hed_group.py index 0a88f56d..842f6369 100644 --- a/hed/models/hed_group.py +++ b/hed/models/hed_group.py @@ -1,5 +1,6 @@ """ A single parenthesized HED string. """ from hed.models.hed_tag import HedTag +from hed.models.model_constants import DefTagNames import copy from typing import Iterable, Union @@ -441,7 +442,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2): tags = self.get_all_tags() else: tags = self.tags() - + search_tags = {tag.lower() for tag in search_tags} for tag in tags: if tag.short_base_tag.lower() in search_tags: found_tags.append((tag, tag._parent)) @@ -453,7 +454,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2): def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2): """ Find the tags and their containing groups. - This searches tag.short_tag, with an implicit wildcard on the end. + This searches tag.short_tag.lower(), with an implicit wildcard on the end. e.g. "Eve" will find Event, but not Sensory-event. @@ -474,6 +475,8 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2): else: tags = self.tags() + search_tags = {search_tag.lower() for search_tag in search_tags} + for tag in tags: for search_tag in search_tags: if tag.short_tag.lower().startswith(search_tag): @@ -539,15 +542,14 @@ def find_def_tags(self, recursive=False, include_groups=3): @staticmethod def _get_def_tags_from_group(group): - from hed.models.definition_dict import DefTagNames def_tags = [] for child in group.children: if isinstance(child, HedTag): - if child.short_base_tag == DefTagNames.DEF_ORG_KEY: + if child.short_base_tag == DefTagNames.DEF_KEY: def_tags.append((child, child, group)) else: for tag in child.tags(): - if tag.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY: + if tag.short_base_tag == DefTagNames.DEF_EXPAND_KEY: def_tags.append((tag, child, group)) return def_tags diff --git a/hed/models/hed_string.py b/hed/models/hed_string.py index a15600a3..9af387c3 100644 --- a/hed/models/hed_string.py +++ b/hed/models/hed_string.py @@ -129,7 +129,7 @@ def shrink_defs(self): for def_expand_tag, def_expand_group in self.find_tags({DefTagNames.DEF_EXPAND_KEY}, recursive=True): expanded_parent = def_expand_group._parent if expanded_parent: - def_expand_tag.short_base_tag = DefTagNames.DEF_ORG_KEY + def_expand_tag.short_base_tag = DefTagNames.DEF_KEY def_expand_tag._parent = expanded_parent expanded_parent.replace(def_expand_group, def_expand_tag) @@ -353,6 +353,7 @@ def find_top_level_tags(self, anchor_tags, include_groups=2): Returns: list: The returned result depends on include_groups. """ + anchor_tags = {tag.lower() for tag in anchor_tags} top_level_tags = [] for group in self.groups(): for tag in group.tags(): @@ -365,29 +366,6 @@ def find_top_level_tags(self, anchor_tags, include_groups=2): return [tag[include_groups] for tag in top_level_tags] return top_level_tags - def find_top_level_tags_grouped(self, anchor_tags): - """ Find top level groups with an anchor tag. - - This is an alternate one designed to be easy to use with Delay/Duration tag. - - Parameters: - anchor_tags (container): A list/set/etc. of short_base_tags to find groups by. - Returns: - list of tuples: - list of tags: the tags in the same subgroup - group: the subgroup containing the tags - """ - top_level_tags = [] - for group in self.groups(): - tags = [] - for tag in group.tags(): - if tag.short_base_tag.lower() in anchor_tags: - tags.append(tag) - if tags: - top_level_tags.append((tags, group)) - - return top_level_tags - def remove_refs(self): """ Remove any refs(tags contained entirely inside curly braces) from the string. diff --git a/hed/models/hed_tag.py b/hed/models/hed_tag.py index 63808bd8..83bd9959 100644 --- a/hed/models/hed_tag.py +++ b/hed/models/hed_tag.py @@ -49,7 +49,7 @@ def __init__(self, hed_string, hed_schema, span=None, def_dict=None): self._def_entry = None if def_dict: - if self.short_base_tag in {DefTagNames.DEF_ORG_KEY, DefTagNames.DEF_EXPAND_ORG_KEY}: + if self.short_base_tag in {DefTagNames.DEF_KEY, DefTagNames.DEF_EXPAND_KEY}: self._def_entry = def_dict.get_definition_entry(self) def copy(self): @@ -277,7 +277,7 @@ def expandable(self): self._parent = save_parent if def_contents is not None: self._expandable = def_contents - self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY + self._expanded = self.short_base_tag == DefTagNames.DEF_EXPAND_KEY return self._expandable def is_column_ref(self): @@ -621,12 +621,12 @@ def __eq__(self, other): return True if isinstance(other, str): - return self.lower() == other + return self.lower() == other.lower() if not isinstance(other, HedTag): return False - if self.short_tag.lower() == other.short_tag.lower(): + if self.short_tag == other.short_tag: return True if self.org_tag.lower() == other.org_tag.lower(): diff --git a/hed/models/model_constants.py b/hed/models/model_constants.py index 06317cd0..d8636a1e 100644 --- a/hed/models/model_constants.py +++ b/hed/models/model_constants.py @@ -2,25 +2,15 @@ class DefTagNames: """ Source names for definitions, def labels, and expanded labels. """ - DEF_ORG_KEY = 'Def' - DEF_EXPAND_ORG_KEY = 'Def-expand' - DEFINITION_ORG_KEY = "Definition" - DEF_KEY = DEF_ORG_KEY.lower() - DEF_EXPAND_KEY = DEF_EXPAND_ORG_KEY.lower() - DEFINITION_KEY = DEFINITION_ORG_KEY.lower() - DEF_KEYS = (DEF_KEY, DEF_EXPAND_KEY) + DEF_KEY = 'Def' + DEF_EXPAND_KEY = 'Def-expand' + DEFINITION_KEY = "Definition" - ONSET_ORG_KEY = "Onset" - OFFSET_ORG_KEY = "Offset" - INSET_ORG_KEY = "Inset" - DURATION_ORG_KEY = "Duration" - DELAY_ORG_KEY = "Delay" - - ONSET_KEY = ONSET_ORG_KEY.lower() - OFFSET_KEY = OFFSET_ORG_KEY.lower() - INSET_KEY = INSET_ORG_KEY.lower() - DURATION_KEY = DURATION_ORG_KEY.lower() - DELAY_KEY = DELAY_ORG_KEY.lower() + ONSET_KEY = "Onset" + OFFSET_KEY = "Offset" + INSET_KEY = "Inset" + DURATION_KEY = "Duration" + DELAY_KEY = "Delay" TEMPORAL_KEYS = {ONSET_KEY, OFFSET_KEY, INSET_KEY} DURATION_KEYS = {DURATION_KEY, DELAY_KEY} diff --git a/hed/models/string_util.py b/hed/models/string_util.py index 30916934..73242490 100644 --- a/hed/models/string_util.py +++ b/hed/models/string_util.py @@ -15,7 +15,7 @@ def gather_descriptions(hed_string): The input HedString has its description tags removed. """ - desc_tags = hed_string.find_tags("description", recursive=True, include_groups=0) + desc_tags = hed_string.find_tags({"description"}, recursive=True, include_groups=0) desc_string = " ".join([tag.extension if tag.extension.endswith(".") else tag.extension + "." for tag in desc_tags]) hed_string.remove(desc_tags) diff --git a/hed/tools/analysis/event_manager.py b/hed/tools/analysis/event_manager.py index fb7800e6..645ff450 100644 --- a/hed/tools/analysis/event_manager.py +++ b/hed/tools/analysis/event_manager.py @@ -5,7 +5,7 @@ from hed.errors import HedFileError from hed.models import HedString from hed.models.model_constants import DefTagNames -from hed.models.df_util import get_assembled +from hed.models import df_util from hed.models.string_util import split_base_tags, split_def_tags from hed.tools.analysis.temporal_event import TemporalEvent from hed.tools.analysis.hed_type_defs import HedTypeDefs @@ -29,16 +29,14 @@ def __init__(self, input_data, hed_schema, extra_defs=None): are separated from the rest of the annotations, which are contained in self.hed_strings. """ - - self.event_list = [[] for _ in range(len(input_data.dataframe))] self.hed_schema = hed_schema self.input_data = input_data self.def_dict = input_data.get_def_dict(hed_schema, extra_def_dicts=extra_defs) - onsets = pd.to_numeric(input_data.dataframe['onset'], errors='coerce') - if not onsets.is_monotonic_increasing: + if self.input_data.needs_sorting: raise HedFileError("OnsetsNotOrdered", "The onset values must be non-decreasing", "") - self.onsets = onsets.tolist() - self.hed_strings = None # Remaining HED strings copy.deepcopy(hed_strings) + self.onsets = None + self.hed_strings = None + self.event_list = None self._create_event_list(input_data) def _create_event_list(self, input_data): @@ -53,8 +51,13 @@ def _create_event_list(self, input_data): Notes: """ - hed_strings, def_dict = get_assembled(input_data, self.hed_schema, extra_def_dicts=None, defs_expanded=False, - return_filtered=True) + hed_strings = input_data.series_a + df_util.shrink_defs(hed_strings, self.hed_schema) + delay_df = df_util.split_delay_tags(hed_strings, self.hed_schema, input_data.onsets) + + hed_strings = [HedString(hed_string, self.hed_schema) for hed_string in delay_df.HED] + self.onsets = pd.to_numeric(delay_df.onset, errors='coerce') + self.event_list = [[] for _ in range(len(hed_strings))] onset_dict = {} # Temporary dictionary keeping track of temporal events that haven't ended yet. for event_index, hed in enumerate(hed_strings): self._extract_temporal_events(hed, event_index, onset_dict) @@ -99,7 +102,7 @@ def _extract_temporal_events(self, hed, event_index, onset_dict): for def_tag, group in group_tuples: anchor_tag = group.find_def_tags(recursive=False, include_groups=0)[0] anchor = anchor_tag.extension.lower() - if anchor in onset_dict or def_tag.short_base_tag.lower() == DefTagNames.OFFSET_KEY: + if anchor in onset_dict or def_tag.short_base_tag == DefTagNames.OFFSET_KEY: temporal_event = onset_dict.pop(anchor) temporal_event.set_end(event_index, self.onsets[event_index]) if def_tag == DefTagNames.ONSET_KEY: diff --git a/hed/tools/analysis/temporal_event.py b/hed/tools/analysis/temporal_event.py index a514b511..f32f632f 100644 --- a/hed/tools/analysis/temporal_event.py +++ b/hed/tools/analysis/temporal_event.py @@ -1,5 +1,6 @@ """ A single event process with starting and ending times. """ from hed.models import HedGroup +from hed.models.model_constants import DefTagNames class TemporalEvent: @@ -36,12 +37,12 @@ def _split_group(self, contents): for item in contents.children: if isinstance(item, HedGroup): self.internal_group = item - elif item.short_base_tag.lower() == "onset": + elif item.short_base_tag == DefTagNames.ONSET_KEY: to_remove.append(item) - elif item.short_base_tag.lower() == "duration": + elif item.short_base_tag == DefTagNames.DURATION_KEY: to_remove.append(item) self.end_time = self.start_time + item.value_as_default_unit() - elif item.short_base_tag.lower() == "def": + elif item.short_base_tag == DefTagNames.DEF_KEY: self.anchor = item.short_tag contents.remove(to_remove) if self.internal_group: diff --git a/hed/validator/def_validator.py b/hed/validator/def_validator.py index b6fbe4aa..953a5f92 100644 --- a/hed/validator/def_validator.py +++ b/hed/validator/def_validator.py @@ -1,4 +1,5 @@ from hed.models.hed_group import HedGroup +from hed.models.hed_tag import HedTag from hed.models.definition_dict import DefinitionDict from hed.errors.error_types import ValidationErrors from hed.errors.error_reporter import ErrorHandler @@ -29,10 +30,6 @@ def validate_def_tags(self, hed_string_obj, hed_validator=None): Returns: list: Issues found related to validating defs. Each issue is a dictionary. """ - hed_string_lower = hed_string_obj.lower() - if self._label_tag_name not in hed_string_lower: - return [] - # This is needed primarily to validate the contents of a def-expand matches the default. def_issues = [] # We need to check for labels to expand in ALL groups @@ -104,7 +101,7 @@ def _validate_def_contents(self, def_tag, def_expand_group, hed_validator): def validate_def_value_units(self, def_tag, hed_validator): """Equivalent to HedValidator.validate_units for the special case of a Def or Def-expand tag""" tag_label, _, placeholder = def_tag.extension.partition('/') - is_def_expand_tag = def_tag.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY + is_def_expand_tag = def_tag.short_base_tag == DefTagNames.DEF_EXPAND_KEY def_entry = self.defs.get(tag_label.lower()) # These errors will be caught as can't match definition @@ -167,8 +164,12 @@ def validate_onset_offset(self, hed_string_obj): def_group = def_tag children = [child for child in found_group.children if def_group is not child and found_onset is not child] + + # Delay tag is checked for uniqueness elsewhere, so we can safely remove all of them + children = [child for child in children + if not isinstance(child, HedTag) or child.short_base_tag != DefTagNames.DELAY_KEY] max_children = 1 - if found_onset.short_base_tag == DefTagNames.OFFSET_ORG_KEY: + if found_onset.short_base_tag == DefTagNames.OFFSET_KEY: max_children = 0 if len(children) > max_children: onset_issues += ErrorHandler.format_error(TemporalErrors.ONSET_WRONG_NUMBER_GROUPS, diff --git a/hed/validator/hed_validator.py b/hed/validator/hed_validator.py index fe4ddc11..ce21e71e 100644 --- a/hed/validator/hed_validator.py +++ b/hed/validator/hed_validator.py @@ -195,7 +195,7 @@ def _validate_individual_tags_in_hed_string(self, hed_string_obj, allow_placehol for group in hed_string_obj.get_all_groups(): is_definition = group in all_definition_groups for hed_tag in group.tags(): - if not self._definitions_allowed and hed_tag.short_base_tag == DefTagNames.DEFINITION_ORG_KEY: + if not self._definitions_allowed and hed_tag.short_base_tag == DefTagNames.DEFINITION_KEY: validation_issues += ErrorHandler.format_error(DefinitionErrors.BAD_DEFINITION_LOCATION, hed_tag) # todo: unclear if this should be restored at some point # if hed_tag.expandable and not hed_tag.expanded: @@ -208,8 +208,8 @@ def _validate_individual_tags_in_hed_string(self, hed_string_obj, allow_placehol run_individual_tag_validators(hed_tag, allow_placeholders=allow_placeholders, is_definition=is_definition) - if (hed_tag.short_base_tag == DefTagNames.DEF_ORG_KEY or - hed_tag.short_base_tag == DefTagNames.DEF_EXPAND_ORG_KEY): + if (hed_tag.short_base_tag == DefTagNames.DEF_KEY or + hed_tag.short_base_tag == DefTagNames.DEF_EXPAND_KEY): validation_issues += self._def_validator.validate_def_value_units(hed_tag, self) else: validation_issues += self.validate_units(hed_tag) diff --git a/hed/validator/onset_validator.py b/hed/validator/onset_validator.py index 6fc9ca56..105090c6 100644 --- a/hed/validator/onset_validator.py +++ b/hed/validator/onset_validator.py @@ -42,42 +42,14 @@ def validate_temporal_relations(self, hed_string_obj): return onset_issues - def validate_duration_tags(self, hed_string_obj): - """ Validate Duration/Delay tag groups - - Parameters: - hed_string_obj (HedString): The hed string to check. - - Returns: - list: A list of issues found in validating durations (i.e., extra tags or groups present, or a group missing) - """ - duration_issues = [] - for tags, group in hed_string_obj.find_top_level_tags_grouped(anchor_tags=DefTagNames.DURATION_KEYS): - # This implicitly validates the duration/delay tag, as they're the only two allowed in the same group - # It should be impossible to have > 2 tags, but it's a good stopgap. - if len(tags) != len(group.tags()) or len(group.tags()) > 2: - for tag in group.tags(): - if tag not in tags: - duration_issues += ErrorHandler.format_error(TemporalErrors.DURATION_HAS_OTHER_TAGS, tag=tag) - continue - if len(group.groups()) != 1: - duration_issues += ErrorHandler.format_error(TemporalErrors.DURATION_WRONG_NUMBER_GROUPS, - tags[0], - hed_string_obj.groups()) - continue - - # Does anything else need verification here? - # That duration is positive? - return duration_issues - def _handle_onset_or_offset(self, def_tag, onset_offset_tag): - is_onset = onset_offset_tag.short_base_tag == DefTagNames.ONSET_ORG_KEY + is_onset = onset_offset_tag.short_base_tag == DefTagNames.ONSET_KEY full_def_name = def_tag.extension if is_onset: # onset can never fail as it implies an offset self._onsets[full_def_name.lower()] = full_def_name else: - is_offset = onset_offset_tag.short_base_tag == DefTagNames.OFFSET_ORG_KEY + is_offset = onset_offset_tag.short_base_tag == DefTagNames.OFFSET_KEY if full_def_name.lower() not in self._onsets: if is_offset: return ErrorHandler.format_error(TemporalErrors.OFFSET_BEFORE_ONSET, tag=def_tag) @@ -101,6 +73,6 @@ def check_for_banned_tags(hed_string): banned_tag_list = DefTagNames.ALL_TIME_KEYS issues = [] for tag in hed_string.get_all_tags(): - if tag.short_base_tag.lower() in banned_tag_list: + if tag.short_base_tag in banned_tag_list: issues += ErrorHandler.format_error(TemporalErrors.HED_ONSET_WITH_NO_COLUMN, tag) return issues diff --git a/hed/validator/spreadsheet_validator.py b/hed/validator/spreadsheet_validator.py index 405c6aa7..1e07af11 100644 --- a/hed/validator/spreadsheet_validator.py +++ b/hed/validator/spreadsheet_validator.py @@ -8,7 +8,7 @@ from hed.errors.error_reporter import sort_issues, check_for_any_errors from hed.validator.onset_validator import OnsetValidator from hed.validator.hed_validator import HedValidator -from hed.models.df_util import sort_dataframe_by_onsets +from hed.models.df_util import sort_dataframe_by_onsets, split_delay_tags PANDAS_COLUMN_PREFIX_TO_IGNORE = "Unnamed: " @@ -25,6 +25,7 @@ def __init__(self, hed_schema): self._schema = hed_schema self._hed_validator = None self._onset_validator = None + self.invalid_original_rows = set() def validate(self, data, def_dicts=None, name=None, error_handler=None): """ @@ -46,6 +47,8 @@ def validate(self, data, def_dicts=None, name=None, error_handler=None): if not isinstance(data, BaseInput): raise TypeError("Invalid type passed to spreadsheet validator. Can only validate BaseInput objects.") + self.invalid_original_rows = set() + error_handler.push_error_context(ErrorContext.FILE_NAME, name) # Adjust to account for 1 based row_adj = 1 @@ -59,7 +62,8 @@ def validate(self, data, def_dicts=None, name=None, error_handler=None): data_new._dataframe = sort_dataframe_by_onsets(data.dataframe) issues += error_handler.format_error_with_context(ValidationErrors.ONSETS_OUT_OF_ORDER) data = data_new - onset_filtered = data.series_filtered + + onsets = split_delay_tags(data.series_a, self._schema, data.onsets) df = data.dataframe_a self._hed_validator = HedValidator(self._schema, def_dicts=def_dicts) @@ -69,15 +73,18 @@ def validate(self, data, def_dicts=None, name=None, error_handler=None): self._onset_validator = None # Check the rows of the input data - issues += self._run_checks(df, onset_filtered, error_handler=error_handler, row_adj=row_adj) + issues += self._run_checks(df, error_handler=error_handler, row_adj=row_adj, has_onsets=bool(self._onset_validator)) + if self._onset_validator: + issues += self._run_onset_checks(onsets, error_handler=error_handler, row_adj=row_adj) error_handler.pop_error_context() issues = sort_issues(issues) return issues - def _run_checks(self, hed_df, onset_filtered, error_handler, row_adj): + def _run_checks(self, hed_df, error_handler, row_adj, has_onsets): issues = [] columns = list(hed_df.columns) + self.invalid_original_rows = set() for row_number, text_file_row in hed_df.iterrows(): error_handler.push_error_context(ErrorContext.ROW, row_number + row_adj) row_strings = [] @@ -94,32 +101,49 @@ def _run_checks(self, hed_df, onset_filtered, error_handler, row_adj): new_column_issues = self._hed_validator.run_basic_checks(column_hed_string, allow_placeholders=False) error_handler.add_context_and_filter(new_column_issues) - error_handler.pop_error_context() - error_handler.pop_error_context() + error_handler.pop_error_context() # HedString + error_handler.pop_error_context() # column issues += new_column_issues + # We want to do full onset checks on the combined and filtered rows if check_for_any_errors(new_column_issues): - error_handler.pop_error_context() + self.invalid_original_rows.add(row_number) + error_handler.pop_error_context() # Row continue - row_string = None - if onset_filtered is not None: - row_string = HedString(onset_filtered[row_number], self._schema, self._hed_validator._def_validator) - elif row_strings: - row_string = HedString.from_hed_strings(row_strings) + if has_onsets or not row_strings: + error_handler.pop_error_context() # Row + continue + + row_string = HedString.from_hed_strings(row_strings) if row_string: error_handler.push_error_context(ErrorContext.HED_STRING, row_string) new_column_issues = self._hed_validator.run_full_string_checks(row_string) - if self._onset_validator is not None: - new_column_issues += self._onset_validator.validate_temporal_relations(row_string) - new_column_issues += self._onset_validator.validate_duration_tags(row_string) - else: - new_column_issues += OnsetValidator.check_for_banned_tags(row_string) + new_column_issues += OnsetValidator.check_for_banned_tags(row_string) error_handler.add_context_and_filter(new_column_issues) - error_handler.pop_error_context() + error_handler.pop_error_context() # HedString + issues += new_column_issues + error_handler.pop_error_context() # Row + return issues + + def _run_onset_checks(self, onset_filtered, error_handler, row_adj): + issues = [] + for row in onset_filtered[["HED", "original_index"]].itertuples(index=True): + # Skip rows that had issues. + if row.original_index in self.invalid_original_rows: + continue + error_handler.push_error_context(ErrorContext.ROW, row.original_index + row_adj) + row_string = HedString(row.HED, self._schema, self._hed_validator._def_validator) + + if row_string: + error_handler.push_error_context(ErrorContext.HED_STRING, row_string) + new_column_issues = self._hed_validator.run_full_string_checks(row_string) + new_column_issues += self._onset_validator.validate_temporal_relations(row_string) + error_handler.add_context_and_filter(new_column_issues) + error_handler.pop_error_context() # HedString issues += new_column_issues - error_handler.pop_error_context() + error_handler.pop_error_context() # Row return issues def _validate_column_structure(self, base_input, error_handler, row_adj): diff --git a/hed/validator/tag_util/group_util.py b/hed/validator/tag_util/group_util.py index 8513c89d..6e6c92ce 100644 --- a/hed/validator/tag_util/group_util.py +++ b/hed/validator/tag_util/group_util.py @@ -4,7 +4,7 @@ from hed.models.model_constants import DefTagNames from hed.schema import HedKey from hed.models import HedTag -from hed.errors.error_types import ValidationErrors +from hed.errors.error_types import ValidationErrors, TemporalErrors class GroupValidator: @@ -43,6 +43,7 @@ def run_tag_level_validators(self, hed_string_obj): validation_issues += self.check_tag_level_issue(original_tag_group.tags(), is_top_level, is_group) validation_issues += self._check_for_duplicate_groups(hed_string_obj) + validation_issues += self.validate_duration_tags(hed_string_obj) return validation_issues def run_all_tags_validators(self, hed_string_obj): @@ -89,9 +90,9 @@ def check_tag_level_issue(original_tag_list, is_top_level, is_group): for top_level_tag in top_level_tags: if not is_top_level: actual_code = None - if top_level_tag.short_base_tag == DefTagNames.DEFINITION_ORG_KEY: + if top_level_tag.short_base_tag == DefTagNames.DEFINITION_KEY: actual_code = ValidationErrors.DEFINITION_INVALID - elif top_level_tag.short_base_tag.lower() in DefTagNames.ALL_TIME_KEYS: + elif top_level_tag.short_base_tag in DefTagNames.ALL_TIME_KEYS: actual_code = ValidationErrors.TEMPORAL_TAG_ERROR # May split this out if we switch error if actual_code: @@ -102,9 +103,20 @@ def check_tag_level_issue(original_tag_list, is_top_level, is_group): tag=top_level_tag) if is_top_level and len(top_level_tags) > 1: - short_tags = [tag.short_base_tag for tag in top_level_tags] - # Special exception for Duration/Delay pairing - if len(top_level_tags) != 2 or DefTagNames.DURATION_ORG_KEY not in short_tags or DefTagNames.DELAY_ORG_KEY not in short_tags: + validation_issue = False + short_tags = {tag.short_base_tag for tag in top_level_tags} + # Verify there's no duplicates, and that if there's two tags they are a delay and temporal tag. + if len(short_tags) != len(top_level_tags): + validation_issue = True + elif DefTagNames.DELAY_KEY not in short_tags or len(short_tags) != 2: + validation_issue = True + else: + short_tags.remove(DefTagNames.DELAY_KEY) + other_tag = next(iter(short_tags)) + if other_tag not in DefTagNames.ALL_TIME_KEYS: + validation_issue = True + + if validation_issue: validation_issues += ErrorHandler.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=top_level_tags[0], multiple_tags=top_level_tags[1:]) @@ -150,6 +162,38 @@ def check_multiple_unique_tags_exist(self, tags): tag_namespace=unique_prefix) return validation_issues + @staticmethod + def validate_duration_tags(hed_string_obj): + """ Validate Duration/Delay tag groups + + Parameters: + hed_string_obj (HedString): The hed string to check. + + Returns: + list: A list of issues found in validating durations (i.e., extra tags or groups present, or a group missing) + """ + duration_issues = [] + for top_tag, group in hed_string_obj.find_top_level_tags(anchor_tags=DefTagNames.DURATION_KEYS): + top_level_tags = [tag.short_base_tag for tag in group.get_all_tags() if tag.base_tag_has_attribute(HedKey.TopLevelTagGroup)] + # Skip onset/inset/offset + if any(tag in DefTagNames.TEMPORAL_KEYS for tag in top_level_tags): + continue + # This implicitly validates the duration/delay tag, as they're the only two allowed in the same group + # It should be impossible to have > 2 tags, but it's a good stopgap. + if len(top_level_tags) != len(group.tags()): + for tag in group.tags(): + if tag.short_base_tag not in top_level_tags: + duration_issues += ErrorHandler.format_error(TemporalErrors.DURATION_HAS_OTHER_TAGS, + tag=tag) + continue + if len(group.groups()) != 1: + duration_issues += ErrorHandler.format_error(TemporalErrors.DURATION_WRONG_NUMBER_GROUPS, + top_tag, + hed_string_obj.groups()) + continue + + return duration_issues + def _validate_tags_in_hed_string(self, tags): """ Validate the multi-tag properties in a HED string. diff --git a/tests/models/test_base_input.py b/tests/models/test_base_input.py index b6d738e2..9b9e9e53 100644 --- a/tests/models/test_base_input.py +++ b/tests/models/test_base_input.py @@ -186,62 +186,3 @@ def test_combine_dataframe_with_mixed_values(self): expected = pd.Series(['apple, guitar', 'elephant, harmonica', 'cherry, fox', '', '']) self.assertTrue(result.equals(expected)) - -class TestOnsetDict(unittest.TestCase): - def test_empty_and_single_onset(self): - self.assertEqual(BaseInput._indexed_dict_from_onsets([]), {}) - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5]), {3.5: [0]}) - - def test_identical_and_approx_equal_onsets(self): - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 3.5]), {3.5: [0, 1]}) - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 3.500000001]), {3.5: [0], 3.500000001: [1]}) - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 3.5000000000001]), {3.5: [0, 1]}) - - def test_distinct_and_mixed_onsets(self): - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 4.0, 4.4]), {3.5: [0], 4.0: [1], 4.4: [2]}) - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 3.5, 4.0, 4.4]), {3.5: [0, 1], 4.0: [2], 4.4: [3]}) - self.assertEqual(BaseInput._indexed_dict_from_onsets([4.0, 3.5, 4.4, 4.4]), {4.0: [0], 3.5: [1], 4.4: [2, 3]}) - - def test_complex_onsets(self): - # Negative, zero, and positive onsets - self.assertEqual(BaseInput._indexed_dict_from_onsets([-1.0, 0.0, 1.0]), {-1.0: [0], 0.0: [1], 1.0: [2]}) - - # Very close but distinct onsets - self.assertEqual(BaseInput._indexed_dict_from_onsets([1.0, 1.0 + 1e-8, 1.0 + 2e-8]), - {1.0: [0], 1.0 + 1e-8: [1], 1.0 + 2e-8: [2]}) - # Very close - self.assertEqual(BaseInput._indexed_dict_from_onsets([1.0, 1.0 + 1e-10, 1.0 + 2e-10]), - {1.0: [0, 1, 2]}) - - # Mixed scenario - self.assertEqual(BaseInput._indexed_dict_from_onsets([3.5, 3.5, 4.0, 4.4, 4.4, -1.0]), - {3.5: [0, 1], 4.0: [2], 4.4: [3, 4], -1.0: [5]}) - - def test_empty_and_single_item_series(self): - self.assertTrue(BaseInput._filter_by_index_list(pd.Series([], dtype=str), {}).equals(pd.Series([], dtype=str))) - self.assertTrue(BaseInput._filter_by_index_list(pd.Series(["apple"]), {0: [0]}).equals(pd.Series(["apple"]))) - - def test_two_item_series_with_same_onset(self): - input_series = pd.Series(["apple", "orange"]) - expected_series = pd.Series(["apple,orange", ""]) - self.assertTrue(BaseInput._filter_by_index_list(input_series, {0: [0, 1]}).equals(expected_series)) - - def test_multiple_item_series(self): - input_series = pd.Series(["apple", "orange", "banana", "mango"]) - indexed_dict = {0: [0, 1], 1: [2], 2: [3]} - expected_series = pd.Series(["apple,orange", "", "banana", "mango"]) - self.assertTrue(BaseInput._filter_by_index_list(input_series, indexed_dict).equals(expected_series)) - - def test_complex_scenarios(self): - # Test with negative, zero and positive onsets - original = pd.Series(["negative", "zero", "positive"]) - indexed_dict = {-1: [0], 0: [1], 1: [2]} - expected_series1 = pd.Series(["negative", "zero", "positive"]) - self.assertTrue(BaseInput._filter_by_index_list(original, indexed_dict).equals(expected_series1)) - - # Test with more complex indexed_dict - original2 = ["apple", "orange", "banana", "mango", "grape"] - indexed_dict2= {0: [0, 1], 1: [2], 2: [3, 4]} - expected_series2 = pd.Series(["apple,orange", "", "banana", "mango,grape", ""]) - self.assertTrue(BaseInput._filter_by_index_list(original2, indexed_dict2).equals(expected_series2)) - diff --git a/tests/models/test_df_util.py b/tests/models/test_df_util.py index 1cff6943..47b7eddc 100644 --- a/tests/models/test_df_util.py +++ b/tests/models/test_df_util.py @@ -5,7 +5,7 @@ from hed import load_schema_version from hed.models.df_util import shrink_defs, expand_defs, convert_to_form, process_def_expands from hed import DefinitionDict -from hed.models.df_util import _handle_curly_braces_refs +from hed.models.df_util import _handle_curly_braces_refs, _indexed_dict_from_onsets, _filter_by_index_list, split_delay_tags class TestShrinkDefs(unittest.TestCase): @@ -425,3 +425,160 @@ def test_insert_columns_with_parentheses_na_values(self): }) result = _handle_curly_braces_refs(df, refs=["column2"], column_names=df.columns) pd.testing.assert_frame_equal(result, expected_df) + +class TestOnsetDict(unittest.TestCase): + def test_empty_and_single_onset(self): + self.assertEqual(_indexed_dict_from_onsets([]), {}) + self.assertEqual(_indexed_dict_from_onsets([3.5]), {3.5: [0]}) + + def test_identical_and_approx_equal_onsets(self): + self.assertEqual(_indexed_dict_from_onsets([3.5, 3.5]), {3.5: [0, 1]}) + self.assertEqual(_indexed_dict_from_onsets([3.5, 3.500000001]), {3.5: [0], 3.500000001: [1]}) + self.assertEqual(_indexed_dict_from_onsets([3.5, 3.5000000000001]), {3.5: [0, 1]}) + + def test_distinct_and_mixed_onsets(self): + self.assertEqual(_indexed_dict_from_onsets([3.5, 4.0, 4.4]), {3.5: [0], 4.0: [1], 4.4: [2]}) + self.assertEqual(_indexed_dict_from_onsets([3.5, 3.5, 4.0, 4.4]), {3.5: [0, 1], 4.0: [2], 4.4: [3]}) + self.assertEqual(_indexed_dict_from_onsets([4.0, 3.5, 4.4, 4.4]), {4.0: [0], 3.5: [1], 4.4: [2, 3]}) + + def test_complex_onsets(self): + # Negative, zero, and positive onsets + self.assertEqual(_indexed_dict_from_onsets([-1.0, 0.0, 1.0]), {-1.0: [0], 0.0: [1], 1.0: [2]}) + + # Very close but distinct onsets + self.assertEqual(_indexed_dict_from_onsets([1.0, 1.0 + 1e-8, 1.0 + 2e-8]), + {1.0: [0], 1.0 + 1e-8: [1], 1.0 + 2e-8: [2]}) + # Very close + self.assertEqual(_indexed_dict_from_onsets([1.0, 1.0 + 1e-10, 1.0 + 2e-10]), + {1.0: [0, 1, 2]}) + + # Mixed scenario + self.assertEqual(_indexed_dict_from_onsets([3.5, 3.5, 4.0, 4.4, 4.4, -1.0]), + {3.5: [0, 1], 4.0: [2], 4.4: [3, 4], -1.0: [5]}) + + def test_empty_and_single_item_series(self): + self.assertTrue(_filter_by_index_list(pd.Series([], dtype=str), {}).equals(pd.Series([], dtype=str))) + self.assertTrue(_filter_by_index_list(pd.Series(["apple"]), {0: [0]}).equals(pd.Series(["apple"]))) + + def test_two_item_series_with_same_onset(self): + input_series = pd.Series(["apple", "orange"]) + expected_series = pd.Series(["apple,orange", ""]) + self.assertTrue(_filter_by_index_list(input_series, {0: [0, 1]}).equals(expected_series)) + + def test_multiple_item_series(self): + input_series = pd.Series(["apple", "orange", "banana", "mango"]) + indexed_dict = {0: [0, 1], 1: [2], 2: [3]} + expected_series = pd.Series(["apple,orange", "", "banana", "mango"]) + self.assertTrue(_filter_by_index_list(input_series, indexed_dict).equals(expected_series)) + + def test_complex_scenarios(self): + # Test with negative, zero and positive onsets + original = pd.Series(["negative", "zero", "positive"]) + indexed_dict = {-1: [0], 0: [1], 1: [2]} + expected_series1 = pd.Series(["negative", "zero", "positive"]) + self.assertTrue(_filter_by_index_list(original, indexed_dict).equals(expected_series1)) + + # Test with more complex indexed_dict + original2 = pd.Series(["apple", "orange", "banana", "mango", "grape"]) + indexed_dict2= {0: [0, 1], 1: [2], 2: [3, 4]} + expected_series2 = pd.Series(["apple,orange", "", "banana", "mango,grape", ""]) + self.assertTrue(_filter_by_index_list(original2, indexed_dict2).equals(expected_series2)) + + def test_empty_and_single_item_series_df(self): + self.assertTrue(_filter_by_index_list(pd.DataFrame([], columns=["HED", "Extra"]), {}).equals( + pd.DataFrame([], columns=["HED", "Extra"]))) + self.assertTrue( + _filter_by_index_list(pd.DataFrame([["apple", "extra1"]], columns=["HED", "Extra"]), {0: [0]}).equals( + pd.DataFrame([["apple", "extra1"]], columns=["HED", "Extra"]))) + + def test_two_item_series_with_same_onset_df(self): + input_df = pd.DataFrame([["apple", "extra1"], ["orange", "extra2"]], columns=["HED", "Extra"]) + expected_df = pd.DataFrame([["apple,orange", "extra1"], ["", "extra2"]], columns=["HED", "Extra"]) + self.assertTrue(_filter_by_index_list(input_df, {0: [0, 1]}).equals(expected_df)) + + def test_multiple_item_series_df(self): + input_df = pd.DataFrame([["apple", "extra1"], ["orange", "extra2"], ["banana", "extra3"], ["mango", "extra4"]], + columns=["HED", "Extra"]) + indexed_dict = {0: [0, 1], 1: [2], 2: [3]} + expected_df = pd.DataFrame( + [["apple,orange", "extra1"], ["", "extra2"], ["banana", "extra3"], ["mango", "extra4"]], + columns=["HED", "Extra"]) + self.assertTrue(_filter_by_index_list(input_df, indexed_dict).equals(expected_df)) + + def test_complex_scenarios_df(self): + # Test with negative, zero, and positive onsets + original = pd.DataFrame([["negative", "extra1"], ["zero", "extra2"], ["positive", "extra3"]], + columns=["HED", "Extra"]) + indexed_dict = {-1: [0], 0: [1], 1: [2]} + expected_df = pd.DataFrame([["negative", "extra1"], ["zero", "extra2"], ["positive", "extra3"]], + columns=["HED", "Extra"]) + self.assertTrue(_filter_by_index_list(original, indexed_dict).equals(expected_df)) + + # Test with more complex indexed_dict + original2 = pd.DataFrame( + [["apple", "extra1"], ["orange", "extra2"], ["banana", "extra3"], ["mango", "extra4"], ["grape", "extra5"]], + columns=["HED", "Extra"]) + indexed_dict2 = {0: [0, 1], 1: [2], 2: [3, 4]} + expected_df2 = pd.DataFrame( + [["apple,orange", "extra1"], ["", "extra2"], ["banana", "extra3"], ["mango,grape", "extra4"], + ["", "extra5"]], columns=["HED", "Extra"]) + self.assertTrue(_filter_by_index_list(original2, indexed_dict2).equals(expected_df2)) + + + +class TestSplitDelayTags(unittest.TestCase): + schema = load_schema_version("8.2.0") + def test_empty_series_and_onsets(self): + empty_series = pd.Series([], dtype="object") + empty_onsets = pd.Series([], dtype="float") + result = split_delay_tags(empty_series, self.schema, empty_onsets) + self.assertIsInstance(result, pd.DataFrame) + + def test_None_series_and_onsets(self): + result = split_delay_tags(None, self.schema, None) + self.assertIsNone(result) + + def test_normal_ordered_series(self): + series = pd.Series([ + "Tag1,Tag2", + "Tag3,Tag4" + ]) + onsets = pd.Series([1.0, 2.0]) + result = split_delay_tags(series, self.schema, onsets) + self.assertTrue(result.onset.equals(pd.Series([1.0, 2.0]))) + self.assertTrue(result.HED.equals(pd.Series([ + "Tag1,Tag2", + "Tag3,Tag4" + ]))) + + def test_normal_ordered_series_with_delays(self): + series = pd.Series([ + "Tag1,Tag2,(Delay/3.0 s,(Tag5))", + "Tag3,Tag4" + ]) + onsets = pd.Series([1.0, 2.0]) + result = split_delay_tags(series, self.schema, onsets) + self.assertTrue(result.onset.equals(pd.Series([1.0, 2.0, 4.0]))) + self.assertTrue(result.HED.equals(pd.Series([ + "Tag1,Tag2", + "Tag3,Tag4", + "(Delay/3.0 s,(Tag5))" + ]))) + + def test_normal_ordered_series_with_double_delays(self): + series = pd.Series([ + "Tag1,Tag2,(Delay/3.0 s,(Tag5))", + "Tag6,(Delay/2.0 s,(Tag7))", + "Tag3,Tag4" + ]) + onsets = pd.Series([1.0, 2.0, 3.0]) + result = split_delay_tags(series, self.schema, onsets) + self.assertTrue(result.onset.equals(pd.Series([1.0, 2.0, 3.0, 4.0, 4.0]))) + self.assertTrue(result.HED.equals(pd.Series([ + "Tag1,Tag2", + "Tag6", + "Tag3,Tag4", + "(Delay/3.0 s,(Tag5)),(Delay/2.0 s,(Tag7))", + "" + ]))) + self.assertTrue(result.original_index.equals(pd.Series([0, 1, 2, 0, 1]))) \ No newline at end of file diff --git a/tests/validator/test_onset_validator.py b/tests/validator/test_onset_validator.py index 55014c04..fcb2abfc 100644 --- a/tests/validator/test_onset_validator.py +++ b/tests/validator/test_onset_validator.py @@ -6,6 +6,8 @@ from hed.models import HedString, DefinitionDict from hed import schema from hed.validator import HedValidator, OnsetValidator, DefValidator +from hed.validator.tag_util.group_util import GroupValidator + from tests.validator.test_tag_validator_base import TestHedBase @@ -56,11 +58,12 @@ def _test_issues_base(self, test_strings, test_issues, test_context, placeholder onset_issues += def_validator.validate_onset_offset(test_string) if not onset_issues: onset_issues += onset_validator.validate_temporal_relations(test_string) - onset_issues += onset_validator.validate_duration_tags(test_string) + onset_issues += GroupValidator.validate_duration_tags(test_string) error_handler.add_context_and_filter(onset_issues) test_string.shrink_defs() issues = self.format_errors_fully(error_handler, hed_string=test_string, params=expected_params) + # print(str(test_string)) # print(str(onset_issues)) # print(str(issues)) # print(onset_validator._onsets) diff --git a/tests/validator/test_spreadsheet_validator.py b/tests/validator/test_spreadsheet_validator.py index 80dbf067..b67bff2d 100644 --- a/tests/validator/test_spreadsheet_validator.py +++ b/tests/validator/test_spreadsheet_validator.py @@ -7,6 +7,7 @@ from hed.validator import SpreadsheetValidator from hed import TabularInput, SpreadsheetInput from hed.errors.error_types import ValidationErrors +from hed import DefinitionDict class TestSpreadsheetValidation(unittest.TestCase): @@ -93,4 +94,4 @@ def test_invalid_onset_invalid_column(self): issues = self.validator.validate(TabularInput(self.df_with_onset_has_tags_unordered), def_dicts=def_dict) self.assertEqual(len(issues), 2) self.assertEqual(issues[0]['code'], ValidationErrors.HED_UNKNOWN_COLUMN) - self.assertEqual(issues[1]['code'], ValidationErrors.TEMPORAL_TAG_ERROR) \ No newline at end of file + self.assertEqual(issues[1]['code'], ValidationErrors.TEMPORAL_TAG_ERROR) diff --git a/tests/validator/test_tag_validator.py b/tests/validator/test_tag_validator.py index 6f28b35c..38ec3ac2 100644 --- a/tests/validator/test_tag_validator.py +++ b/tests/validator/test_tag_validator.py @@ -1,6 +1,6 @@ import unittest -from hed.errors.error_types import ValidationErrors, DefinitionErrors +from hed.errors.error_types import ValidationErrors, DefinitionErrors, TemporalErrors from tests.validator.test_tag_validator_base import TestValidatorBase from hed import load_schema_version from functools import partial @@ -468,7 +468,7 @@ def test_topLevelTagGroup_validation(self): 'valid2TwoInOne': '(Duration/5.0 s, Delay, (Event))', 'invalid3InOne': '(Duration/5.0 s, Delay, Onset, (Event))', 'invalidDuration': '(Duration/5.0 s, Onset, (Event))', - 'invalidDelay': '(Delay, Onset, (Event))', + 'validDelay': '(Delay, Onset, (Event))', 'invalidDurationPair': '(Duration/5.0 s, Duration/3.0 s, (Event))', 'invalidDelayPair': '(Delay/3.0 s, Delay, (Event))', } @@ -482,7 +482,7 @@ def test_topLevelTagGroup_validation(self): 'valid2TwoInOne': True, 'invalid3InOne': False, 'invalidDuration': False, - 'invalidDelay': False, + 'validDelay': True, 'invalidDurationPair': False, 'invalidDelayPair': False, } @@ -492,13 +492,13 @@ def test_topLevelTagGroup_validation(self): 'valid1': [], 'valid2': [], 'invalid2': self.format_error(ValidationErrors.HED_TOP_LEVEL_TAG, tag=1, actual_error=ValidationErrors.DEFINITION_INVALID) - + self.format_error(ValidationErrors.HED_TOP_LEVEL_TAG, tag=1), + + self.format_error(ValidationErrors.HED_TOP_LEVEL_TAG, tag=1), 'invalidTwoInOne': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Definition/InvalidDef3".split(", ")), 'invalid2TwoInOne': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Onset".split(", ")), 'valid2TwoInOne': [], 'invalid3InOne': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Delay, Onset".split(", ")), 'invalidDuration': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Onset".split(", ")), - 'invalidDelay': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Onset".split(", ")), + 'validDelay': [], 'invalidDurationPair': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Duration/3.0 s".split(", ")), 'invalidDelayPair': self.format_error(ValidationErrors.HED_MULTIPLE_TOP_TAGS, tag=0, multiple_tags="Delay".split(", ")), } diff --git a/tests/validator/test_tag_validator_base.py b/tests/validator/test_tag_validator_base.py index 568ee0bb..ac362da3 100644 --- a/tests/validator/test_tag_validator_base.py +++ b/tests/validator/test_tag_validator_base.py @@ -88,9 +88,9 @@ def validator_base(self, test_strings, expected_results, expected_issues, test_f error_handler.add_context_and_filter(test_issues) test_result = not test_issues - print(test_key) - print(str(expected_issue)) - print(str(test_issues)) + # print(test_key) + # print(str(expected_issue)) + # print(str(test_issues)) error_handler.pop_error_context() self.assertEqual(test_result, expected_result, test_strings[test_key]) self.assertCountEqual(test_issues, expected_issue, test_strings[test_key])