diff --git a/hed/models/hed_tag.py b/hed/models/hed_tag.py index d5afb132..63808bd8 100644 --- a/hed/models/hed_tag.py +++ b/hed/models/hed_tag.py @@ -580,33 +580,17 @@ def _get_tag_units_portion(extension_text, tag_unit_classes): return None, None, None for unit_class_entry in tag_unit_classes.values(): - all_valid_unit_permutations = unit_class_entry.derivative_units - - possible_match = HedTag._find_modifier_unit_entry(units, all_valid_unit_permutations) + possible_match = unit_class_entry.get_derivative_unit_entry(units) if possible_match and not possible_match.has_attribute(HedKey.UnitPrefix): return value, units, possible_match # Repeat the above, but as a prefix - possible_match = HedTag._find_modifier_unit_entry(value, all_valid_unit_permutations) + possible_match = unit_class_entry.get_derivative_unit_entry(value) if possible_match and possible_match.has_attribute(HedKey.UnitPrefix): return units, value, possible_match return None, None, None - @staticmethod - def _find_modifier_unit_entry(units, all_valid_unit_permutations): - possible_match = all_valid_unit_permutations.get(units) - # If we have a match that's a unit symbol, we're done, return it. - if possible_match and possible_match.has_attribute(HedKey.UnitSymbol): - return possible_match - - possible_match = all_valid_unit_permutations.get(units.lower()) - # Unit symbols must match including case, a match of a unit symbol now is something like M becoming m. - if possible_match and possible_match.has_attribute(HedKey.UnitSymbol): - possible_match = None - - return possible_match - def is_placeholder(self): if "#" in self.org_tag or "#" in self._extension_value: return True diff --git a/hed/schema/hed_schema.py b/hed/schema/hed_schema.py index eb871e10..dbba8046 100644 --- a/hed/schema/hed_schema.py +++ b/hed/schema/hed_schema.py @@ -8,7 +8,8 @@ from hed.schema.schema_io.schema2wiki import Schema2Wiki from hed.schema.schema_io.schema2owl import Schema2Owl from hed.schema.schema_io.owl_constants import ext_to_format -from hed.schema.hed_schema_section import HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection +from hed.schema.hed_schema_section import (HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection, + HedSchemaUnitSection) from hed.errors import ErrorHandler from hed.errors.error_types import ValidationErrors from hed.schema.hed_schema_base import HedSchemaBase @@ -747,7 +748,7 @@ def _create_empty_sections(): dictionaries[HedSectionKey.Properties] = HedSchemaSection(HedSectionKey.Properties) dictionaries[HedSectionKey.Attributes] = HedSchemaSection(HedSectionKey.Attributes) dictionaries[HedSectionKey.UnitModifiers] = HedSchemaSection(HedSectionKey.UnitModifiers) - dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units) + dictionaries[HedSectionKey.Units] = HedSchemaUnitSection(HedSectionKey.Units) dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses) dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses) dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False) @@ -767,9 +768,13 @@ def _get_modifiers_for_unit(self, unit): This is a lower level one that doesn't rely on the Unit entries being fully setup. """ + # todo: could refactor this so this unit.lower() part is in HedSchemaUnitSection.get unit_entry = self.get_tag_entry(unit, HedSectionKey.Units) if unit_entry is None: - return [] + unit_entry = self.get_tag_entry(unit.lower(), HedSectionKey.Units) + # Unit symbols must match exactly + if unit_entry is None or unit_entry.has_attribute(HedKey.UnitSymbol): + return [] is_si_unit = unit_entry.has_attribute(HedKey.SIUnit) is_unit_symbol = unit_entry.has_attribute(HedKey.UnitSymbol) if not is_si_unit: diff --git a/hed/schema/hed_schema_entry.py b/hed/schema/hed_schema_entry.py index b5693a17..3f23838d 100644 --- a/hed/schema/hed_schema_entry.py +++ b/hed/schema/hed_schema_entry.py @@ -187,6 +187,27 @@ def __eq__(self, other): return False return True + def get_derivative_unit_entry(self, units): + """ Gets the (derivative) unit entry if it exists + + Parameters: + units (str): The unit name to check, can be plural or include a modifier. + + Returns: + unit_entry(UnitEntry or None): The unit entry if it exists + """ + possible_match = self.derivative_units.get(units) + # If we have a match that's a unit symbol, we're done, return it. + if possible_match and possible_match.has_attribute(HedKey.UnitSymbol): + return possible_match + + possible_match = self.derivative_units.get(units.lower()) + # Unit symbols must match including case, a match of a unit symbol now is something like M becoming m. + if possible_match and possible_match.has_attribute(HedKey.UnitSymbol): + possible_match = None + + return possible_match + class UnitEntry(HedSchemaEntry): """ A single unit entry with modifiers in the HedSchema. """ @@ -206,15 +227,16 @@ def finalize_entry(self, schema): self.unit_modifiers = schema._get_modifiers_for_unit(self.name) derivative_units = {} - base_plural_units = {self.name} - if not self.has_attribute(HedKey.UnitSymbol): - base_plural_units.add(pluralize.plural(self.name)) + if self.has_attribute(HedKey.UnitSymbol): + base_plural_units = {self.name} + else: + base_plural_units = {self.name.lower()} + base_plural_units.add(pluralize.plural(self.name.lower())) for derived_unit in base_plural_units: derivative_units[derived_unit] = self._get_conversion_factor(None) for modifier in self.unit_modifiers: derivative_units[modifier.name + derived_unit] = self._get_conversion_factor(modifier_entry=modifier) - self.derivative_units = derivative_units def _get_conversion_factor(self, modifier_entry): diff --git a/hed/schema/hed_schema_section.py b/hed/schema/hed_schema_section.py index f7934a21..dc3c64fe 100644 --- a/hed/schema/hed_schema_section.py +++ b/hed/schema/hed_schema_section.py @@ -148,6 +148,14 @@ def _finalize_section(self, hed_schema): entry.finalize_entry(hed_schema) +class HedSchemaUnitSection(HedSchemaSection): + def _check_if_duplicate(self, name_key, new_entry): + """We need to mark duplicate units(units with unitSymbol are case sensitive, while others are not.""" + if not new_entry.has_attribute(HedKey.UnitSymbol): + name_key = name_key.lower() + return super()._check_if_duplicate(name_key, new_entry) + + class HedSchemaUnitClassSection(HedSchemaSection): def _check_if_duplicate(self, name_key, new_entry): """Allow adding units to existing unit classes, using a placeholder one with no attributes.""" diff --git a/hed/schema/schema_attribute_validators.py b/hed/schema/schema_attribute_validators.py index 4dd39d02..a053b962 100644 --- a/hed/schema/schema_attribute_validators.py +++ b/hed/schema/schema_attribute_validators.py @@ -148,7 +148,7 @@ def value_class_exists(hed_schema, tag_entry, attribute_name): def unit_exists(hed_schema, tag_entry, attribute_name): issues = [] unit = tag_entry.attributes.get(attribute_name, "") - unit_entry = tag_entry.derivative_units.get(unit) + unit_entry = tag_entry.get_derivative_unit_entry(unit) if unit and not unit_entry: issues += ErrorHandler.format_error(SchemaAttributeErrors.SCHEMA_DEFAULT_UNITS_INVALID, tag_entry.name, diff --git a/tests/validator/test_tag_validator.py b/tests/validator/test_tag_validator.py index 9c7aa307..6f28b35c 100644 --- a/tests/validator/test_tag_validator.py +++ b/tests/validator/test_tag_validator.py @@ -2,6 +2,7 @@ from hed.errors.error_types import ValidationErrors, DefinitionErrors from tests.validator.test_tag_validator_base import TestValidatorBase +from hed import load_schema_version from functools import partial @@ -11,6 +12,7 @@ class TestHed(TestValidatorBase): class IndividualHedTagsShort(TestHed): + hed_schema = load_schema_version("score_1.1.0") @staticmethod def string_obj_func(validator): return partial(validator._validate_individual_tags_in_hed_string) @@ -215,6 +217,20 @@ def test_correct_units(self): # Update tests - 8.0 currently has no clockTime nodes. # 'properTime': 'Item/2D shape/Clock face/08:30', # 'invalidTime': 'Item/2D shape/Clock face/54:54' + 'voltsTest1': 'Finding-amplitude/30 v', + 'voltsTest2': 'Finding-amplitude/30 Volt', + 'voltsTest3': 'Finding-amplitude/30 volts', + 'voltsTest4': 'Finding-amplitude/30 VOLTS', + 'voltsTest5': 'Finding-amplitude/30 kv', + 'voltsTest6': 'Finding-amplitude/30 kiloVolt', + 'voltsTest7': 'Finding-amplitude/30 KiloVolt', + 'volumeTest1': "Sound-volume/5 dB", + 'volumeTest2': "Sound-volume/5 kdB", # Invalid, not SI unit + 'volumeTest3': "Sound-volume/5 candela", + 'volumeTest4': "Sound-volume/5 kilocandela", + 'volumeTest5': "Sound-volume/5 cd", + 'volumeTest6': "Sound-volume/5 kcd", + 'volumeTest7': "Sound-volume/5 DB", # Invalid, case doesn't match } expected_results = { 'correctUnit': True, @@ -236,12 +252,27 @@ def test_correct_units(self): # 'invalidTime': True, # 'specialAllowedCharCurrency': True, # 'specialNotAllowedCharCurrency': False, + 'voltsTest1': True, + 'voltsTest2': True, + 'voltsTest3': True, + 'voltsTest4': True, + 'voltsTest5': True, + 'voltsTest6': True, + 'voltsTest7': True, + 'volumeTest1': True, + 'volumeTest2': False, + 'volumeTest3': True, + 'volumeTest4': True, + 'volumeTest5': True, + 'volumeTest6': True, + 'volumeTest7': False, } legal_time_units = ['s', 'second', 'day', 'minute', 'hour'] # legal_clock_time_units = ['hour:min', 'hour:min:sec'] # legal_datetime_units = ['YYYY-MM-DDThh:mm:ss'] legal_freq_units = ['Hz', 'hertz'] # legal_currency_units = ['dollar', "$", "point"] + legal_intensity_units = ["candela", "cd", "dB"] expected_issues = { 'correctUnit': [], @@ -273,6 +304,20 @@ def test_correct_units(self): # 'specialNotAllowedCharCurrency': self.format_error(ValidationErrors.UNITS_INVALID, # tag=0, # units=legal_currency_units), + 'voltsTest1': [], + 'voltsTest2': [], + 'voltsTest3': [], + 'voltsTest4': [], + 'voltsTest5': [], + 'voltsTest6': [], + 'voltsTest7': [], + 'volumeTest1': [], + 'volumeTest2': self.format_error(ValidationErrors.UNITS_INVALID,tag=0, units=legal_intensity_units), + 'volumeTest3': [], + 'volumeTest4': [], + 'volumeTest5': [], + 'volumeTest6': [], + 'volumeTest7': self.format_error(ValidationErrors.UNITS_INVALID, tag=0, units=legal_intensity_units), } self.validator_semantic(test_strings, expected_results, expected_issues, True) diff --git a/tests/validator/test_tag_validator_base.py b/tests/validator/test_tag_validator_base.py index 5b2930cb..568ee0bb 100644 --- a/tests/validator/test_tag_validator_base.py +++ b/tests/validator/test_tag_validator_base.py @@ -10,10 +10,11 @@ #todo: update these tests(TagValidator no longer exists) class TestHedBase(unittest.TestCase): schema_file = None + hed_schema = None @classmethod def setUpClass(cls): - if cls.schema_file: + if cls.schema_file and not cls.hed_schema: hed_xml = os.path.join(os.path.dirname(os.path.realpath(__file__)), cls.schema_file) cls.hed_schema = schema.load_schema(hed_xml) elif not cls.hed_schema: @@ -87,9 +88,9 @@ def validator_base(self, test_strings, expected_results, expected_issues, test_f error_handler.add_context_and_filter(test_issues) test_result = not test_issues - # print(test_key) - # print(str(expected_issue)) - # print(str(test_issues)) + print(test_key) + print(str(expected_issue)) + print(str(test_issues)) error_handler.pop_error_context() self.assertEqual(test_result, expected_result, test_strings[test_key]) self.assertCountEqual(test_issues, expected_issue, test_strings[test_key])