From bf952d731747d0f1e1003ce00ae38e7f3467ba9e Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 20 May 2024 14:48:26 -0500 Subject: [PATCH 1/5] Move schema scripts over to hed-python --- hed/scripts/__init__.py | 0 hed/scripts/convert_and_update_schema.py | 78 +++++++++ hed/scripts/script_util.py | 151 ++++++++++++++++++ hed/scripts/validate_schemas.py | 22 +++ pyproject.toml | 2 + .../test_schema_attribute_validators.py | 2 +- tests/scripts/__init__.py | 0 .../scripts/test_convert_and_update_schema.py | 90 +++++++++++ tests/scripts/test_script_util.py | 121 ++++++++++++++ 9 files changed, 465 insertions(+), 1 deletion(-) create mode 100644 hed/scripts/__init__.py create mode 100644 hed/scripts/convert_and_update_schema.py create mode 100644 hed/scripts/script_util.py create mode 100644 hed/scripts/validate_schemas.py create mode 100644 tests/scripts/__init__.py create mode 100644 tests/scripts/test_convert_and_update_schema.py create mode 100644 tests/scripts/test_script_util.py diff --git a/hed/scripts/__init__.py b/hed/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hed/scripts/convert_and_update_schema.py b/hed/scripts/convert_and_update_schema.py new file mode 100644 index 000000000..38d723208 --- /dev/null +++ b/hed/scripts/convert_and_update_schema.py @@ -0,0 +1,78 @@ +from hed.schema import load_schema_version +from hed.scripts.script_util import sort_base_schemas, validate_all_schemas, add_extension +from hed.schema.schema_io.df2schema import load_dataframes +from hed.schema.schema_io.ontology_util import update_dataframes_from_schema, save_dataframes +from hed.schema.hed_schema_io import load_schema, from_dataframes +import argparse + + +def convert_and_update(filenames, set_ids): + """ Validate, convert, and update as needed all schemas listed in filenames + + If any schema fails to validate, no schemas will be updated. + + Parameters: + filenames(list of str): A list of filenames that have been updated + set_ids(bool): If True, assign missing hedIds + """ + # Find and group the changed files + schema_files = sort_base_schemas(filenames) + all_issues = validate_all_schemas(schema_files) + + if all_issues or not schema_files: + print("Did not attempt to update schemas due to validation failures") + return 1 + + updated = [] + # If we are here, we have validated the schemas(and if there's more than one version changed, that they're the same) + for basename, extensions in schema_files.items(): + # Skip any with multiple extensions or not in pre-release + if "prerelease" not in basename: + print(f"Skipping updates on {basename}, not in a prerelease folder.") + continue + source_filename = add_extension(basename, + list(extensions)[0]) # Load any changed schema version, they're all the same + source_df_filename = add_extension(basename, ".tsv") + schema = load_schema(source_filename) + print(f"Trying to convert/update file {source_filename}") + source_dataframes = load_dataframes(source_df_filename) + # todo: We need a more robust system for if some files are missing + # (especially for library schemas which will probably lack some) + if any(value is None for value in source_dataframes.values()): + source_dataframes = schema.get_as_dataframes() + + result = update_dataframes_from_schema(source_dataframes, schema, assign_missing_ids=set_ids) + + schema_reloaded = from_dataframes(result) + schema_reloaded.save_as_mediawiki(basename + ".mediawiki") + schema_reloaded.save_as_xml(basename + ".xml") + + save_dataframes(source_df_filename, result) + updated.append(basename) + + for basename in updated: + print(f"Schema {basename} updated.") + + if not updated: + print("Did not update any schemas") + return 0 + + +def main(): + parser = argparse.ArgumentParser(description='Update other schema formats based on the changed one.') + parser.add_argument('filenames', nargs='*', help='List of files to process') + parser.add_argument('--set-ids', action='store_true', help='Set IDs for each file') + + args = parser.parse_args() + + filenames = args.filenames + set_ids = args.set_ids + + # Trigger a local cache hit (this ensures trying to load withStandard schemas will work properly) + _ = load_schema_version("8.2.0") + + return convert_and_update(filenames, set_ids) + + +if __name__ == "__main__": + exit(main()) diff --git a/hed/scripts/script_util.py b/hed/scripts/script_util.py new file mode 100644 index 000000000..07605ce31 --- /dev/null +++ b/hed/scripts/script_util.py @@ -0,0 +1,151 @@ +import os.path +from collections import defaultdict +from hed.schema import from_string, load_schema +from hed.errors import get_printable_issue_string, HedFileError, SchemaWarnings + +all_extensions = [".tsv", ".mediawiki", ".xml"] + + +def validate_schema(file_path): + """ Validates the given schema, ensuring it can save/load as well as validates. + + This is probably overkill... + """ + validation_issues = [] + try: + base_schema = load_schema(file_path) + issues = base_schema.check_compliance() + issues = [issue for issue in issues if issue["code"] != SchemaWarnings.SCHEMA_PRERELEASE_VERSION_USED] + if issues: + error_message = get_printable_issue_string(issues, title=file_path) + validation_issues.append(error_message) + + mediawiki_string = base_schema.get_as_mediawiki_string() + reloaded_schema = from_string(mediawiki_string, schema_format=".mediawiki") + + if reloaded_schema != base_schema: + error_text = f"Failed to reload {file_path} as mediawiki. " \ + f"There is either a problem with the source file, or the saving/loading code." + validation_issues.append(error_text) + + xml_string = base_schema.get_as_xml_string() + reloaded_schema = from_string(xml_string, schema_format=".xml") + + if reloaded_schema != base_schema: + error_text = f"Failed to reload {file_path} as xml. " \ + f"There is either a problem with the source file, or the saving/loading code." + validation_issues.append(error_text) + except HedFileError as e: + print(f"Saving/loading error: {e.message}") + error_text = e.message + if e.issues: + error_text = get_printable_issue_string(e.issues, title=file_path) + validation_issues.append(error_text) + + return validation_issues + + +def add_extension(basename, extension): + """Generate the final name for a given extension. Only .tsv varies notably.""" + if extension == ".tsv": + parent_path, basename = os.path.split(basename) + return os.path.join(parent_path, "hedtsv", basename) + return basename + extension + + +def sort_base_schemas(filenames): + """ Sort and group the changed files based on basename + + Example input: ["test_schema.mediawiki", "hedtsv/test_schema/test_schema_Tag.tsv", "other_schema.xml"] + + Example output: + { + "test_schema": {".mediawiki", ".tsv"}, + other_schema": {".xml"} + } + + Parameters: + filenames(list or container): The changed filenames + + Returns: + sorted_files(dict): A dictionary where keys are the basename, and the values are a set of extensions modified + Can include tsv, mediawiki, and xml. + """ + schema_files = defaultdict(set) + for file_path in filenames: + basename, extension = os.path.splitext(file_path.lower()) + if extension == ".xml" or extension == ".mediawiki": + schema_files[basename].add(extension) + continue + elif extension == ".tsv": + tsv_basename = basename.rpartition("_")[0] + full_parent_path, real_basename = os.path.split(tsv_basename) + full_parent_path, real_basename2 = os.path.split(full_parent_path) + real_parent_path, hedtsv_folder = os.path.split(full_parent_path) + if hedtsv_folder != "hedtsv": + print(f"Ignoring file {file_path}. .tsv files must be in an 'hedtsv' subfolder.") + continue + if real_basename != real_basename2: + print(f"Ignoring file {file_path}. .tsv files must be in a subfolder with the same name.") + continue + real_name = os.path.join(real_parent_path, real_basename) + schema_files[real_name].add(extension) + else: + print(f"Ignoring file {file_path}") + + return schema_files + + +def validate_all_schema_formats(basename): + """ Validate all 3 versions of the given schema. + + Parameters: + basename(str): a schema to check all 3 formats are identical of. + + Returns: + issue_list(list): A non-empty list if there are any issues. + """ + # Note if more than one is changed, it intentionally checks all 3 even if one wasn't changed. + paths = [add_extension(basename, extension) for extension in all_extensions] + try: + schemas = [load_schema(path) for path in paths] + all_equal = all(obj == schemas[0] for obj in schemas[1:]) + if not all_equal: + return [ + f"Multiple schemas of type {basename} were modified, and are not equal.\n" + f"Only modify one source schema type at a time(mediawiki, xml, tsv), or modify all 3 at once."] + except HedFileError as e: + error_message = f"Error loading schema: {e.message}" + return [error_message] + + return [] + + +def validate_all_schemas(schema_files): + """Validates all the schema files/formats in the schema dict + + If multiple formats were edited, ensures all 3 formats exist and match. + + Parameters: + schema_files(dict of sets): basename:[extensions] dictionary for all files changed + + Returns: + issues(list of str): Any issues found validating or loading schemas. + """ + all_issues = [] + for basename, extensions in schema_files.items(): + single_schema_issues = [] + for extension in extensions: + full_path = add_extension(basename, extension) + single_schema_issues += validate_schema(full_path) + + if len(extensions) > 1 and not single_schema_issues and "prerelease" in basename: + single_schema_issues += validate_all_schema_formats(basename) + + print(f"Validating {basename}...") + if single_schema_issues: + for issue in single_schema_issues: + print(issue) + + all_issues += single_schema_issues + return all_issues diff --git a/hed/scripts/validate_schemas.py b/hed/scripts/validate_schemas.py new file mode 100644 index 000000000..43c2bf173 --- /dev/null +++ b/hed/scripts/validate_schemas.py @@ -0,0 +1,22 @@ +import sys +from hed.schema import load_schema_version +from hed.scripts.script_util import validate_all_schemas, sort_base_schemas + + +def main(arg_list=None): + # Trigger a local cache hit + _ = load_schema_version("8.2.0") + + if not arg_list: + arg_list = sys.argv[1:] + + schema_files = sort_base_schemas(arg_list) + issues = validate_all_schemas(schema_files) + + if issues: + return 1 + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/pyproject.toml b/pyproject.toml index 1dde170f0..1d34b246a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ dependencies = [ run_remodel = "hed.tools.remodeling.cli.run_remodel:main" run_remodel_backup = "hed.tools.remodeling.cli.run_remodel_backup:main" run_remodel_restore = "hed.tools.remodeling.cli.run_remodel_restore:main" +hed_validate_schemas = "hed.scripts.validate_schemas:main" +hed_update_schemas = "hed.scripts.convert_and_update_schema:main" [tool.versioneer] VCS = "git" diff --git a/tests/schema/test_schema_attribute_validators.py b/tests/schema/test_schema_attribute_validators.py index 9d9a6bf18..95bdd5507 100644 --- a/tests/schema/test_schema_attribute_validators.py +++ b/tests/schema/test_schema_attribute_validators.py @@ -90,7 +90,7 @@ def test_deprecatedFrom(self): self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, tag_entry, attribute_name)) del tag_entry.attributes["deprecatedFrom"] - unit_class_entry = self.hed_schema.unit_classes["temperatureUnits"] + unit_class_entry = copy.deepcopy(self.hed_schema.unit_classes["temperatureUnits"]) # This should raise an issue because it assumes the attribute is set self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, unit_class_entry, attribute_name)) unit_class_entry.attributes["deprecatedFrom"] = "8.1.0" diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/scripts/test_convert_and_update_schema.py b/tests/scripts/test_convert_and_update_schema.py new file mode 100644 index 000000000..48e419a4a --- /dev/null +++ b/tests/scripts/test_convert_and_update_schema.py @@ -0,0 +1,90 @@ +import unittest +import os +import shutil +import copy +from hed import load_schema, load_schema_version +from hed.schema import HedSectionKey, HedKey +from hed.scripts.script_util import add_extension +from hed.scripts.convert_and_update_schema import convert_and_update + + +class TestConvertAndUpdate(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Create a temporary directory for schema files + cls.base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'schemas_update', 'prerelease') + if not os.path.exists(cls.base_path): + os.makedirs(cls.base_path) + + def test_schema_conversion_and_update(self): + # Load a known schema, modify it if necessary, and save it + schema = load_schema_version("8.3.0") + original_name = os.path.join(self.base_path, "test_schema.mediawiki") + schema.save_as_mediawiki(original_name) + + # Assume filenames updated includes just the original schema file for simplicity + filenames = [original_name] + result = convert_and_update(filenames, set_ids=False) + + # Verify no error from convert_and_update and the correct schema version was saved + self.assertEqual(result, 0) + + tsv_filename = add_extension(os.path.join(self.base_path, "test_schema"), ".tsv") + schema_reload1 = load_schema(tsv_filename) + schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml")) + + self.assertEqual(schema, schema_reload1) + self.assertEqual(schema, schema_reload2) + + # Now verify after doing this again with a new schema, they're still the same. + schema = load_schema_version("8.2.0") + schema.save_as_dataframes(tsv_filename) + + filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")] + result = convert_and_update(filenames, set_ids=False) + + # Verify no error from convert_and_update and the correct schema version was saved + self.assertEqual(result, 0) + + schema_reload1 = load_schema(os.path.join(self.base_path, "test_schema.mediawiki")) + schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml")) + + self.assertEqual(schema, schema_reload1) + self.assertEqual(schema, schema_reload2) + + def test_schema_adding_tag(self): + schema = load_schema_version("8.3.0") + basename = os.path.join(self.base_path, "test_schema_edited") + schema.save_as_mediawiki(add_extension(basename, ".mediawiki")) + schema.save_as_xml(add_extension(basename, ".xml")) + schema.save_as_dataframes(add_extension(basename, ".tsv")) + + schema_edited = copy.deepcopy(schema) + test_tag_name = "NewTagWithoutID" + new_entry = schema_edited._create_tag_entry(test_tag_name, HedSectionKey.Tags) + schema_edited._add_tag_to_dict(test_tag_name, new_entry, HedSectionKey.Tags) + + schema_edited.save_as_mediawiki(add_extension(basename, ".mediawiki")) + + # Assume filenames updated includes just the original schema file for simplicity + filenames = [add_extension(basename, ".mediawiki")] + result = convert_and_update(filenames, set_ids=False) + self.assertEqual(result, 0) + + schema_reloaded = load_schema(add_extension(basename, ".xml")) + + self.assertEqual(schema_reloaded, schema_edited) + + result = convert_and_update(filenames, set_ids=True) + self.assertEqual(result, 0) + + schema_reloaded = load_schema(add_extension(basename, ".xml")) + + reloaded_entry = schema_reloaded.tags[test_tag_name] + self.assertTrue(reloaded_entry.has_attribute(HedKey.HedID)) + + + @classmethod + def tearDownClass(cls): + # Clean up the directory created for testing + shutil.rmtree(cls.base_path) diff --git a/tests/scripts/test_script_util.py b/tests/scripts/test_script_util.py new file mode 100644 index 000000000..5c8f1fb97 --- /dev/null +++ b/tests/scripts/test_script_util.py @@ -0,0 +1,121 @@ +import unittest +import os +import shutil +from hed import load_schema_version +from hed.scripts.script_util import add_extension, sort_base_schemas, validate_all_schema_formats + + +class TestAddExtension(unittest.TestCase): + + def test_regular_extension(self): + """Test that regular extensions are added correctly.""" + self.assertEqual(add_extension("filename", ".txt"), "filename.txt") + self.assertEqual(add_extension("document", ".pdf"), "document.pdf") + + def test_tsv_extension(self): + """Test that .tsv extensions are handled differently.""" + # Assuming the function correctly handles paths with directories + self.assertEqual(add_extension("path/to/filename", ".tsv"), "path/to/hedtsv/filename") + # Testing with a basename only + self.assertEqual(add_extension("filename", ".tsv"), "hedtsv/filename") + + def test_empty_extension(self): + """Test adding an empty extension.""" + self.assertEqual(add_extension("filename", ""), "filename") + + def test_none_extension(self): + """Test behavior with None as extension.""" + with self.assertRaises(TypeError): + add_extension("filename", None) + +class TestSortBaseSchemas(unittest.TestCase): + def test_mixed_file_types(self): + filenames = [ + "test_schema.mediawiki", + "hedtsv/test_schema/test_schema_Tag.tsv", + "other_schema.xml" + ] + expected = { + "test_schema": {".mediawiki", ".tsv"}, + "other_schema": {".xml"} + } + result = sort_base_schemas(filenames) + self.assertEqual(dict(result), expected) + + def test_tsv_in_correct_subfolder(self): + filenames = [ + "hedtsv/test_schema/test_schema_Tag.tsv", + "hedtsv/test_schema/test_schema_Tag.tsv", + "hedtsv/wrong_folder/wrong_name_Tag.tsv" # Should be ignored + ] + expected = { + "test_schema": {".tsv"} + } + result = sort_base_schemas(filenames) + self.assertEqual(dict(result), expected) + + def test_tsv_in_correct_subfolder2(self): + filenames = [ + "prerelease/hedtsv/test_schema/test_schema_Tag.tsv", + "prerelease/hedtsv/test_schema/test_schema_Tag.tsv", + "prerelease/hedtsv/wrong_folder/wrong_name_Tag.tsv" # Should be ignored + ] + expected = { + "prerelease/test_schema": {".tsv"} + } + result = sort_base_schemas(filenames) + self.assertEqual(dict(result), expected) + + def test_ignored_files(self): + filenames = [ + "test_schema.mediawiki", + "not_hedtsv/test_schema/test_schema_Tag.tsv" # Should be ignored + ] + expected = { + "test_schema": {".mediawiki"} + } + result = sort_base_schemas(filenames) + self.assertEqual(dict(result), expected) + + def test_empty_input(self): + filenames = [] + expected = {} + result = sort_base_schemas(filenames) + self.assertEqual(dict(result), expected) + + +class TestValidateAllSchemaFormats(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Determine the path to save schemas based on the location of this test file + cls.base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'schemas') + if not os.path.exists(cls.base_path): + os.makedirs(cls.base_path) + cls.basename = "test_schema" + + def test_error_no_error(self): + """Test the function with correctly saved schemas in all three formats.""" + # Load specific schema versions and save them correctly + schema = load_schema_version("8.3.0") + schema.save_as_xml(os.path.join(self.base_path, self.basename + ".xml")) + schema.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename)) + issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename)) + self.assertTrue(issues) + self.assertIn("Error loading schema", issues[0]) + + schema.save_as_mediawiki(os.path.join(self.base_path, self.basename + ".mediawiki")) + + self.assertEqual(validate_all_schema_formats(os.path.join(self.base_path, self.basename)), []) + + schema_incorrect = load_schema_version("8.2.0") + schema_incorrect.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename)) + + # Validate and expect errors + issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename)) + self.assertTrue(issues) + self.assertIn("Multiple schemas of type", issues[0]) + + @classmethod + def tearDownClass(cls): + """Remove the entire directory created for testing to ensure a clean state.""" + shutil.rmtree(cls.base_path) # This will delete the directory and all its contents From b9ffbb6754308e85606168874c193fbbe9fb1b7a Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 20 May 2024 15:06:14 -0500 Subject: [PATCH 2/5] Add debug print --- hed/scripts/script_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hed/scripts/script_util.py b/hed/scripts/script_util.py index 07605ce31..ab5dad1fe 100644 --- a/hed/scripts/script_util.py +++ b/hed/scripts/script_util.py @@ -142,7 +142,8 @@ def validate_all_schemas(schema_files): if len(extensions) > 1 and not single_schema_issues and "prerelease" in basename: single_schema_issues += validate_all_schema_formats(basename) - print(f"Validating {basename}...") + print(f"Validating: {basename}...") + print(f"Extensions: {extensions}") if single_schema_issues: for issue in single_schema_issues: print(issue) From 56524c996b62a4baae9b27d43f1ed5f0de52d7f9 Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 20 May 2024 15:41:11 -0500 Subject: [PATCH 3/5] Update print, better handle non lowercase --- hed/scripts/convert_and_update_schema.py | 9 ++++++++- hed/scripts/script_util.py | 11 ++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/hed/scripts/convert_and_update_schema.py b/hed/scripts/convert_and_update_schema.py index 38d723208..bc34d9e8d 100644 --- a/hed/scripts/convert_and_update_schema.py +++ b/hed/scripts/convert_and_update_schema.py @@ -32,7 +32,14 @@ def convert_and_update(filenames, set_ids): continue source_filename = add_extension(basename, list(extensions)[0]) # Load any changed schema version, they're all the same - source_df_filename = add_extension(basename, ".tsv") + + # todo: more properly decide how we want to handle non lowercase extensions. + tsv_extension = ".tsv" + for extension in extensions: + if extension.lower() == ".tsv": + tsv_extension = extension + + source_df_filename = add_extension(basename, tsv_extension) schema = load_schema(source_filename) print(f"Trying to convert/update file {source_filename}") source_dataframes = load_dataframes(source_df_filename) diff --git a/hed/scripts/script_util.py b/hed/scripts/script_util.py index ab5dad1fe..278415742 100644 --- a/hed/scripts/script_util.py +++ b/hed/scripts/script_util.py @@ -36,7 +36,7 @@ def validate_schema(file_path): f"There is either a problem with the source file, or the saving/loading code." validation_issues.append(error_text) except HedFileError as e: - print(f"Saving/loading error: {e.message}") + print(f"Saving/loading error: {file_path} {e.message}") error_text = e.message if e.issues: error_text = get_printable_issue_string(e.issues, title=file_path) @@ -47,7 +47,7 @@ def validate_schema(file_path): def add_extension(basename, extension): """Generate the final name for a given extension. Only .tsv varies notably.""" - if extension == ".tsv": + if extension.lower() == ".tsv": parent_path, basename = os.path.split(basename) return os.path.join(parent_path, "hedtsv", basename) return basename + extension @@ -73,11 +73,11 @@ def sort_base_schemas(filenames): """ schema_files = defaultdict(set) for file_path in filenames: - basename, extension = os.path.splitext(file_path.lower()) - if extension == ".xml" or extension == ".mediawiki": + basename, extension = os.path.splitext(file_path) + if extension.lower() == ".xml" or extension.lower() == ".mediawiki": schema_files[basename].add(extension) continue - elif extension == ".tsv": + elif extension.lower() == ".tsv": tsv_basename = basename.rpartition("_")[0] full_parent_path, real_basename = os.path.split(tsv_basename) full_parent_path, real_basename2 = os.path.split(full_parent_path) @@ -106,6 +106,7 @@ def validate_all_schema_formats(basename): issue_list(list): A non-empty list if there are any issues. """ # Note if more than one is changed, it intentionally checks all 3 even if one wasn't changed. + # todo: this needs to be updated to handle capital letters in the extension. paths = [add_extension(basename, extension) for extension in all_extensions] try: schemas = [load_schema(path) for path in paths] From c02af7f218dbc63822e03c61e063c577f5f78802 Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 20 May 2024 16:11:36 -0500 Subject: [PATCH 4/5] Fix add_extension for None --- hed/scripts/script_util.py | 2 +- tests/scripts/test_script_util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hed/scripts/script_util.py b/hed/scripts/script_util.py index 278415742..f94469b7a 100644 --- a/hed/scripts/script_util.py +++ b/hed/scripts/script_util.py @@ -47,7 +47,7 @@ def validate_schema(file_path): def add_extension(basename, extension): """Generate the final name for a given extension. Only .tsv varies notably.""" - if extension.lower() == ".tsv": + if extension and extension.lower() == ".tsv": parent_path, basename = os.path.split(basename) return os.path.join(parent_path, "hedtsv", basename) return basename + extension diff --git a/tests/scripts/test_script_util.py b/tests/scripts/test_script_util.py index 5c8f1fb97..638ad5a84 100644 --- a/tests/scripts/test_script_util.py +++ b/tests/scripts/test_script_util.py @@ -25,7 +25,7 @@ def test_empty_extension(self): def test_none_extension(self): """Test behavior with None as extension.""" - with self.assertRaises(TypeError): + with self.assertRaises(AttributeError): add_extension("filename", None) class TestSortBaseSchemas(unittest.TestCase): From d601d8b92afcc27a7a317ee2709a42bab0421c2e Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 20 May 2024 16:19:47 -0500 Subject: [PATCH 5/5] Restore add_extension --- hed/scripts/script_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hed/scripts/script_util.py b/hed/scripts/script_util.py index f94469b7a..278415742 100644 --- a/hed/scripts/script_util.py +++ b/hed/scripts/script_util.py @@ -47,7 +47,7 @@ def validate_schema(file_path): def add_extension(basename, extension): """Generate the final name for a given extension. Only .tsv varies notably.""" - if extension and extension.lower() == ".tsv": + if extension.lower() == ".tsv": parent_path, basename = os.path.split(basename) return os.path.join(parent_path, "hedtsv", basename) return basename + extension