Skip to content

Commit

Permalink
Fix it so it caches the local hed files when asking for all versions
Browse files Browse the repository at this point in the history
  • Loading branch information
IanCa committed Jul 26, 2024
1 parent 3b6d897 commit 97f29bc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 18 deletions.
24 changes: 17 additions & 7 deletions hed/schema/hed_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,23 @@ def get_hed_versions(local_hed_directory=None, library_name=None, check_prerelea
library_name = None

all_hed_versions = {}
local_directory = local_hed_directory
if check_prerelease and not local_directory.endswith(prerelease_suffix):
local_directory += prerelease_suffix
try:
hed_files = os.listdir(local_directory)
except FileNotFoundError:
hed_files = []
local_directories = [local_hed_directory]
if check_prerelease and not local_hed_directory.endswith(prerelease_suffix):
local_directories.append(os.path.join(local_hed_directory, "prerelease"))

hed_files = []
for hed_dir in local_directories:
try:
hed_files += os.listdir(hed_dir)
except FileNotFoundError:
pass
if not hed_files:
cache_local_versions(local_hed_directory)
for hed_dir in local_directories:
try:
hed_files += os.listdir(hed_dir)
except FileNotFoundError:
pass
for hed_file in hed_files:
expression_match = version_pattern.match(hed_file)
if expression_match is not None:
Expand Down
18 changes: 8 additions & 10 deletions hed/schema/schema_io/base2schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def _load(self):
"""
self._loading_merged = True
# Do a full load of the standard schema if this is a partnered schema
# todo: this could simply cache the schema, rather than a full load, if it's not a merged schema.
if not self.appending_to_schema and self._schema.with_standard:
if not self.appending_to_schema and self._schema.with_standard and not self._schema.merged:
from hed.schema.hed_schema_io import load_schema_version
saved_attr = self._schema.header_attributes
saved_format = self._schema.source_format
Expand All @@ -118,14 +117,13 @@ def _load(self):
raise HedFileError(HedExceptions.BAD_WITH_STANDARD,
message=f"Cannot load withStandard schema '{self._schema.with_standard}'",
filename=e.filename)
if not self._schema.merged:
# Copy the non-alterable cached schema
self._schema = copy.deepcopy(base_version)
self._schema.filename = self.filename
self._schema.name = self.name # Manually set name here as we don't want to pass it to load_schema_version
self._schema.header_attributes = saved_attr
self._schema.source_format = saved_format
self._loading_merged = False
# Copy the non-alterable cached schema
self._schema = copy.deepcopy(base_version)
self._schema.filename = self.filename
self._schema.name = self.name # Manually set name here as we don't want to pass it to load_schema_version
self._schema.header_attributes = saved_attr
self._schema.source_format = saved_format
self._loading_merged = False

self._parse_data()
self._schema.finalize_dictionaries()
Expand Down
7 changes: 6 additions & 1 deletion spec_tests/test_hed_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_cache_again(self):
# this should fail to cache, since it was cached too recently.
self.assertEqual(time_since_update, -1)


def test_get_cache_directory(self):
from hed.schema import get_cache_directory
cache_dir = get_cache_directory()
Expand Down Expand Up @@ -76,6 +75,12 @@ def test_get_hed_versions_library(self):
self.assertIsInstance(cached_versions, list)
self.assertTrue(len(cached_versions) > 0)

def test_get_hed_versions_library_prerelease(self):
# Todo: improve this code to actually test it.
cached_versions = hed_cache.get_hed_versions(self.hed_cache_dir, library_name="score", check_prerelease=True)
self.assertIsInstance(cached_versions, list)
self.assertTrue(len(cached_versions) > 0)

def test_sort_version_list(self):
valid_versions = ["8.1.0", "8.0.0", "8.0.0-alpha.1", "7.1.1", "1.0.0"]
for shuffled_versions in itertools.permutations(valid_versions):
Expand Down

0 comments on commit 97f29bc

Please sign in to comment.