From 97f29bcb9207a59e81372e88eec80906b9d740ce Mon Sep 17 00:00:00 2001 From: IanCa Date: Fri, 26 Jul 2024 12:34:45 -0500 Subject: [PATCH] Fix it so it caches the local hed files when asking for all versions --- hed/schema/hed_cache.py | 24 +++++++++++++++++------- hed/schema/schema_io/base2schema.py | 18 ++++++++---------- spec_tests/test_hed_cache.py | 7 ++++++- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/hed/schema/hed_cache.py b/hed/schema/hed_cache.py index 6d829a2f..cf4cbb03 100644 --- a/hed/schema/hed_cache.py +++ b/hed/schema/hed_cache.py @@ -86,13 +86,23 @@ def get_hed_versions(local_hed_directory=None, library_name=None, check_prerelea library_name = None all_hed_versions = {} - local_directory = local_hed_directory - if check_prerelease and not local_directory.endswith(prerelease_suffix): - local_directory += prerelease_suffix - try: - hed_files = os.listdir(local_directory) - except FileNotFoundError: - hed_files = [] + local_directories = [local_hed_directory] + if check_prerelease and not local_hed_directory.endswith(prerelease_suffix): + local_directories.append(os.path.join(local_hed_directory, "prerelease")) + + hed_files = [] + for hed_dir in local_directories: + try: + hed_files += os.listdir(hed_dir) + except FileNotFoundError: + pass + if not hed_files: + cache_local_versions(local_hed_directory) + for hed_dir in local_directories: + try: + hed_files += os.listdir(hed_dir) + except FileNotFoundError: + pass for hed_file in hed_files: expression_match = version_pattern.match(hed_file) if expression_match is not None: diff --git a/hed/schema/schema_io/base2schema.py b/hed/schema/schema_io/base2schema.py index b1e50ed6..2aebb055 100644 --- a/hed/schema/schema_io/base2schema.py +++ b/hed/schema/schema_io/base2schema.py @@ -107,8 +107,7 @@ def _load(self): """ self._loading_merged = True # Do a full load of the standard schema if this is a partnered schema - # todo: this could simply cache the schema, rather than a full load, if it's not a merged schema. - if not self.appending_to_schema and self._schema.with_standard: + if not self.appending_to_schema and self._schema.with_standard and not self._schema.merged: from hed.schema.hed_schema_io import load_schema_version saved_attr = self._schema.header_attributes saved_format = self._schema.source_format @@ -118,14 +117,13 @@ def _load(self): raise HedFileError(HedExceptions.BAD_WITH_STANDARD, message=f"Cannot load withStandard schema '{self._schema.with_standard}'", filename=e.filename) - if not self._schema.merged: - # Copy the non-alterable cached schema - self._schema = copy.deepcopy(base_version) - self._schema.filename = self.filename - self._schema.name = self.name # Manually set name here as we don't want to pass it to load_schema_version - self._schema.header_attributes = saved_attr - self._schema.source_format = saved_format - self._loading_merged = False + # Copy the non-alterable cached schema + self._schema = copy.deepcopy(base_version) + self._schema.filename = self.filename + self._schema.name = self.name # Manually set name here as we don't want to pass it to load_schema_version + self._schema.header_attributes = saved_attr + self._schema.source_format = saved_format + self._loading_merged = False self._parse_data() self._schema.finalize_dictionaries() diff --git a/spec_tests/test_hed_cache.py b/spec_tests/test_hed_cache.py index 79ffb83d..cf86833e 100644 --- a/spec_tests/test_hed_cache.py +++ b/spec_tests/test_hed_cache.py @@ -44,7 +44,6 @@ def test_cache_again(self): # this should fail to cache, since it was cached too recently. self.assertEqual(time_since_update, -1) - def test_get_cache_directory(self): from hed.schema import get_cache_directory cache_dir = get_cache_directory() @@ -76,6 +75,12 @@ def test_get_hed_versions_library(self): self.assertIsInstance(cached_versions, list) self.assertTrue(len(cached_versions) > 0) + def test_get_hed_versions_library_prerelease(self): + # Todo: improve this code to actually test it. + cached_versions = hed_cache.get_hed_versions(self.hed_cache_dir, library_name="score", check_prerelease=True) + self.assertIsInstance(cached_versions, list) + self.assertTrue(len(cached_versions) > 0) + def test_sort_version_list(self): valid_versions = ["8.1.0", "8.0.0", "8.0.0-alpha.1", "7.1.1", "1.0.0"] for shuffled_versions in itertools.permutations(valid_versions):