Skip to content

Commit

Permalink
Switch to .casefold in most places
Browse files Browse the repository at this point in the history
Rewrite extract_tags
  • Loading branch information
IanCa committed Mar 30, 2024
1 parent 8604bef commit f198b6b
Show file tree
Hide file tree
Showing 30 changed files with 174 additions and 253 deletions.
18 changes: 9 additions & 9 deletions hed/models/def_expand_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,20 @@ def _handle_known_definition(self, def_tag, def_expand_group, def_group):

if def_group_contents:
if def_group_contents != def_expand_group:
self.errors.setdefault(def_tag_name.lower(), []).append(def_expand_group.get_first_group())
self.errors.setdefault(def_tag_name.casefold(), []).append(def_expand_group.get_first_group())
return True

has_extension = "/" in def_tag.extension
if not has_extension:
group_tag = def_expand_group.get_first_group()
self.def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=group_tag,
self.def_dict.defs[def_tag_name.casefold()] = DefinitionEntry(name=def_tag_name, contents=group_tag,
takes_value=False,
source_context=[])
return True

# this is needed for the cases where we have a definition with errors, but it's not a known definition.
if def_tag_name.lower() in self.errors:
self.errors.setdefault(f"{def_tag_name.lower()}", []).append(def_expand_group.get_first_group())
if def_tag_name.casefold() in self.errors:
self.errors.setdefault(f"{def_tag_name.casefold()}", []).append(def_expand_group.get_first_group())
return True

return False
Expand All @@ -181,20 +181,20 @@ def _handle_ambiguous_definition(self, def_tag, def_expand_group):
def_expand_group (HedGroup): The group containing the def-expand tag.
"""
def_tag_name = def_tag.extension.split('/')[0]
these_defs = self.ambiguous_defs.setdefault(def_tag_name.lower(), AmbiguousDef())
these_defs = self.ambiguous_defs.setdefault(def_tag_name.casefold(), AmbiguousDef())
these_defs.add_def(def_tag, def_expand_group)

try:
if these_defs.validate():
new_contents = these_defs.get_group()
self.def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=new_contents,
self.def_dict.defs[def_tag_name.casefold()] = DefinitionEntry(name=def_tag_name, contents=new_contents,
takes_value=True,
source_context=[])
del self.ambiguous_defs[def_tag_name.lower()]
del self.ambiguous_defs[def_tag_name.casefold()]
except ValueError:
for ambiguous_def in these_defs.placeholder_defs:
self.errors.setdefault(def_tag_name.lower(), []).append(ambiguous_def)
del self.ambiguous_defs[def_tag_name.lower()]
self.errors.setdefault(def_tag_name.casefold(), []).append(ambiguous_def)
del self.ambiguous_defs[def_tag_name.casefold()]

return

Expand Down
12 changes: 6 additions & 6 deletions hed/models/definition_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get(self, def_name):
Returns:
DefinitionEntry: Definition entry for the requested definition.
"""
return self.defs.get(def_name.lower())
return self.defs.get(def_name.casefold())

def __iter__(self):
return iter(self.defs)
Expand Down Expand Up @@ -144,14 +144,14 @@ def check_for_definitions(self, hed_string_obj, error_handler=None):
def_issues += new_def_issues
continue

self.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=group_tag,
self.defs[def_tag_name.casefold()] = DefinitionEntry(name=def_tag_name, contents=group_tag,
takes_value=def_takes_value,
source_context=context)

return def_issues

def _strip_value_placeholder(self, def_tag_name):
def_takes_value = def_tag_name.lower().endswith("/#")
def_takes_value = def_tag_name.endswith("/#")
if def_takes_value:
def_tag_name = def_tag_name[:-len("/#")]
return def_tag_name, def_takes_value
Expand All @@ -162,7 +162,7 @@ def _validate_name_and_context(self, def_tag_name, error_handler):
else:
context = []
new_def_issues = []
if def_tag_name.lower() in self.defs:
if def_tag_name.casefold() in self.defs:
new_def_issues += ErrorHandler.format_error_with_context(error_handler,
DefinitionErrors.DUPLICATE_DEFINITION,
def_name=def_tag_name)
Expand Down Expand Up @@ -263,7 +263,7 @@ def get_definition_entry(self, def_tag):
"""
tag_label, _, placeholder = def_tag.extension.partition('/')

label_tag_lower = tag_label.lower()
label_tag_lower = tag_label.casefold()
def_entry = self.defs.get(label_tag_lower)
return def_entry

Expand All @@ -281,7 +281,7 @@ def _get_definition_contents(self, def_tag):
"""
tag_label, _, placeholder = def_tag.extension.partition('/')

label_tag_lower = tag_label.lower()
label_tag_lower = tag_label.casefold()
def_entry = self.defs.get(label_tag_lower)
if def_entry is None:
# Could raise an error here?
Expand Down
14 changes: 6 additions & 8 deletions hed/models/df_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,20 @@ def sort_dataframe_by_onsets(df):
return df


def replace_ref(text, newvalue, column_ref):
def replace_ref(text, oldvalue, newvalue="n/a"):
""" Replace column ref in x with y. If it's n/a, delete extra commas/parentheses.
Parameters:
text (str): The input string containing the ref enclosed in curly braces.
oldvalue (str): The full tag or ref to replace
newvalue (str): The replacement value for the ref.
column_ref (str): The ref to be replaced, without curly braces.
Returns:
str: The modified string with the ref replaced or removed.
"""
# Note: This function could easily be updated to handle non-curly brace values, but it seemed faster this way

# If it's not n/a, we can just replace directly.
if newvalue != "n/a":
return text.replace(f"{{{column_ref}}}", newvalue)
return text.replace(oldvalue, newvalue)

def _remover(match):
p1 = match.group("p1").count("(")
Expand All @@ -162,7 +160,7 @@ def _remover(match):
# c1/c2 contain the comma(and possibly spaces) separating this ref from other tags
# p1/p2 contain the parentheses directly surrounding the tag
# All four groups can have spaces.
pattern = r'(?P<c1>[\s,]*)(?P<p1>[(\s]*)\{' + column_ref + r'\}(?P<p2>[\s)]*)(?P<c2>[\s,]*)'
pattern = r'(?P<c1>[\s,]*)(?P<p1>[(\s]*)' + oldvalue + r'(?P<p2>[\s)]*)(?P<c2>[\s,]*)'
return re.sub(pattern, _remover, text)


Expand Down Expand Up @@ -192,7 +190,7 @@ def _handle_curly_braces_refs(df, refs, column_names):
# column_name_brackets = f"{{{replacing_name}}}"
# df[column_name] = pd.Series(x.replace(column_name_brackets, y) for x, y
# in zip(df[column_name], saved_columns[replacing_name]))
new_df[column_name] = pd.Series(replace_ref(x, y, replacing_name) for x, y
new_df[column_name] = pd.Series(replace_ref(x, f"{{{replacing_name}}}", y) for x, y
in zip(new_df[column_name], saved_columns[replacing_name]))
new_df = new_df[remaining_columns]

Expand Down Expand Up @@ -220,7 +218,7 @@ def split_delay_tags(series, hed_schema, onsets):
return
split_df = pd.DataFrame({"onset": onsets, "HED": series, "original_index": series.index})
delay_strings = [(i, HedString(hed_string, hed_schema)) for (i, hed_string) in series.items() if
"delay/" in hed_string.lower()]
"delay/" in hed_string.casefold()]
delay_groups = []
for i, delay_string in delay_strings:
duration_tags = delay_string.find_top_level_tags({DefTagNames.DELAY_KEY})
Expand Down
16 changes: 10 additions & 6 deletions hed/models/hed_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ def lower(self):
""" Convenience function, equivalent to str(self).lower(). """
return str(self).lower()

def casefold(self):
""" Convenience function, equivalent to str(self).casefold(). """
return str(self).casefold()

def get_as_indented(self, tag_attribute="short_tag"):
"""Return the string as a multiline indented format.
Expand Down Expand Up @@ -442,9 +446,9 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
tags = self.get_all_tags()
else:
tags = self.tags()
search_tags = {tag.lower() for tag in search_tags}
search_tags = {tag.casefold() for tag in search_tags}
for tag in tags:
if tag.short_base_tag.lower() in search_tags:
if tag.short_base_tag.casefold() in search_tags:
found_tags.append((tag, tag._parent))

if include_groups == 0 or include_groups == 1:
Expand All @@ -454,7 +458,7 @@ def find_tags(self, search_tags, recursive=False, include_groups=2):
def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
""" Find the tags and their containing groups.
This searches tag.short_tag.lower(), with an implicit wildcard on the end.
This searches tag.short_tag.casefold(), with an implicit wildcard on the end.
e.g. "Eve" will find Event, but not Sensory-event.
Expand All @@ -475,11 +479,11 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2):
else:
tags = self.tags()

search_tags = {search_tag.lower() for search_tag in search_tags}
search_tags = {search_tag.casefold() for search_tag in search_tags}

for tag in tags:
for search_tag in search_tags:
if tag.short_tag.lower().startswith(search_tag):
if tag.short_tag.casefold().startswith(search_tag):
found_tags.append((tag, tag._parent))
# We can't find the same tag twice
break
Expand Down Expand Up @@ -575,7 +579,7 @@ def find_tags_with_term(self, term, recursive=False, include_groups=2):
else:
tags = self.tags()

search_for = term.lower()
search_for = term.casefold()
for tag in tags:
if search_for in tag.tag_terms:
found_tags.append((tag, tag._parent))
Expand Down
4 changes: 2 additions & 2 deletions hed/models/hed_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,11 @@ def find_top_level_tags(self, anchor_tags, include_groups=2):
Returns:
list: The returned result depends on include_groups.
"""
anchor_tags = {tag.lower() for tag in anchor_tags}
anchor_tags = {tag.casefold() for tag in anchor_tags}
top_level_tags = []
for group in self.groups():
for tag in group.tags():
if tag.short_base_tag.lower() in anchor_tags:
if tag.short_base_tag.casefold() in anchor_tags:
top_level_tags.append((tag, group))
# Only capture a max of 1 per group. These are implicitly unique.
break
Expand Down
12 changes: 8 additions & 4 deletions hed/models/hed_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def lower(self):
""" Convenience function, equivalent to str(self).lower(). """
return str(self).lower()

def casefold(self):
""" Convenience function, equivalent to str(self).casefold(). """
return str(self).casefold()

def _calculate_to_canonical_forms(self, hed_schema):
""" Update internal state based on schema.
Expand Down Expand Up @@ -617,24 +621,24 @@ def replace_placeholder(self, placeholder_value):
def __hash__(self):
if self._schema_entry:
return hash(
self._namespace + self._schema_entry.short_tag_name.lower() + self._extension_value.lower())
self._namespace + self._schema_entry.short_tag_name.casefold() + self._extension_value.casefold())
else:
return hash(self.lower())
return hash(self.casefold())

def __eq__(self, other):
if self is other:
return True

if isinstance(other, str):
return self.lower() == other.lower()
return self.casefold() == other.casefold()

if not isinstance(other, HedTag):
return False

if self.short_tag == other.short_tag:
return True

if self.org_tag.lower() == other.org_tag.lower():
if self.org_tag.casefold() == other.org_tag.casefold():
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion hed/models/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, expression_string):
"""
self.tokens = []
self.at_token = -1
self.tree = self._parse(expression_string.lower())
self.tree = self._parse(expression_string.casefold())
self._org_string = expression_string

def search(self, hed_string_obj):
Expand Down
4 changes: 2 additions & 2 deletions hed/models/string_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def split_base_tags(hed_string, base_tags, remove_group=False):
- The second HedString object contains the tags from hed_string that match the base_tags.
"""

base_tags = [tag.lower() for tag in base_tags]
base_tags = [tag.casefold() for tag in base_tags]
include_groups = 0
if remove_group:
include_groups = 2
Expand Down Expand Up @@ -70,7 +70,7 @@ def split_def_tags(hed_string, def_names, remove_group=False):
include_groups = 0
if remove_group:
include_groups = 2
wildcard_tags = [f"def/{def_name}".lower() for def_name in def_names]
wildcard_tags = [f"def/{def_name}".casefold() for def_name in def_names]
found_things = hed_string.find_wildcard_tags(wildcard_tags, recursive=True, include_groups=include_groups)
if remove_group:
found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things]
Expand Down
6 changes: 3 additions & 3 deletions hed/schema/hed_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _find_tag_entry(self, tag, schema_namespace=""):
clean_tag = str(tag)
namespace = schema_namespace
clean_tag = clean_tag[len(namespace):]
working_tag = clean_tag.lower()
working_tag = clean_tag.casefold()

# Most tags are in the schema directly, so test that first
found_entry = self._get_tag_entry(working_tag)
Expand Down Expand Up @@ -699,10 +699,10 @@ def _get_modifiers_for_unit(self, unit):
This is a lower level one that doesn't rely on the Unit entries being fully setup.
"""
# todo: could refactor this so this unit.lower() part is in HedSchemaUnitSection.get
# todo: could refactor this so this unit.casefold() part is in HedSchemaUnitSection.get
unit_entry = self.get_tag_entry(unit, HedSectionKey.Units)
if unit_entry is None:
unit_entry = self.get_tag_entry(unit.lower(), HedSectionKey.Units)
unit_entry = self.get_tag_entry(unit.casefold(), HedSectionKey.Units)
# Unit symbols must match exactly
if unit_entry is None or unit_entry.has_attribute(HedKey.UnitSymbol):
return []
Expand Down
4 changes: 2 additions & 2 deletions hed/schema/hed_schema_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_derivative_unit_entry(self, units):
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
return possible_match

possible_match = self.derivative_units.get(units.lower())
possible_match = self.derivative_units.get(units.casefold())
# Unit symbols must match including case, a match of a unit symbol now is something like M becoming m.
if possible_match and possible_match.has_attribute(HedKey.UnitSymbol):
possible_match = None
Expand Down Expand Up @@ -416,7 +416,7 @@ def finalize_entry(self, schema):
if self._parent_tag:
self._parent_tag.children[self.short_tag_name] = self
self.takes_value_child_entry = schema._get_tag_entry(self.name + "/#")
self.tag_terms = tuple(self.long_tag_name.lower().split("/"))
self.tag_terms = tuple(self.long_tag_name.casefold().split("/"))

self._finalize_inherited_attributes()
self._finalize_takes_value_tag(schema)
16 changes: 8 additions & 8 deletions hed/schema/hed_schema_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _add_to_dict(self, name, new_entry):
""" Add a name to the dictionary for this section. """
name_key = name
if not self.case_sensitive:
name_key = name.lower()
name_key = name.casefold()

return_entry = self._check_if_duplicate(name_key, new_entry)

Expand Down Expand Up @@ -115,7 +115,7 @@ def keys(self):

def __getitem__(self, key):
if not self.case_sensitive:
key = key.lower()
key = key.casefold()
return self.all_names[key]

def get(self, key):
Expand All @@ -126,7 +126,7 @@ def get(self, key):
"""
if not self.case_sensitive:
key = key.lower()
key = key.casefold()
return self.all_names.get(key)

def __eq__(self, other):
Expand All @@ -153,7 +153,7 @@ class HedSchemaUnitSection(HedSchemaSection):
def _check_if_duplicate(self, name_key, new_entry):
"""We need to mark duplicate units(units with unitSymbol are case sensitive, while others are not."""
if not new_entry.has_attribute(HedKey.UnitSymbol):
name_key = name_key.lower()
name_key = name_key.casefold()
return super()._check_if_duplicate(name_key, new_entry)


Expand Down Expand Up @@ -220,24 +220,24 @@ def _check_if_duplicate(self, name, new_entry):
else:
self.all_names[name] = new_entry
for tag_key in tag_forms:
name_key = tag_key.lower()
name_key = tag_key.casefold()
self.long_form_tags[name_key] = new_entry

return new_entry

def get(self, key):
if not self.case_sensitive:
key = key.lower()
key = key.casefold()
return self.long_form_tags.get(key)

def __getitem__(self, key):
if not self.case_sensitive:
key = key.lower()
key = key.casefold()
return self.long_form_tags[key]

def __contains__(self, key):
if not self.case_sensitive:
key = key.lower()
key = key.casefold()
return key in self.long_form_tags

@staticmethod
Expand Down
Loading

0 comments on commit f198b6b

Please sign in to comment.