Skip to content

Commit

Permalink
Allow mixed case in non unitsymbol units
Browse files Browse the repository at this point in the history
  • Loading branch information
IanCa committed Mar 9, 2024
1 parent 6c9b45a commit 0568209
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 30 deletions.
20 changes: 2 additions & 18 deletions hed/models/hed_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,33 +580,17 @@ def _get_tag_units_portion(extension_text, tag_unit_classes):
return None, None, None

for unit_class_entry in tag_unit_classes.values():
all_valid_unit_permutations = unit_class_entry.derivative_units

possible_match = HedTag._find_modifier_unit_entry(units, all_valid_unit_permutations)
possible_match = unit_class_entry.get_derivative_unit_entry(units)
if possible_match and not possible_match.has_attribute(HedKey.UnitPrefix):
return value, units, possible_match

# Repeat the above, but as a prefix
possible_match = HedTag._find_modifier_unit_entry(value, all_valid_unit_permutations)
possible_match = unit_class_entry.get_derivative_unit_entry(value)
if possible_match and possible_match.has_attribute(HedKey.UnitPrefix):
return units, value, possible_match

return None, None, None

@staticmethod
def _find_modifier_unit_entry(units, all_valid_unit_permutations):
possible_match = all_valid_unit_permutations.get(units)
# If we have a match that's a unit symbol, we're done, return it.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
return possible_match

possible_match = all_valid_unit_permutations.get(units.lower())
# Unit symbols must match including case, a match of a unit symbol now is something like M becoming m.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
possible_match = None

return possible_match

def is_placeholder(self):
if "#" in self.org_tag or "#" in self._extension_value:
return True
Expand Down
11 changes: 8 additions & 3 deletions hed/schema/hed_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from hed.schema.schema_io.schema2wiki import Schema2Wiki
from hed.schema.schema_io.schema2owl import Schema2Owl
from hed.schema.schema_io.owl_constants import ext_to_format
from hed.schema.hed_schema_section import HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection
from hed.schema.hed_schema_section import (HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection,
HedSchemaUnitSection)
from hed.errors import ErrorHandler
from hed.errors.error_types import ValidationErrors
from hed.schema.hed_schema_base import HedSchemaBase
Expand Down Expand Up @@ -747,7 +748,7 @@ def _create_empty_sections():
dictionaries[HedSectionKey.Properties] = HedSchemaSection(HedSectionKey.Properties)
dictionaries[HedSectionKey.Attributes] = HedSchemaSection(HedSectionKey.Attributes)
dictionaries[HedSectionKey.UnitModifiers] = HedSchemaSection(HedSectionKey.UnitModifiers)
dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units)
dictionaries[HedSectionKey.Units] = HedSchemaUnitSection(HedSectionKey.Units)
dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses)
dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses)
dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False)
Expand All @@ -767,9 +768,13 @@ def _get_modifiers_for_unit(self, unit):
This is a lower level one that doesn't rely on the Unit entries being fully setup.
"""
# todo: could refactor this so this unit.lower() part is in HedSchemaUnitSection.get
unit_entry = self.get_tag_entry(unit, HedSectionKey.Units)
if unit_entry is None:
return []
unit_entry = self.get_tag_entry(unit.lower(), HedSectionKey.Units)
# Unit symbols must match exactly
if unit_entry is None or unit_entry.has_attribute(HedKey.UnitSymbol):
return []
is_si_unit = unit_entry.has_attribute(HedKey.SIUnit)
is_unit_symbol = unit_entry.has_attribute(HedKey.UnitSymbol)
if not is_si_unit:
Expand Down
30 changes: 26 additions & 4 deletions hed/schema/hed_schema_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,27 @@ def __eq__(self, other):
return False
return True

def get_derivative_unit_entry(self, units):
""" Gets the (derivative) unit entry if it exists
Parameters:
units (str): The unit name to check, can be plural or include a modifier.
Returns:
unit_entry(UnitEntry or None): The unit entry if it exists
"""
possible_match = self.derivative_units.get(units)
# If we have a match that's a unit symbol, we're done, return it.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
return possible_match

possible_match = self.derivative_units.get(units.lower())
# Unit symbols must match including case, a match of a unit symbol now is something like M becoming m.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
possible_match = None

return possible_match


class UnitEntry(HedSchemaEntry):
""" A single unit entry with modifiers in the HedSchema. """
Expand All @@ -206,15 +227,16 @@ def finalize_entry(self, schema):
self.unit_modifiers = schema._get_modifiers_for_unit(self.name)

derivative_units = {}
base_plural_units = {self.name}
if not self.has_attribute(HedKey.UnitSymbol):
base_plural_units.add(pluralize.plural(self.name))
if self.has_attribute(HedKey.UnitSymbol):
base_plural_units = {self.name}
else:
base_plural_units = {self.name.lower()}
base_plural_units.add(pluralize.plural(self.name.lower()))

for derived_unit in base_plural_units:
derivative_units[derived_unit] = self._get_conversion_factor(None)
for modifier in self.unit_modifiers:
derivative_units[modifier.name + derived_unit] = self._get_conversion_factor(modifier_entry=modifier)

self.derivative_units = derivative_units

def _get_conversion_factor(self, modifier_entry):
Expand Down
8 changes: 8 additions & 0 deletions hed/schema/hed_schema_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def _finalize_section(self, hed_schema):
entry.finalize_entry(hed_schema)


class HedSchemaUnitSection(HedSchemaSection):
def _check_if_duplicate(self, name_key, new_entry):
"""We need to mark duplicate units(units with unitSymbol are case sensitive, while others are not."""
if not new_entry.has_attribute(HedKey.UnitSymbol):
name_key = name_key.lower()
return super()._check_if_duplicate(name_key, new_entry)


class HedSchemaUnitClassSection(HedSchemaSection):
def _check_if_duplicate(self, name_key, new_entry):
"""Allow adding units to existing unit classes, using a placeholder one with no attributes."""
Expand Down
2 changes: 1 addition & 1 deletion hed/schema/schema_attribute_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def value_class_exists(hed_schema, tag_entry, attribute_name):
def unit_exists(hed_schema, tag_entry, attribute_name):
issues = []
unit = tag_entry.attributes.get(attribute_name, "")
unit_entry = tag_entry.derivative_units.get(unit)
unit_entry = tag_entry.get_derivative_unit_entry(unit)
if unit and not unit_entry:
issues += ErrorHandler.format_error(SchemaAttributeErrors.SCHEMA_DEFAULT_UNITS_INVALID,
tag_entry.name,
Expand Down
45 changes: 45 additions & 0 deletions tests/validator/test_tag_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from hed.errors.error_types import ValidationErrors, DefinitionErrors
from tests.validator.test_tag_validator_base import TestValidatorBase
from hed import load_schema_version
from functools import partial


Expand All @@ -11,6 +12,7 @@ class TestHed(TestValidatorBase):


class IndividualHedTagsShort(TestHed):
hed_schema = load_schema_version("score_1.1.0")
@staticmethod
def string_obj_func(validator):
return partial(validator._validate_individual_tags_in_hed_string)
Expand Down Expand Up @@ -215,6 +217,20 @@ def test_correct_units(self):
# Update tests - 8.0 currently has no clockTime nodes.
# 'properTime': 'Item/2D shape/Clock face/08:30',
# 'invalidTime': 'Item/2D shape/Clock face/54:54'
'voltsTest1': 'Finding-amplitude/30 v',
'voltsTest2': 'Finding-amplitude/30 Volt',
'voltsTest3': 'Finding-amplitude/30 volts',
'voltsTest4': 'Finding-amplitude/30 VOLTS',
'voltsTest5': 'Finding-amplitude/30 kv',
'voltsTest6': 'Finding-amplitude/30 kiloVolt',
'voltsTest7': 'Finding-amplitude/30 KiloVolt',
'volumeTest1': "Sound-volume/5 dB",
'volumeTest2': "Sound-volume/5 kdB", # Invalid, not SI unit
'volumeTest3': "Sound-volume/5 candela",
'volumeTest4': "Sound-volume/5 kilocandela",
'volumeTest5': "Sound-volume/5 cd",
'volumeTest6': "Sound-volume/5 kcd",
'volumeTest7': "Sound-volume/5 DB", # Invalid, case doesn't match
}
expected_results = {
'correctUnit': True,
Expand All @@ -236,12 +252,27 @@ def test_correct_units(self):
# 'invalidTime': True,
# 'specialAllowedCharCurrency': True,
# 'specialNotAllowedCharCurrency': False,
'voltsTest1': True,
'voltsTest2': True,
'voltsTest3': True,
'voltsTest4': True,
'voltsTest5': True,
'voltsTest6': True,
'voltsTest7': True,
'volumeTest1': True,
'volumeTest2': False,
'volumeTest3': True,
'volumeTest4': True,
'volumeTest5': True,
'volumeTest6': True,
'volumeTest7': False,
}
legal_time_units = ['s', 'second', 'day', 'minute', 'hour']
# legal_clock_time_units = ['hour:min', 'hour:min:sec']
# legal_datetime_units = ['YYYY-MM-DDThh:mm:ss']
legal_freq_units = ['Hz', 'hertz']
# legal_currency_units = ['dollar', "$", "point"]
legal_intensity_units = ["candela", "cd", "dB"]

expected_issues = {
'correctUnit': [],
Expand Down Expand Up @@ -273,6 +304,20 @@ def test_correct_units(self):
# 'specialNotAllowedCharCurrency': self.format_error(ValidationErrors.UNITS_INVALID,
# tag=0,
# units=legal_currency_units),
'voltsTest1': [],
'voltsTest2': [],
'voltsTest3': [],
'voltsTest4': [],
'voltsTest5': [],
'voltsTest6': [],
'voltsTest7': [],
'volumeTest1': [],
'volumeTest2': self.format_error(ValidationErrors.UNITS_INVALID,tag=0, units=legal_intensity_units),
'volumeTest3': [],
'volumeTest4': [],
'volumeTest5': [],
'volumeTest6': [],
'volumeTest7': self.format_error(ValidationErrors.UNITS_INVALID, tag=0, units=legal_intensity_units),
}
self.validator_semantic(test_strings, expected_results, expected_issues, True)

Expand Down
9 changes: 5 additions & 4 deletions tests/validator/test_tag_validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#todo: update these tests(TagValidator no longer exists)
class TestHedBase(unittest.TestCase):
schema_file = None
hed_schema = None

@classmethod
def setUpClass(cls):
if cls.schema_file:
if cls.schema_file and not cls.hed_schema:
hed_xml = os.path.join(os.path.dirname(os.path.realpath(__file__)), cls.schema_file)
cls.hed_schema = schema.load_schema(hed_xml)
elif not cls.hed_schema:
Expand Down Expand Up @@ -87,9 +88,9 @@ def validator_base(self, test_strings, expected_results, expected_issues, test_f
error_handler.add_context_and_filter(test_issues)
test_result = not test_issues

# print(test_key)
# print(str(expected_issue))
# print(str(test_issues))
print(test_key)
print(str(expected_issue))
print(str(test_issues))
error_handler.pop_error_context()
self.assertEqual(test_result, expected_result, test_strings[test_key])
self.assertCountEqual(test_issues, expected_issue, test_strings[test_key])
Expand Down

0 comments on commit 0568209

Please sign in to comment.