From 352b9265d6fdd0655e633fc2667cd86e24c4c1ed Mon Sep 17 00:00:00 2001 From: John Chilton Date: Thu, 29 Aug 2024 19:25:09 -0400 Subject: [PATCH] Tool Request API. --- .github/workflows/framework_tools.yaml | 3 +- lib/galaxy/app.py | 9 +- lib/galaxy/celery/tasks.py | 17 +- lib/galaxy/managers/jobs.py | 63 ++++- lib/galaxy/model/__init__.py | 28 +++ ...3d5d144_implement_structured_tool_state.py | 69 ++++++ lib/galaxy/schema/jobs.py | 13 ++ lib/galaxy/schema/schema.py | 17 ++ lib/galaxy/schema/tasks.py | 13 ++ lib/galaxy/tool_util/parameters/__init__.py | 6 + lib/galaxy/tool_util/parameters/convert.py | 102 +++++++- lib/galaxy/tool_util/parameters/factory.py | 4 +- lib/galaxy/tool_util/parameters/models.py | 25 +- lib/galaxy/tool_util/parameters/visitor.py | 90 +++++++- lib/galaxy/tool_util/verify/_types.py | 11 +- lib/galaxy/tool_util/verify/interactor.py | 218 ++++++++++++++---- lib/galaxy/tool_util/verify/parse.py | 43 +++- lib/galaxy/tools/__init__.py | 109 ++++++++- lib/galaxy/tools/execute.py | 106 ++++++++- lib/galaxy/tools/parameters/basic.py | 10 +- lib/galaxy/tools/parameters/meta.py | 77 ++++++- lib/galaxy/webapps/galaxy/api/histories.py | 12 + lib/galaxy/webapps/galaxy/api/jobs.py | 25 +- lib/galaxy/webapps/galaxy/api/tools.py | 86 ++++++- lib/galaxy/webapps/galaxy/services/base.py | 20 +- .../webapps/galaxy/services/histories.py | 9 + lib/galaxy/webapps/galaxy/services/jobs.py | 134 ++++++++++- lib/galaxy/webapps/galaxy/services/tools.py | 79 ++++--- lib/galaxy_test/base/populators.py | 22 +- scripts/gen_typescript_artifacts.py | 20 ++ test/functional/test_toolbox_pytest.py | 10 +- .../tool_util/parameter_specification.yml | 22 +- test/unit/tool_util/test_parameter_covert.py | 99 ++++++++ 33 files changed, 1415 insertions(+), 156 deletions(-) create mode 100644 lib/galaxy/model/migrations/alembic/versions_gxy/7ffd33d5d144_implement_structured_tool_state.py create mode 100644 scripts/gen_typescript_artifacts.py create mode 100644 test/unit/tool_util/test_parameter_covert.py diff --git a/.github/workflows/framework_tools.yaml b/.github/workflows/framework_tools.yaml index a06f6e469aad..1652cb85b005 100644 --- a/.github/workflows/framework_tools.yaml +++ b/.github/workflows/framework_tools.yaml @@ -26,6 +26,7 @@ jobs: strategy: matrix: python-version: ['3.8'] + use-legacy-api: ['if_needed', 'always'] services: postgres: image: postgres:13 @@ -66,7 +67,7 @@ jobs: path: 'galaxy root/.venv' key: gxy-venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('galaxy root/requirements.txt') }}-framework-tools - name: Run tests - run: ./run_tests.sh --coverage --framework-tools + run: GALAXY_TEST_USE_LEGACY_TOOL_API="${{ matrix.use-legacy-api }}" ./run_tests.sh --coverage --framework-tools working-directory: 'galaxy root' - uses: codecov/codecov-action@v3 with: diff --git a/lib/galaxy/app.py b/lib/galaxy/app.py index dc69fdd91f1f..c864e3248833 100644 --- a/lib/galaxy/app.py +++ b/lib/galaxy/app.py @@ -672,6 +672,10 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl self._register_singleton(Registry, self.datatypes_registry) galaxy.model.set_datatypes_registry(self.datatypes_registry) self.configure_sentry_client() + # Load dbkey / genome build manager + self._configure_genome_builds(data_table_name="__dbkeys__", load_old_style=True) + # Tool Data Tables + self._configure_tool_data_tables(from_shed_config=False) self._configure_tool_shed_registry() self._register_singleton(tool_shed_registry.Registry, self.tool_shed_registry) @@ -750,11 +754,6 @@ def __init__(self, **kwargs) -> None: ) self.api_keys_manager = self._register_singleton(ApiKeyManager) - # Tool Data Tables - self._configure_tool_data_tables(from_shed_config=False) - # Load dbkey / genome build manager - self._configure_genome_builds(data_table_name="__dbkeys__", load_old_style=True) - # Genomes self.genomes = self._register_singleton(Genomes) # Data providers registry. diff --git a/lib/galaxy/celery/tasks.py b/lib/galaxy/celery/tasks.py index 3b2e4c6272a7..6939c7799e61 100644 --- a/lib/galaxy/celery/tasks.py +++ b/lib/galaxy/celery/tasks.py @@ -28,6 +28,7 @@ DatasetManager, ) from galaxy.managers.hdas import HDAManager +from galaxy.managers.jobs import JobSubmitter from galaxy.managers.lddas import LDDAManager from galaxy.managers.markdown_util import generate_branded_pdf from galaxy.managers.model_stores import ModelStoreManager @@ -54,6 +55,7 @@ MaterializeDatasetInstanceTaskRequest, PrepareDatasetCollectionDownload, PurgeDatasetsTaskRequest, + QueueJobs, SetupHistoryExportJob, WriteHistoryContentTo, WriteHistoryTo, @@ -75,9 +77,9 @@ def setup_data_table_manager(app): @lru_cache -def cached_create_tool_from_representation(app, raw_tool_source): +def cached_create_tool_from_representation(app, raw_tool_source, tool_dir=""): return create_tool_from_representation( - app=app, raw_tool_source=raw_tool_source, tool_dir="", tool_source_class="XmlToolSource" + app=app, raw_tool_source=raw_tool_source, tool_dir=tool_dir, tool_source_class="XmlToolSource" ) @@ -335,6 +337,17 @@ def fetch_data( return abort_when_job_stops(_fetch_data, session=sa_session, job_id=job_id, setup_return=setup_return) +@galaxy_task(action="queuing up submitted jobs") +def queue_jobs(request: QueueJobs, app: MinimalManagerApp, job_submitter: JobSubmitter): + tool = cached_create_tool_from_representation( + app, request.tool_source.raw_tool_source, tool_dir=request.tool_source.tool_dir + ) + job_submitter.queue_jobs( + tool, + request, + ) + + @galaxy_task(ignore_result=True, action="setting up export history job") def export_history( model_store_manager: ModelStoreManager, diff --git a/lib/galaxy/managers/jobs.py b/lib/galaxy/managers/jobs.py index b7221cf0fef9..3e1b4cf1c4ac 100644 --- a/lib/galaxy/managers/jobs.py +++ b/lib/galaxy/managers/jobs.py @@ -48,12 +48,15 @@ ) from galaxy.managers.datasets import DatasetManager from galaxy.managers.hdas import HDAManager +from galaxy.managers.histories import HistoryManager from galaxy.managers.lddas import LDDAManager +from galaxy.managers.users import UserManager from galaxy.model import ( ImplicitCollectionJobs, ImplicitCollectionJobsJobAssociation, Job, JobParameter, + ToolRequest, User, Workflow, WorkflowInvocation, @@ -70,8 +73,13 @@ JobIndexQueryPayload, JobIndexSortByEnum, ) +from galaxy.schema.tasks import QueueJobs from galaxy.security.idencoding import IdEncodingHelper -from galaxy.structured_app import StructuredApp +from galaxy.structured_app import ( + MinimalManagerApp, + StructuredApp, +) +from galaxy.tools import Tool from galaxy.tools._types import ( ToolStateDumpedToJsonInternalT, ToolStateJobInstancePopulatedT, @@ -86,6 +94,7 @@ parse_filters_structured, RawTextTerm, ) +from galaxy.work.context import WorkRequestContext log = logging.getLogger(__name__) @@ -134,6 +143,8 @@ def index_query(self, trans: ProvidesUserContext, payload: JobIndexQueryPayload) workflow_id = payload.workflow_id invocation_id = payload.invocation_id implicit_collection_jobs_id = payload.implicit_collection_jobs_id + tool_request_id = payload.tool_request_id + search = payload.search order_by = payload.order_by @@ -150,6 +161,7 @@ def build_and_apply_filters(stmt, objects, filter_func): def add_workflow_jobs(): wfi_step = select(WorkflowInvocationStep) + if workflow_id is not None: wfi_step = ( wfi_step.join(WorkflowInvocation).join(Workflow).where(Workflow.stored_workflow_id == workflow_id) @@ -164,6 +176,7 @@ def add_workflow_jobs(): ImplicitCollectionJobsJobAssociation.implicit_collection_jobs_id == wfi_step_sq.c.implicit_collection_jobs_id, ) + # Ensure the result is models, not tuples sq = stmt1.union(stmt2).subquery() # SQLite won't recognize Job.foo as a valid column for the ORDER BY clause due to the UNION clause, so we'll use the subquery `columns` collection (`sq.c`). @@ -241,6 +254,9 @@ def add_search_criteria(stmt): if history_id is not None: stmt = stmt.where(Job.history_id == history_id) + if tool_request_id is not None: + stmt = stmt.filter(model.Job.tool_request_id == tool_request_id) + order_by_columns = Job if workflow_id or invocation_id: stmt, order_by_columns = add_workflow_jobs() @@ -1150,3 +1166,48 @@ def get_jobs_to_check_at_startup(session: galaxy_scoped_session, track_jobs_in_d def get_job(session, *where_clauses): stmt = select(Job).where(*where_clauses).limit(1) return session.scalars(stmt).first() + + +class JobSubmitter: + def __init__( + self, + history_manager: HistoryManager, + user_manager: UserManager, + app: MinimalManagerApp, + ): + self.history_manager = history_manager + self.user_manager = user_manager + self.app = app + + def queue_jobs(self, tool: Tool, request: QueueJobs) -> None: + user = self.user_manager.by_id(request.user.user_id) + sa_session = self.app.model.context + tool_request: ToolRequest = cast(ToolRequest, sa_session.query(ToolRequest).get(request.tool_request_id)) + if tool_request is None: + raise Exception(f"Problem fetching request with ID {request.tool_request_id}") + try: + target_history = tool_request.history + use_cached_jobs = request.use_cached_jobs + rerun_remap_job_id = request.rerun_remap_job_id + trans = WorkRequestContext( + self.app, + user, + history=target_history, + ) + tool.handle_input_async( + trans, + tool_request, + history=target_history, + use_cached_job=use_cached_jobs, + rerun_remap_job_id=rerun_remap_job_id, + ) + tool_request.state = ToolRequest.states.SUBMITTED + sa_session.add(tool_request) + with transaction(sa_session): + sa_session.commit() + except Exception as e: + tool_request.state = ToolRequest.states.FAILED + tool_request.state_message = str(e) + sa_session.add(tool_request) + with transaction(sa_session): + sa_session.commit() diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index b01620b59c2e..1f4b8e9e5f4e 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -176,6 +176,7 @@ DatasetValidatedState, InvocationsStateCounts, JobState, + ToolRequestState, ) from galaxy.schema.workflow.comments import WorkflowCommentModel from galaxy.security import get_permitted_actions @@ -1336,6 +1337,30 @@ def __init__(self, user, token=None): self.expiration_time = now() + timedelta(hours=24) +class ToolSource(Base, Dictifiable, RepresentById): + __tablename__ = "tool_source" + + id: Mapped[int] = mapped_column(primary_key=True) + hash: Mapped[Optional[str]] = mapped_column(Unicode(255)) + source: Mapped[dict] = mapped_column(JSONType) + + +class ToolRequest(Base, Dictifiable, RepresentById): + __tablename__ = "tool_request" + + states: TypeAlias = ToolRequestState + + id: Mapped[int] = mapped_column(primary_key=True) + tool_source_id: Mapped[int] = mapped_column(ForeignKey("tool_source.id"), index=True) + history_id: Mapped[Optional[int]] = mapped_column(ForeignKey("history.id"), index=True) + request: Mapped[dict] = mapped_column(JSONType) + state: Mapped[Optional[str]] = mapped_column(TrimmedString(32), index=True) + state_message: Mapped[Optional[str]] = mapped_column(JSONType, index=True) + + tool_source: Mapped["ToolSource"] = relationship() + history: Mapped[Optional["History"]] = relationship(back_populates="tool_requests") + + class DynamicTool(Base, Dictifiable, RepresentById): __tablename__ = "dynamic_tool" @@ -1462,7 +1487,9 @@ class Job(Base, JobLike, UsesCreateAndUpdateTime, Dictifiable, Serializable): handler: Mapped[Optional[str]] = mapped_column(TrimmedString(255), index=True) preferred_object_store_id: Mapped[Optional[str]] = mapped_column(String(255)) object_store_id_overrides: Mapped[Optional[STR_TO_STR_DICT]] = mapped_column(JSONType) + tool_request_id: Mapped[Optional[int]] = mapped_column(ForeignKey("tool_request.id"), index=True) + tool_request: Mapped[Optional["ToolRequest"]] = relationship() user: Mapped[Optional["User"]] = relationship() galaxy_session: Mapped[Optional["GalaxySession"]] = relationship() history: Mapped[Optional["History"]] = relationship(back_populates="jobs") @@ -3185,6 +3212,7 @@ class History(Base, HasTags, Dictifiable, UsesAnnotations, HasName, Serializable ) user: Mapped[Optional["User"]] = relationship(back_populates="histories") jobs: Mapped[List["Job"]] = relationship(back_populates="history", cascade_backrefs=False) + tool_requests: Mapped[List["ToolRequest"]] = relationship(back_populates="history") update_time = column_property( select(func.max(HistoryAudit.update_time)).where(HistoryAudit.history_id == id).scalar_subquery(), diff --git a/lib/galaxy/model/migrations/alembic/versions_gxy/7ffd33d5d144_implement_structured_tool_state.py b/lib/galaxy/model/migrations/alembic/versions_gxy/7ffd33d5d144_implement_structured_tool_state.py new file mode 100644 index 000000000000..fe76f3c199a1 --- /dev/null +++ b/lib/galaxy/model/migrations/alembic/versions_gxy/7ffd33d5d144_implement_structured_tool_state.py @@ -0,0 +1,69 @@ +"""implement structured tool state + +Revision ID: 7ffd33d5d144 +Revises: eee9229a9765 +Create Date: 2022-11-09 15:53:11.451185 + +""" + +from sqlalchemy import ( + Column, + ForeignKey, + Integer, + String, +) + +from galaxy.model.custom_types import JSONType +from galaxy.model.database_object_names import build_index_name +from galaxy.model.migrations.util import ( + _is_sqlite, + add_column, + create_table, + drop_column, + drop_index, + drop_table, + transaction, +) + +# revision identifiers, used by Alembic. +revision = "7ffd33d5d144" +down_revision = "eee9229a9765" +branch_labels = None +depends_on = None + +job_table_name = "job" +request_column_name = "tool_request_id" +job_request_index_name = build_index_name(job_table_name, request_column_name) + + +def upgrade(): + with transaction(): + create_table( + "tool_source", + Column("id", Integer, primary_key=True), + Column("hash", String(255), index=True), + Column("source", JSONType), + ) + create_table( + "tool_request", + Column("id", Integer, primary_key=True), + Column("request", JSONType), + Column("state", String(32)), + Column("state_message", JSONType), + Column("tool_source_id", Integer, ForeignKey("tool_source.id"), index=True), + Column("history_id", Integer, ForeignKey("history.id"), index=True), + ) + index = not _is_sqlite() + add_column( + job_table_name, + Column(request_column_name, Integer, ForeignKey("tool_request.id"), default=None, index=index), + ) + + +def downgrade(): + with transaction(): + if not _is_sqlite(): + drop_index(job_request_index_name, job_table_name) + drop_column(job_table_name, request_column_name) + drop_table("tool_request") + drop_table("tool_source") diff --git a/lib/galaxy/schema/jobs.py b/lib/galaxy/schema/jobs.py index fe6316262983..283c330e24de 100644 --- a/lib/galaxy/schema/jobs.py +++ b/lib/galaxy/schema/jobs.py @@ -82,6 +82,19 @@ class JobOutputAssociation(JobAssociation): ) +class JobOutputCollectionAssociation(Model): + name: str = Field( + default=..., + title="name", + description="Name of the job parameter.", + ) + dataset_collection_instance: EncodedDataItemSourceId = Field( + default=..., + title="dataset_collection_instance", + description="Reference to the associated item.", + ) + + class ReportJobErrorPayload(Model): dataset_id: DecodedDatabaseIdField = Field( default=..., diff --git a/lib/galaxy/schema/schema.py b/lib/galaxy/schema/schema.py index 427777fa319d..06490242b75f 100644 --- a/lib/galaxy/schema/schema.py +++ b/lib/galaxy/schema/schema.py @@ -1533,6 +1533,7 @@ class JobIndexQueryPayload(Model): workflow_id: Optional[DecodedDatabaseIdField] = None invocation_id: Optional[DecodedDatabaseIdField] = None implicit_collection_jobs_id: Optional[DecodedDatabaseIdField] = None + tool_request_id: Optional[DecodedDatabaseIdField] = None order_by: JobIndexSortByEnum = JobIndexSortByEnum.update_time search: Optional[str] = None limit: int = 500 @@ -3732,6 +3733,22 @@ class AsyncTaskResultSummary(Model): ) +ToolRequestIdField = Field(title="ID", description="Encoded ID of the role") + + +class ToolRequestState(str, Enum): + NEW = "new" + SUBMITTED = "submitted" + FAILED = "failed" + + +class ToolRequestModel(Model): + id: EncodedDatabaseIdField = ToolRequestIdField + request: Dict[str, Any] + state: ToolRequestState + state_message: Optional[str] + + class AsyncFile(Model): storage_request_id: UUID task: AsyncTaskResultSummary diff --git a/lib/galaxy/schema/tasks.py b/lib/galaxy/schema/tasks.py index 022d82666aed..ad81ff1b7324 100644 --- a/lib/galaxy/schema/tasks.py +++ b/lib/galaxy/schema/tasks.py @@ -119,3 +119,16 @@ class ComputeDatasetHashTaskRequest(Model): class PurgeDatasetsTaskRequest(Model): dataset_ids: List[int] + + +class ToolSource(Model): + raw_tool_source: str + tool_dir: str + + +class QueueJobs(Model): + tool_source: ToolSource + tool_request_id: int # links to request ("incoming") and history + user: RequestUser # TODO: test anonymous users through this submission path + use_cached_jobs: bool + rerun_remap_job_id: Optional[int] # link to a job to rerun & remap diff --git a/lib/galaxy/tool_util/parameters/__init__.py b/lib/galaxy/tool_util/parameters/__init__.py index 1ad01e4eb328..3f607d94476d 100644 --- a/lib/galaxy/tool_util/parameters/__init__.py +++ b/lib/galaxy/tool_util/parameters/__init__.py @@ -2,6 +2,7 @@ from .convert import ( decode, encode, + encode_test, ) from .factory import ( from_input_source, @@ -26,7 +27,9 @@ CwlStringParameterModel, CwlUnionParameterModel, DataCollectionParameterModel, + DataCollectionRequest, DataParameterModel, + DataRequest, FloatParameterModel, HiddenParameterModel, IntegerParameterModel, @@ -76,6 +79,8 @@ "JobInternalToolState", "ToolParameterBundle", "ToolParameterBundleModel", + "DataRequest", + "DataCollectionRequest", "ToolParameterModel", "IntegerParameterModel", "BooleanParameterModel", @@ -122,6 +127,7 @@ "VISITOR_NO_REPLACEMENT", "decode", "encode", + "encode_test", "WorkflowStepToolState", "WorkflowStepLinkedToolState", ) diff --git a/lib/galaxy/tool_util/parameters/convert.py b/lib/galaxy/tool_util/parameters/convert.py index 14caed47e92c..77423034d82e 100644 --- a/lib/galaxy/tool_util/parameters/convert.py +++ b/lib/galaxy/tool_util/parameters/convert.py @@ -1,24 +1,38 @@ """Utilities for converting between request states. """ +import logging from typing import ( Any, Callable, + cast, + List, ) +from galaxy.tool_util.parser.interface import ( + JsonTestCollectionDefDict, + JsonTestDatasetDefDict, +) from .models import ( + DataCollectionRequest, + DataParameterModel, + DataRequest, + SelectParameterModel, ToolParameterBundle, ToolParameterT, ) from .state import ( RequestInternalToolState, RequestToolState, + TestCaseToolState, ) from .visitor import ( visit_input_values, VISITOR_NO_REPLACEMENT, ) +log = logging.getLogger(__name__) + def decode( external_state: RequestToolState, input_models: ToolParameterBundle, decode_id: Callable[[str], int] @@ -27,13 +41,24 @@ def decode( external_state.validate(input_models) + def decode_src_dict(src_dict: dict): + assert "id" in src_dict + decoded_dict = src_dict.copy() + decoded_dict["id"] = decode_id(src_dict["id"]) + return decoded_dict + def decode_callback(parameter: ToolParameterT, value: Any): if parameter.parameter_type == "gx_data": + data_parameter = cast(DataParameterModel, parameter) + if data_parameter.multiple: + assert isinstance(value, list), str(value) + return list(map(decode_src_dict, value)) + else: + assert isinstance(value, dict), str(value) + return decode_src_dict(value) + elif parameter.parameter_type == "gx_data_collection": assert isinstance(value, dict), str(value) - assert "id" in value - decoded_dict = value.copy() - decoded_dict["id"] = decode_id(value["id"]) - return decoded_dict + return decode_src_dict(value) else: return VISITOR_NO_REPLACEMENT @@ -53,13 +78,24 @@ def encode( ) -> RequestToolState: """Prepare an external representation of tool state (request) for storing in the database (request_internal).""" + def encode_src_dict(src_dict: dict): + assert "id" in src_dict + encoded_dict = src_dict.copy() + encoded_dict["id"] = encode_id(src_dict["id"]) + return encoded_dict + def encode_callback(parameter: ToolParameterT, value: Any): if parameter.parameter_type == "gx_data": + data_parameter = cast(DataParameterModel, parameter) + if data_parameter.multiple: + assert isinstance(value, list), str(value) + return list(map(encode_src_dict, value)) + else: + assert isinstance(value, dict), str(value) + return encode_src_dict(value) + elif parameter.parameter_type == "gx_data_collection": assert isinstance(value, dict), str(value) - assert "id" in value - encoded_dict = value.copy() - encoded_dict["id"] = encode_id(value["id"]) - return encoded_dict + return encode_src_dict(value) else: return VISITOR_NO_REPLACEMENT @@ -71,3 +107,53 @@ def encode_callback(parameter: ToolParameterT, value: Any): request_state = RequestToolState(request_state_dict) request_state.validate(input_models) return request_state + + +# interfaces for adapting test data dictionaries to tool request dictionaries +# e.g. {class: File, path: foo.bed} => {src: hda, id: ab1235cdfea3} +AdaptDatasets = Callable[[JsonTestDatasetDefDict], DataRequest] +AdaptCollections = Callable[[JsonTestCollectionDefDict], DataCollectionRequest] + + +def encode_test( + test_case_state: TestCaseToolState, + input_models: ToolParameterBundle, + adapt_datasets: AdaptDatasets, + adapt_collections: AdaptCollections, +): + + def encode_callback(parameter: ToolParameterT, value: Any): + if parameter.parameter_type == "gx_data": + data_parameter = cast(DataParameterModel, parameter) + if value is not None: + if data_parameter.multiple: + assert isinstance(value, list), str(value) + test_datasets = cast(List[JsonTestDatasetDefDict], value) + return [d.model_dump() for d in map(adapt_datasets, test_datasets)] + else: + assert isinstance(value, dict), str(value) + test_dataset = cast(JsonTestDatasetDefDict, value) + return adapt_datasets(test_dataset).model_dump() + elif parameter.parameter_type == "gx_data_collection": + # data_collection_parameter = cast(DataCollectionParameterModel, parameter) + if value is not None: + assert isinstance(value, dict), str(value) + test_collection = cast(JsonTestCollectionDefDict, value) + return adapt_collections(test_collection).model_dump() + elif parameter.parameter_type == "gx_select": + select_parameter = cast(SelectParameterModel, parameter) + if select_parameter.multiple and value is not None: + return [v.strip() for v in value.split(",")] + else: + return VISITOR_NO_REPLACEMENT + + return VISITOR_NO_REPLACEMENT + + request_state_dict = visit_input_values( + input_models, + test_case_state, + encode_callback, + ) + request_state = RequestToolState(request_state_dict) + request_state.validate(input_models) + return request_state diff --git a/lib/galaxy/tool_util/parameters/factory.py b/lib/galaxy/tool_util/parameters/factory.py index a636b13a17e6..a2534455cf33 100644 --- a/lib/galaxy/tool_util/parameters/factory.py +++ b/lib/galaxy/tool_util/parameters/factory.py @@ -229,7 +229,9 @@ def _from_input_source_galaxy(input_source: InputSource) -> ToolParameterT: if typed_value == default_test_value: is_default_when = True whens.append( - ConditionalWhen(discriminator=value, parameters=tool_parameter_models, is_default_when=is_default_when) + ConditionalWhen( + discriminator=typed_value, parameters=tool_parameter_models, is_default_when=is_default_when + ) ) return ConditionalParameterModel( name=input_source.parse_name(), diff --git a/lib/galaxy/tool_util/parameters/models.py b/lib/galaxy/tool_util/parameters/models.py index 80a658708bd1..5af4ea05239b 100644 --- a/lib/galaxy/tool_util/parameters/models.py +++ b/lib/galaxy/tool_util/parameters/models.py @@ -272,15 +272,6 @@ class MultiDataInstanceInternal(StrictModel): id: StrictInt -class DataTestCaseValue(StrictModel): - src: TestCaseDataSrcT - path: str - - -class MultipleDataTestCaseValue(RootModel): - root: List[DataTestCaseValue] - - MultiDataRequestInternal: Type = union_type([MultiDataInstanceInternal, List[MultiDataInstanceInternal]]) @@ -368,6 +359,8 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam return allow_batching(dynamic_model_information_from_py_type(self, self.py_type)) elif state_representation == "request_internal": return allow_batching(dynamic_model_information_from_py_type(self, self.py_type_internal)) + elif state_representation == "job_internal": + return dynamic_model_information_from_py_type(self, self.py_type_internal) elif state_representation == "workflow_step": return dynamic_model_information_from_py_type(self, type(None), requires_value=False) elif state_representation == "workflow_step_linked": @@ -641,11 +634,13 @@ def request_requires_value(self) -> bool: DrillDownHierarchyT = Literal["recurse", "exact"] -def drill_down_possible_values(options: List[DrillDownOptionsDict], multiple: bool) -> List[str]: +def drill_down_possible_values( + options: List[DrillDownOptionsDict], multiple: bool, hierarchy: DrillDownHierarchyT +) -> List[str]: possible_values = [] def add_value(option: str, is_leaf: bool): - if not multiple and not is_leaf: + if not multiple and not is_leaf and hierarchy == "recurse": return possible_values.append(option) @@ -673,7 +668,8 @@ class DrillDownParameterModel(BaseGalaxyToolParameterModelDefinition): def py_type(self) -> Type: if self.options is not None: literal_options: List[Type] = [ - cast_as_type(Literal[o]) for o in drill_down_possible_values(self.options, self.multiple) + cast_as_type(Literal[o]) + for o in drill_down_possible_values(self.options, self.multiple, self.hierarchy) ] py_type = union_type(literal_options) else: @@ -819,6 +815,7 @@ class ConditionalParameterModel(BaseGalaxyToolParameterModelDefinition): whens: List[ConditionalWhen] def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: + is_boolean = isinstance(self.test_parameter, BooleanParameterModel) test_param_name = self.test_parameter.name test_info = self.test_parameter.pydantic_template(state_representation) extra_validators = test_info.validators @@ -832,7 +829,7 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam initialize_test = ... else: initialize_test = None - + tag = str(discriminator) if not is_boolean else str(discriminator).lower() extra_kwd = {test_param_name: (Union[str, bool], initialize_test)} when_types.append( cast( @@ -845,7 +842,7 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam extra_kwd=extra_kwd, extra_validators=extra_validators, ), - Tag(str(discriminator)), + Tag(tag), ], ) ) diff --git a/lib/galaxy/tool_util/parameters/visitor.py b/lib/galaxy/tool_util/parameters/visitor.py index 5b8e059f2895..35d4bc176a3e 100644 --- a/lib/galaxy/tool_util/parameters/visitor.py +++ b/lib/galaxy/tool_util/parameters/visitor.py @@ -12,14 +12,27 @@ from typing_extensions import Protocol from .models import ( + ConditionalParameterModel, + ConditionalWhen, + RepeatParameterModel, + SectionParameterModel, simple_input_models, ToolParameterBundle, ToolParameterT, ) from .state import ToolState -VISITOR_NO_REPLACEMENT = object() -VISITOR_UNDEFINED = object() + +class VisitorNoReplacement: + pass + + +class VisitorUndefined: + pass + + +VISITOR_NO_REPLACEMENT = VisitorNoReplacement() +VISITOR_UNDEFINED = VisitorUndefined() class Callback(Protocol): @@ -47,20 +60,79 @@ def _visit_input_values( callback: Callback, no_replacement_value=VISITOR_NO_REPLACEMENT, ) -> Dict[str, Any]: - new_input_values = {} + + def _callback(name: str, old_values: Dict[str, Any], new_values: Dict[str, Any]): + input_value = old_values.get(name, VISITOR_UNDEFINED) + if input_value is VISITOR_UNDEFINED: + return + replacement = callback(model, input_value) + if replacement != no_replacement_value: + new_values[name] = replacement + else: + new_values[name] = input_value + + new_input_values: Dict[str, Any] = {} for model in input_models: name = model.name + parameter_type = model.parameter_type input_value = input_values.get(name, VISITOR_UNDEFINED) - replacement = callback(model, input_value) - if replacement != no_replacement_value: - new_input_values[name] = replacement - elif replacement is VISITOR_UNDEFINED: - pass + if input_value is VISITOR_UNDEFINED: + continue + + if parameter_type == "gx_repeat": + repeat_parameter = cast(RepeatParameterModel, model) + repeat_parameters = repeat_parameter.parameters + repeat_values = cast(list, input_value) + new_repeat_values = [] + for repeat_instance_values in repeat_values: + new_repeat_values.append( + _visit_input_values( + repeat_parameters, repeat_instance_values, callback, no_replacement_value=no_replacement_value + ) + ) + new_input_values[name] = new_repeat_values + elif parameter_type == "gx_section": + section_parameter = cast(SectionParameterModel, model) + section_parameters = section_parameter.parameters + section_values = cast(dict, input_value) + new_section_values = _visit_input_values( + section_parameters, section_values, callback, no_replacement_value=no_replacement_value + ) + new_input_values[name] = new_section_values + elif parameter_type == "gx_conditional": + conditional_parameter = cast(ConditionalParameterModel, model) + test_parameter = conditional_parameter.test_parameter + test_parameter_name = test_parameter.name + + conditional_values = cast(dict, input_value) + when: ConditionalWhen = _select_which_when(conditional_parameter, conditional_values) + new_conditional_values = _visit_input_values( + when.parameters, conditional_values, callback, no_replacement_value=no_replacement_value + ) + if test_parameter_name in conditional_values: + _callback(test_parameter_name, conditional_values, new_conditional_values) + new_input_values[name] = new_conditional_values else: - new_input_values[name] = input_value + _callback(name, input_values, new_input_values) return new_input_values +def _select_which_when(conditional: ConditionalParameterModel, state: dict) -> ConditionalWhen: + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + explicit_test_value = state.get(test_parameter_name) + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) + for when in conditional.whens: + print(when.discriminator) + print(type(when.discriminator)) + if test_value is None and when.is_default_when: + return when + elif test_value == when.discriminator: + return when + else: + raise Exception(f"Invalid conditional test value ({explicit_test_value}) for parameter ({test_parameter_name})") + + def flat_state_path(has_name: Union[str, ToolParameterT], prefix: Optional[str] = None) -> str: """Given a parameter name or model and an optional prefix, give 'flat' name for parameter in tree.""" if hasattr(has_name, "name"): diff --git a/lib/galaxy/tool_util/verify/_types.py b/lib/galaxy/tool_util/verify/_types.py index e5aa85f1ddb7..c532dab9aa69 100644 --- a/lib/galaxy/tool_util/verify/_types.py +++ b/lib/galaxy/tool_util/verify/_types.py @@ -19,10 +19,15 @@ ToolSourceTestOutputs, ) -# inputs that have been processed with parse.py and expanded out +# legacy inputs for working with POST /api/tools +# + inputs that have been processed with parse.py and expanded out ExpandedToolInputs = Dict[str, Any] -# ExpandedToolInputs where any model objects have been json-ified with to_dict() +# + ExpandedToolInputs where any model objects have been json-ified with to_dict() ExpandedToolInputsJsonified = Dict[str, Any] + +# modern inputs for working with POST /api/jobs* +RawTestToolRequest = Dict[str, Any] + ExtraFileInfoDictT = Dict[str, Any] RequiredFileTuple = Tuple[str, ExtraFileInfoDictT] RequiredFilesT = List[RequiredFileTuple] @@ -36,6 +41,8 @@ class ToolTestDescriptionDict(TypedDict): name: str test_index: int inputs: ExpandedToolInputsJsonified + request: NotRequired[Optional[Dict[str, Any]]] + request_schema: NotRequired[Optional[Dict[str, Any]]] outputs: ToolSourceTestOutputs output_collections: List[TestSourceTestOutputColllection] stdout: Optional[AssertionList] diff --git a/lib/galaxy/tool_util/verify/interactor.py b/lib/galaxy/tool_util/verify/interactor.py index 9e1dd5bf87d4..970a6e03ae94 100644 --- a/lib/galaxy/tool_util/verify/interactor.py +++ b/lib/galaxy/tool_util/verify/interactor.py @@ -35,8 +35,18 @@ ) from galaxy import util +from galaxy.tool_util.parameters import ( + DataCollectionRequest, + DataRequest, + encode_test, + input_models_from_json, + TestCaseToolState, + ToolParameterBundle, +) from galaxy.tool_util.parser.interface import ( AssertionList, + JsonTestCollectionDefDict, + JsonTestDatasetDefDict, TestCollectionDef, TestCollectionOutputDef, TestSourceTestOutputColllection, @@ -53,6 +63,7 @@ from ._types import ( ExpandedToolInputs, ExpandedToolInputsJsonified, + RawTestToolRequest, RequiredDataTablesT, RequiredFilesT, RequiredLocFileT, @@ -63,6 +74,9 @@ log = getLogger(__name__) +UseLegacyApiT = Literal["always", "never", "if_needed"] +DEFAULT_USE_LEGACY_API: UseLegacyApiT = "always" + # Off by default because it can pound the database pretty heavily # and result in sqlite errors on larger tests or larger numbers of # tests. @@ -102,6 +116,8 @@ def __getitem__(self, item): class ValidToolTestDict(TypedDict): inputs: ExpandedToolInputs + request: NotRequired[Optional[RawTestToolRequest]] + request_schema: NotRequired[Optional[Dict[str, Any]]] outputs: ToolSourceTestOutputs output_collections: List[TestSourceTestOutputColllection] stdout: NotRequired[AssertionList] @@ -148,7 +164,7 @@ def stage_data_in_history( # Upload any needed files upload_waits = [] - assert tool_id + assert tool_id, "Tool id not set" if UPLOAD_ASYNC: for test_data in all_test_data: @@ -236,6 +252,15 @@ def get_tests_summary(self): assert response.status_code == 200, f"Non 200 response from tool tests available API. [{response.content}]" return response.json() + def get_tool_inputs(self, tool_id: str, tool_version: Optional[str] = None) -> ToolParameterBundle: + url = f"tools/{tool_id}/inputs" + params = {"tool_version": tool_version} if tool_version else None + response = self._get(url, data=params) + assert response.status_code == 200, f"Non 200 response from tool inputs API. [{response.content}]" + raw_inputs_array = response.json() + tool_parameter_bundle = input_models_from_json(raw_inputs_array) + return tool_parameter_bundle + def get_tool_tests(self, tool_id: str, tool_version: Optional[str] = None) -> List[ToolTestDescriptionDict]: url = f"tools/{tool_id}/test_data" params = {"tool_version": tool_version} if tool_version else None @@ -366,9 +391,27 @@ def wait_for_content(): def wait_for_job(self, job_id: str, history_id: Optional[str] = None, maxseconds=DEFAULT_TOOL_TEST_WAIT) -> None: self.wait_for(lambda: self.__job_ready(job_id, history_id), maxseconds=maxseconds) + def wait_on_tool_request(self, tool_request_id: str): + def state(): + state_response = self._get(f"tool_requests/{tool_request_id}/state") + state_response.raise_for_status() + return state_response.json() + + def is_ready(): + is_complete = state() in ["submitted", "failed"] + return True if is_complete else None + + self.wait_for(is_ready, "waiting for tool request to submit") + return state() == "submitted" + + def get_tool_request(self, tool_request_id: str): + response_raw = self._get(f"tool_requests/{tool_request_id}") + response_raw.raise_for_status() + return response_raw.json() + def wait_for(self, func: Callable, what: str = "tool test run", **kwd) -> None: walltime_exceeded = int(kwd.get("maxseconds", DEFAULT_TOOL_TEST_WAIT)) - wait_on(func, what, walltime_exceeded) + return wait_on(func, what, walltime_exceeded) def get_job_stdio(self, job_id: str) -> Dict[str, Any]: return self.__get_job_stdio(job_id).json() @@ -562,8 +605,9 @@ def stage_data_async( else: file_content = self.test_data_download(tool_id, fname, is_output=False, tool_version=tool_version) files = {"files_0|file_data": file_content} + # upload1 will always be the legacy API... submit_response_object = self.__submit_tool( - history_id, "upload1", tool_input, extra_data={"type": "upload_dataset"}, files=files + history_id, "upload1", tool_input, extra_data={"type": "upload_dataset"}, files=files, use_legacy_api=True ) submit_response = ensure_tool_run_response_okay(submit_response_object, f"upload dataset {name}") assert ( @@ -589,39 +633,68 @@ def _ensure_valid_location_in(self, test_data: dict) -> Optional[str]: return location def run_tool( - self, testdef: "ToolTestDescription", history_id: str, resource_parameters: Optional[Dict[str, Any]] = None + self, + testdef: "ToolTestDescription", + history_id: str, + resource_parameters: Optional[Dict[str, Any]] = None, + use_legacy_api: UseLegacyApiT = DEFAULT_USE_LEGACY_API, ) -> RunToolResponse: # We need to handle the case where we've uploaded a valid compressed file since the upload # tool will have uncompressed it on the fly. resource_parameters = resource_parameters or {} - inputs_tree = testdef.inputs.copy() - for key, value in inputs_tree.items(): - values = [value] if not isinstance(value, list) else value - new_values = [] - for value in values: - if isinstance(value, TestCollectionDef): - hdca_id = self._create_collection(history_id, value) - new_values = [dict(src="hdca", id=hdca_id)] - elif value in self.uploads: - new_values.append(self.uploads[value]) - else: - new_values.append(value) - inputs_tree[key] = new_values + request = testdef.request + request_schema = testdef.request_schema + submit_with_legacy_api = use_legacy_api == "always" or (use_legacy_api == "if_needed" and request is None) + if submit_with_legacy_api: + inputs_tree = testdef.inputs.copy() + for key, value in inputs_tree.items(): + values = [value] if not isinstance(value, list) else value + new_values = [] + for value in values: + if isinstance(value, TestCollectionDef): + hdca_id = self._create_collection(history_id, value) + new_values = [dict(src="hdca", id=hdca_id)] + elif value in self.uploads: + new_values.append(self.uploads[value]) + else: + new_values.append(value) + inputs_tree[key] = new_values + + # HACK: Flatten single-value lists. Required when using expand_grouping + for key, value in inputs_tree.items(): + if isinstance(value, list) and len(value) == 1: + inputs_tree[key] = value[0] + else: + assert request is not None, "Request not set" + assert request_schema is not None, "Request schema not set" + parameters = request_schema["parameters"] + + def adapt_datasets(test_input: JsonTestDatasetDefDict) -> DataRequest: + return DataRequest(**self.uploads[test_input["path"]]) + + def adapt_collections(test_input: JsonTestCollectionDefDict) -> DataCollectionRequest: + test_collection_def = TestCollectionDef.from_dict(test_input) + hdca_id = self._create_collection(history_id, test_collection_def) + return DataCollectionRequest(src="hdca", id=hdca_id) + + test_case_state = TestCaseToolState(input_state=request) + inputs_tree = encode_test( + test_case_state, input_models_from_json(parameters), adapt_datasets, adapt_collections + ).input_state if resource_parameters: inputs_tree["__job_resource|__job_resource__select"] = "yes" for key, value in resource_parameters.items(): inputs_tree[f"__job_resource|{key}"] = value - # HACK: Flatten single-value lists. Required when using expand_grouping - for key, value in inputs_tree.items(): - if isinstance(value, list) and len(value) == 1: - inputs_tree[key] = value[0] - submit_response = None for _ in range(DEFAULT_TOOL_TEST_WAIT): submit_response = self.__submit_tool( - history_id, tool_id=testdef.tool_id, tool_input=inputs_tree, tool_version=testdef.tool_version + history_id, + tool_id=testdef.tool_id, + tool_input=inputs_tree, + tool_version=testdef.tool_version, + use_legacy_api=submit_with_legacy_api, ) if _are_tool_inputs_not_ready(submit_response): print("Tool inputs not ready yet") @@ -630,12 +703,37 @@ def run_tool( else: break submit_response_object = ensure_tool_run_response_okay(submit_response, "execute tool", inputs_tree) + if not submit_with_legacy_api: + tool_request_id = submit_response_object["tool_request_id"] + successful = self.wait_on_tool_request(tool_request_id) + if not successful: + request = self.get_tool_request(tool_request_id) or {} + raise Exception( + f"Tool request failure - state {request.get('state')}, message: {request.get('state_message')}" + ) + jobs = self.jobs_for_tool_request(tool_request_id) + outputs = OutputsDict() + output_collections = {} + if len(jobs) != 1: + raise Exception(f"Found incorrect number of jobs for tool request - was expecting a single job {jobs}") + assert len(jobs) == 1, jobs + job_id = jobs[0]["id"] + job_outputs = self.job_outputs(job_id) + for job_output in job_outputs: + if "dataset" in job_output: + outputs[job_output["name"]] = job_output["dataset"] + else: + output_collections[job_output["name"]] = job_output["dataset_collection_instance"] + else: + outputs = self.__dictify_outputs(submit_response_object) + output_collections = self.__dictify_output_collections(submit_response_object) + jobs = submit_response_object["jobs"] try: return RunToolResponse( inputs=inputs_tree, - outputs=self.__dictify_outputs(submit_response_object), - output_collections=self.__dictify_output_collections(submit_response_object), - jobs=submit_response_object["jobs"], + outputs=outputs, + output_collections=output_collections, + jobs=jobs, ) except KeyError: message = ( @@ -773,14 +871,24 @@ def format_for_summary(self, blob, empty_message, prefix="| "): contents = "\n".join(f"{prefix}{line.strip()}" for line in io.StringIO(blob).readlines() if line.rstrip("\n\r")) return contents or f"{prefix}*{empty_message}*" - def _dataset_provenance(self, history_id, id): + def _dataset_provenance(self, history_id: str, id: str): provenance = self._get(f"histories/{history_id}/contents/{id}/provenance").json() return provenance - def _dataset_info(self, history_id, id): + def _dataset_info(self, history_id: str, id: str): dataset_json = self._get(f"histories/{history_id}/contents/{id}").json() return dataset_json + def jobs_for_tool_request(self, tool_request_id: str) -> List[Dict[str, Any]]: + job_list_response = self._get("jobs", data={"tool_request_id": tool_request_id}) + job_list_response.raise_for_status() + return job_list_response.json() + + def job_outputs(self, job_id: str) -> List[Dict[str, Any]]: + outputs = self._get(f"jobs/{job_id}/outputs") + outputs.raise_for_status() + return outputs.json() + def __contents(self, history_id): history_contents_response = self._get(f"histories/{history_id}/contents") history_contents_response.raise_for_status() @@ -797,12 +905,33 @@ def _state_ready(self, job_id: str, error_msg: str): ) return None - def __submit_tool(self, history_id, tool_id, tool_input, extra_data=None, files=None, tool_version=None): + def __submit_tool( + self, + history_id, + tool_id, + tool_input, + extra_data=None, + files=None, + tool_version=None, + use_legacy_api: bool = True, + ): extra_data = extra_data or {} - data = dict( - history_id=history_id, tool_id=tool_id, inputs=dumps(tool_input), tool_version=tool_version, **extra_data - ) - return self._post("tools", files=files, data=data) + if use_legacy_api: + data = dict( + history_id=history_id, + tool_id=tool_id, + inputs=dumps(tool_input), + tool_version=tool_version, + **extra_data, + ) + return self._post("tools", files=files, data=data) + else: + assert files is None + data = dict( + history_id=history_id, tool_id=tool_id, inputs=tool_input, tool_version=tool_version, **extra_data + ) + submit_tool_request_response = self._post("jobs", data=data, json=True) + return submit_tool_request_response def ensure_user_with_email(self, email, password=None): admin_key = self.master_api_key @@ -1313,6 +1442,7 @@ def verify_tool( register_job_data: Optional[JobDataCallbackT] = None, test_index: int = 0, tool_version: Optional[str] = None, + use_legacy_api: UseLegacyApiT = DEFAULT_USE_LEGACY_API, quiet: bool = False, test_history: Optional[str] = None, no_history_cleanup: bool = False, @@ -1329,11 +1459,7 @@ def verify_tool( if client_test_config is None: client_test_config = NullClientTestConfig() tool_test_dicts = _tool_test_dicts or galaxy_interactor.get_tool_tests(tool_id, tool_version=tool_version) - tool_test_dict = tool_test_dicts[test_index] - if "test_index" not in tool_test_dict: - tool_test_dict["test_index"] = test_index - if "tool_id" not in tool_test_dict: - tool_test_dict["tool_id"] = tool_id + tool_test_dict: ToolTestDescriptionDict = tool_test_dicts[test_index] if tool_version is None and "tool_version" in tool_test_dict: tool_version = tool_test_dict.get("tool_version") @@ -1398,7 +1524,9 @@ def verify_tool( input_staging_exception = e raise try: - tool_response = galaxy_interactor.run_tool(testdef, test_history, resource_parameters=resource_parameters) + tool_response = galaxy_interactor.run_tool( + testdef, test_history, resource_parameters=resource_parameters, use_legacy_api=use_legacy_api + ) data_list, jobs, tool_inputs = tool_response.outputs, tool_response.jobs, tool_response.inputs data_collection_list = tool_response.output_collections except RunToolException as e: @@ -1683,6 +1811,8 @@ def adapt_tool_source_dict(processed_dict: ToolTestDict) -> ToolTestDescriptionD expect_test_failure: bool = DEFAULT_EXPECT_TEST_FAILURE inputs: ExpandedToolInputsJsonified = {} maxseconds: Optional[int] = None + request: Optional[Dict[str, Any]] = None + request_schema: Optional[Dict[str, Any]] = None if not error_in_test_definition: processed_test_dict = cast(ValidToolTestDict, processed_dict) @@ -1708,6 +1838,8 @@ def adapt_tool_source_dict(processed_dict: ToolTestDict) -> ToolTestDescriptionD expect_failure = processed_test_dict.get("expect_failure", DEFAULT_EXPECT_FAILURE) expect_test_failure = processed_test_dict.get("expect_test_failure", DEFAULT_EXPECT_TEST_FAILURE) inputs = processed_test_dict.get("inputs", {}) + request = processed_test_dict.get("request", None) + request_schema = processed_test_dict.get("request_schema", None) else: invalid_test_dict = cast(InvalidToolTestDict, processed_dict) maxseconds = DEFAULT_TOOL_TEST_WAIT @@ -1735,6 +1867,8 @@ def adapt_tool_source_dict(processed_dict: ToolTestDict) -> ToolTestDescriptionD expect_failure=expect_failure, expect_test_failure=expect_test_failure, inputs=inputs, + request=request, + request_schema=request_schema, ) @@ -1797,6 +1931,8 @@ class ToolTestDescription: expect_test_failure: bool exception: Optional[str] inputs: ExpandedToolInputs + request: Optional[Dict[str, Any]] + request_schema: Optional[Dict[str, Any]] outputs: ToolSourceTestOutputs output_collections: List[TestCollectionOutputDef] maxseconds: Optional[int] @@ -1825,6 +1961,8 @@ def __init__(self, json_dict: ToolTestDescriptionDict): self.expect_failure = json_dict.get("expect_failure", DEFAULT_EXPECT_FAILURE) self.expect_test_failure = json_dict.get("expect_test_failure", DEFAULT_EXPECT_TEST_FAILURE) self.inputs = expanded_inputs_from_json(json_dict.get("inputs", {})) + self.request = json_dict.get("request", None) + self.request_schema = json_dict.get("request_schema", None) self.tool_id = json_dict["tool_id"] self.tool_version = json_dict.get("tool_version") self.maxseconds = _get_maxseconds(json_dict) @@ -1857,6 +1995,8 @@ def to_dict(self) -> ToolTestDescriptionDict: "required_files": self.required_files, "required_data_tables": self.required_data_tables, "required_loc_files": self.required_loc_files, + "request": self.request, + "request_schema": self.request_schema, "error": self.error, "exception": self.exception, "maxseconds": self.maxseconds, diff --git a/lib/galaxy/tool_util/verify/parse.py b/lib/galaxy/tool_util/verify/parse.py index a3aee97eef0c..46f9853e1675 100644 --- a/lib/galaxy/tool_util/verify/parse.py +++ b/lib/galaxy/tool_util/verify/parse.py @@ -1,7 +1,9 @@ import logging import os +from dataclasses import dataclass from typing import ( Any, + Dict, Iterable, List, Optional, @@ -14,6 +16,8 @@ from galaxy.tool_util.parameters import ( input_models_for_tool_source, test_case_state as case_state, + TestCaseToolState, + ToolParameterBundleModel, ) from galaxy.tool_util.parser.interface import ( InputSource, @@ -64,15 +68,18 @@ def parse_tool_test_descriptions( profile = tool_source.parse_profile() for i, raw_test_dict in enumerate(raw_tests_dict.get("tests", [])): validation_exception: Optional[Exception] = None - if validate_on_load: + request_and_schema: Optional[TestRequestAndSchema] = None + try: tool_parameter_bundle = input_models_for_tool_source(tool_source) - try: - case_state(raw_test_dict, tool_parameter_bundle.parameters, profile, validate=True) - except Exception as e: - # TOOD: restrict types of validation exceptions a bit probably? - validation_exception = e + validated_test_case = case_state(raw_test_dict, tool_parameter_bundle.parameters, profile, validate=True) + request_and_schema = TestRequestAndSchema( + validated_test_case.tool_state, + tool_parameter_bundle, + ) + except Exception as e: + validation_exception = e - if validation_exception: + if validation_exception and validate_on_load: tool_id, tool_version = _tool_id_and_version(tool_source, tool_guid) test = ToolTestDescription.from_tool_source_dict( InvalidToolTestDict( @@ -88,13 +95,23 @@ def parse_tool_test_descriptions( ) ) else: - test = _description_from_tool_source(tool_source, raw_test_dict, i, tool_guid) + test = _description_from_tool_source(tool_source, raw_test_dict, i, tool_guid, request_and_schema) tests.append(test) return tests +@dataclass +class TestRequestAndSchema: + request: TestCaseToolState + request_schema: ToolParameterBundleModel + + def _description_from_tool_source( - tool_source: ToolSource, raw_test_dict: ToolSourceTest, test_index: int, tool_guid: Optional[str] + tool_source: ToolSource, + raw_test_dict: ToolSourceTest, + test_index: int, + tool_guid: Optional[str], + request_and_schema: Optional[TestRequestAndSchema], ) -> ToolTestDescription: required_files: RequiredFilesT = [] required_data_tables: RequiredDataTablesT = [] @@ -107,6 +124,12 @@ def _description_from_tool_source( if maxseconds is not None: maxseconds = int(maxseconds) + request: Optional[Dict[str, Any]] = None + request_schema: Optional[Dict[str, Any]] = None + if request_and_schema: + request = request_and_schema.request.input_state + request_schema = request_and_schema.request_schema.dict() + tool_id, tool_version = _tool_id_and_version(tool_source, tool_guid) processed_test_dict: Union[ValidToolTestDict, InvalidToolTestDict] try: @@ -121,6 +144,8 @@ def _description_from_tool_source( processed_test_dict = ValidToolTestDict( { "inputs": processed_inputs, + "request": request, + "request_schema": request_schema, "outputs": raw_test_dict["outputs"], "output_collections": raw_test_dict["output_collections"], "num_outputs": num_outputs, diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index e1dc1a4365d9..d346465c9b85 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -49,6 +49,7 @@ from galaxy.model import ( Job, StoredWorkflow, + ToolRequest, ) from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections @@ -71,6 +72,12 @@ expand_ontology_data, ) from galaxy.tool_util.output_checker import DETECTED_JOB_STATE +from galaxy.tool_util.parameters import ( + input_models_for_pages, + JobInternalToolState, + RequestInternalToolState, + ToolParameterBundle, +) from galaxy.tool_util.parser import ( get_tool_source, get_tool_source_from_representation, @@ -150,7 +157,10 @@ UploadDataset, ) from galaxy.tools.parameters.input_translation import ToolInputTranslator -from galaxy.tools.parameters.meta import expand_meta_parameters +from galaxy.tools.parameters.meta import ( + expand_meta_parameters, + expand_meta_parameters_async, +) from galaxy.tools.parameters.workflow_utils import workflow_building_modes from galaxy.tools.parameters.wrapped_json import json_wrap from galaxy.util import ( @@ -204,7 +214,8 @@ DEFAULT_RERUN_REMAP_JOB_ID, DEFAULT_SET_OUTPUT_HID, DEFAULT_USE_CACHED_JOB, - execute as execute_job, + execute as execute_sync, + execute_async, ExecutionSlice, JobCallbackT, MappingParameters, @@ -753,7 +764,7 @@ class _Options(Bunch): refresh: str -class Tool(UsesDictVisibleKeys): +class Tool(UsesDictVisibleKeys, ToolParameterBundle): """ Represents a computational tool that can be executed through Galaxy. """ @@ -1423,6 +1434,11 @@ def parse_inputs(self, tool_source: ToolSource): self.inputs: Dict[str, Union[Group, ToolParameter]] = {} pages = tool_source.parse_input_pages() enctypes: Set[str] = set() + try: + parameters = input_models_for_pages(pages) + self.parameters = parameters + except Exception: + pass if pages.inputs_defined: if hasattr(pages, "input_elem"): input_elem = pages.input_elem @@ -1814,6 +1830,53 @@ def visit_inputs(self, values, callback): if self.check_values: visit_input_values(self.inputs, values, callback) + def expand_incoming_async( + self, + request_context: WorkRequestContext, + tool_request_internal_state: RequestInternalToolState, + rerun_remap_job_id: Optional[int], + ) -> Tuple[ + List[ToolStateJobInstancePopulatedT], + List[ToolStateJobInstancePopulatedT], + Optional[MatchingCollections], + List[JobInternalToolState], + ]: + """The tool request API+tasks version of expand_incoming. + + This is responsible for breaking the map over job requests into individual jobs for execution. + """ + if self.input_translator: + raise exceptions.RequestParameterInvalidException( + "Failure executing tool request with id '%s' (cannot validate inputs from this type of data source tool - please POST to /api/tools).", + self.id, + ) + + set_dataset_matcher_factory(request_context, self) + + job_tool_states: List[JobInternalToolState] + collection_info: Optional[MatchingCollections] + job_tool_states, collection_info = expand_meta_parameters_async( + request_context.app, self, tool_request_internal_state + ) + + self._ensure_expansion_is_valid(job_tool_states, rerun_remap_job_id) + + # Process incoming data + validation_timer = self.app.execution_timer_factory.get_timer( + "internals.galaxy.tools.validation", + "Validated and populated state for tool request", + ) + all_errors = [] + all_params: List[ToolStateJobInstancePopulatedT] = [] + for expanded_incoming in job_tool_states: + params, errors = self._populate(request_context, expanded_incoming.input_state, "21.01") + all_errors.append(errors) + all_params.append(params) + unset_dataset_matcher_factory(request_context) + + log.info(validation_timer) + return all_params, all_errors, collection_info, job_tool_states + def expand_incoming( self, request_context: WorkRequestContext, incoming: ToolRequestT, input_format: InputFormatT = "legacy" ) -> Tuple[ @@ -1851,7 +1914,9 @@ def expand_incoming( return all_params, all_errors, rerun_remap_job_id, collection_info def _ensure_expansion_is_valid( - self, expanded_incomings: List[ToolStateJobInstanceT], rerun_remap_job_id: Optional[int] + self, + expanded_incomings: Union[List[JobInternalToolState], List[ToolStateJobInstanceT]], + rerun_remap_job_id: Optional[int], ) -> None: """If the request corresponds to multiple jobs but this doesn't work with request configuration - raise an error. @@ -1929,6 +1994,38 @@ def completed_jobs( completed_jobs[i] = None return completed_jobs + def handle_input_async( + self, + trans, + tool_request: ToolRequest, + history: Optional[model.History] = None, + use_cached_job: bool = DEFAULT_USE_CACHED_JOB, + preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID, + rerun_remap_job_id: Optional[int] = None, + input_format: str = "legacy", + ): + """The tool request API+tasks version of handle_input.""" + request_context = proxy_work_context_for_history(trans, history=history) + tool_request_state = RequestInternalToolState(tool_request.request) + all_params, all_errors, collection_info, job_tool_states = self.expand_incoming_async( + request_context, tool_request_state, rerun_remap_job_id + ) + self.handle_incoming_errors(all_errors) + + mapping_params = MappingParameters(tool_request.request, all_params, tool_request_state, job_tool_states) + completed_jobs: Dict[int, Optional[model.Job]] = self.completed_jobs(trans, use_cached_job, all_params) + execute_async( + request_context, + self, + mapping_params, + request_context.history, + tool_request, + completed_jobs, + rerun_remap_job_id=rerun_remap_job_id, + preferred_object_store_id=preferred_object_store_id, + collection_info=collection_info, + ) + def handle_input( self, trans, @@ -1954,9 +2051,9 @@ def handle_input( # If there were errors, we stay on the same page and display them self.handle_incoming_errors(all_errors) - mapping_params = MappingParameters(incoming, all_params) + mapping_params = MappingParameters(incoming, all_params, None, None) completed_jobs: Dict[int, Optional[model.Job]] = self.completed_jobs(trans, use_cached_job, all_params) - execution_tracker = execute_job( + execution_tracker = execute_sync( trans, self, mapping_params, diff --git a/lib/galaxy/tools/execute.py b/lib/galaxy/tools/execute.py index d6e65f592a6e..6b5c7e4a18a1 100644 --- a/lib/galaxy/tools/execute.py +++ b/lib/galaxy/tools/execute.py @@ -24,12 +24,17 @@ from galaxy import model from galaxy.exceptions import ToolInputsNotOKException +from galaxy.model import ToolRequest from galaxy.model.base import transaction from galaxy.model.dataset_collections.matching import MatchingCollections from galaxy.model.dataset_collections.structure import ( get_structure, tool_output_to_structure, ) +from galaxy.tool_util.parameters.state import ( + JobInternalToolState, + RequestInternalToolState, +) from galaxy.tool_util.parser import ToolOutputCollectionPart from galaxy.tools.execution_helpers import ( filter_output, @@ -69,8 +74,58 @@ def __init__(self, execution_tracker: "ExecutionTracker"): class MappingParameters(NamedTuple): + # the raw request - might correspond to multiple jobs param_template: ToolRequestT + # parameters corresponding to individual job param_combinations: List[ToolStateJobInstancePopulatedT] + # schema driven parameters + # model validated tool request - might correspond to multiple jobs + validated_param_template: Optional[RequestInternalToolState] = None + # validated job parameters for individual jobs + validated_param_combinations: Optional[List[JobInternalToolState]] = None + + def ensure_validated(self): + assert self.validated_param_template is not None + assert self.validated_param_combinations is not None + + +def execute_async( + trans, + tool: "Tool", + mapping_params: MappingParameters, + history: model.History, + tool_request: ToolRequest, + completed_jobs: Optional[CompletedJobsT] = None, + rerun_remap_job_id: Optional[int] = None, + preferred_object_store_id: Optional[str] = None, + collection_info: Optional[MatchingCollections] = None, + workflow_invocation_uuid: Optional[str] = None, + invocation_step: Optional[model.WorkflowInvocationStep] = None, + max_num_jobs: Optional[int] = None, + job_callback: Optional[Callable] = None, + workflow_resource_parameters: Optional[Dict[str, Any]] = None, + validate_outputs: bool = False, +) -> "ExecutionTracker": + """The tool request/async version of execute.""" + completed_jobs = completed_jobs or {} + mapping_params.ensure_validated() + return _execute( + trans, + tool, + mapping_params, + history, + tool_request, + rerun_remap_job_id, + preferred_object_store_id, + collection_info, + workflow_invocation_uuid, + invocation_step, + max_num_jobs, + job_callback, + completed_jobs, + workflow_resource_parameters, + validate_outputs, + ) def execute( @@ -88,12 +143,48 @@ def execute( completed_jobs: Optional[CompletedJobsT] = None, workflow_resource_parameters: Optional[WorkflowResourceParametersT] = None, validate_outputs: bool = False, -): +) -> "ExecutionTracker": """ Execute a tool and return object containing summary (output data, number of failures, etc...). """ completed_jobs = completed_jobs or {} + return _execute( + trans, + tool, + mapping_params, + history, + None, + rerun_remap_job_id, + preferred_object_store_id, + collection_info, + workflow_invocation_uuid, + invocation_step, + max_num_jobs, + job_callback, + completed_jobs, + workflow_resource_parameters, + validate_outputs, + ) + + +def _execute( + trans, + tool: "Tool", + mapping_params: MappingParameters, + history: model.History, + tool_request: Optional[model.ToolRequest], + rerun_remap_job_id: Optional[int], + preferred_object_store_id: Optional[str], + collection_info: Optional[MatchingCollections], + workflow_invocation_uuid: Optional[str], + invocation_step: Optional[model.WorkflowInvocationStep], + max_num_jobs: Optional[int], + job_callback: Optional[Callable], + completed_jobs: Dict[int, Optional[model.Job]], + workflow_resource_parameters: Optional[Dict[str, Any]], + validate_outputs: bool, +) -> "ExecutionTracker": if max_num_jobs is not None: assert invocation_step is not None if rerun_remap_job_id: @@ -118,8 +209,9 @@ def execute_single_job(execution_slice: "ExecutionSlice", completed_job: Optiona "internals.galaxy.tools.execute.job_single", SINGLE_EXECUTION_SUCCESS_MESSAGE ) params = execution_slice.param_combination - if "__data_manager_mode" in mapping_params.param_template: - params["__data_manager_mode"] = mapping_params.param_template["__data_manager_mode"] + request_state = mapping_params.param_template + if "__data_manager_mode" in request_state: + params["__data_manager_mode"] = request_state["__data_manager_mode"] if workflow_invocation_uuid: params["__workflow_invocation_uuid__"] = workflow_invocation_uuid elif "__workflow_invocation_uuid__" in params: @@ -148,6 +240,8 @@ def execute_single_job(execution_slice: "ExecutionSlice", completed_job: Optiona skip=skip, ) if job: + if tool_request: + job.tool_request = tool_request log.debug(job_timer.to_str(tool_id=tool.id, job_id=job.id)) execution_tracker.record_success(execution_slice, job, result) # associate dataset instances with the job that creates them @@ -188,7 +282,11 @@ def execute_single_job(execution_slice: "ExecutionSlice", completed_job: Optiona has_remaining_jobs = True break else: - skip = execution_slice.param_combination.pop("__when_value__", None) is False + slice_params = execution_slice.param_combination + if isinstance(slice_params, JobInternalToolState): + slice_params = slice_params.input_state + + skip = slice_params.pop("__when_value__", None) is False execute_single_job(execution_slice, completed_jobs[i], skip=skip) history = execution_slice.history or history jobs_executed += 1 diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index 9669b62771e9..27f3dbedf44b 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -13,6 +13,7 @@ from collections.abc import MutableMapping from typing import ( Any, + cast, Dict, List, Optional, @@ -41,6 +42,7 @@ ) from galaxy.model.dataset_collections import builder from galaxy.schema.fetch_data import FilesPayload +from galaxy.tool_util.parameters.factory import get_color_value from galaxy.tool_util.parser import get_input_source as ensure_input_source from galaxy.tool_util.parser.util import ( boolean_is_checked, @@ -649,6 +651,7 @@ def legal_values(self): return [self.truevalue, self.falsevalue] +# Used only by upload1, deprecated. class FileToolParameter(ToolParameter): """ Parameter that takes an uploaded file as a value. @@ -848,7 +851,7 @@ class ColorToolParameter(ToolParameter): def __init__(self, tool, input_source): input_source = ensure_input_source(input_source) super().__init__(tool, input_source) - self.value = input_source.get("value", "#000000") + self.value = get_color_value(input_source) self.rgb = input_source.get_bool("rgb", False) def get_initial_value(self, trans, other_values): @@ -2484,7 +2487,10 @@ def from_json(self, value, trans=None, other_values=None): rval = value elif isinstance(value, MutableMapping) and "src" in value and "id" in value: if value["src"] == "hdca": - rval = session.get(HistoryDatasetCollectionAssociation, trans.security.decode_id(value["id"])) + rval = cast( + HistoryDatasetCollectionAssociation, + src_id_to_item(sa_session=trans.sa_session, value=value, security=trans.security), + ) elif isinstance(value, list): if len(value) > 0: value = value[0] diff --git a/lib/galaxy/tools/parameters/meta.py b/lib/galaxy/tools/parameters/meta.py index f2d8ba1a68d1..5fe669e6cc73 100644 --- a/lib/galaxy/tools/parameters/meta.py +++ b/lib/galaxy/tools/parameters/meta.py @@ -19,6 +19,11 @@ matching, subcollections, ) +from galaxy.tool_util.parameters import ( + JobInternalToolState, + RequestInternalToolState, + ToolParameterBundle, +) from galaxy.util import permutations from . import visit_input_values from .wrapped import process_key @@ -229,8 +234,50 @@ def classifier(input_key): return expanded_incomings, collection_info +Expanded2T = Tuple[List[JobInternalToolState], Optional[matching.MatchingCollections]] + + +def expand_meta_parameters_async(app, tool: ToolParameterBundle, incoming: RequestInternalToolState) -> Expanded2T: + # TODO: Tool State 2.0 Follow Up: rework this to only test permutation at actual input value roots. + + def classifier(input_key): + value = incoming.input_state[input_key] + if isinstance(value, dict) and "values" in value: + # Explicit meta wrapper for inputs... + is_batch = value.get("__class__", "Batch") + is_linked = value.get("linked", True) + if is_batch and is_linked: + classification = permutations.input_classification.MATCHED + elif is_batch: + classification = permutations.input_classification.MULTIPLIED + else: + classification = permutations.input_classification.SINGLE + if __collection_multirun_parameter(value): + collection_value = value["values"][0] + values = __expand_collection_parameter_async( + app, input_key, collection_value, collections_to_match, linked=is_linked + ) + else: + values = value["values"] + else: + classification = permutations.input_classification.SINGLE + values = value + return classification, values + + collections_to_match = matching.CollectionsToMatch() + expanded_incoming_dicts = permutations.expand_multi_inputs(incoming.input_state, classifier) + if collections_to_match.has_collections(): + collection_info = app.dataset_collection_manager.match_collections(collections_to_match) + else: + collection_info = None + expanded_incomings = [JobInternalToolState(d) for d in expanded_incoming_dicts] + for expanded_state in expanded_incomings: + expanded_state.validate(tool) + return expanded_incomings, collection_info + + def __expand_collection_parameter(trans, input_key, incoming_val, collections_to_match, linked=False): - # If subcollectin multirun of data_collection param - value will + # If subcollection multirun of data_collection param - value will # be "hdca_id|subcollection_type" else it will just be hdca_id if "|" in incoming_val: encoded_hdc_id, subcollection_type = incoming_val.split("|", 1) @@ -261,8 +308,34 @@ def __expand_collection_parameter(trans, input_key, incoming_val, collections_to return hdas +def __expand_collection_parameter_async(app, input_key, incoming_val, collections_to_match, linked=False): + # If subcollection multirun of data_collection param - value will + # be "hdca_id|subcollection_type" else it will just be hdca_id + try: + src = incoming_val["src"] + if src != "hdca": + raise exceptions.ToolMetaParameterException(f"Invalid dataset collection source type {src}") + hdc_id = incoming_val["id"] + subcollection_type = incoming_val.get("map_over_type", None) + except TypeError: + hdc_id = incoming_val + subcollection_type = None + hdc = app.model.context.get(HistoryDatasetCollectionAssociation, hdc_id) + collections_to_match.add(input_key, hdc, subcollection_type=subcollection_type, linked=linked) + if subcollection_type is not None: + subcollection_elements = subcollections.split_dataset_collection_instance(hdc, subcollection_type) + return subcollection_elements + else: + hdas = [] + for element in hdc.collection.dataset_elements: + hda = element.dataset_instance + hda.element_identifier = element.element_identifier + hdas.append(hda) + return hdas + + def __collection_multirun_parameter(value): - is_batch = value.get("batch", False) + is_batch = value.get("batch", False) or value.get("__class__", None) == "Batch" if not is_batch: return False diff --git a/lib/galaxy/webapps/galaxy/api/histories.py b/lib/galaxy/webapps/galaxy/api/histories.py index 57b18c1f1c17..e1cbdee66d7e 100644 --- a/lib/galaxy/webapps/galaxy/api/histories.py +++ b/lib/galaxy/webapps/galaxy/api/histories.py @@ -61,6 +61,7 @@ ShareWithPayload, SharingStatus, StoreExportPayload, + ToolRequestModel, UpdateHistoryPayload, WriteStoreToPayload, ) @@ -374,6 +375,17 @@ def citations( ) -> List[Any]: return self.service.citations(trans, history_id) + @router.get( + "/api/histories/{history_id}/tool_requests", + summary="Return all the tool requests for the tools submitted to this history.", + ) + def tool_requests( + self, + history_id: HistoryIDPathParam, + trans: ProvidesHistoryContext = DependsOnTrans, + ) -> List[ToolRequestModel]: + return self.service.tool_requests(trans, history_id) + @router.post( "/api/histories", summary="Creates a new history.", diff --git a/lib/galaxy/webapps/galaxy/api/jobs.py b/lib/galaxy/webapps/galaxy/api/jobs.py index 9eb5efb40938..b99ae5e2cfba 100644 --- a/lib/galaxy/webapps/galaxy/api/jobs.py +++ b/lib/galaxy/webapps/galaxy/api/jobs.py @@ -44,6 +44,7 @@ JobInputAssociation, JobInputSummary, JobOutputAssociation, + JobOutputCollectionAssociation, ReportJobErrorPayload, SearchJobsPayload, ShowFullJobResponse, @@ -67,11 +68,14 @@ ) from galaxy.webapps.galaxy.api.common import query_parameter_as_list from galaxy.webapps.galaxy.services.jobs import ( + JobCreateResponse, JobIndexPayload, JobIndexViewEnum, + JobRequest, JobsService, ) from galaxy.work.context import proxy_work_context_for_history +from .tools import validate_not_protected log = logging.getLogger(__name__) @@ -155,6 +159,12 @@ description="Limit listing of jobs to those that match the specified implicit collection job ID. If none, jobs from any implicit collection execution (or from no implicit collection execution) may be returned.", ) +ToolRequestIdQueryParam: Optional[DecodedDatabaseIdField] = Query( + default=None, + title="Tool Request ID", + description="Limit listing of jobs to those that were created from the supplied tool request ID. If none, jobs from any tool request (or from no workflows) may be returned.", +) + SortByQueryParam: JobIndexSortByEnum = Query( default=JobIndexSortByEnum.update_time, title="Sort By", @@ -207,6 +217,13 @@ class FastAPIJobs: service: JobsService = depends(JobsService) + @router.post("/api/jobs") + def create( + self, trans: ProvidesHistoryContext = DependsOnTrans, job_request: JobRequest = Body(...) + ) -> JobCreateResponse: + validate_not_protected(job_request.tool_id) + return self.service.create(trans, job_request) + @router.get("/api/jobs") def index( self, @@ -223,6 +240,7 @@ def index( workflow_id: Optional[DecodedDatabaseIdField] = WorkflowIdQueryParam, invocation_id: Optional[DecodedDatabaseIdField] = InvocationIdQueryParam, implicit_collection_jobs_id: Optional[DecodedDatabaseIdField] = ImplicitCollectionJobsIdQueryParam, + tool_request_id: Optional[DecodedDatabaseIdField] = ToolRequestIdQueryParam, order_by: JobIndexSortByEnum = SortByQueryParam, search: Optional[str] = SearchQueryParam, limit: int = LimitQueryParam, @@ -241,6 +259,7 @@ def index( workflow_id=workflow_id, invocation_id=invocation_id, implicit_collection_jobs_id=implicit_collection_jobs_id, + tool_request_id=tool_request_id, order_by=order_by, search=search, limit=limit, @@ -361,12 +380,14 @@ def outputs( self, job_id: JobIdPathParam, trans: ProvidesUserContext = DependsOnTrans, - ) -> List[JobOutputAssociation]: + ) -> List[Union[JobOutputAssociation, JobOutputCollectionAssociation]]: job = self.service.get_job(trans=trans, job_id=job_id) associations = self.service.dictify_associations(trans, job.output_datasets, job.output_library_datasets) - output_associations = [] + output_associations: List[Union[JobOutputAssociation, JobOutputCollectionAssociation]] = [] for association in associations: output_associations.append(JobOutputAssociation(name=association.name, dataset=association.dataset)) + + output_associations.extend(self.service.dictify_output_collection_associations(trans, job)) return output_associations @router.get( diff --git a/lib/galaxy/webapps/galaxy/api/tools.py b/lib/galaxy/webapps/galaxy/api/tools.py index f5dacdc64541..9b6ea943f0cb 100644 --- a/lib/galaxy/webapps/galaxy/api/tools.py +++ b/lib/galaxy/webapps/galaxy/api/tools.py @@ -12,6 +12,8 @@ from fastapi import ( Body, Depends, + Path, + Query, Request, UploadFile, ) @@ -27,10 +29,14 @@ from galaxy.managers.context import ProvidesHistoryContext from galaxy.managers.hdas import HDAManager from galaxy.managers.histories import HistoryManager +from galaxy.model import ToolRequest from galaxy.schema.fetch_data import ( FetchDataFormPayload, FetchDataPayload, ) +from galaxy.schema.fields import DecodedDatabaseIdField +from galaxy.schema.schema import ToolRequestModel +from galaxy.tool_util.parameters import ToolParameterT from galaxy.tool_util.verify import ToolTestDescriptionDict from galaxy.tools.evaluation import global_tool_errors from galaxy.util.zipstream import ZipstreamWrapper @@ -42,7 +48,11 @@ ) from galaxy.webapps.base.controller import UsesVisualizationMixin from galaxy.webapps.base.webapp import GalaxyWebTransaction -from galaxy.webapps.galaxy.services.tools import ToolsService +from galaxy.webapps.galaxy.services.base import tool_request_to_model +from galaxy.webapps.galaxy.services.tools import ( + ToolRunReference, + ToolsService, +) from . import ( APIContentTypeRoute, as_form, @@ -74,6 +84,14 @@ class JsonApiRoute(APIContentTypeRoute): FetchDataForm = as_form(FetchDataFormPayload) +ToolIDPathParam: str = Path( + ..., + title="Tool ID", + description="The tool ID for the lineage stored in Galaxy's toolbox.", +) +ToolVersionQueryParam: Optional[str] = Query(default=None, title="Tool Version", description="") + + @router.cbv class FetchTools: service: ToolsService = depends(ToolsService) @@ -104,6 +122,57 @@ async def fetch_form( files2.append(value) return self.service.create_fetch(trans, payload, files2) + @router.get( + "/api/tool_requests/{id}", + summary="Get tool request state.", + ) + def get_tool_request( + self, + id: DecodedDatabaseIdField, + trans: ProvidesHistoryContext = DependsOnTrans, + ) -> ToolRequestModel: + tool_request = self._get_tool_request_or_raise_not_found(trans, id) + return tool_request_to_model(tool_request) + + @router.get( + "/api/tool_requests/{id}/state", + summary="Get tool request state.", + ) + def tool_request_state( + self, + id: DecodedDatabaseIdField, + trans: ProvidesHistoryContext = DependsOnTrans, + ) -> str: + tool_request = self._get_tool_request_or_raise_not_found(trans, id) + state = tool_request.state + if not state: + raise exceptions.InconsistentDatabase() + return cast(str, state) + + def _get_tool_request_or_raise_not_found( + self, trans: ProvidesHistoryContext, id: DecodedDatabaseIdField + ) -> ToolRequest: + tool_request: Optional[ToolRequest] = cast( + Optional[ToolRequest], trans.app.model.context.query(ToolRequest).get(id) + ) + if tool_request is None: + raise exceptions.ObjectNotFound() + assert tool_request + return tool_request + + @router.get( + "/api/tools/{tool_id}/inputs", + summary="Get tool inputs.", + ) + def tool_inputs( + self, + tool_id: str = ToolIDPathParam, + tool_version: Optional[str] = ToolVersionQueryParam, + trans: ProvidesHistoryContext = DependsOnTrans, + ) -> List[ToolParameterT]: + tool_run_ref = ToolRunReference(tool_id=tool_id, tool_version=tool_version, tool_uuid=None) + return self.service.inputs(trans, tool_run_ref) + class ToolsController(BaseGalaxyAPIController, UsesVisualizationMixin): """ @@ -584,16 +653,17 @@ def create(self, trans: GalaxyWebTransaction, payload, **kwd): :type input_format: str """ tool_id = payload.get("tool_id") - tool_uuid = payload.get("tool_uuid") - if tool_id in PROTECTED_TOOLS: - raise exceptions.RequestParameterInvalidException( - f"Cannot execute tool [{tool_id}] directly, must use alternative endpoint." - ) - if tool_id is None and tool_uuid is None: - raise exceptions.RequestParameterInvalidException("Must specify a valid tool_id to use this endpoint.") + validate_not_protected(tool_id) return self.service._create(trans, payload, **kwd) +def validate_not_protected(tool_id: Optional[str]): + if tool_id in PROTECTED_TOOLS: + raise exceptions.RequestParameterInvalidException( + f"Cannot execute tool [{tool_id}] directly, must use alternative endpoint." + ) + + def _kwd_or_payload(kwd: Dict[str, Any]) -> Dict[str, Any]: if "payload" in kwd: kwd = cast(Dict[str, Any], kwd.get("payload")) diff --git a/lib/galaxy/webapps/galaxy/services/base.py b/lib/galaxy/webapps/galaxy/services/base.py index dcf91e80f2f2..423df8f4b96d 100644 --- a/lib/galaxy/webapps/galaxy/services/base.py +++ b/lib/galaxy/webapps/galaxy/services/base.py @@ -23,13 +23,19 @@ ) from galaxy.managers.context import ProvidesUserContext from galaxy.managers.model_stores import create_objects_from_store -from galaxy.model import User +from galaxy.model import ( + ToolRequest, + User, +) from galaxy.model.store import ( get_export_store_factory, ModelExportStore, ) from galaxy.schema.fields import EncodedDatabaseIdField -from galaxy.schema.schema import AsyncTaskResultSummary +from galaxy.schema.schema import ( + AsyncTaskResultSummary, + ToolRequestModel, +) from galaxy.security.idencoding import IdEncodingHelper from galaxy.short_term_storage import ( ShortTermStorageAllocator, @@ -193,3 +199,13 @@ def async_task_summary(async_result: AsyncResult) -> AsyncTaskResultSummary: name=name, queue=queue, ) + + +def tool_request_to_model(tool_request: ToolRequest) -> ToolRequestModel: + as_dict = { + "id": tool_request.id, + "request": tool_request.request, + "state": tool_request.state, + "state_message": tool_request.state_message, + } + return ToolRequestModel.model_validate(as_dict) diff --git a/lib/galaxy/webapps/galaxy/services/histories.py b/lib/galaxy/webapps/galaxy/services/histories.py index 32e8a9fa8a1c..764ce5a748a1 100644 --- a/lib/galaxy/webapps/galaxy/services/histories.py +++ b/lib/galaxy/webapps/galaxy/services/histories.py @@ -70,6 +70,7 @@ ShareHistoryWithStatus, ShareWithPayload, StoreExportPayload, + ToolRequestModel, WriteStoreToPayload, ) from galaxy.schema.tasks import ( @@ -87,6 +88,7 @@ model_store_storage_target, ServesExportStores, ServiceBase, + tool_request_to_model, ) from galaxy.webapps.galaxy.services.notifications import NotificationService from galaxy.webapps.galaxy.services.sharable import ShareableService @@ -533,6 +535,13 @@ def published( ] return rval + def tool_requests( + self, trans: ProvidesHistoryContext, history_id: DecodedDatabaseIdField + ) -> List[ToolRequestModel]: + history = self.manager.get_accessible(history_id, trans.user, current_history=trans.history) + tool_requests = history.tool_requests + return [tool_request_to_model(tr) for tr in tool_requests] + def citations(self, trans: ProvidesHistoryContext, history_id: DecodedDatabaseIdField): """ Return all the citations for the tools used to produce the datasets in diff --git a/lib/galaxy/webapps/galaxy/services/jobs.py b/lib/galaxy/webapps/galaxy/services/jobs.py index 5c39175567bf..7e860ffbb85d 100644 --- a/lib/galaxy/webapps/galaxy/services/jobs.py +++ b/lib/galaxy/webapps/galaxy/services/jobs.py @@ -1,3 +1,4 @@ +import logging from enum import Enum from typing import ( Any, @@ -6,24 +7,83 @@ Optional, ) +from pydantic import ( + BaseModel, + Field, +) + from galaxy import ( exceptions, model, ) +from galaxy.celery.tasks import queue_jobs from galaxy.managers import hdas from galaxy.managers.base import security_check -from galaxy.managers.context import ProvidesUserContext +from galaxy.managers.context import ( + ProvidesHistoryContext, + ProvidesUserContext, +) +from galaxy.managers.histories import HistoryManager from galaxy.managers.jobs import ( JobManager, JobSearch, view_show_job, ) -from galaxy.model import Job -from galaxy.schema.fields import DecodedDatabaseIdField -from galaxy.schema.jobs import JobAssociation -from galaxy.schema.schema import JobIndexQueryPayload +from galaxy.model import ( + Job, + ToolRequest, + ToolSource as ToolSourceModel, +) +from galaxy.model.base import transaction +from galaxy.schema.fields import ( + DecodedDatabaseIdField, + EncodedDatabaseIdField, +) +from galaxy.schema.jobs import ( + JobAssociation, + JobOutputCollectionAssociation, +) +from galaxy.schema.schema import ( + AsyncTaskResultSummary, + JobIndexQueryPayload, +) +from galaxy.schema.tasks import ( + QueueJobs, + ToolSource, +) from galaxy.security.idencoding import IdEncodingHelper -from galaxy.webapps.galaxy.services.base import ServiceBase +from galaxy.tool_util.parameters import ( + decode, + RequestToolState, +) +from galaxy.webapps.galaxy.services.base import ( + async_task_summary, + ServiceBase, +) +from .tools import ( + ToolRunReference, + validate_tool_for_running, +) + +log = logging.getLogger(__name__) + + +class JobRequest(BaseModel): + tool_id: Optional[str] = Field(default=None, title="tool_id", description="TODO") + tool_uuid: Optional[str] = Field(default=None, title="tool_uuid", description="TODO") + tool_version: Optional[str] = Field(default=None, title="tool_version", description="TODO") + history_id: Optional[DecodedDatabaseIdField] = Field(default=None, title="history_id", description="TODO") + inputs: Optional[Dict[str, Any]] = Field(default_factory=lambda: {}, title="Inputs", description="TODO") + use_cached_jobs: Optional[bool] = Field(default=None, title="use_cached_jobs") + rerun_remap_job_id: Optional[DecodedDatabaseIdField] = Field( + default=None, title="rerun_remap_job_id", description="TODO" + ) + send_email_notification: bool = Field(default=False, title="Send Email Notification", description="TODO") + + +class JobCreateResponse(BaseModel): + tool_request_id: EncodedDatabaseIdField + task_result: AsyncTaskResultSummary class JobIndexViewEnum(str, Enum): @@ -39,6 +99,7 @@ class JobsService(ServiceBase): job_manager: JobManager job_search: JobSearch hda_manager: hdas.HDAManager + history_manager: HistoryManager def __init__( self, @@ -46,11 +107,13 @@ def __init__( job_manager: JobManager, job_search: JobSearch, hda_manager: hdas.HDAManager, + history_manager: HistoryManager, ): super().__init__(security=security) self.job_manager = job_manager self.job_search = job_search self.hda_manager = hda_manager + self.history_manager = history_manager def show( self, @@ -146,3 +209,62 @@ def __dictify_association(self, trans, job_dataset_association) -> JobAssociatio else: dataset_dict = {"src": "ldda", "id": dataset.id} return JobAssociation(name=job_dataset_association.name, dataset=dataset_dict) + + def dictify_output_collection_associations(self, trans, job: model.Job) -> List[JobOutputCollectionAssociation]: + output_associations: List[JobOutputCollectionAssociation] = [] + for job_output_collection_association in job.output_dataset_collection_instances: + ref_dict = {"src": "hdca", "id": job_output_collection_association.id} + output_associations.append( + JobOutputCollectionAssociation( + name=job_output_collection_association.name, + dataset_collection_instance=ref_dict, + ) + ) + return output_associations + + def create(self, trans: ProvidesHistoryContext, job_request: JobRequest) -> JobCreateResponse: + tool_run_reference = ToolRunReference(job_request.tool_id, job_request.tool_uuid, job_request.tool_version) + tool = validate_tool_for_running(trans, tool_run_reference) + history_id = job_request.history_id + target_history = None + if history_id is not None: + target_history = self.history_manager.get_owned(history_id, trans.user, current_history=trans.history) + inputs = job_request.inputs + request_state = RequestToolState(inputs or {}) + request_state.validate(tool) + request_internal_state = decode(request_state, tool, trans.security.decode_id) + tool_request = ToolRequest() + # TODO: hash and such... + tool_source_model = ToolSourceModel( + source=[p.model_dump() for p in tool.parameters], + hash="TODO", + ) + tool_request.request = request_internal_state.input_state + tool_request.tool_source = tool_source_model + tool_request.state = ToolRequest.states.NEW + tool_request.history = target_history + sa_session = trans.sa_session + sa_session.add(tool_source_model) + sa_session.add(tool_request) + with transaction(sa_session): + sa_session.commit() + tool_request_id = tool_request.id + tool_source = ToolSource( + raw_tool_source=tool.tool_source.to_string(), + tool_dir=tool.tool_dir, + ) + task_request = QueueJobs( + user=trans.async_request_user, + history_id=target_history and target_history.id, + tool_source=tool_source, + tool_request_id=tool_request_id, + use_cached_jobs=job_request.use_cached_jobs or False, + rerun_remap_job_id=job_request.rerun_remap_job_id, + ) + result = queue_jobs.delay(request=task_request) + return JobCreateResponse( + **{ + "tool_request_id": tool_request_id, + "task_result": async_task_summary(result), + } + ) diff --git a/lib/galaxy/webapps/galaxy/services/tools.py b/lib/galaxy/webapps/galaxy/services/tools.py index 9e2298eae134..5d0059320089 100644 --- a/lib/galaxy/webapps/galaxy/services/tools.py +++ b/lib/galaxy/webapps/galaxy/services/tools.py @@ -4,8 +4,10 @@ from json import dumps from typing import ( Any, + cast, Dict, List, + NamedTuple, Optional, Union, ) @@ -34,7 +36,9 @@ FilesPayload, ) from galaxy.security.idencoding import IdEncodingHelper +from galaxy.tool_util.parameters import ToolParameterT from galaxy.tools import Tool +from galaxy.tools._types import InputFormatT from galaxy.tools.search import ToolBoxSearch from galaxy.webapps.galaxy.services._fetch_util import validate_and_normalize_targets from galaxy.webapps.galaxy.services.base import ServiceBase @@ -42,6 +46,39 @@ log = logging.getLogger(__name__) +class ToolRunReference(NamedTuple): + tool_id: Optional[str] + tool_uuid: Optional[str] + tool_version: Optional[str] + + +def get_tool(trans: ProvidesHistoryContext, tool_ref: ToolRunReference) -> Tool: + get_kwds = dict( + tool_id=tool_ref.tool_id, + tool_uuid=tool_ref.tool_uuid, + tool_version=tool_ref.tool_version, + ) + + tool = trans.app.toolbox.get_tool(**get_kwds) + if not tool: + log.debug(f"Not found tool with kwds [{tool_ref}]") + raise exceptions.ToolMissingException("Tool not found.") + return tool + + +def validate_tool_for_running(trans: ProvidesHistoryContext, tool_ref: ToolRunReference) -> Tool: + if trans.user_is_bootstrap_admin: + raise exceptions.RealUserRequiredException("Only real users can execute tools or run jobs.") + + if tool_ref.tool_id is None and tool_ref.tool_uuid is None: + raise exceptions.RequestParameterMissingException("Must specify a valid tool_id to use this endpoint.") + + tool = get_tool(trans, tool_ref) + if not tool.allow_user_access(trans.user): + raise exceptions.ItemAccessibilityException("Tool not accessible.") + return tool + + class ToolsService(ServiceBase): def __init__( self, @@ -55,6 +92,14 @@ def __init__( self.toolbox_search = toolbox_search self.history_manager = history_manager + def inputs( + self, + trans: ProvidesHistoryContext, + tool_ref: ToolRunReference, + ) -> List[ToolParameterT]: + tool = get_tool(trans, tool_ref) + return tool.parameters + def create_fetch( self, trans: ProvidesHistoryContext, @@ -100,37 +145,14 @@ def create_fetch( return self._create(trans, create_payload) def _create(self, trans: ProvidesHistoryContext, payload, **kwd): - if trans.user_is_bootstrap_admin: - raise exceptions.RealUserRequiredException("Only real users can execute tools or run jobs.") action = payload.get("action") if action == "rerun": raise Exception("'rerun' action has been deprecated") - # Get tool. - tool_version = payload.get("tool_version") - tool_id = payload.get("tool_id") - tool_uuid = payload.get("tool_uuid") - get_kwds = dict( - tool_id=tool_id, - tool_uuid=tool_uuid, - tool_version=tool_version, + tool_run_reference = ToolRunReference( + payload.get("tool_id"), payload.get("tool_uuid"), payload.get("tool_version") ) - if tool_id is None and tool_uuid is None: - raise exceptions.RequestParameterMissingException("Must specify either a tool_id or a tool_uuid.") - - tool = trans.app.toolbox.get_tool(**get_kwds) - if not tool: - log.debug(f"Not found tool with kwds [{get_kwds}]") - raise exceptions.ToolMissingException("Tool not found.") - if not tool.allow_user_access(trans.user): - raise exceptions.ItemAccessibilityException("Tool not accessible.") - if self.config.user_activation_on: - if not trans.user: - log.warning("Anonymous user attempts to execute tool, but account activation is turned on.") - elif not trans.user.active: - log.warning( - f'User "{trans.user.email}" attempts to execute tool, but account activation is turned on and user account is not active.' - ) + tool = validate_tool_for_running(trans, tool_run_reference) # Set running history from payload parameters. # History not set correctly as part of this API call for @@ -166,7 +188,10 @@ def _create(self, trans: ProvidesHistoryContext, payload, **kwd): inputs.get("use_cached_job", "false") ) preferred_object_store_id = payload.get("preferred_object_store_id") - input_format = str(payload.get("input_format", "legacy")) + input_format_raw = str(payload.get("input_format", "legacy")) + if input_format_raw not in ["legacy", "21.01"]: + raise exceptions.RequestParameterInvalidException(f"invalid input format {input_format_raw}") + input_format = cast(InputFormatT, input_format_raw) if "data_manager_mode" in payload: incoming["__data_manager_mode"] = payload["data_manager_mode"] vars = tool.handle_input( diff --git a/lib/galaxy_test/base/populators.py b/lib/galaxy_test/base/populators.py index ef71ec8d6132..efd9e3a6da61 100644 --- a/lib/galaxy_test/base/populators.py +++ b/lib/galaxy_test/base/populators.py @@ -1435,8 +1435,28 @@ def is_ready(): wait_on(is_ready, "waiting for download to become ready") assert is_ready() + def wait_on_tool_request(self, tool_request_id: str): + # should this to defer to interactor's copy of this method? + + def state(): + state_response = self._get(f"tool_requests/{tool_request_id}/state") + state_response.raise_for_status() + return state_response.json() + + def is_ready(): + is_complete = state() in ["submitted", "failed"] + return True if is_complete else None + + wait_on(is_ready, "waiting for tool request to submit") + return state() == "submitted" + def wait_on_task(self, async_task_response: Response): - task_id = async_task_response.json()["id"] + response_json = async_task_response.json() + self.wait_on_task_object(response_json) + + def wait_on_task_object(self, async_task_json: Dict[str, Any]): + assert "id" in async_task_json, f"Task response {async_task_json} does not contain expected 'id' field." + task_id = async_task_json["id"] return self.wait_on_task_id(task_id) def wait_on_task_id(self, task_id: str): diff --git a/scripts/gen_typescript_artifacts.py b/scripts/gen_typescript_artifacts.py new file mode 100644 index 000000000000..a9da728b9459 --- /dev/null +++ b/scripts/gen_typescript_artifacts.py @@ -0,0 +1,20 @@ +import os +import sys + +try: + from pydantic2ts import generate_typescript_defs +except ImportError: + generate_typescript_defs = None + + +sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "lib"))) + + +def main(): + if generate_typescript_defs is None: + raise Exception("Please install pydantic-to-typescript into Galaxy's environment") + generate_typescript_defs("galaxy.tool_util.parser.parameters", "client/src/components/Tool/parameterModels.ts") + + +if __name__ == "__main__": + main() diff --git a/test/functional/test_toolbox_pytest.py b/test/functional/test_toolbox_pytest.py index 896e3609913e..cd6314a9fc85 100644 --- a/test/functional/test_toolbox_pytest.py +++ b/test/functional/test_toolbox_pytest.py @@ -1,11 +1,16 @@ import os from typing import ( + cast, List, NamedTuple, ) import pytest +from galaxy.tool_util.verify.interactor import ( + DEFAULT_USE_LEGACY_API, + UseLegacyApiT, +) from galaxy_test.api._framework import ApiTestCase from galaxy_test.driver.driver_util import GalaxyTestDriver @@ -61,4 +66,7 @@ class TestFrameworkTools(ApiTestCase): @pytest.mark.parametrize("testcase", cases(), ids=idfn) def test_tool(self, testcase: ToolTest): - self._test_driver.run_tool_test(testcase.tool_id, testcase.test_index, tool_version=testcase.tool_version) + use_legacy_api = cast(UseLegacyApiT, os.environ.get("GALAXY_TEST_USE_LEGACY_TOOL_API", DEFAULT_USE_LEGACY_API)) + self._test_driver.run_tool_test( + testcase.tool_id, testcase.test_index, tool_version=testcase.tool_version, use_legacy_api=use_legacy_api + ) diff --git a/test/unit/tool_util/parameter_specification.yml b/test/unit/tool_util/parameter_specification.yml index 8497843a427c..d4d3e07c0218 100644 --- a/test/unit/tool_util/parameter_specification.yml +++ b/test/unit/tool_util/parameter_specification.yml @@ -1017,9 +1017,9 @@ gx_drill_down_exact: - parameter: aa - parameter: bbb - parameter: ba - request_invalid: - # not multiple so cannot choose a non-leaf + # non-leaf nodes seem to be selectable in exact mode - parameter: a + request_invalid: - parameter: c - parameter: {} # no implicit default currently - see test_drill_down_first_by_default in API test test_tools.py. @@ -1032,13 +1032,27 @@ gx_drill_down_exact_with_selection: - parameter: bbb - parameter: ba # - {} - request_invalid: - # not multiple so cannot choose a non-leaf + # non-leaf nodes seem to be selectable in exact mode - parameter: a + request_invalid: - parameter: c - parameter: {} - parameter: null +gx_drill_down_recurse: + request_valid: + - parameter: bba + request_invalid: + - parameter: a + - parameter: c + +gx_drill_down_recurse_multiple: + request_valid: + - parameter: [bba] + - parameter: [a] + request_invalid: + - parameter: c + gx_data_column: request_valid: - { ref_parameter: {src: hda, id: abcdabcd}, parameter: 0 } diff --git a/test/unit/tool_util/test_parameter_covert.py b/test/unit/tool_util/test_parameter_covert.py new file mode 100644 index 000000000000..434032c1c1bc --- /dev/null +++ b/test/unit/tool_util/test_parameter_covert.py @@ -0,0 +1,99 @@ +from typing import Dict + +from galaxy.tool_util.parameters import ( + decode, + encode, + input_models_for_tool_source, + RequestToolState, +) +from .test_parameter_test_cases import tool_source_for + +EXAMPLE_ID_1_ENCODED = "123456789abcde" +EXAMPLE_ID_1 = 13 +EXAMPLE_ID_2_ENCODED = "123456789abcd2" +EXAMPLE_ID_2 = 14 + +ID_MAP: Dict[int, str] = { + EXAMPLE_ID_1: EXAMPLE_ID_1_ENCODED, + EXAMPLE_ID_2: EXAMPLE_ID_2_ENCODED, +} + + +def test_encode_data(): + tool_source = tool_source_for("parameters/gx_data") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState({"parameter": {"src": "hda", "id": EXAMPLE_ID_1_ENCODED}}) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["parameter"]["src"] == "hda" + assert decoded_state.input_state["parameter"]["id"] == EXAMPLE_ID_1 + + +def test_encode_collection(): + tool_source = tool_source_for("parameters/gx_data_collection") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState({"parameter": {"src": "hdca", "id": EXAMPLE_ID_1_ENCODED}}) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["parameter"]["src"] == "hdca" + assert decoded_state.input_state["parameter"]["id"] == EXAMPLE_ID_1 + + +def test_encode_repeat(): + tool_source = tool_source_for("parameters/gx_repeat_data") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState({"parameter": [{"data_parameter": {"src": "hda", "id": EXAMPLE_ID_1_ENCODED}}]}) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["parameter"][0]["data_parameter"]["src"] == "hda" + assert decoded_state.input_state["parameter"][0]["data_parameter"]["id"] == EXAMPLE_ID_1 + + +def test_encode_section(): + tool_source = tool_source_for("parameters/gx_section_data") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState({"parameter": {"data_parameter": {"src": "hda", "id": EXAMPLE_ID_1_ENCODED}}}) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["parameter"]["data_parameter"]["src"] == "hda" + assert decoded_state.input_state["parameter"]["data_parameter"]["id"] == EXAMPLE_ID_1 + + +def test_encode_conditional(): + tool_source = tool_source_for("identifier_in_conditional") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState( + {"outer_cond": {"multi_input": False, "input1": {"src": "hda", "id": EXAMPLE_ID_1_ENCODED}}} + ) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["outer_cond"]["input1"]["src"] == "hda" + assert decoded_state.input_state["outer_cond"]["input1"]["id"] == EXAMPLE_ID_1 + + +def test_multi_data(): + tool_source = tool_source_for("parameters/gx_data_multiple") + bundle = input_models_for_tool_source(tool_source) + request_state = RequestToolState( + {"parameter": [{"src": "hda", "id": EXAMPLE_ID_1_ENCODED}, {"src": "hda", "id": EXAMPLE_ID_2_ENCODED}]} + ) + request_state.validate(bundle) + decoded_state = decode(request_state, bundle, _fake_decode) + assert decoded_state.input_state["parameter"][0]["src"] == "hda" + assert decoded_state.input_state["parameter"][0]["id"] == EXAMPLE_ID_1 + assert decoded_state.input_state["parameter"][1]["src"] == "hda" + assert decoded_state.input_state["parameter"][1]["id"] == EXAMPLE_ID_2 + + encoded_state = encode(decoded_state, bundle, _fake_encode) + assert encoded_state.input_state["parameter"][0]["src"] == "hda" + assert encoded_state.input_state["parameter"][0]["id"] == EXAMPLE_ID_1_ENCODED + assert encoded_state.input_state["parameter"][1]["src"] == "hda" + assert encoded_state.input_state["parameter"][1]["id"] == EXAMPLE_ID_2_ENCODED + + +def _fake_decode(input: str) -> int: + return next(key for key, value in ID_MAP.items() if value == input) + + +def _fake_encode(input: int) -> str: + return ID_MAP[input]