diff --git a/lib/galaxy/config/schemas/tool_shed_config_schema.yml b/lib/galaxy/config/schemas/tool_shed_config_schema.yml index 2a1eee2b533f..b7b6bfee049e 100644 --- a/lib/galaxy/config/schemas/tool_shed_config_schema.yml +++ b/lib/galaxy/config/schemas/tool_shed_config_schema.yml @@ -102,12 +102,12 @@ mapping: the repositories and tools within the Tool Shed given that you specify the following two config options. - tool_state_cache_dir: + model_cache_dir: type: str - default: database/tool_state_cache + default: database/model_cache required: false desc: | - Cache directory for tool state. + Cache directory for Pydantic model objects. repo_name_boost: type: float diff --git a/lib/tool_shed/managers/model_cache.py b/lib/tool_shed/managers/model_cache.py new file mode 100644 index 000000000000..9ce8c205e901 --- /dev/null +++ b/lib/tool_shed/managers/model_cache.py @@ -0,0 +1,64 @@ +import json +import os +from typing import ( + Any, + Dict, + Optional, + Type, + TypeVar, +) + +from pydantic import BaseModel + +from galaxy.util.hash_util import md5_hash_str + +RAW_CACHED_JSON = Dict[str, Any] + + +def hash_model(model_class: Type[BaseModel]) -> str: + return md5_hash_str(json.dumps(model_class.model_json_schema())) + + +MODEL_HASHES: Dict[Type[BaseModel], str] = {} + + +M = TypeVar("M", bound=BaseModel) + + +def ensure_model_has_hash(model_class: Type[BaseModel]) -> None: + if model_class not in MODEL_HASHES: + MODEL_HASHES[model_class] = hash_model(model_class) + + +class ModelCache: + _cache_directory: str + + def __init__(self, cache_directory: str): + if not os.path.exists(cache_directory): + os.makedirs(cache_directory) + self._cache_directory = cache_directory + + def _cache_target(self, model_class: Type[M], tool_id: str, tool_version: str) -> str: + ensure_model_has_hash(model_class) + # consider breaking this into multiple directories... + cache_target = os.path.join(self._cache_directory, MODEL_HASHES[model_class], tool_id, tool_version) + return cache_target + + def get_cache_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> Optional[M]: + cache_target = self._cache_target(model_class, tool_id, tool_version) + if not os.path.exists(cache_target): + return None + with open(cache_target) as f: + return model_class.model_validate(json.load(f)) + + def has_cached_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> bool: + cache_target = self._cache_target(model_class, tool_id, tool_version) + return os.path.exists(cache_target) + + def insert_cache_entry_for(self, model_object: M, tool_id: str, tool_version: str) -> None: + cache_target = self._cache_target(model_object.__class__, tool_id, tool_version) + parent_directory = os.path.dirname(cache_target) + if not os.path.exists(parent_directory): + os.makedirs(parent_directory) + with open(cache_target, "w") as f: + json.dump(model_object.dict(), f) diff --git a/lib/tool_shed/managers/tool_state_cache.py b/lib/tool_shed/managers/tool_state_cache.py deleted file mode 100644 index 010ab288a334..000000000000 --- a/lib/tool_shed/managers/tool_state_cache.py +++ /dev/null @@ -1,42 +0,0 @@ -import json -import os -from typing import ( - Any, - Dict, - Optional, -) - -RAW_CACHED_JSON = Dict[str, Any] - - -class ToolStateCache: - _cache_directory: str - - def __init__(self, cache_directory: str): - if not os.path.exists(cache_directory): - os.makedirs(cache_directory) - self._cache_directory = cache_directory - - def _cache_target(self, tool_id: str, tool_version: str): - # consider breaking this into multiple directories... - cache_target = os.path.join(self._cache_directory, tool_id, tool_version) - return cache_target - - def get_cache_entry_for(self, tool_id: str, tool_version: str) -> Optional[RAW_CACHED_JSON]: - cache_target = self._cache_target(tool_id, tool_version) - if not os.path.exists(cache_target): - return None - with open(cache_target) as f: - return json.load(f) - - def has_cached_entry_for(self, tool_id: str, tool_version: str) -> bool: - cache_target = self._cache_target(tool_id, tool_version) - return os.path.exists(cache_target) - - def insert_cache_entry_for(self, tool_id: str, tool_version: str, entry: RAW_CACHED_JSON) -> None: - cache_target = self._cache_target(tool_id, tool_version) - parent_directory = os.path.dirname(cache_target) - if not os.path.exists(parent_directory): - os.makedirs(parent_directory) - with open(cache_target, "w") as f: - json.dump(entry, f) diff --git a/lib/tool_shed/managers/tools.py b/lib/tool_shed/managers/tools.py index a881c13f046d..c2c97e5c7fbd 100644 --- a/lib/tool_shed/managers/tools.py +++ b/lib/tool_shed/managers/tools.py @@ -8,6 +8,8 @@ Tuple, ) +from pydantic import BaseModel + from galaxy import exceptions from galaxy.exceptions import ( InternalServerError, @@ -21,13 +23,16 @@ ) from galaxy.tool_util.parameters import ( input_models_for_tool_source, - tool_parameter_bundle_from_json, - ToolParameterBundleModel, + ToolParameterT, ) from galaxy.tool_util.parser import ( get_tool_source, ToolSource, ) +from galaxy.tool_util.parser.interface import ( + Citation, + XrefDict, +) from galaxy.tools.stock import stock_tool_sources from tool_shed.context import ( ProvidesRepositoriesContext, @@ -41,6 +46,53 @@ STOCK_TOOL_SOURCES: Optional[Dict[str, Dict[str, ToolSource]]] = None +# parse the tool source with galaxy.util abstractions to provide a bit richer +# information about the tool than older tool shed abstractions. +class ParsedTool(BaseModel): + id: str + version: Optional[str] + name: str + description: Optional[str] + inputs: List[ToolParameterT] + citations: List[Citation] + license: Optional[str] + profile: Optional[str] + edam_operations: List[str] + edam_topics: List[str] + xrefs: List[XrefDict] + help: Optional[str] + + +def _parse_tool(tool_source: ToolSource) -> ParsedTool: + id = tool_source.parse_id() + version = tool_source.parse_version() + name = tool_source.parse_name() + description = tool_source.parse_description() + inputs = input_models_for_tool_source(tool_source).input_models + citations = tool_source.parse_citations() + license = tool_source.parse_license() + profile = tool_source.parse_profile() + edam_operations = tool_source.parse_edam_operations() + edam_topics = tool_source.parse_edam_topics() + xrefs = tool_source.parse_xrefs() + help = tool_source.parse_help() + + return ParsedTool( + id=id, + version=version, + name=name, + description=description, + profile=profile, + inputs=inputs, + license=license, + citations=citations, + edam_operations=edam_operations, + edam_topics=edam_topics, + xrefs=xrefs, + help=help, + ) + + def search(trans: SessionRequestContext, q: str, page: int = 1, page_size: int = 10) -> dict: """ Perform the search over TS tools index. @@ -97,23 +149,23 @@ def get_repository_metadata_tool_dict( raise ObjectNotFound() -def tool_input_models_cached_for( +def parsed_tool_model_cached_for( trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None -) -> ToolParameterBundleModel: - tool_state_cache = trans.app.tool_state_cache - raw_json = tool_state_cache.get_cache_entry_for(trs_tool_id, tool_version) - if raw_json is not None: - return tool_parameter_bundle_from_json(raw_json) - bundle = tool_input_models_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) - tool_state_cache.insert_cache_entry_for(trs_tool_id, tool_version, bundle.dict()) - return bundle +) -> ParsedTool: + model_cache = trans.app.model_cache + parsed_tool = model_cache.get_cache_entry_for(ParsedTool, trs_tool_id, tool_version) + if parsed_tool is not None: + return parsed_tool + parsed_tool = parsed_tool_model_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) + model_cache.insert_cache_entry_for(parsed_tool, trs_tool_id, tool_version) + return parsed_tool -def tool_input_models_for( +def parsed_tool_model_for( trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None -) -> ToolParameterBundleModel: +) -> ParsedTool: tool_source = tool_source_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) - return input_models_for_tool_source(tool_source) + return _parse_tool(tool_source) def tool_source_for( diff --git a/lib/tool_shed/structured_app.py b/lib/tool_shed/structured_app.py index c3eee0c94299..8fd828f9f5e1 100644 --- a/lib/tool_shed/structured_app.py +++ b/lib/tool_shed/structured_app.py @@ -3,7 +3,7 @@ from galaxy.structured_app import BasicSharedApp if TYPE_CHECKING: - from tool_shed.managers.tool_state_cache import ToolStateCache + from tool_shed.managers.model_cache import ModelCache from tool_shed.repository_registry import Registry as RepositoryRegistry from tool_shed.repository_types.registry import Registry as RepositoryTypesRegistry from tool_shed.util.hgweb_config import HgWebConfigManager @@ -17,4 +17,4 @@ class ToolShedApp(BasicSharedApp): repository_registry: "RepositoryRegistry" hgweb_config_manager: "HgWebConfigManager" security_agent: "CommunityRBACAgent" - tool_state_cache: "ToolStateCache" + model_cache: "ModelCache" diff --git a/lib/tool_shed/webapp/api2/tools.py b/lib/tool_shed/webapp/api2/tools.py index 486a88730909..7f549e80d1c0 100644 --- a/lib/tool_shed/webapp/api2/tools.py +++ b/lib/tool_shed/webapp/api2/tools.py @@ -10,12 +10,12 @@ from galaxy.tool_util.parameters import ( RequestToolState, to_json_schema_string, - ToolParameterBundleModel, ) from tool_shed.context import SessionRequestContext from tool_shed.managers.tools import ( + parsed_tool_model_cached_for, + ParsedTool, search, - tool_input_models_cached_for, ) from tool_shed.managers.trs import ( get_tool, @@ -144,17 +144,17 @@ def trs_get_versions( return get_tool(trans, tool_id).versions @router.get( - "/api/tools/{tool_id}/versions/{tool_version}/parameter_model", + "/api/tools/{tool_id}/versions/{tool_version}", operation_id="tools__parameter_model", summary="Return Galaxy's meta model description of the tool's inputs", ) - def tool_parameters_meta_model( + def show_tool( self, trans: SessionRequestContext = DependsOnTrans, tool_id: str = TOOL_ID_PATH_PARAM, tool_version: str = TOOL_VERSION_PATH_PARAM, - ) -> ToolParameterBundleModel: - return tool_input_models_cached_for(trans, tool_id, tool_version) + ) -> ParsedTool: + return parsed_tool_model_cached_for(trans, tool_id, tool_version) @router.get( "/api/tools/{tool_id}/versions/{tool_version}/parameter_request_schema", @@ -168,6 +168,5 @@ def tool_state( tool_id: str = TOOL_ID_PATH_PARAM, tool_version: str = TOOL_VERSION_PATH_PARAM, ) -> Response: - return json_schema_response( - RequestToolState.parameter_model_for(tool_input_models_cached_for(trans, tool_id, tool_version)) - ) + parsed_tool = parsed_tool_model_cached_for(trans, tool_id, tool_version) + return json_schema_response(RequestToolState.parameter_model_for(parsed_tool.inputs)) diff --git a/lib/tool_shed/webapp/app.py b/lib/tool_shed/webapp/app.py index 4083674241a4..e046301497ad 100644 --- a/lib/tool_shed/webapp/app.py +++ b/lib/tool_shed/webapp/app.py @@ -33,7 +33,7 @@ from galaxy.structured_app import BasicSharedApp from galaxy.web_stack import application_stack_instance from tool_shed.grids.repository_grid_filter_manager import RepositoryGridFilterManager -from tool_shed.managers.tool_state_cache import ToolStateCache +from tool_shed.managers.model_cache import ModelCache from tool_shed.structured_app import ToolShedApp from tool_shed.util.hgweb_config import hgweb_config_manager from tool_shed.webapp.model.migrations import verify_database @@ -84,7 +84,7 @@ def __init__(self, **kwd) -> None: self._register_singleton(SharedModelMapping, model) self._register_singleton(mapping.ToolShedModelMapping, model) self._register_singleton(scoped_session, self.model.context) - self.tool_state_cache = ToolStateCache(self.config.tool_state_cache_dir) + self.model_cache = ModelCache(self.config.model_cache_dir) self.user_manager = self._register_singleton(UserManager, UserManager(self, app_type="tool_shed")) self.api_keys_manager = self._register_singleton(ApiKeyManager) # initialize the Tool Shed tag handler. diff --git a/test/unit/tool_shed/_util.py b/test/unit/tool_shed/_util.py index 3002c6a82fab..e20bb17c3c13 100644 --- a/test/unit/tool_shed/_util.py +++ b/test/unit/tool_shed/_util.py @@ -17,8 +17,8 @@ from galaxy.security.idencoding import IdEncodingHelper from galaxy.util import safe_makedirs from tool_shed.context import ProvidesRepositoriesContext +from tool_shed.managers.model_cache import ModelCache from tool_shed.managers.repositories import upload_tar_and_set_metadata -from tool_shed.managers.tool_state_cache import ToolStateCache from tool_shed.managers.users import create_user from tool_shed.repository_types import util as rt_util from tool_shed.repository_types.registry import Registry as RepositoryTypesRegistry @@ -81,7 +81,7 @@ def __init__(self, temp_directory=None): self.config = TestToolShedConfig(temp_directory) self.security = IdEncodingHelper(id_secret=self.config.id_secret) self.repository_registry = tool_shed.repository_registry.Registry(self) - self.tool_state_cache = ToolStateCache(os.path.join(temp_directory, "tool_state_cache")) + self.model_cache = ModelCache(os.path.join(temp_directory, "model_cache")) @property def security_agent(self): diff --git a/test/unit/tool_shed/test_model_cache.py b/test/unit/tool_shed/test_model_cache.py new file mode 100644 index 000000000000..89ebb8d54042 --- /dev/null +++ b/test/unit/tool_shed/test_model_cache.py @@ -0,0 +1,53 @@ +from pydantic import ( + BaseModel, + ConfigDict, +) + +from tool_shed.managers.model_cache import ( + hash_model, + ModelCache, +) + + +class Moo(BaseModel): + foo: int + + +class MooLike(BaseModel): + model_config = ConfigDict(title="Moo") + foo: int + + +class NewMoo(BaseModel): + model_config = ConfigDict(title="Moo") + foo: int + new_prop: str + + +def test_hash(): + hash_moo_1 = hash_model(Moo) + hash_moo_2 = hash_model(Moo) + assert hash_moo_1 == hash_moo_2 + + +def test_hash_by_value(): + hash_moo_1 = hash_model(Moo) + hash_moo_like = hash_model(MooLike) + assert hash_moo_1 == hash_moo_like + + +def test_hash_different_on_updates(): + hash_moo_1 = hash_model(Moo) + hash_moo_new = hash_model(NewMoo) + assert hash_moo_1 != hash_moo_new + + +def cache_dict(tmp_path): + model_cache = ModelCache(tmp_path) + assert not model_cache.has_cached_entry_for(Moo, "moo", "1.0") + assert None is model_cache.get_cache_entry_for(Moo, "moo", "1.0") + model_cache.insert_cache_entry_for(Moo(foo=4), "moo", "1.0") + moo = model_cache.get_cache_entry_for(Moo, "moo", "1.0") + assert moo + assert moo.foo == 4 + assert model_cache.has_cached_entry_for(Moo, "moo", "1.0") diff --git a/test/unit/tool_shed/test_tool_source.py b/test/unit/tool_shed/test_tool_source.py index 4925cd52cd3f..601d4d63df54 100644 --- a/test/unit/tool_shed/test_tool_source.py +++ b/test/unit/tool_shed/test_tool_source.py @@ -1,7 +1,7 @@ from tool_shed.context import ProvidesRepositoriesContext from tool_shed.managers.tools import ( - tool_input_models_cached_for, - tool_input_models_for, + parsed_tool_model_cached_for, + parsed_tool_model_for, tool_source_for, ) from tool_shed.webapp.model import Repository @@ -17,22 +17,22 @@ def test_get_tool(provides_repositories: ProvidesRepositoriesContext, new_reposi repo_path = new_repository.repo_path(app=provides_repositories.app) tool_source = tool_source_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) assert tool_source.parse_id() == "Add_a_column1" - bundle = tool_input_models_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) - assert len(bundle.input_models) == 3 + bundle = parsed_tool_model_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) + assert len(bundle.inputs) == 3 - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path ) - assert len(cached_bundle.input_models) == 3 + assert len(cached_bundle.inputs) == 3 - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path ) - assert len(cached_bundle.input_models) == 3 + assert len(cached_bundle.inputs) == 3 def test_stock_bundle(provides_repositories: ProvidesRepositoriesContext): - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, "__ZIP_COLLECTION__", "1.0.0", repository_clone_url=None ) - assert len(cached_bundle.input_models) == 2 + assert len(cached_bundle.inputs) == 2