diff --git a/doc/source/dev/tool_state_state_classes.plantuml.svg b/doc/source/dev/tool_state_state_classes.plantuml.svg index b0c086bf18b0..07270f21f7a4 100644 --- a/doc/source/dev/tool_state_state_classes.plantuml.svg +++ b/doc/source/dev/tool_state_state_classes.plantuml.svg @@ -41,14 +41,36 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|- - RequestToolState ToolState <|- - RequestInternalToolState ToolState <|- - JobInternalToolState +ToolState <|- - TestCaseToolState +ToolState <|- - WorkflowStepToolState +ToolState <|- - WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > RequestInternalToolState o- - JobInternalToolState : expand > +WorkflowStepToolState o- - WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml @@ -132,14 +154,36 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|- - RequestToolState ToolState <|- - RequestInternalToolState ToolState <|- - JobInternalToolState +ToolState <|- - TestCaseToolState +ToolState <|- - WorkflowStepToolState +ToolState <|- - WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > RequestInternalToolState o- - JobInternalToolState : expand > +WorkflowStepToolState o- - WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml diff --git a/doc/source/dev/tool_state_state_classes.plantuml.txt b/doc/source/dev/tool_state_state_classes.plantuml.txt index 612c13d8e683..67da8a30c725 100644 --- a/doc/source/dev/tool_state_state_classes.plantuml.txt +++ b/doc/source/dev/tool_state_state_classes.plantuml.txt @@ -29,13 +29,35 @@ state_representation = "job_internal" } note bottom: Object references of the form \n{src: "hda", id: }.\n Mapping constructs expanded out.\n (Defaults are inserted?) +class TestCaseToolState { +state_representation = "test_case" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Object references of the form file name and URIs.\n Mapping constructs not allowed.\n + +class WorkflowStepToolState { +state_representation = "workflow_step" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Nearly everything optional except conditional discriminators.\n + +class WorkflowStepLinkedToolState { +state_representation = "workflow_step_linked" ++ _to_base_model(input_models: ToolParameterBundle): Type[BaseModel] +} +note bottom: Expect pre-process ``in`` dictionaries and bring in representation\n of links and defaults and validate them in model.\n + ToolState <|-- RequestToolState ToolState <|-- RequestInternalToolState ToolState <|-- JobInternalToolState +ToolState <|-- TestCaseToolState +ToolState <|-- WorkflowStepToolState +ToolState <|-- WorkflowStepLinkedToolState RequestToolState - RequestInternalToolState : decode > RequestInternalToolState o-- JobInternalToolState : expand > +WorkflowStepToolState o-- WorkflowStepLinkedToolState : preprocess_links_and_defaults > } @enduml \ No newline at end of file diff --git a/lib/galaxy/tool_util/parameters/__init__.py b/lib/galaxy/tool_util/parameters/__init__.py index 4c50680fa118..9d8b899832f3 100644 --- a/lib/galaxy/tool_util/parameters/__init__.py +++ b/lib/galaxy/tool_util/parameters/__init__.py @@ -42,6 +42,8 @@ validate_internal_request, validate_request, validate_test_case, + validate_workflow_step, + validate_workflow_step_linked, ) from .state import ( JobInternalToolState, @@ -49,8 +51,14 @@ RequestToolState, TestCaseToolState, ToolState, + WorkflowStepLinkedToolState, + WorkflowStepToolState, +) +from .visitor import ( + flat_state_path, + keys_starting_with, + visit_input_values, ) -from .visitor import visit_input_values __all__ = ( "from_input_source", @@ -89,13 +97,19 @@ "validate_internal_request", "validate_request", "validate_test_case", + "validate_workflow_step", + "validate_workflow_step_linked", "ToolState", "TestCaseToolState", "ToolParameterT", "to_json_schema_string", "RequestToolState", "RequestInternalToolState", + "flat_state_path", + "keys_starting_with", "visit_input_values", "decode", "encode", + "WorkflowStepToolState", + "WorkflowStepLinkedToolState", ) diff --git a/lib/galaxy/tool_util/parameters/_types.py b/lib/galaxy/tool_util/parameters/_types.py index 4a33f6406a50..2b97ee16c200 100644 --- a/lib/galaxy/tool_util/parameters/_types.py +++ b/lib/galaxy/tool_util/parameters/_types.py @@ -20,10 +20,15 @@ ) +def optional(type: Type) -> Type: + return_type: Type = Optional[type] # type: ignore[assignment] + return return_type + + def optional_if_needed(type: Type, is_optional: bool) -> Type: return_type: Type = type if is_optional: - return_type = Optional[type] # type: ignore[assignment] + return_type = optional(type) return return_type diff --git a/lib/galaxy/tool_util/parameters/factory.py b/lib/galaxy/tool_util/parameters/factory.py index cd872e43a974..ba567814733a 100644 --- a/lib/galaxy/tool_util/parameters/factory.py +++ b/lib/galaxy/tool_util/parameters/factory.py @@ -18,6 +18,7 @@ from .models import ( BooleanParameterModel, ColorParameterModel, + cond_test_parameter_default_value, ConditionalParameterModel, ConditionalWhen, CwlBooleanParameterModel, @@ -172,16 +173,7 @@ def _from_input_source_galaxy(input_source: InputSource) -> ToolParameterT: Union[BooleanParameterModel, SelectParameterModel], _from_input_source_galaxy(test_param_input_source) ) whens = [] - default_value = object() - if isinstance(test_parameter, BooleanParameterModel): - default_value = test_parameter.value - elif isinstance(test_parameter, SelectParameterModel): - select_parameter = cast(SelectParameterModel, test_parameter) - select_default_value = select_parameter.default_value - if select_default_value is not None: - default_value = select_default_value - - # TODO: handle select parameter model... + default_value = cond_test_parameter_default_value(test_parameter) for value, case_inputs_sources in input_source.parse_when_input_sources(): if isinstance(test_parameter, BooleanParameterModel): # TODO: investigate truevalue/falsevalue when... diff --git a/lib/galaxy/tool_util/parameters/models.py b/lib/galaxy/tool_util/parameters/models.py index 27efc38feee2..3a9cf59097ac 100644 --- a/lib/galaxy/tool_util/parameters/models.py +++ b/lib/galaxy/tool_util/parameters/models.py @@ -42,6 +42,7 @@ cast_as_type, is_optional, list_type, + optional, optional_if_needed, union_type, ) @@ -56,7 +57,9 @@ # + request: Return info needed to build request pydantic model at runtime. # + request_internal: This is a pydantic model to validate what Galaxy expects to find in the database, # in particular dataset and collection references should be decoded integers. -StateRepresentationT = Literal["request", "request_internal", "job_internal", "test_case"] +StateRepresentationT = Literal[ + "request", "request_internal", "job_internal", "test_case", "workflow_step", "workflow_step_linked" +] # could be made more specific - validators need to be classmethod @@ -73,6 +76,14 @@ class StrictModel(BaseModel): model_config = ConfigDict(extra="forbid") +class ConnectedValue(BaseModel): + discriminator: Literal["ConnectedValue"] = Field(alias="__class__") + + +def allow_connected_value(type: Type): + return union_type([type, ConnectedValue]) + + def allow_batching(job_template: DynamicModelInformation, batch_type: Optional[Type] = None) -> DynamicModelInformation: job_py_type: Type = job_template.definition[0] default_value = job_template.definition[1] @@ -108,11 +119,15 @@ def request_requires_value(self) -> bool: ... -def dynamic_model_information_from_py_type(param_model: ParamModel, py_type: Type): +def dynamic_model_information_from_py_type( + param_model: ParamModel, py_type: Type, requires_value: Optional[bool] = None +): name = param_model.name - initialize = ... if param_model.request_requires_value else None + if requires_value is None: + requires_value = param_model.request_requires_value + initialize = ... if requires_value else None py_type_is_optional = is_optional(py_type) - if not py_type_is_optional and not param_model.request_requires_value: + if not py_type_is_optional and not requires_value: validators = {"not_null": field_validator(name)(Validators.validate_not_none)} else: validators = {} @@ -162,7 +177,10 @@ def py_type(self) -> Type: return optional_if_needed(StrictStr, self.optional) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + py_type = self.py_type + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + return dynamic_model_information_from_py_type(self, py_type) @property def request_requires_value(self) -> bool: @@ -181,7 +199,10 @@ def py_type(self) -> Type: return optional_if_needed(StrictInt, self.optional) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + py_type = self.py_type + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + return dynamic_model_information_from_py_type(self, py_type) @property def request_requires_value(self) -> bool: @@ -199,7 +220,10 @@ def py_type(self) -> Type: return optional_if_needed(union_type([StrictInt, StrictFloat]), self.optional) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + py_type = self.py_type + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + return dynamic_model_information_from_py_type(self, py_type) @property def request_requires_value(self) -> bool: @@ -303,6 +327,10 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam return dynamic_model_information_from_py_type(self, self.py_type_internal) elif state_representation == "test_case": return dynamic_model_information_from_py_type(self, self.py_type_test_case) + elif state_representation == "workflow_step": + return dynamic_model_information_from_py_type(self, type(None), requires_value=False) + elif state_representation == "workflow_step_linked": + return dynamic_model_information_from_py_type(self, ConnectedValue) @property def request_requires_value(self) -> bool: @@ -337,8 +365,14 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam return allow_batching(dynamic_model_information_from_py_type(self, self.py_type)) elif state_representation == "request_internal": return allow_batching(dynamic_model_information_from_py_type(self, self.py_type_internal)) + elif state_representation == "workflow_step": + return dynamic_model_information_from_py_type(self, type(None), requires_value=False) + elif state_representation == "workflow_step_linked": + return dynamic_model_information_from_py_type(self, ConnectedValue) else: - raise NotImplementedError("...") + raise NotImplementedError( + f"Have not implemented data collection parameter models for state representation {state_representation}" + ) @property def request_requires_value(self) -> bool: @@ -353,7 +387,15 @@ def py_type(self) -> Type: return optional_if_needed(StrictStr, self.optional) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + py_type = self.py_type + requires_value = not self.optional + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + elif state_representation == "workflow_step" and not self.optional: + # allow it to be linked in so force allow optional... + py_type = optional(py_type) + requires_value = False + return dynamic_model_information_from_py_type(self, py_type, requires_value=requires_value) @property def request_requires_value(self) -> bool: @@ -389,11 +431,34 @@ def validate_color_str(value) -> str: ensure_color_valid(value) return value + @staticmethod + def validate_color_str_if_value(value) -> str: + if value: + ensure_color_valid(value) + return value + + @staticmethod + def validate_color_str_or_connected_value(value) -> str: + if not isinstance(value, ConnectedValue): + ensure_color_valid(value) + return value + def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str)} + py_type = self.py_type + initialize: Any = ... + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + validators = { + "color_format": field_validator(self.name)(ColorParameterModel.validate_color_str_or_connected_value) + } + elif state_representation == "workflow_step": + initialize = None + validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str_if_value)} + else: + validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str)} return DynamicModelInformation( self.name, - (self.py_type, ...), + (py_type, initialize), validators, ) @@ -413,7 +478,10 @@ def py_type(self) -> Type: return optional_if_needed(StrictBool, self.optional) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + py_type = self.py_type + if state_representation == "workflow_step_linked": + py_type = allow_connected_value(py_type) + return dynamic_model_information_from_py_type(self, py_type) @property def request_requires_value(self) -> bool: @@ -461,19 +529,38 @@ class SelectParameterModel(BaseGalaxyToolParameterModelDefinition): options: Optional[List[LabelValue]] = None multiple: bool - @property - def py_type(self) -> Type: + def py_type_if_required(self, allow_connections=False) -> Type: if self.options is not None: literal_options: List[Type] = [cast_as_type(Literal[o.value]) for o in self.options] py_type = union_type(literal_options) else: py_type = StrictStr if self.multiple: - py_type = list_type(py_type) - return optional_if_needed(py_type, self.optional) + if allow_connections: + py_type = list_type(allow_connected_value(py_type)) + else: + py_type = list_type(py_type) + elif allow_connections: + py_type = allow_connected_value(py_type) + return py_type + + @property + def py_type(self) -> Type: + return optional_if_needed(self.py_type_if_required(), self.optional) + + @property + def py_type_workflow_step(self) -> Type: + # this is always optional in this context + return optional(self.py_type_if_required()) def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation: - return dynamic_model_information_from_py_type(self, self.py_type) + if state_representation == "workflow_step": + return dynamic_model_information_from_py_type(self, self.py_type_workflow_step, requires_value=False) + elif state_representation == "workflow_step_linked": + py_type = self.py_type_if_required(allow_connections=True) + return dynamic_model_information_from_py_type(self, optional_if_needed(py_type, self.optional)) + else: + return dynamic_model_information_from_py_type(self, self.py_type) @property def has_selected_static_option(self): @@ -590,6 +677,20 @@ def request_requires_value(self) -> bool: DiscriminatorType = Union[bool, str] +def cond_test_parameter_default_value( + test_parameter: Union["BooleanParameterModel", "SelectParameterModel"] +) -> Optional[DiscriminatorType]: + default_value: Optional[DiscriminatorType] = None + if isinstance(test_parameter, BooleanParameterModel): + default_value = test_parameter.value + elif isinstance(test_parameter, SelectParameterModel): + select_parameter = cast(SelectParameterModel, test_parameter) + select_default_value = select_parameter.default_value + if select_default_value is not None: + default_value = select_default_value + return default_value + + class ConditionalWhen(StrictModel): discriminator: DiscriminatorType parameters: List["ToolParameterT"] @@ -996,6 +1097,14 @@ def create_test_case_model(tool: ToolParameterBundle, name: str = "DynamicModelF return create_field_model(tool.input_models, name, "test_case") +def create_workflow_step_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]: + return create_field_model(tool.input_models, name, "workflow_step") + + +def create_workflow_step_linked_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]: + return create_field_model(tool.input_models, name, "workflow_step_linked") + + def create_field_model( tool_parameter_models: Union[List[ToolParameterModel], List[ToolParameterT]], name: str, @@ -1048,3 +1157,13 @@ def validate_internal_job(tool: ToolParameterBundle, request: Dict[str, Any]) -> def validate_test_case(tool: ToolParameterBundle, request: Dict[str, Any]) -> None: pydantic_model = create_test_case_model(tool) validate_against_model(pydantic_model, request) + + +def validate_workflow_step(tool: ToolParameterBundle, request: Dict[str, Any]) -> None: + pydantic_model = create_workflow_step_model(tool) + validate_against_model(pydantic_model, request) + + +def validate_workflow_step_linked(tool: ToolParameterBundle, request: Dict[str, Any]) -> None: + pydantic_model = create_workflow_step_linked_model(tool) + validate_against_model(pydantic_model, request) diff --git a/lib/galaxy/tool_util/parameters/state.py b/lib/galaxy/tool_util/parameters/state.py index 3991054bbd33..3c5389c9c230 100644 --- a/lib/galaxy/tool_util/parameters/state.py +++ b/lib/galaxy/tool_util/parameters/state.py @@ -6,7 +6,6 @@ Any, Dict, List, - Optional, Type, Union, ) @@ -18,6 +17,8 @@ create_job_internal_model, create_request_internal_model, create_request_model, + create_workflow_step_linked_model, + create_workflow_step_model, StateRepresentationT, ToolParameterBundle, ToolParameterBundleModel, @@ -51,7 +52,7 @@ def state_representation(self) -> StateRepresentationT: """Get state representation of the inputs.""" @classmethod - def parameter_model_for(cls, input_models: HasToolParameters) -> Optional[Type[BaseModel]]: + def parameter_model_for(cls, input_models: HasToolParameters) -> Type[BaseModel]: bundle: ToolParameterBundle if isinstance(input_models, list): bundle = ToolParameterBundleModel(input_models=input_models) @@ -61,7 +62,7 @@ def parameter_model_for(cls, input_models: HasToolParameters) -> Optional[Type[B @classmethod @abstractmethod - def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Optional[Type[BaseModel]]: + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: """Return a model type for this tool state kind.""" @@ -96,3 +97,19 @@ class TestCaseToolState(ToolState): def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: # implement a test case model... return create_request_internal_model(input_models) + + +class WorkflowStepToolState(ToolState): + state_representation: Literal["workflow_step"] = "workflow_step" + + @classmethod + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + return create_workflow_step_model(input_models) + + +class WorkflowStepLinkedToolState(ToolState): + state_representation: Literal["workflow_step_linked"] = "workflow_step_linked" + + @classmethod + def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]: + return create_workflow_step_linked_model(input_models) diff --git a/lib/galaxy/tool_util/parameters/visitor.py b/lib/galaxy/tool_util/parameters/visitor.py index 7b68afa4a0aa..dbd284e795d5 100644 --- a/lib/galaxy/tool_util/parameters/visitor.py +++ b/lib/galaxy/tool_util/parameters/visitor.py @@ -1,7 +1,11 @@ from typing import ( Any, + cast, Dict, Iterable, + Optional, + TypeVar, + Union, ) from typing_extensions import Protocol @@ -54,3 +58,23 @@ def _visit_input_values( else: new_input_values[name] = input_value return new_input_values + + +def flat_state_path(has_name: Union[str, ToolParameterT], prefix: Optional[str] = None) -> str: + """Given a parameter name or model and an optional prefix, give 'flat' name for parameter in tree.""" + if hasattr(has_name, "name"): + name = cast(ToolParameterT, has_name).name + else: + name = cast(str, has_name) + return name if prefix is None else f"{prefix}|{name}" + + +KVT = TypeVar("KVT") + + +def keys_starting_with(flat_tree: Dict[str, KVT], flat_state_path: str) -> Dict[str, KVT]: + subset: Dict[str, KVT] = {} + for key, value in flat_tree.items(): + if key.startswith(flat_state_path): + subset[key] = value + return subset diff --git a/lib/galaxy/tool_util/workflow_state/__init__.py b/lib/galaxy/tool_util/workflow_state/__init__.py new file mode 100644 index 000000000000..a5eafb281bb2 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/__init__.py @@ -0,0 +1,7 @@ +"""Abstractions for reasoning about tool state within Galaxy workflows. + +Like everything else in galaxy-tool-util, this package should be independent of +Galaxy's runtime. It is meant to provide utilities for reasonsing about tool state +(largely building on the abstractions in galaxy.tool_util.parameters) within the +context of workflows. +""" diff --git a/lib/galaxy/tool_util/workflow_state/_types.py b/lib/galaxy/tool_util/workflow_state/_types.py new file mode 100644 index 000000000000..f25dea56a888 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/_types.py @@ -0,0 +1,27 @@ +from typing import ( + Any, + Dict, + Union, +) + +from typing_extensions import ( + Literal, + Protocol, +) + +from galaxy.tool_util.models import ParsedTool + +NativeWorkflowDict = Dict[str, Any] +Format2WorkflowDict = Dict[str, Any] +AnyWorkflowDict = Union[NativeWorkflowDict, Format2WorkflowDict] +WorkflowFormat = Literal["gxformat2", "native"] +NativeStepDict = Dict[str, Any] +Format2StepDict = Dict[str, Any] +NativeToolStateDict = Dict[str, Any] +Format2StateDict = Dict[str, Any] + + +class GetToolInfo(Protocol): + """An interface for fetching tool information for steps in a workflow.""" + + def get_tool_info(self, tool_id: str, tool_version: str) -> ParsedTool: ... diff --git a/lib/galaxy/tool_util/workflow_state/_validation_util.py b/lib/galaxy/tool_util/workflow_state/_validation_util.py new file mode 100644 index 000000000000..068521f579dd --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/_validation_util.py @@ -0,0 +1,12 @@ +from typing import ( + Any, + cast, + Optional, + Union, +) + + +def validate_explicit_conditional_test_value(test_parameter_name: str, value: Any) -> Optional[Union[str, bool]]: + if value is not None and not isinstance(value, (str, bool)): + raise Exception(f"Invalid conditional test value ({value}) for parameter ({test_parameter_name})") + return cast(Optional[Union[str, bool]], value) diff --git a/lib/galaxy/tool_util/workflow_state/convert.py b/lib/galaxy/tool_util/workflow_state/convert.py new file mode 100644 index 000000000000..6f1fe0cdec10 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/convert.py @@ -0,0 +1,127 @@ +from typing import ( + Dict, + List, + Optional, +) + +from pydantic import ( + BaseModel, + Field, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ToolParameterT +from .validation_native import ( + get_parsed_tool_for_native_step, + native_tool_state, + validate_native_step_against, +) +from .validation_format2 import validate_step_against +from ._types import ( + GetToolInfo, + NativeStepDict, + Format2StateDict, +) + +Format2InputsDictT = Dict[str, str] + + +class Format2State(BaseModel): + state: Format2StateDict + inputs: Format2InputsDictT = Field(alias="in") + + +class ConversionValidationFailure(Exception): + pass + + +def convert_state_to_format2(native_step_dict: NativeStepDict, get_tool_info: GetToolInfo) -> Format2State: + parsed_tool = get_parsed_tool_for_native_step(native_step_dict, get_tool_info) + return convert_state_to_format2_using(native_step_dict, parsed_tool) + + +def convert_state_to_format2_using(native_step_dict: NativeStepDict, parsed_tool: Optional[ParsedTool]) -> Format2State: + """Create a "clean" gxformat2 workflow tool state from a native workflow step. + + gxformat2 does not know about tool specifications so it cannot reason about the native + tool state attribute and just copies it as is. This native state can be pretty ugly. The purpose + of this function is to build a cleaned up state to replace the gxformat2 copied native tool_state + with that is more readable and has stronger typing by using the tool's inputs to guide + the conversion (the parsed_tool parameter). + + This method validates both the native tool state and the resulting gxformat2 tool state + so that we can be more confident the conversion doesn't corrupt the workflow. If no meta + model to validate against is supplied or if either validation fails this method throws + ConversionValidationFailure to signal the caller to just use the native tool state as is + instead of trying to convert it to a cleaner gxformat2 tool state - under the assumption + it is better to have an "ugly" workflow than a corrupted one during conversion. + """ + if parsed_tool is None: + raise ConversionValidationFailure("Could not resolve tool inputs") + try: + validate_native_step_against(native_step_dict, parsed_tool) + except Exception: + raise ConversionValidationFailure( + "Failed to validate native step - not going to convert a tool state that isn't understood" + ) + result = _convert_valid_state_to_format2(native_step_dict, parsed_tool) + print(result.dict()) + try: + validate_step_against(result.dict(), parsed_tool) + except Exception: + raise ConversionValidationFailure( + "Failed to validate resulting cleaned step - not going to convert to an unvalidated tool state" + ) + return result + + +def _convert_valid_state_to_format2(native_step_dict: NativeStepDict, parsed_tool: ParsedTool) -> Format2State: + format2_state: Format2StateDict = {} + format2_in: Format2InputsDictT = {} + + root_tool_state = native_tool_state(native_step_dict) + tool_inputs = parsed_tool.inputs + _convert_state_level(native_step_dict, tool_inputs, root_tool_state, format2_state, format2_in) + return Format2State(**{ + "state": format2_state, + "in": format2_in, + }) + + +def _convert_state_level( + step: NativeStepDict, + tool_inputs: List[ToolParameterT], + native_state: dict, + format2_state_at_level: dict, + format2_in: Format2InputsDictT, + prefix: Optional[str] = None, +) -> None: + for tool_input in tool_inputs: + _convert_state_at_level(step, tool_input, native_state, format2_state_at_level, format2_in, prefix) + + +def _convert_state_at_level( + step: NativeStepDict, + tool_input: ToolParameterT, + native_state_at_level: dict, + format2_state_at_level: dict, + format2_in: Format2InputsDictT, + prefix: str +) -> None: + parameter_type = tool_input.parameter_type + parameter_name = tool_input.name + value = native_state_at_level.get(parameter_name, None) + state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + if parameter_type == "gx_integer": + # check for runtime input + format2_value = int(value) + format2_state_at_level[parameter_name] = format2_value + elif parameter_type == "gx_data": + input_connections = step.get("input_connections", {}) + print(state_path) + print(input_connections) + if state_path in input_connections: + format2_in[state_path] = "placeholder" + else: + pass + # raise NotImplementedError(f"Unhandled parameter type {parameter_type}") diff --git a/lib/galaxy/tool_util/workflow_state/validation.py b/lib/galaxy/tool_util/workflow_state/validation.py new file mode 100644 index 000000000000..7fe9f53bb4cd --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation.py @@ -0,0 +1,22 @@ +from ._types import ( + AnyWorkflowDict, + GetToolInfo, + WorkflowFormat, +) + +from .validation_format2 import validate_workflow_format2 +from .validation_native import validate_workflow_native + + +def validate_workflow(workflow_dict: AnyWorkflowDict, get_tool_info: GetToolInfo): + if _format(workflow_dict) == "gxformat2": + validate_workflow_format2(workflow_dict, get_tool_info) + else: + validate_workflow_native(workflow_dict, get_tool_info) + + +def _format(workflow_dict: AnyWorkflowDict) -> WorkflowFormat: + if workflow_dict.get("a_galaxy_workflow") == "true": + return "native" + else: + return "gxformat2" diff --git a/lib/galaxy/tool_util/workflow_state/validation_format2.py b/lib/galaxy/tool_util/workflow_state/validation_format2.py new file mode 100644 index 000000000000..0ef8a192cbbd --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation_format2.py @@ -0,0 +1,170 @@ +from typing import ( + Any, + cast, + Dict, + List, + Optional, +) + +from gxformat2.model import ( + get_native_step_type, + pop_connect_from_step_dict, + setup_connected_values, + steps_as_list, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ( + ConditionalParameterModel, + ConditionalWhen, + flat_state_path, + keys_starting_with, + RepeatParameterModel, + ToolParameterT, + WorkflowStepLinkedToolState, + WorkflowStepToolState, +) +from ._types import ( + GetToolInfo, + Format2WorkflowDict, + Format2StepDict, +) +from ._validation_util import validate_explicit_conditional_test_value + + +def validate_workflow_format2(workflow_dict: Format2WorkflowDict, get_tool_info: GetToolInfo): + steps = steps_as_list(workflow_dict) + for step in steps: + validate_step_format2(step, get_tool_info) + + +def validate_step_format2(step_dict: Format2StepDict, get_tool_info: GetToolInfo): + step_type = get_native_step_type(step_dict) + if step_type != "tool": + return + tool_id = step_dict.get("tool_id") + tool_version = step_dict.get("tool_version") + parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version) + if parsed_tool is not None: + validate_step_against(step_dict, parsed_tool) + + +def validate_step_against(step_dict: Format2StepDict, parsed_tool: ParsedTool): + source_tool_state_model = WorkflowStepToolState.parameter_model_for(parsed_tool.inputs) + linked_tool_state_model = WorkflowStepLinkedToolState.parameter_model_for(parsed_tool.inputs) + contains_format2_state = "state" in step_dict + contains_native_state = "tool_state" in step_dict + if contains_format2_state: + assert source_tool_state_model + source_tool_state_model.model_validate(step_dict["state"]) + if not contains_native_state: + if not contains_format2_state: + step_dict["state"] = {} + # setup links and then validate against model... + linked_step = merge_inputs(step_dict, parsed_tool) + linked_tool_state_model.model_validate(linked_step["state"]) + + +def merge_inputs(step_dict: Format2StepDict, parsed_tool: ParsedTool) -> Format2StepDict: + connect = pop_connect_from_step_dict(step_dict) + step_dict = setup_connected_values(step_dict, connect) + tool_inputs = parsed_tool.inputs + + state_at_level = step_dict["state"] + + for tool_input in tool_inputs: + _merge_into_state(connect, tool_input, state_at_level) + + for key in connect: + raise Exception(f"Failed to find parameter definition matching workflow linked key {key}") + return step_dict + + +def _merge_into_state( + connect, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None, branch_connect=None +): + if branch_connect is None: + branch_connect = connect + + name = tool_input.name + parameter_type = tool_input.parameter_type + state_path = flat_state_path(name, prefix) + if parameter_type == "gx_conditional": + conditional_state = state.get(name, {}) + if name not in state: + state[name] = conditional_state + + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when(conditional, conditional_state) + test_parameter = conditional.test_parameter + conditional_connect = keys_starting_with(branch_connect, state_path) + _merge_into_state( + connect, test_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect + ) + for when_parameter in when.parameters: + _merge_into_state( + connect, when_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect + ) + elif parameter_type == "gx_repeat": + repeat_state_array = state.get(name, []) + repeat = cast(RepeatParameterModel, tool_input) + repeat_instance_connects = repeat_inputs_to_array(state_path, connect) + for i, repeat_instance_connect in enumerate(repeat_instance_connects): + while len(repeat_state_array) <= i: + repeat_state_array.append({}) + + repeat_instance_prefix = f"{state_path}_{i}" + for repeat_parameter in repeat.parameters: + _merge_into_state( + connect, + repeat_parameter, + repeat_state_array[i], + prefix=repeat_instance_prefix, + branch_connect=repeat_instance_connect, + ) + if repeat_state_array and name not in state: + state[name] = repeat_state_array + else: + if state_path in branch_connect: + state[name] = {"__class__": "ConnectedValue"} + del connect[state_path] + + +def _select_which_when(conditional: ConditionalParameterModel, state: dict) -> ConditionalWhen: + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + explicit_test_value = state.get(test_parameter_name) + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) + for when in conditional.whens: + if test_value is None and when.is_default_when: + return when + elif test_value == when.discriminator: + return when + else: + raise Exception(f"Invalid conditional test value ({explicit_test_value}) for parameter ({test_parameter_name})") + + +def repeat_inputs_to_array(state_path: str, inputs: dict) -> List[Dict[str, Any]]: + repeat_connect = keys_starting_with(inputs, state_path + "_") + highest_count = -1 + for key in repeat_connect.keys(): + repeat_num_str = key[len(state_path) + 1 :].split("|")[0] + try: + repeat_num = int(repeat_num_str) + if repeat_num > highest_count: + highest_count = repeat_num + except ValueError: + continue + + params: List[Dict[str, Any]] = [] + for _ in range(highest_count + 1): + instance_params: Dict[str, Any] = {} + params.append(instance_params) + for key, value in repeat_connect.items(): + repeat_num_str = key[len(state_path) + 1 :].split("|")[0] + try: + repeat_num = int(repeat_num_str) + params[repeat_num][key] = value + except ValueError: + continue + return params diff --git a/lib/galaxy/tool_util/workflow_state/validation_native.py b/lib/galaxy/tool_util/workflow_state/validation_native.py new file mode 100644 index 000000000000..d3518f52a222 --- /dev/null +++ b/lib/galaxy/tool_util/workflow_state/validation_native.py @@ -0,0 +1,206 @@ +import json +from typing import ( + cast, + List, + Optional, +) + +from galaxy.tool_util.models import ParsedTool +from galaxy.tool_util.parameters import ( + ConditionalParameterModel, + ConditionalWhen, + flat_state_path, + RepeatParameterModel, + SelectParameterModel, + ToolParameterT, +) +from ._types import ( + GetToolInfo, + NativeWorkflowDict, + NativeStepDict, + NativeToolStateDict, +) +from ._validation_util import validate_explicit_conditional_test_value +from .validation_format2 import repeat_inputs_to_array + + +def validate_native_step_against(step: NativeStepDict, parsed_tool: ParsedTool): + tool_state_jsonified = step.get("tool_state") + assert tool_state_jsonified + tool_state = json.loads(tool_state_jsonified) + tool_inputs = parsed_tool.inputs + + # merge input connections into ConnectedValues if there aren't already there... + _merge_inputs_into_state_dict(step, tool_inputs, tool_state) + + allowed_extra_keys = ["__page__", "__rerun_remap_job_id__"] + _validate_native_state_level(step, tool_inputs, tool_state, allowed_extra_keys=allowed_extra_keys) + + +def _validate_native_state_level( + step: NativeStepDict, tool_inputs: List[ToolParameterT], state_at_level: dict, allowed_extra_keys=None +): + if allowed_extra_keys is None: + allowed_extra_keys = [] + + keys_processed = set() + for tool_input in tool_inputs: + parameter_name = tool_input.name + keys_processed.add(parameter_name) + _validate_native_state_at_level(step, tool_input, state_at_level) + + for key in state_at_level.keys(): + if key not in keys_processed and key not in allowed_extra_keys: + raise Exception(f"Unknown key found {key}, failing state validation") + + +def _validate_native_state_at_level(step: NativeStepDict, tool_input: ToolParameterT, state_at_level: dict, prefix: Optional[str] = None): + parameter_type = tool_input.parameter_type + parameter_name = tool_input.name + value = state_at_level.get(parameter_name, None) + # state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + if parameter_type == "gx_integer": + try: + int(value) + except ValueError: + raise Exception(f"Invalid integer data found {value}") + elif parameter_type in ["gx_data", "gx_data_collection"]: + if isinstance(value, dict): + assert "__class__" in value + assert value["__class__"] in ["RuntimeValue", "ConnectedValue"] + else: + assert value in [None, "null"] + connections = native_connections_for(step, tool_input, prefix) + optional = tool_input.optional + if not optional and not connections: + raise Exception("Disconnected non-optional input found, not attempting to validate non-practice workflow") + + elif parameter_type == "gx_select": + select = cast(SelectParameterModel, tool_input) + options = select.options + if options is not None: + valid_values = [o.value for o in options] + if value not in valid_values: + raise Exception(f"Invalid select option found {value}") + elif parameter_type == "gx_conditional": + conditional_state = state_at_level.get(parameter_name, None) + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when_native(conditional, conditional_state) + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + _validate_native_state_at_level(step, test_parameter, conditional_state) + _validate_native_state_level( + step, when.parameters, conditional_state, allowed_extra_keys=["__current_case__", test_parameter_name] + ) + else: + raise NotImplementedError(f"Unhandled parameter type ({parameter_type})") + + +def _select_which_when_native(conditional: ConditionalParameterModel, conditional_state: dict) -> ConditionalWhen: + test_parameter = conditional.test_parameter + test_parameter_name = test_parameter.name + explicit_test_value = conditional_state.get(test_parameter_name) + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) + target_when = None + for when in conditional.whens: + # deal with native string -> bool issues in here... + if test_value is None and when.is_default_when: + target_when = when + elif test_value == when.discriminator: + target_when = when + + recorded_case = conditional_state.get("__current_case__") + if recorded_case is not None: + if not isinstance(recorded_case, int): + raise Exception(f"Unknown type of value for __current_case__ encountered {recorded_case}") + if recorded_case < 0 or recorded_case >= len(conditional.whens): + raise Exception(f"Unknown index value for __current_case__ encountered {recorded_case}") + recorded_when = conditional.whens[recorded_case] + + if target_when is None: + raise Exception("is this okay? I need more tests") + if target_when and recorded_when and target_when != recorded_when: + raise Exception( + f"Problem parsing out tool state - inferred conflicting tool states for parameter {test_parameter_name}" + ) + return target_when + + +def _merge_inputs_into_state_dict( + step_dict: NativeStepDict, tool_inputs: List[ToolParameterT], state_at_level: dict, prefix: Optional[str] = None +): + for tool_input in tool_inputs: + _merge_into_state(step_dict, tool_input, state_at_level, prefix=prefix) + + +def _merge_into_state(step_dict: NativeStepDict, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None): + name = tool_input.name + parameter_type = tool_input.parameter_type + state_path = flat_state_path(name, prefix) + if parameter_type == "gx_conditional": + conditional_state = state.get(name, {}) + if name not in state: + state[name] = conditional_state + + conditional = cast(ConditionalParameterModel, tool_input) + when: ConditionalWhen = _select_which_when_native(conditional, conditional_state) + test_parameter = conditional.test_parameter + _merge_into_state(step_dict, test_parameter, conditional_state, prefix=state_path) + for when_parameter in when.parameters: + _merge_into_state(step_dict, when_parameter, conditional_state, prefix=state_path) + elif parameter_type == "gx_repeat": + repeat_state_array = state.get(name, []) + repeat = cast(RepeatParameterModel, tool_input) + repeat_instance_connects = repeat_inputs_to_array(state_path, step_dict) + for i, _ in enumerate(repeat_instance_connects): + while len(repeat_state_array) <= i: + repeat_state_array.append({}) + + repeat_instance_prefix = f"{state_path}_{i}" + for repeat_parameter in repeat.parameters: + _merge_into_state( + step_dict, + repeat_parameter, + repeat_state_array[i], + prefix=repeat_instance_prefix, + ) + if repeat_state_array and name not in state: + state[name] = repeat_state_array + else: + input_connections = step_dict.get("input_connections", {}) + if state_path in input_connections and state.get(name) is None: + state[name] = {"__class__": "ConnectedValue"} + + +def validate_step_native(step: NativeStepDict, get_tool_info: GetToolInfo): + parsed_tool = get_parsed_tool_for_native_step(step, get_tool_info) + if parsed_tool is not None: + validate_native_step_against(step, parsed_tool) + + +def get_parsed_tool_for_native_step(step: NativeStepDict, get_tool_info: GetToolInfo) -> Optional[ParsedTool]: + tool_id = step.get("tool_id") + if not tool_id: + return None + tool_version = step.get("tool_version") + parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version) + return parsed_tool + + +def validate_workflow_native(workflow_dict: NativeWorkflowDict, get_tool_info: GetToolInfo): + for step_def in workflow_dict["steps"].values(): + validate_step_native(step_def, get_tool_info) + + +def native_tool_state(step: NativeStepDict) -> NativeToolStateDict: + tool_state_jsonified = step.get("tool_state") + assert tool_state_jsonified + tool_state = json.loads(tool_state_jsonified) + return tool_state + + +def native_connections_for(step: NativeStepDict, parameter: ToolParameterT, prefix: Optional[str]): + parameter_name = parameter.name + state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}" + step.get("input_connections", {}) + return step.get(state_path) diff --git a/lib/galaxy/workflow/gx_validator.py b/lib/galaxy/workflow/gx_validator.py new file mode 100644 index 000000000000..3b0c90bb25f7 --- /dev/null +++ b/lib/galaxy/workflow/gx_validator.py @@ -0,0 +1,60 @@ +""""A validator for Galaxy workflows that is hooked up to Galaxy internals. + +The interface is designed to be usable from the tool shed for external tooling, +but for internal tooling - Galaxy should have its own tool available. +""" + +from typing import Dict + +from galaxy.tool_util.models import ( + parse_tool, + ParsedTool, +) +from galaxy.tool_util.version import parse_version +from galaxy.tool_util.version_util import AnyVersionT +from galaxy.tool_util.workflow_state.validation import ( + GetToolInfo, + validate_workflow as validate_workflow_generic, +) +from galaxy.tools.stock import stock_tool_sources + + +class GalaxyGetToolInfo(GetToolInfo): + stock_tools_by_version: Dict[str, Dict[AnyVersionT, ParsedTool]] + stock_tools_latest_version: Dict[str, AnyVersionT] + + def __init__(self): + # todo take in a toolbox in the future... + stock_tools: Dict[str, Dict[str, ParsedTool]] = {} + stock_tools_latest_version: Dict[str, AnyVersionT] = {} + for stock_tool in stock_tool_sources(): + id = stock_tool.parse_id() + version = stock_tool.parse_version() + if version is not None: + version_object = parse_version(version) + if id not in stock_tools: + stock_tools[id] = {} + if version_object is not None: + stock_tools_latest_version[id] = version_object + try: + stock_tools[id][version_object] = parse_tool(stock_tool) + except Exception: + pass + if version_object and version_object > stock_tools_latest_version[id]: + stock_tools_latest_version[id] = version_object + self.stock_tools = stock_tools + self.stock_tools_latest_version = stock_tools_latest_version + + def get_tool_info(self, tool_id: str, tool_version: str) -> ParsedTool: + if tool_version is not None: + return self.stock_tools[tool_id][parse_version(tool_version)] + else: + latest_verison = self.stock_tools_latest_version[tool_id] + return self.stock_tools[tool_id][latest_verison] + + +GET_TOOL_INFO = GalaxyGetToolInfo() + + +def validate_workflow(as_dict): + return validate_workflow_generic(as_dict, get_tool_info=GET_TOOL_INFO) diff --git a/packages/tool_util/setup.cfg b/packages/tool_util/setup.cfg index 7c8fd75feec1..bf882d6a3db0 100644 --- a/packages/tool_util/setup.cfg +++ b/packages/tool_util/setup.cfg @@ -34,6 +34,7 @@ version = 24.2.dev0 include_package_data = True install_requires = galaxy-util>=22.1 + gxformat2 conda-package-streaming lxml!=4.2.2 MarkupSafe diff --git a/test/unit/tool_util/parameter_specification.yml b/test/unit/tool_util/parameter_specification.yml index 639f711a8825..c7e2929a5cdb 100644 --- a/test/unit/tool_util/parameter_specification.yml +++ b/test/unit/tool_util/parameter_specification.yml @@ -14,6 +14,7 @@ # Things to verify: # - non optional, multi-selects require a selection (see TODO below...) +# - https://github.com/galaxyproject/galaxy/issues/18541 gx_int: request_valid: - parameter: 5 @@ -25,12 +26,27 @@ gx_int: - parameter: "null" - parameter: "None" - parameter: { 5 } + - parameter: {__class__: 'ConnectedValue'} test_case_valid: - parameter: 5 - {} test_case_invalid: - parameter: null - parameter: "5" + workflow_step_valid: + - parameter: 5 + - {} + workflow_step_invalid: + - parameterx: 5 + - parameter: 'foobar' + workflow_step_linked_valid: + - parameter: 5 + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: null + - parameter: 'foobar' + - parameter: {__class__: 'ConnectedValue2'} + gx_boolean: request_valid: @@ -44,6 +60,21 @@ gx_boolean: # Marius and John were on fence here. - parameter: "mytrue" - parameter: null + - parameter: {__class__: 'ConnectedValue'} + workflow_step_valid: + - parameter: True + - {} + workflow_step_invalid: + - parameter: "true" + - parameter: mytrue + - parameter: null + workflow_step_linked_valid: + - parameter: True + - {} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: "true" + - parameter: {__class__: 'ConnectedValue3'} gx_int_optional: request_valid: @@ -55,17 +86,57 @@ gx_int_optional: - parameter: "None" - parameter: "null" - parameter: [5] + - parameter: {__class__: 'ConnectedValue'} + workflow_step_valid: + - parameter: 5 + - parameter: null + - {} + workflow_step_invalid: + - parameter: "5" + - parameter: "None" + - parameter: "null" + - parameter: [5] + workflow_step_linked_valid: + - parameter: 5 + - parameter: null + - {} + - parameter: {__class__: 'ConnectedValue'} gx_text: request_valid: - parameter: moocow - parameter: 'some spaces' - parameter: '' + - {} request_invalid: - parameter: 5 - parameter: null - parameter: {} - parameter: { "moo": "cow" } + - parameter: {__class__: 'ConnectedValue'} + workflow_step_valid: + - parameter: moocow + - parameter: 'some spaces' + - parameter: '' + - {} + workflow_step_invalid: + - parameter: 5 + - parameter: null + - parameter: {} + - parameter: { "moo": "cow" } + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - parameter: moocow + - parameter: 'some spaces' + - parameter: '' + - {} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: 5 + - parameter: null + - parameter: {} + - parameter: { "moo": "cow" } + - parameter: {"class": 'ConnectedValue'} gx_text_optional: request_valid: @@ -77,6 +148,26 @@ gx_text_optional: - parameter: 5 - parameter: {} - parameter: { "moo": "cow" } + workflow_step_valid: + - parameter: moocow + - parameter: 'some spaces' + - parameter: '' + - parameter: null + workflow_step_invalid: + - parameter: 5 + - parameter: {} + - parameter: { "moo": "cow" } + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - parameter: moocow + - parameter: 'some spaces' + - parameter: '' + - parameter: null + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: 5 + - parameter: {} + - parameter: { "moo": "cow" } gx_select: request_valid: @@ -105,6 +196,19 @@ gx_select: test_case_invalid: - parameter: {} - parameter: null + workflow_step_valid: + - parameter: "--ex1" + - {} + workflow_step_invalid: + - parameter: 'foobar' + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - parameter: "--ex1" + - parameter: {__class__: 'ConnectedValue'} + - {} + workflow_step_linked_invalid: + - parameter: 'foobar' + - parameter: null gx_select_optional: request_valid: @@ -120,6 +224,28 @@ gx_select_optional: - parameter: ["ex2"] - parameter: {} - parameter: 5 + workflow_step_valid: + - parameter: "--ex1" + - parameter: "ex2" + - parameter: null + - {} + workflow_step_invalid: + - parameter: "Ex1" + - parameter: ["ex2"] + - parameter: {} + - parameter: 5 + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - parameter: "--ex1" + - parameter: "ex2" + - parameter: null + - {} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: "Ex1" + - parameter: ["ex2"] + - parameter: {} + - parameter: 5 # TODO: confirm null should vaguely not be allowed here gx_select_multiple: @@ -132,6 +258,32 @@ gx_select_multiple: - parameter: {} - parameter: 5 - {} + workflow_step_valid: + - parameter: ["--ex1"] + - parameter: ["ex2"] + - {} # could come in linked... + # ... hmmm? this should maybe be invalid right? + - parameter: null + workflow_step_invalid: + - parameter: ["Ex1"] + - parameter: {} + - parameter: 5 + - parameter: {__class__: 'ConnectedValue'} + - parameter: [{__class__: 'ConnectedValue'}] + workflow_step_linked_valid: + - parameter: ["--ex1"] + - parameter: ["ex2"] + - parameter: [{__class__: 'ConnectedValue'}] + workflow_step_linked_invalid: + - parameter: ["Ex1"] + - parameter: {} + - parameter: 5 + - {} + # might be wrong? I guess we would expect the semantic of this to do like a map-over + # but as far as I am aware that is not implemented https://github.com/galaxyproject/galaxy/issues/18541 + - parameter: {__class__: 'ConnectedValue'} + # they are non-optinoal right? + - parameter: null gx_select_multiple_optional: request_valid: @@ -154,6 +306,19 @@ gx_hidden: - parameter: 5 - parameter: {} - parameter: { "moo": "cow" } + workflow_step_valid: + - parameter: moocow + - {} + workflow_step_invalid: + - parmaeter: 5 + - parameter: {} + - parameter: { "moo": "cow" } + workflow_step_linked_valid: + - parameter: moocow + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: 5 + - parameter: null gx_hidden_optional: request_valid: @@ -165,7 +330,21 @@ gx_hidden_optional: - parameter: 5 - parameter: {} - parameter: { "moo": "cow" } - + workflow_step_valid: + - parameter: moocow + - {} + - parameter: null + workflow_step_invalid: + - parmaeter: 5 + - parameter: {} + - parameter: { "moo": "cow" } + workflow_step_linked_valid: + - parameter: moocow + - parameter: {__class__: 'ConnectedValue'} + - parameter: null + workflow_step_linked_invalid: + - parameter: 5 + gx_float: request_valid: - parameter: 5 @@ -178,6 +357,30 @@ gx_float: - parameter: "5" - parameter: "5.0" - parameter: { "moo": "cow" } + test_case_valid: + - parameter: 5 + - parameter: 5.0 + - {} + test_case_invalid: + - parameter: null + - parameter: "5.0" + - parameter: "5.1" + workflow_step_valid: + - parameter: 5 + - parameter: 5.0 + - {} + workflow_step_invalid: + - parameterx: 5 + - parameter: 'foobar' + workflow_step_linked_valid: + - parameter: 5 + - parameter: 5.4 + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: null + - parameter: 'foobar' + - parameter: {__class__: 'ConnectedValue2'} + gx_float_optional: request_valid: @@ -191,6 +394,29 @@ gx_float_optional: - parameter: "5.0" - parameter: {} - parameter: { "moo": "cow" } + test_case_valid: + - parameter: 5 + - parameter: 5.0 + - {} + - parameter: null + test_case_invalid: + - parameter: "5.0" + - parameter: "5.1" + workflow_step_valid: + - parameter: 5 + - parameter: 5.0 + - {} + workflow_step_invalid: + - parameterx: 5 + - parameter: 'foobar' + workflow_step_linked_valid: + - parameter: 5 + - parameter: 5.4 + - parameter: {__class__: 'ConnectedValue'} + - parameter: null + workflow_step_linked_invalid: + - parameter: 'foobar' + - parameter: {__class__: 'ConnectedValue2'} gx_color: request_valid: @@ -200,6 +426,22 @@ gx_color: - parameter: null - parameter: {} - parameter: '#abcd' + workflow_step_valid: + - parameter: '#aabbcc' + - parameter: '#000000' + - {} + workflow_step_invalid: + - parameterx: '#aabbcc' + - parameter: 'foobar' + - parameter: 5 + workflow_step_linked_valid: + - parameter: '#aabbcc' + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - parameter: null + - parameter: 'foobar' + - parameter: 5 + - parameter: {__class__: 'ConnectedValue2'} gx_data: request_valid: @@ -236,6 +478,21 @@ gx_data: # expanded out. - parameter: {__class__: "Batch", values: [{src: hdca, id: 5}]} - parameter: {src: hda, id: abcdabcd} + workflow_step_valid: + - {} + workflow_step_invalid: + - {src: hda, id: abcdabcd} + - {src: hda, id: 7} + - parameter: {__class__: "Batch", values: [{src: hdca, id: 5}]} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - {} + - {src: hda, id: abcdabcd} + - {src: hda, id: 7} + - parameter: {__class__: "Batch", values: [{src: hdca, id: 5}]} + - parameter: {__class__: 'ConnectedValueX'} gx_data_optional: @@ -265,6 +522,21 @@ gx_data_optional: - parameter: true - parameter: 5 - parameter: "5" + workflow_step_valid: + - {} + workflow_step_invalid: + - {src: hda, id: abcdabcd} + - {src: hda, id: 7} + - parameter: {__class__: "Batch", values: [{src: hdca, id: 5}]} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_valid: + - {} + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - {src: hda, id: abcdabcd} + - {src: hda, id: 7} + - parameter: {__class__: "Batch", values: [{src: hdca, id: 5}]} + - parameter: {__class__: 'ConnectedValueX'} gx_data_multiple: request_valid: @@ -353,6 +625,16 @@ gx_data_collection: - parameter: true - parameter: 5 - parameter: "5" + workflow_step_valid: + - {} + workflow_step_invalid: + - parameter: {src: hdca, id: abcdabcd} + - parameter: 5 + - parameter: {} + workflow_step_linked_valid: + - parameter: {__class__: 'ConnectedValue'} + workflow_step_linked_invalid: + - {} gx_data_collection_optional: request_valid: diff --git a/test/unit/tool_util/test_parameter_specification.py b/test/unit/tool_util/test_parameter_specification.py index 52fd285a9b8a..0e7b7eb8c687 100644 --- a/test/unit/tool_util/test_parameter_specification.py +++ b/test/unit/tool_util/test_parameter_specification.py @@ -19,6 +19,8 @@ validate_internal_request, validate_request, validate_test_case, + validate_workflow_step, + validate_workflow_step_linked, ) from galaxy.tool_util.parameters.json import to_json_schema_string from galaxy.tool_util.unittest_utils.parameters import parameter_bundle_for_file @@ -68,6 +70,10 @@ def _test_file(file: str, specification=None): "job_internal_invalid": _assert_internal_jobs_invalid, "test_case_valid": _assert_test_cases_validate, "test_case_invalid": _assert_test_cases_invalid, + "workflow_step_valid": _assert_workflow_steps_validate, + "workflow_step_invalid": _assert_workflow_steps_invalid, + "workflow_step_linked_valid": _assert_workflow_steps_linked_validate, + "workflow_step_linked_invalid": _assert_workflow_steps_linked_invalid, } for valid_or_invalid, tests in combos.items(): @@ -158,6 +164,44 @@ def _assert_test_case_invalid(parameters: ToolParameterBundleModel, test_case: R ), f"Parameters {parameters} didn't result in validation error on test_case {test_case} as expected." +def _assert_workflow_step_validates(parameters: ToolParameterBundleModel, workflow_step: RawStateDict) -> None: + try: + validate_workflow_step(parameters, workflow_step: RawStateDict) + except RequestParameterInvalidException as e: + raise AssertionError(f"Parameter {parameter} failed to validate workflow step {workflow_step}. {e}") + + +def _assert_workflow_step_invalid(parameters: ToolParameterBundleModel, workflow_step: RawStateDict) -> None: + exc = None + try: + validate_workflow_step(parameters, workflow_step) + except RequestParameterInvalidException as e: + exc = e + assert ( + exc is not None + ), f"Parameter {parameter} didn't result in validation error on workflow step {workflow_step} as expected." + + +def _assert_workflow_step_linked_validates(parameters: ToolParameterBundleModel, workflow_step_linked: RawStateDict) -> None: + try: + validate_workflow_step_linked(parameters, workflow_step_linked) + except RequestParameterInvalidException as e: + raise AssertionError( + f"Parameter {parameter} failed to validate linked workflow step {workflow_step_linked}. {e}" + ) + + +def _assert_workflow_step_linked_invalid(parameters: ToolParameterBundleModel, workflow_step_linked: RawStateDict) -> None: + exc = None + try: + validate_workflow_step_linked(parameters, workflow_step_linked) + except RequestParameterInvalidException as e: + exc = e + assert ( + exc is not None + ), f"Parameter {parameter} didn't result in validation error on linked workflow step {workflow_step_linked} as expected." + + _assert_requests_validate = partial(_for_each, _assert_request_validates) _assert_requests_invalid = partial(_for_each, _assert_request_invalid) _assert_internal_requests_validate = partial(_for_each, _assert_internal_request_validates) @@ -166,6 +210,10 @@ def _assert_test_case_invalid(parameters: ToolParameterBundleModel, test_case: R _assert_internal_jobs_invalid = partial(_for_each, _assert_internal_job_invalid) _assert_test_cases_validate = partial(_for_each, _assert_test_case_validates) _assert_test_cases_invalid = partial(_for_each, _assert_test_case_invalid) +_assert_workflow_steps_validate = partial(_for_each, _assert_workflow_step_validates) +_assert_workflow_steps_invalid = partial(_for_each, _assert_workflow_step_invalid) +_assert_workflow_steps_linked_validate = partial(_for_each, _assert_workflow_step_linked_validates) +_assert_workflow_steps_linked_invalid = partial(_for_each, _assert_workflow_step_linked_invalid) def decode_val(val: str) -> int: diff --git a/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py b/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py new file mode 100644 index 000000000000..efc3b8c29ae7 --- /dev/null +++ b/test/unit/tool_util/workflow_state/test_workflow_state_helpers.py @@ -0,0 +1,25 @@ +from galaxy.workflow.validator import repeat_inputs_to_array + + +def test_repeat_inputs_to_array(): + rval = repeat_inputs_to_array( + "repeatfoo", + { + "moo": "cow", + }, + ) + assert not rval + rval = repeat_inputs_to_array( + "repeatfoo", + { + "moo": "cow", + "repeatfoo_0|moocow": ["moo"], + "repeatfoo_2|moocow": ["cow"], + }, + ) + assert len(rval) == 3 + assert "repeatfoo_0|moocow" in rval[0] + assert "repeatfoo_0|moocow" not in rval[1] + assert "repeatfoo_0|moocow" not in rval[2] + assert "repeatfoo_2|moocow" not in rval[1] + assert "repeatfoo_2|moocow" in rval[2] diff --git a/test/unit/workflows/invalid/extra_attribute.gxwf.yml b/test/unit/workflows/invalid/extra_attribute.gxwf.yml new file mode 100644 index 000000000000..6ae50799394c --- /dev/null +++ b/test/unit/workflows/invalid/extra_attribute.gxwf.yml @@ -0,0 +1,15 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + state: + parameter2: 6 + in: + parameter: input diff --git a/test/unit/workflows/invalid/missing_link.gxwf.yml b/test/unit/workflows/invalid/missing_link.gxwf.yml new file mode 100644 index 000000000000..526b40f6f502 --- /dev/null +++ b/test/unit/workflows/invalid/missing_link.gxwf.yml @@ -0,0 +1,11 @@ +class: GalaxyWorkflow +inputs: + input: + type: data +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_data + tool_version: "1.0.0" diff --git a/test/unit/workflows/invalid/wrong_link_name.gxwf.yml b/test/unit/workflows/invalid/wrong_link_name.gxwf.yml new file mode 100644 index 000000000000..f0e0e8d12004 --- /dev/null +++ b/test/unit/workflows/invalid/wrong_link_name.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + in: + parameterx: input diff --git a/test/unit/workflows/test_workflow_state_conversion.py b/test/unit/workflows/test_workflow_state_conversion.py new file mode 100644 index 000000000000..8289e1c66972 --- /dev/null +++ b/test/unit/workflows/test_workflow_state_conversion.py @@ -0,0 +1,16 @@ +from galaxy.tool_util.workflow_state.convert import ( + convert_state_to_format2, + Format2InputsDictT, +) +from galaxy.workflow.gx_validator import GET_TOOL_INFO +from .test_workflow_validation import base_package_workflow_as_dict + + +def convert_state(native_step_dict: dict) -> Format2InputsDictT: + return convert_state_to_format2(native_step_dict, GET_TOOL_INFO) + + +def test_simple_convert(): + workflow_dict = base_package_workflow_as_dict("test_workflow_1.ga") + cat_step = workflow_dict["steps"]["2"] + format2_pair = convert_state(cat_step) diff --git a/test/unit/workflows/test_workflow_validation.py b/test/unit/workflows/test_workflow_validation.py new file mode 100644 index 000000000000..51dbd855806c --- /dev/null +++ b/test/unit/workflows/test_workflow_validation.py @@ -0,0 +1,75 @@ +import os +from typing import Optional + +from gxformat2.yaml import ordered_load + +from galaxy.util import galaxy_directory +from galaxy.workflow.gx_validator import validate_workflow + +TEST_WORKFLOW_DIRECTORY = os.path.join(galaxy_directory(), "lib", "galaxy_test", "workflow") +TEST_BASE_DATA_DIRECTORY = os.path.join(galaxy_directory(), "lib", "galaxy_test", "base", "data") +SCRIPT_DIRECTORY = os.path.abspath(os.path.dirname(__file__)) + + +def test_validate_simple_functional_test_case_workflow(): + validate_workflow(framework_test_workflow_as_dict("multiple_versions")) + validate_workflow(framework_test_workflow_as_dict("zip_collection")) + validate_workflow(framework_test_workflow_as_dict("empty_collection_sort")) + validate_workflow(framework_test_workflow_as_dict("flatten_collection")) + validate_workflow(framework_test_workflow_as_dict("flatten_collection_over_execution")) + + +def test_validate_native_workflows(): + validate_workflow(base_package_workflow_as_dict("test_workflow_two_random_lines.ga")) + validate_workflow(base_package_workflow_as_dict("test_workflow_topoambigouity.ga")) + validate_workflow(base_package_workflow_as_dict("test_Workflow_map_reduce_pause.ga")) + validate_workflow(base_package_workflow_as_dict("test_subworkflow_with_integer_inputs.ga")) + validate_workflow(base_package_workflow_as_dict("test_workflow_batch.ga")) + +def test_validate_unit_test_workflows(): + validate_workflow(unit_test_workflow_as_dict("valid/simple_int")) + validate_workflow(unit_test_workflow_as_dict("valid/simple_data")) + + +def test_invalidate_with_extra_attribute(): + e = _assert_validation_failure("invalid/extra_attribute") + assert "parameter2" in str(e) + + +def test_invalidate_with_wrong_link_name(): + e = _assert_validation_failure("invalid/wrong_link_name") + assert "parameterx" in str(e) + + +def test_invalidate_with_missing_link(): + e = _assert_validation_failure("invalid/missing_link") + assert "parameter" in str(e) + assert "type=missing" in str(e) + + +def _assert_validation_failure(workflow_name: str) -> Exception: + as_dict = unit_test_workflow_as_dict(workflow_name) + exc: Optional[Exception] = None + try: + validate_workflow(as_dict) + except Exception as e: + exc = e + assert exc, f"Target workflow ({workflow_name}) did not failure validation as expected." + return exc + + +def base_package_workflow_as_dict(file_name: str) -> dict: + return _load(os.path.join(TEST_BASE_DATA_DIRECTORY, file_name)) + + +def unit_test_workflow_as_dict(workflow_name: str) -> dict: + return _load(os.path.join(SCRIPT_DIRECTORY, f"{workflow_name}.gxwf.yml")) + + +def framework_test_workflow_as_dict(workflow_name: str) -> dict: + return _load(os.path.join(TEST_WORKFLOW_DIRECTORY, f"{workflow_name}.gxwf.yml")) + + +def _load(path: str) -> dict: + with open(path) as f: + return ordered_load(f) diff --git a/test/unit/workflows/test_workflow_validation_helpers.py b/test/unit/workflows/test_workflow_validation_helpers.py new file mode 100644 index 000000000000..af74e2b32a5f --- /dev/null +++ b/test/unit/workflows/test_workflow_validation_helpers.py @@ -0,0 +1,13 @@ +from galaxy.workflow.gx_validator import GET_TOOL_INFO + + +def test_get_tool(): + parsed_tool = GET_TOOL_INFO.get_tool_info("cat1", "1.0.0") + assert parsed_tool + assert parsed_tool.id == "cat1" + assert parsed_tool.version == "1.0.0" + + parsed_tool = GET_TOOL_INFO.get_tool_info("cat1", None) + assert parsed_tool + assert parsed_tool.id == "cat1" + assert parsed_tool.version == "1.0.0" diff --git a/test/unit/workflows/valid/simple_data.gxwf.yml b/test/unit/workflows/valid/simple_data.gxwf.yml new file mode 100644 index 000000000000..44f0a90f3dd9 --- /dev/null +++ b/test/unit/workflows/valid/simple_data.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: data +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_data + tool_version: "1.0.0" + in: + parameter: input diff --git a/test/unit/workflows/valid/simple_int.gxwf.yml b/test/unit/workflows/valid/simple_int.gxwf.yml new file mode 100644 index 000000000000..d7c53f78d0a6 --- /dev/null +++ b/test/unit/workflows/valid/simple_int.gxwf.yml @@ -0,0 +1,13 @@ +class: GalaxyWorkflow +inputs: + input: + type: int +outputs: + output: + outputSource: the_step/output +steps: + the_step: + tool_id: gx_int + tool_version: "1.0.0" + in: + parameter: input