diff --git a/lib/galaxy/tool_util/parameters/state.py b/lib/galaxy/tool_util/parameters/state.py index 52a929a19383..5745c6267fa2 100644 --- a/lib/galaxy/tool_util/parameters/state.py +++ b/lib/galaxy/tool_util/parameters/state.py @@ -5,8 +5,10 @@ from typing import ( Any, Dict, + List, Optional, Type, + Union, ) from pydantic import BaseModel @@ -17,9 +19,12 @@ create_request_model, StateRepresentationT, ToolParameterBundle, + ToolParameterT, validate_against_model, ) +HasToolParameters = Union[List[ToolParameterT], ToolParameterBundle] + class ToolState(ABC): input_state: Dict[str, Any] @@ -30,7 +35,7 @@ def __init__(self, input_state: Dict[str, Any]): def _validate(self, pydantic_model: Type[BaseModel]) -> None: validate_against_model(pydantic_model, self.input_state) - def validate(self, input_models: ToolParameterBundle) -> None: + def validate(self, input_models: HasToolParameters) -> None: base_model = self.parameter_model_for(input_models) if base_model is None: raise NotImplementedError( @@ -44,15 +49,19 @@ def state_representation(self) -> StateRepresentationT: """Get state representation of the inputs.""" @classmethod - def parameter_model_for(cls, input_models: ToolParameterBundle) -> Optional[Type[BaseModel]]: - return None + def parameter_model_for(cls, input_models: HasToolParameters) -> Optional[Type[BaseModel]]: + if isinstance(input_models, list): + bundle = ToolParameterBundleModel(input_models=input_models) + else: + bundle = input_models + return cls._parameter_model_for(input_models) class RequestToolState(ToolState): state_representation: Literal["request"] = "request" @classmethod - def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: return create_request_model(input_models) @@ -60,7 +69,7 @@ class RequestInternalToolState(ToolState): state_representation: Literal["request_internal"] = "request_internal" @classmethod - def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: return create_request_internal_model(input_models) @@ -68,7 +77,7 @@ class JobInternalToolState(ToolState): state_representation: Literal["job_internal"] = "job_internal" @classmethod - def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: # implement a job model... return create_request_internal_model(input_models) @@ -77,6 +86,6 @@ class TestCaseToolState(ToolState): state_representation: Literal["test_case"] = "test_case" @classmethod - def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: # implement a test case model... return create_request_internal_model(input_models) diff --git a/lib/tool_shed/managers/model_cache.py b/lib/tool_shed/managers/model_cache.py index ef7d11962ba6..1649ce1b1051 100644 --- a/lib/tool_shed/managers/model_cache.py +++ b/lib/tool_shed/managers/model_cache.py @@ -20,10 +20,18 @@ def hash_model(model_class: Type[BaseModel]) -> str: return md5_hash_str(json.dumps(model_class.model_json_schema())) +MODEL_HASHES: Dict[Type[BaseModel], str] = {} + + M = TypeVar("M", bound=BaseModel) -class ModelCache(Generic[M]): +def ensure_model_has_hash(model_class: Type[BaseModel]) -> None: + if model_class not in MODEL_HASHES: + MODEL_HASHES[model_class] = hash_model(model_class) + + +class ModelCache: _cache_directory: str def __init__(self, cache_directory: str): @@ -31,26 +39,27 @@ def __init__(self, cache_directory: str): os.makedirs(cache_directory) self._cache_directory = cache_directory - def _cache_target(self, tool_id: str, tool_version: str): + def _cache_target(self, model_class: Type[M], tool_id: str, tool_version: str) -> str: + ensure_model_has_hash(model_class) # consider breaking this into multiple directories... - cache_target = os.path.join(self._cache_directory, tool_id, tool_version) + cache_target = os.path.join(self._cache_directory, MODEL_HASHES[model_class], tool_id, tool_version) return cache_target - def get_cache_entry_for(self, tool_id: str, tool_version: str) -> Optional[RAW_CACHED_JSON]: - cache_target = self._cache_target(tool_id, tool_version) + def get_cache_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> Optional[M]: + cache_target = self._cache_target(model_class, tool_id, tool_version) if not os.path.exists(cache_target): return None with open(cache_target) as f: - return json.load(f) + return model_class.model_validate(json.load(f)) - def has_cached_entry_for(self, tool_id: str, tool_version: str) -> bool: - cache_target = self._cache_target(tool_id, tool_version) + def has_cached_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> bool: + cache_target = self._cache_target(model_class, tool_id, tool_version) return os.path.exists(cache_target) - def insert_cache_entry_for(self, tool_id: str, tool_version: str, entry: RAW_CACHED_JSON) -> None: - cache_target = self._cache_target(tool_id, tool_version) + def insert_cache_entry_for(self, model_object: M, tool_id: str, tool_version: str) -> None: + cache_target = self._cache_target(model_object.__class__, tool_id, tool_version) parent_directory = os.path.dirname(cache_target) if not os.path.exists(parent_directory): os.makedirs(parent_directory) with open(cache_target, "w") as f: - json.dump(entry, f) + json.dump(model_object.dict(), f) diff --git a/lib/tool_shed/managers/tools.py b/lib/tool_shed/managers/tools.py index a8b873afd6ca..6a5926068dea 100644 --- a/lib/tool_shed/managers/tools.py +++ b/lib/tool_shed/managers/tools.py @@ -23,8 +23,7 @@ ) from galaxy.tool_util.parameters import ( input_models_for_tool_source, - tool_parameter_bundle_from_json, - ToolParameterBundleModel, + ToolParameterT, ) from galaxy.tool_util.parser import ( get_tool_source, @@ -54,7 +53,7 @@ class ParsedTool(BaseModel): version: Optional[str] name: str description: Optional[str] - inputs: ToolParameterBundleModel + inputs: List[ToolParameterT] citations: List[Citation] license: Optional[str] profile: Optional[str] @@ -68,7 +67,7 @@ def _parse_tool(tool_source: ToolSource) -> ParsedTool: version = tool_source.parse_version() name = tool_source.parse_name() description = tool_source.parse_description() - inputs = input_models_for_tool_source(tool_source) + inputs = input_models_for_tool_source(tool_source).input_models citations = tool_source.parse_citations() license = tool_source.parse_license() profile = tool_source.parse_profile() @@ -147,23 +146,23 @@ def get_repository_metadata_tool_dict( raise ObjectNotFound() -def tool_input_models_cached_for( +def parsed_tool_model_cached_for( trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None -) -> ToolParameterBundleModel: +) -> ParsedTool: model_cache = trans.app.model_cache - raw_json = model_cache.get_cache_entry_for(trs_tool_id, tool_version) - if raw_json is not None: - return tool_parameter_bundle_from_json(raw_json) - bundle = tool_input_models_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) - model_cache.insert_cache_entry_for(trs_tool_id, tool_version, bundle.dict()) - return bundle + parsed_tool = model_cache.get_cache_entry_for(ParsedTool, trs_tool_id, tool_version) + if parsed_tool is not None: + return parsed_tool + parsed_tool = parsed_tool_model_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) + model_cache.insert_cache_entry_for(parsed_tool, trs_tool_id, tool_version) + return parsed_tool -def tool_input_models_for( +def parsed_tool_model_for( trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None -) -> ToolParameterBundleModel: +) -> ParsedTool: tool_source = tool_source_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url) - return input_models_for_tool_source(tool_source) + return _parse_tool(tool_source) def tool_source_for( diff --git a/lib/tool_shed/webapp/api2/tools.py b/lib/tool_shed/webapp/api2/tools.py index 486a88730909..3c127d384b4d 100644 --- a/lib/tool_shed/webapp/api2/tools.py +++ b/lib/tool_shed/webapp/api2/tools.py @@ -14,8 +14,8 @@ ) from tool_shed.context import SessionRequestContext from tool_shed.managers.tools import ( + parsed_tool_model_cached_for, search, - tool_input_models_cached_for, ) from tool_shed.managers.trs import ( get_tool, @@ -144,17 +144,17 @@ def trs_get_versions( return get_tool(trans, tool_id).versions @router.get( - "/api/tools/{tool_id}/versions/{tool_version}/parameter_model", + "/api/tools/{tool_id}/versions/{tool_version}", operation_id="tools__parameter_model", summary="Return Galaxy's meta model description of the tool's inputs", ) - def tool_parameters_meta_model( + def show_tool( self, trans: SessionRequestContext = DependsOnTrans, tool_id: str = TOOL_ID_PATH_PARAM, tool_version: str = TOOL_VERSION_PATH_PARAM, - ) -> ToolParameterBundleModel: - return tool_input_models_cached_for(trans, tool_id, tool_version) + ) -> ParsedTool: + return parsed_tool_model_cached_for(trans, tool_id, tool_version) @router.get( "/api/tools/{tool_id}/versions/{tool_version}/parameter_request_schema", @@ -168,6 +168,5 @@ def tool_state( tool_id: str = TOOL_ID_PATH_PARAM, tool_version: str = TOOL_VERSION_PATH_PARAM, ) -> Response: - return json_schema_response( - RequestToolState.parameter_model_for(tool_input_models_cached_for(trans, tool_id, tool_version)) - ) + parsed_tool = parsed_tool_model_cached_for(trans, tool_id, tool_version) + return json_schema_response(RequestToolState.parameter_model_for(parsed_tool.inputs)) diff --git a/test/unit/tool_shed/test_model_cache.py b/test/unit/tool_shed/test_model_cache.py index 3be2f9cef9f1..30dac2a3a620 100644 --- a/test/unit/tool_shed/test_model_cache.py +++ b/test/unit/tool_shed/test_model_cache.py @@ -1,6 +1,9 @@ from pydantic import BaseModel, ConfigDict -from tool_shed.managers.model_cache import hash_model +from tool_shed.managers.model_cache import ( + hash_model, + ModelCache, +) class Moo(BaseModel): @@ -34,3 +37,14 @@ def test_hash_different_on_updates(): hash_moo_1 = hash_model(Moo) hash_moo_new = hash_model(NewMoo) assert hash_moo_1 != hash_moo_new + + +def cache_dict(tmp_path): + model_cache = ModelCache(tmp_path) + assert not model_cache.has_cached_entry_for(Moo, "moo", "1.0") + assert None is model_cache.get_cache_entry_for(Moo, "moo", "1.0") + model_cache.insert_cache_entry_for(Moo(foo=4), "moo", "1.0") + moo = model_cache.get_cache_entry_for(Moo, "moo", "1.0") + assert moo + assert moo.foo == 4 + assert model_cache.has_cached_entry_for(Moo, "moo", "1.0") diff --git a/test/unit/tool_shed/test_tool_source.py b/test/unit/tool_shed/test_tool_source.py index 4925cd52cd3f..601d4d63df54 100644 --- a/test/unit/tool_shed/test_tool_source.py +++ b/test/unit/tool_shed/test_tool_source.py @@ -1,7 +1,7 @@ from tool_shed.context import ProvidesRepositoriesContext from tool_shed.managers.tools import ( - tool_input_models_cached_for, - tool_input_models_for, + parsed_tool_model_cached_for, + parsed_tool_model_for, tool_source_for, ) from tool_shed.webapp.model import Repository @@ -17,22 +17,22 @@ def test_get_tool(provides_repositories: ProvidesRepositoriesContext, new_reposi repo_path = new_repository.repo_path(app=provides_repositories.app) tool_source = tool_source_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) assert tool_source.parse_id() == "Add_a_column1" - bundle = tool_input_models_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) - assert len(bundle.input_models) == 3 + bundle = parsed_tool_model_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path) + assert len(bundle.inputs) == 3 - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path ) - assert len(cached_bundle.input_models) == 3 + assert len(cached_bundle.inputs) == 3 - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path ) - assert len(cached_bundle.input_models) == 3 + assert len(cached_bundle.inputs) == 3 def test_stock_bundle(provides_repositories: ProvidesRepositoriesContext): - cached_bundle = tool_input_models_cached_for( + cached_bundle = parsed_tool_model_cached_for( provides_repositories, "__ZIP_COLLECTION__", "1.0.0", repository_clone_url=None ) - assert len(cached_bundle.input_models) == 2 + assert len(cached_bundle.inputs) == 2