Skip to content

Commit

Permalink
Move schema scripts over to hed-python
Browse files Browse the repository at this point in the history
  • Loading branch information
IanCa committed May 20, 2024
1 parent 7fb7317 commit bf952d7
Show file tree
Hide file tree
Showing 9 changed files with 465 additions and 1 deletion.
Empty file added hed/scripts/__init__.py
Empty file.
78 changes: 78 additions & 0 deletions hed/scripts/convert_and_update_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from hed.schema import load_schema_version
from hed.scripts.script_util import sort_base_schemas, validate_all_schemas, add_extension
from hed.schema.schema_io.df2schema import load_dataframes
from hed.schema.schema_io.ontology_util import update_dataframes_from_schema, save_dataframes
from hed.schema.hed_schema_io import load_schema, from_dataframes
import argparse


def convert_and_update(filenames, set_ids):
""" Validate, convert, and update as needed all schemas listed in filenames
If any schema fails to validate, no schemas will be updated.
Parameters:
filenames(list of str): A list of filenames that have been updated
set_ids(bool): If True, assign missing hedIds
"""
# Find and group the changed files
schema_files = sort_base_schemas(filenames)
all_issues = validate_all_schemas(schema_files)

if all_issues or not schema_files:
print("Did not attempt to update schemas due to validation failures")
return 1

updated = []
# If we are here, we have validated the schemas(and if there's more than one version changed, that they're the same)
for basename, extensions in schema_files.items():
# Skip any with multiple extensions or not in pre-release
if "prerelease" not in basename:
print(f"Skipping updates on {basename}, not in a prerelease folder.")
continue
source_filename = add_extension(basename,
list(extensions)[0]) # Load any changed schema version, they're all the same
source_df_filename = add_extension(basename, ".tsv")
schema = load_schema(source_filename)
print(f"Trying to convert/update file {source_filename}")
source_dataframes = load_dataframes(source_df_filename)
# todo: We need a more robust system for if some files are missing
# (especially for library schemas which will probably lack some)
if any(value is None for value in source_dataframes.values()):
source_dataframes = schema.get_as_dataframes()

result = update_dataframes_from_schema(source_dataframes, schema, assign_missing_ids=set_ids)

schema_reloaded = from_dataframes(result)
schema_reloaded.save_as_mediawiki(basename + ".mediawiki")
schema_reloaded.save_as_xml(basename + ".xml")

save_dataframes(source_df_filename, result)
updated.append(basename)

for basename in updated:
print(f"Schema {basename} updated.")

if not updated:
print("Did not update any schemas")
return 0


def main():
parser = argparse.ArgumentParser(description='Update other schema formats based on the changed one.')
parser.add_argument('filenames', nargs='*', help='List of files to process')
parser.add_argument('--set-ids', action='store_true', help='Set IDs for each file')

args = parser.parse_args()

filenames = args.filenames
set_ids = args.set_ids

# Trigger a local cache hit (this ensures trying to load withStandard schemas will work properly)
_ = load_schema_version("8.2.0")

return convert_and_update(filenames, set_ids)


if __name__ == "__main__":
exit(main())
151 changes: 151 additions & 0 deletions hed/scripts/script_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os.path
from collections import defaultdict
from hed.schema import from_string, load_schema
from hed.errors import get_printable_issue_string, HedFileError, SchemaWarnings

all_extensions = [".tsv", ".mediawiki", ".xml"]


def validate_schema(file_path):
""" Validates the given schema, ensuring it can save/load as well as validates.
This is probably overkill...
"""
validation_issues = []
try:
base_schema = load_schema(file_path)
issues = base_schema.check_compliance()
issues = [issue for issue in issues if issue["code"] != SchemaWarnings.SCHEMA_PRERELEASE_VERSION_USED]
if issues:
error_message = get_printable_issue_string(issues, title=file_path)
validation_issues.append(error_message)

mediawiki_string = base_schema.get_as_mediawiki_string()
reloaded_schema = from_string(mediawiki_string, schema_format=".mediawiki")

if reloaded_schema != base_schema:
error_text = f"Failed to reload {file_path} as mediawiki. " \
f"There is either a problem with the source file, or the saving/loading code."
validation_issues.append(error_text)

xml_string = base_schema.get_as_xml_string()
reloaded_schema = from_string(xml_string, schema_format=".xml")

if reloaded_schema != base_schema:
error_text = f"Failed to reload {file_path} as xml. " \
f"There is either a problem with the source file, or the saving/loading code."
validation_issues.append(error_text)
except HedFileError as e:
print(f"Saving/loading error: {e.message}")
error_text = e.message
if e.issues:
error_text = get_printable_issue_string(e.issues, title=file_path)
validation_issues.append(error_text)

return validation_issues


def add_extension(basename, extension):
"""Generate the final name for a given extension. Only .tsv varies notably."""
if extension == ".tsv":
parent_path, basename = os.path.split(basename)
return os.path.join(parent_path, "hedtsv", basename)
return basename + extension


def sort_base_schemas(filenames):
""" Sort and group the changed files based on basename
Example input: ["test_schema.mediawiki", "hedtsv/test_schema/test_schema_Tag.tsv", "other_schema.xml"]
Example output:
{
"test_schema": {".mediawiki", ".tsv"},
other_schema": {".xml"}
}
Parameters:
filenames(list or container): The changed filenames
Returns:
sorted_files(dict): A dictionary where keys are the basename, and the values are a set of extensions modified
Can include tsv, mediawiki, and xml.
"""
schema_files = defaultdict(set)
for file_path in filenames:
basename, extension = os.path.splitext(file_path.lower())
if extension == ".xml" or extension == ".mediawiki":
schema_files[basename].add(extension)
continue
elif extension == ".tsv":
tsv_basename = basename.rpartition("_")[0]
full_parent_path, real_basename = os.path.split(tsv_basename)
full_parent_path, real_basename2 = os.path.split(full_parent_path)
real_parent_path, hedtsv_folder = os.path.split(full_parent_path)
if hedtsv_folder != "hedtsv":
print(f"Ignoring file {file_path}. .tsv files must be in an 'hedtsv' subfolder.")
continue
if real_basename != real_basename2:
print(f"Ignoring file {file_path}. .tsv files must be in a subfolder with the same name.")
continue
real_name = os.path.join(real_parent_path, real_basename)
schema_files[real_name].add(extension)
else:
print(f"Ignoring file {file_path}")

return schema_files


def validate_all_schema_formats(basename):
""" Validate all 3 versions of the given schema.
Parameters:
basename(str): a schema to check all 3 formats are identical of.
Returns:
issue_list(list): A non-empty list if there are any issues.
"""
# Note if more than one is changed, it intentionally checks all 3 even if one wasn't changed.
paths = [add_extension(basename, extension) for extension in all_extensions]
try:
schemas = [load_schema(path) for path in paths]
all_equal = all(obj == schemas[0] for obj in schemas[1:])
if not all_equal:
return [
f"Multiple schemas of type {basename} were modified, and are not equal.\n"
f"Only modify one source schema type at a time(mediawiki, xml, tsv), or modify all 3 at once."]
except HedFileError as e:
error_message = f"Error loading schema: {e.message}"
return [error_message]

return []


def validate_all_schemas(schema_files):
"""Validates all the schema files/formats in the schema dict
If multiple formats were edited, ensures all 3 formats exist and match.
Parameters:
schema_files(dict of sets): basename:[extensions] dictionary for all files changed
Returns:
issues(list of str): Any issues found validating or loading schemas.
"""
all_issues = []
for basename, extensions in schema_files.items():
single_schema_issues = []
for extension in extensions:
full_path = add_extension(basename, extension)
single_schema_issues += validate_schema(full_path)

if len(extensions) > 1 and not single_schema_issues and "prerelease" in basename:
single_schema_issues += validate_all_schema_formats(basename)

print(f"Validating {basename}...")
if single_schema_issues:
for issue in single_schema_issues:
print(issue)

all_issues += single_schema_issues
return all_issues
22 changes: 22 additions & 0 deletions hed/scripts/validate_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
from hed.schema import load_schema_version
from hed.scripts.script_util import validate_all_schemas, sort_base_schemas


def main(arg_list=None):
# Trigger a local cache hit
_ = load_schema_version("8.2.0")

if not arg_list:
arg_list = sys.argv[1:]

schema_files = sort_base_schemas(arg_list)
issues = validate_all_schemas(schema_files)

if issues:
return 1
return 0


if __name__ == "__main__":
exit(main())
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dependencies = [
run_remodel = "hed.tools.remodeling.cli.run_remodel:main"
run_remodel_backup = "hed.tools.remodeling.cli.run_remodel_backup:main"
run_remodel_restore = "hed.tools.remodeling.cli.run_remodel_restore:main"
hed_validate_schemas = "hed.scripts.validate_schemas:main"
hed_update_schemas = "hed.scripts.convert_and_update_schema:main"

[tool.versioneer]
VCS = "git"
Expand Down
2 changes: 1 addition & 1 deletion tests/schema/test_schema_attribute_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_deprecatedFrom(self):
self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, tag_entry, attribute_name))
del tag_entry.attributes["deprecatedFrom"]

unit_class_entry = self.hed_schema.unit_classes["temperatureUnits"]
unit_class_entry = copy.deepcopy(self.hed_schema.unit_classes["temperatureUnits"])
# This should raise an issue because it assumes the attribute is set
self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, unit_class_entry, attribute_name))
unit_class_entry.attributes["deprecatedFrom"] = "8.1.0"
Expand Down
Empty file added tests/scripts/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions tests/scripts/test_convert_and_update_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import unittest
import os
import shutil
import copy
from hed import load_schema, load_schema_version
from hed.schema import HedSectionKey, HedKey
from hed.scripts.script_util import add_extension
from hed.scripts.convert_and_update_schema import convert_and_update


class TestConvertAndUpdate(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create a temporary directory for schema files
cls.base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'schemas_update', 'prerelease')
if not os.path.exists(cls.base_path):
os.makedirs(cls.base_path)

def test_schema_conversion_and_update(self):
# Load a known schema, modify it if necessary, and save it
schema = load_schema_version("8.3.0")
original_name = os.path.join(self.base_path, "test_schema.mediawiki")
schema.save_as_mediawiki(original_name)

# Assume filenames updated includes just the original schema file for simplicity
filenames = [original_name]
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)

tsv_filename = add_extension(os.path.join(self.base_path, "test_schema"), ".tsv")
schema_reload1 = load_schema(tsv_filename)
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml"))

self.assertEqual(schema, schema_reload1)
self.assertEqual(schema, schema_reload2)

# Now verify after doing this again with a new schema, they're still the same.
schema = load_schema_version("8.2.0")
schema.save_as_dataframes(tsv_filename)

filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")]
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)

schema_reload1 = load_schema(os.path.join(self.base_path, "test_schema.mediawiki"))
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml"))

self.assertEqual(schema, schema_reload1)
self.assertEqual(schema, schema_reload2)

def test_schema_adding_tag(self):
schema = load_schema_version("8.3.0")
basename = os.path.join(self.base_path, "test_schema_edited")
schema.save_as_mediawiki(add_extension(basename, ".mediawiki"))
schema.save_as_xml(add_extension(basename, ".xml"))
schema.save_as_dataframes(add_extension(basename, ".tsv"))

schema_edited = copy.deepcopy(schema)
test_tag_name = "NewTagWithoutID"
new_entry = schema_edited._create_tag_entry(test_tag_name, HedSectionKey.Tags)
schema_edited._add_tag_to_dict(test_tag_name, new_entry, HedSectionKey.Tags)

schema_edited.save_as_mediawiki(add_extension(basename, ".mediawiki"))

# Assume filenames updated includes just the original schema file for simplicity
filenames = [add_extension(basename, ".mediawiki")]
result = convert_and_update(filenames, set_ids=False)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))

self.assertEqual(schema_reloaded, schema_edited)

result = convert_and_update(filenames, set_ids=True)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))

reloaded_entry = schema_reloaded.tags[test_tag_name]
self.assertTrue(reloaded_entry.has_attribute(HedKey.HedID))


@classmethod
def tearDownClass(cls):
# Clean up the directory created for testing
shutil.rmtree(cls.base_path)
Loading

0 comments on commit bf952d7

Please sign in to comment.