Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mixed case in non unitsymbol units #884

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions hed/models/hed_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,33 +580,17 @@ def _get_tag_units_portion(extension_text, tag_unit_classes):
return None, None, None

for unit_class_entry in tag_unit_classes.values():
all_valid_unit_permutations = unit_class_entry.derivative_units

possible_match = HedTag._find_modifier_unit_entry(units, all_valid_unit_permutations)
possible_match = unit_class_entry.get_derivative_unit_entry(units)
if possible_match and not possible_match.has_attribute(HedKey.UnitPrefix):
return value, units, possible_match

# Repeat the above, but as a prefix
possible_match = HedTag._find_modifier_unit_entry(value, all_valid_unit_permutations)
possible_match = unit_class_entry.get_derivative_unit_entry(value)
if possible_match and possible_match.has_attribute(HedKey.UnitPrefix):
return units, value, possible_match

return None, None, None

@staticmethod
def _find_modifier_unit_entry(units, all_valid_unit_permutations):
possible_match = all_valid_unit_permutations.get(units)
# If we have a match that's a unit symbol, we're done, return it.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
return possible_match

possible_match = all_valid_unit_permutations.get(units.lower())
# Unit symbols must match including case, a match of a unit symbol now is something like M becoming m.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
possible_match = None

return possible_match

def is_placeholder(self):
if "#" in self.org_tag or "#" in self._extension_value:
return True
Expand Down
11 changes: 8 additions & 3 deletions hed/schema/hed_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from hed.schema.schema_io.schema2wiki import Schema2Wiki
from hed.schema.schema_io.schema2owl import Schema2Owl
from hed.schema.schema_io.owl_constants import ext_to_format
from hed.schema.hed_schema_section import HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection
from hed.schema.hed_schema_section import (HedSchemaSection, HedSchemaTagSection, HedSchemaUnitClassSection,
HedSchemaUnitSection)
from hed.errors import ErrorHandler
from hed.errors.error_types import ValidationErrors
from hed.schema.hed_schema_base import HedSchemaBase
Expand Down Expand Up @@ -747,7 +748,7 @@ def _create_empty_sections():
dictionaries[HedSectionKey.Properties] = HedSchemaSection(HedSectionKey.Properties)
dictionaries[HedSectionKey.Attributes] = HedSchemaSection(HedSectionKey.Attributes)
dictionaries[HedSectionKey.UnitModifiers] = HedSchemaSection(HedSectionKey.UnitModifiers)
dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units)
dictionaries[HedSectionKey.Units] = HedSchemaUnitSection(HedSectionKey.Units)
dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses)
dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses)
dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False)
Expand All @@ -767,9 +768,13 @@ def _get_modifiers_for_unit(self, unit):
This is a lower level one that doesn't rely on the Unit entries being fully setup.

"""
# todo: could refactor this so this unit.lower() part is in HedSchemaUnitSection.get
unit_entry = self.get_tag_entry(unit, HedSectionKey.Units)
if unit_entry is None:
return []
unit_entry = self.get_tag_entry(unit.lower(), HedSectionKey.Units)
# Unit symbols must match exactly
if unit_entry is None or unit_entry.has_attribute(HedKey.UnitSymbol):
return []
is_si_unit = unit_entry.has_attribute(HedKey.SIUnit)
is_unit_symbol = unit_entry.has_attribute(HedKey.UnitSymbol)
if not is_si_unit:
Expand Down
30 changes: 26 additions & 4 deletions hed/schema/hed_schema_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,27 @@ def __eq__(self, other):
return False
return True

def get_derivative_unit_entry(self, units):
""" Gets the (derivative) unit entry if it exists

Parameters:
units (str): The unit name to check, can be plural or include a modifier.

Returns:
unit_entry(UnitEntry or None): The unit entry if it exists
"""
possible_match = self.derivative_units.get(units)
# If we have a match that's a unit symbol, we're done, return it.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
return possible_match

possible_match = self.derivative_units.get(units.lower())
# Unit symbols must match including case, a match of a unit symbol now is something like M becoming m.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
possible_match = None

return possible_match


class UnitEntry(HedSchemaEntry):
""" A single unit entry with modifiers in the HedSchema. """
Expand All @@ -206,15 +227,16 @@ def finalize_entry(self, schema):
self.unit_modifiers = schema._get_modifiers_for_unit(self.name)

derivative_units = {}
base_plural_units = {self.name}
if not self.has_attribute(HedKey.UnitSymbol):
base_plural_units.add(pluralize.plural(self.name))
if self.has_attribute(HedKey.UnitSymbol):
base_plural_units = {self.name}
else:
base_plural_units = {self.name.lower()}
base_plural_units.add(pluralize.plural(self.name.lower()))

for derived_unit in base_plural_units:
derivative_units[derived_unit] = self._get_conversion_factor(None)
for modifier in self.unit_modifiers:
derivative_units[modifier.name + derived_unit] = self._get_conversion_factor(modifier_entry=modifier)

self.derivative_units = derivative_units

def _get_conversion_factor(self, modifier_entry):
Expand Down
8 changes: 8 additions & 0 deletions hed/schema/hed_schema_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def _finalize_section(self, hed_schema):
entry.finalize_entry(hed_schema)


class HedSchemaUnitSection(HedSchemaSection):
def _check_if_duplicate(self, name_key, new_entry):
"""We need to mark duplicate units(units with unitSymbol are case sensitive, while others are not."""
if not new_entry.has_attribute(HedKey.UnitSymbol):
name_key = name_key.lower()
return super()._check_if_duplicate(name_key, new_entry)


class HedSchemaUnitClassSection(HedSchemaSection):
def _check_if_duplicate(self, name_key, new_entry):
"""Allow adding units to existing unit classes, using a placeholder one with no attributes."""
Expand Down
2 changes: 1 addition & 1 deletion hed/schema/schema_attribute_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def value_class_exists(hed_schema, tag_entry, attribute_name):
def unit_exists(hed_schema, tag_entry, attribute_name):
issues = []
unit = tag_entry.attributes.get(attribute_name, "")
unit_entry = tag_entry.derivative_units.get(unit)
unit_entry = tag_entry.get_derivative_unit_entry(unit)
if unit and not unit_entry:
issues += ErrorHandler.format_error(SchemaAttributeErrors.SCHEMA_DEFAULT_UNITS_INVALID,
tag_entry.name,
Expand Down
45 changes: 45 additions & 0 deletions tests/validator/test_tag_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from hed.errors.error_types import ValidationErrors, DefinitionErrors
from tests.validator.test_tag_validator_base import TestValidatorBase
from hed import load_schema_version
from functools import partial


Expand All @@ -11,6 +12,7 @@ class TestHed(TestValidatorBase):


class IndividualHedTagsShort(TestHed):
hed_schema = load_schema_version("score_1.1.0")
@staticmethod
def string_obj_func(validator):
return partial(validator._validate_individual_tags_in_hed_string)
Expand Down Expand Up @@ -215,6 +217,20 @@ def test_correct_units(self):
# Update tests - 8.0 currently has no clockTime nodes.
# 'properTime': 'Item/2D shape/Clock face/08:30',
# 'invalidTime': 'Item/2D shape/Clock face/54:54'
'voltsTest1': 'Finding-amplitude/30 v',
'voltsTest2': 'Finding-amplitude/30 Volt',
'voltsTest3': 'Finding-amplitude/30 volts',
'voltsTest4': 'Finding-amplitude/30 VOLTS',
'voltsTest5': 'Finding-amplitude/30 kv',
'voltsTest6': 'Finding-amplitude/30 kiloVolt',
'voltsTest7': 'Finding-amplitude/30 KiloVolt',
'volumeTest1': "Sound-volume/5 dB",
'volumeTest2': "Sound-volume/5 kdB", # Invalid, not SI unit
'volumeTest3': "Sound-volume/5 candela",
'volumeTest4': "Sound-volume/5 kilocandela",
'volumeTest5': "Sound-volume/5 cd",
'volumeTest6': "Sound-volume/5 kcd",
'volumeTest7': "Sound-volume/5 DB", # Invalid, case doesn't match
}
expected_results = {
'correctUnit': True,
Expand All @@ -236,12 +252,27 @@ def test_correct_units(self):
# 'invalidTime': True,
# 'specialAllowedCharCurrency': True,
# 'specialNotAllowedCharCurrency': False,
'voltsTest1': True,
'voltsTest2': True,
'voltsTest3': True,
'voltsTest4': True,
'voltsTest5': True,
'voltsTest6': True,
'voltsTest7': True,
'volumeTest1': True,
'volumeTest2': False,
'volumeTest3': True,
'volumeTest4': True,
'volumeTest5': True,
'volumeTest6': True,
'volumeTest7': False,
}
legal_time_units = ['s', 'second', 'day', 'minute', 'hour']
# legal_clock_time_units = ['hour:min', 'hour:min:sec']
# legal_datetime_units = ['YYYY-MM-DDThh:mm:ss']
legal_freq_units = ['Hz', 'hertz']
# legal_currency_units = ['dollar', "$", "point"]
legal_intensity_units = ["candela", "cd", "dB"]

expected_issues = {
'correctUnit': [],
Expand Down Expand Up @@ -273,6 +304,20 @@ def test_correct_units(self):
# 'specialNotAllowedCharCurrency': self.format_error(ValidationErrors.UNITS_INVALID,
# tag=0,
# units=legal_currency_units),
'voltsTest1': [],
'voltsTest2': [],
'voltsTest3': [],
'voltsTest4': [],
'voltsTest5': [],
'voltsTest6': [],
'voltsTest7': [],
'volumeTest1': [],
'volumeTest2': self.format_error(ValidationErrors.UNITS_INVALID,tag=0, units=legal_intensity_units),
'volumeTest3': [],
'volumeTest4': [],
'volumeTest5': [],
'volumeTest6': [],
'volumeTest7': self.format_error(ValidationErrors.UNITS_INVALID, tag=0, units=legal_intensity_units),
}
self.validator_semantic(test_strings, expected_results, expected_issues, True)

Expand Down
9 changes: 5 additions & 4 deletions tests/validator/test_tag_validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#todo: update these tests(TagValidator no longer exists)
class TestHedBase(unittest.TestCase):
schema_file = None
hed_schema = None

@classmethod
def setUpClass(cls):
if cls.schema_file:
if cls.schema_file and not cls.hed_schema:
hed_xml = os.path.join(os.path.dirname(os.path.realpath(__file__)), cls.schema_file)
cls.hed_schema = schema.load_schema(hed_xml)
elif not cls.hed_schema:
Expand Down Expand Up @@ -87,9 +88,9 @@ def validator_base(self, test_strings, expected_results, expected_issues, test_f
error_handler.add_context_and_filter(test_issues)
test_result = not test_issues

# print(test_key)
# print(str(expected_issue))
# print(str(test_issues))
print(test_key)
print(str(expected_issue))
print(str(test_issues))
error_handler.pop_error_context()
self.assertEqual(test_result, expected_result, test_strings[test_key])
self.assertCountEqual(test_issues, expected_issue, test_strings[test_key])
Expand Down