Skip to content

Commit

Permalink
Improved tool shed tool API.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jul 7, 2024
1 parent a1843f9 commit 8784137
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 52 deletions.
23 changes: 16 additions & 7 deletions lib/galaxy/tool_util/parameters/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import (
Any,
Dict,
List,
Optional,
Type,
Union,
)

from pydantic import BaseModel
Expand All @@ -17,9 +19,12 @@
create_request_model,
StateRepresentationT,
ToolParameterBundle,
ToolParameterT,
validate_against_model,
)

HasToolParameters = Union[List[ToolParameterT], ToolParameterBundle]


class ToolState(ABC):
input_state: Dict[str, Any]
Expand All @@ -30,7 +35,7 @@ def __init__(self, input_state: Dict[str, Any]):
def _validate(self, pydantic_model: Type[BaseModel]) -> None:
validate_against_model(pydantic_model, self.input_state)

def validate(self, input_models: ToolParameterBundle) -> None:
def validate(self, input_models: HasToolParameters) -> None:
base_model = self.parameter_model_for(input_models)
if base_model is None:
raise NotImplementedError(
Expand All @@ -44,31 +49,35 @@ def state_representation(self) -> StateRepresentationT:
"""Get state representation of the inputs."""

@classmethod
def parameter_model_for(cls, input_models: ToolParameterBundle) -> Optional[Type[BaseModel]]:
return None
def parameter_model_for(cls, input_models: HasToolParameters) -> Optional[Type[BaseModel]]:
if isinstance(input_models, list):
bundle = ToolParameterBundleModel(input_models=input_models)
else:
bundle = input_models
return cls._parameter_model_for(input_models)


class RequestToolState(ToolState):
state_representation: Literal["request"] = "request"

@classmethod
def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(input_models)


class RequestInternalToolState(ToolState):
state_representation: Literal["request_internal"] = "request_internal"

@classmethod
def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(input_models)


class JobInternalToolState(ToolState):
state_representation: Literal["job_internal"] = "job_internal"

@classmethod
def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
# implement a job model...
return create_request_internal_model(input_models)

Expand All @@ -77,6 +86,6 @@ class TestCaseToolState(ToolState):
state_representation: Literal["test_case"] = "test_case"

@classmethod
def parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
# implement a test case model...
return create_request_internal_model(input_models)
31 changes: 20 additions & 11 deletions lib/tool_shed/managers/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,46 @@ def hash_model(model_class: Type[BaseModel]) -> str:
return md5_hash_str(json.dumps(model_class.model_json_schema()))


MODEL_HASHES: Dict[Type[BaseModel], str] = {}


M = TypeVar("M", bound=BaseModel)


class ModelCache(Generic[M]):
def ensure_model_has_hash(model_class: Type[BaseModel]) -> None:
if model_class not in MODEL_HASHES:
MODEL_HASHES[model_class] = hash_model(model_class)


class ModelCache:
_cache_directory: str

def __init__(self, cache_directory: str):
if not os.path.exists(cache_directory):
os.makedirs(cache_directory)
self._cache_directory = cache_directory

def _cache_target(self, tool_id: str, tool_version: str):
def _cache_target(self, model_class: Type[M], tool_id: str, tool_version: str) -> str:
ensure_model_has_hash(model_class)
# consider breaking this into multiple directories...
cache_target = os.path.join(self._cache_directory, tool_id, tool_version)
cache_target = os.path.join(self._cache_directory, MODEL_HASHES[model_class], tool_id, tool_version)
return cache_target

def get_cache_entry_for(self, tool_id: str, tool_version: str) -> Optional[RAW_CACHED_JSON]:
cache_target = self._cache_target(tool_id, tool_version)
def get_cache_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> Optional[M]:
cache_target = self._cache_target(model_class, tool_id, tool_version)
if not os.path.exists(cache_target):
return None
with open(cache_target) as f:
return json.load(f)
return model_class.model_validate(json.load(f))

def has_cached_entry_for(self, tool_id: str, tool_version: str) -> bool:
cache_target = self._cache_target(tool_id, tool_version)
def has_cached_entry_for(self, model_class: Type[M], tool_id: str, tool_version: str) -> bool:
cache_target = self._cache_target(model_class, tool_id, tool_version)
return os.path.exists(cache_target)

def insert_cache_entry_for(self, tool_id: str, tool_version: str, entry: RAW_CACHED_JSON) -> None:
cache_target = self._cache_target(tool_id, tool_version)
def insert_cache_entry_for(self, model_object: M, tool_id: str, tool_version: str) -> None:
cache_target = self._cache_target(model_object.__class__, tool_id, tool_version)
parent_directory = os.path.dirname(cache_target)
if not os.path.exists(parent_directory):
os.makedirs(parent_directory)
with open(cache_target, "w") as f:
json.dump(entry, f)
json.dump(model_object.dict(), f)
29 changes: 14 additions & 15 deletions lib/tool_shed/managers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
)
from galaxy.tool_util.parameters import (
input_models_for_tool_source,
tool_parameter_bundle_from_json,
ToolParameterBundleModel,
ToolParameterT,
)
from galaxy.tool_util.parser import (
get_tool_source,
Expand Down Expand Up @@ -54,7 +53,7 @@ class ParsedTool(BaseModel):
version: Optional[str]
name: str
description: Optional[str]
inputs: ToolParameterBundleModel
inputs: List[ToolParameterT]
citations: List[Citation]
license: Optional[str]
profile: Optional[str]
Expand All @@ -68,7 +67,7 @@ def _parse_tool(tool_source: ToolSource) -> ParsedTool:
version = tool_source.parse_version()
name = tool_source.parse_name()
description = tool_source.parse_description()
inputs = input_models_for_tool_source(tool_source)
inputs = input_models_for_tool_source(tool_source).input_models
citations = tool_source.parse_citations()
license = tool_source.parse_license()
profile = tool_source.parse_profile()
Expand Down Expand Up @@ -147,23 +146,23 @@ def get_repository_metadata_tool_dict(
raise ObjectNotFound()


def tool_input_models_cached_for(
def parsed_tool_model_cached_for(
trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None
) -> ToolParameterBundleModel:
) -> ParsedTool:
model_cache = trans.app.model_cache
raw_json = model_cache.get_cache_entry_for(trs_tool_id, tool_version)
if raw_json is not None:
return tool_parameter_bundle_from_json(raw_json)
bundle = tool_input_models_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
model_cache.insert_cache_entry_for(trs_tool_id, tool_version, bundle.dict())
return bundle
parsed_tool = model_cache.get_cache_entry_for(ParsedTool, trs_tool_id, tool_version)
if parsed_tool is not None:
return parsed_tool
parsed_tool = parsed_tool_model_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
model_cache.insert_cache_entry_for(parsed_tool, trs_tool_id, tool_version)
return parsed_tool


def tool_input_models_for(
def parsed_tool_model_for(
trans: ProvidesRepositoriesContext, trs_tool_id: str, tool_version: str, repository_clone_url: Optional[str] = None
) -> ToolParameterBundleModel:
) -> ParsedTool:
tool_source = tool_source_for(trans, trs_tool_id, tool_version, repository_clone_url=repository_clone_url)
return input_models_for_tool_source(tool_source)
return _parse_tool(tool_source)


def tool_source_for(
Expand Down
15 changes: 7 additions & 8 deletions lib/tool_shed/webapp/api2/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)
from tool_shed.context import SessionRequestContext
from tool_shed.managers.tools import (
parsed_tool_model_cached_for,
search,
tool_input_models_cached_for,
)
from tool_shed.managers.trs import (
get_tool,
Expand Down Expand Up @@ -144,17 +144,17 @@ def trs_get_versions(
return get_tool(trans, tool_id).versions

@router.get(
"/api/tools/{tool_id}/versions/{tool_version}/parameter_model",
"/api/tools/{tool_id}/versions/{tool_version}",
operation_id="tools__parameter_model",
summary="Return Galaxy's meta model description of the tool's inputs",
)
def tool_parameters_meta_model(
def show_tool(
self,
trans: SessionRequestContext = DependsOnTrans,
tool_id: str = TOOL_ID_PATH_PARAM,
tool_version: str = TOOL_VERSION_PATH_PARAM,
) -> ToolParameterBundleModel:
return tool_input_models_cached_for(trans, tool_id, tool_version)
) -> ParsedTool:
return parsed_tool_model_cached_for(trans, tool_id, tool_version)

@router.get(
"/api/tools/{tool_id}/versions/{tool_version}/parameter_request_schema",
Expand All @@ -168,6 +168,5 @@ def tool_state(
tool_id: str = TOOL_ID_PATH_PARAM,
tool_version: str = TOOL_VERSION_PATH_PARAM,
) -> Response:
return json_schema_response(
RequestToolState.parameter_model_for(tool_input_models_cached_for(trans, tool_id, tool_version))
)
parsed_tool = parsed_tool_model_cached_for(trans, tool_id, tool_version)
return json_schema_response(RequestToolState.parameter_model_for(parsed_tool.inputs))
16 changes: 15 additions & 1 deletion test/unit/tool_shed/test_model_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pydantic import BaseModel, ConfigDict

from tool_shed.managers.model_cache import hash_model
from tool_shed.managers.model_cache import (
hash_model,
ModelCache,
)


class Moo(BaseModel):
Expand Down Expand Up @@ -34,3 +37,14 @@ def test_hash_different_on_updates():
hash_moo_1 = hash_model(Moo)
hash_moo_new = hash_model(NewMoo)
assert hash_moo_1 != hash_moo_new


def cache_dict(tmp_path):
model_cache = ModelCache(tmp_path)
assert not model_cache.has_cached_entry_for(Moo, "moo", "1.0")
assert None is model_cache.get_cache_entry_for(Moo, "moo", "1.0")
model_cache.insert_cache_entry_for(Moo(foo=4), "moo", "1.0")
moo = model_cache.get_cache_entry_for(Moo, "moo", "1.0")
assert moo
assert moo.foo == 4
assert model_cache.has_cached_entry_for(Moo, "moo", "1.0")
20 changes: 10 additions & 10 deletions test/unit/tool_shed/test_tool_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tool_shed.context import ProvidesRepositoriesContext
from tool_shed.managers.tools import (
tool_input_models_cached_for,
tool_input_models_for,
parsed_tool_model_cached_for,
parsed_tool_model_for,
tool_source_for,
)
from tool_shed.webapp.model import Repository
Expand All @@ -17,22 +17,22 @@ def test_get_tool(provides_repositories: ProvidesRepositoriesContext, new_reposi
repo_path = new_repository.repo_path(app=provides_repositories.app)
tool_source = tool_source_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path)
assert tool_source.parse_id() == "Add_a_column1"
bundle = tool_input_models_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path)
assert len(bundle.input_models) == 3
bundle = parsed_tool_model_for(provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path)
assert len(bundle.inputs) == 3

cached_bundle = tool_input_models_cached_for(
cached_bundle = parsed_tool_model_cached_for(
provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path
)
assert len(cached_bundle.input_models) == 3
assert len(cached_bundle.inputs) == 3

cached_bundle = tool_input_models_cached_for(
cached_bundle = parsed_tool_model_cached_for(
provides_repositories, encoded_id, "1.2.0", repository_clone_url=repo_path
)
assert len(cached_bundle.input_models) == 3
assert len(cached_bundle.inputs) == 3


def test_stock_bundle(provides_repositories: ProvidesRepositoriesContext):
cached_bundle = tool_input_models_cached_for(
cached_bundle = parsed_tool_model_cached_for(
provides_repositories, "__ZIP_COLLECTION__", "1.0.0", repository_clone_url=None
)
assert len(cached_bundle.input_models) == 2
assert len(cached_bundle.inputs) == 2

0 comments on commit 8784137

Please sign in to comment.