Skip to content

Commit

Permalink
First pass at workflow step models - linked and unlinked.
Browse files Browse the repository at this point in the history
Work scoped out in 18536.
  • Loading branch information
jmchilton committed Aug 11, 2024
1 parent bfdbf05 commit 5efb22b
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 44 deletions.
29 changes: 18 additions & 11 deletions doc/source/dev/tool_state_state_classes.plantuml.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 17 additions & 1 deletion lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,24 @@
validate_internal_request,
validate_request,
validate_test_case,
validate_workflow_step,
validate_workflow_step_linked,
)
from .state import (
JobInternalToolState,
RequestInternalToolState,
RequestToolState,
TestCaseToolState,
ToolState,
WorkflowStepLinkedToolState,
WorkflowStepToolState,
)
from .visitor import (
flat_state_path,
keys_starting_with,
repeat_inputs_to_array,
visit_input_values,
)
from .visitor import visit_input_values

__all__ = (
"from_input_source",
Expand Down Expand Up @@ -89,13 +98,20 @@
"validate_internal_request",
"validate_request",
"validate_test_case",
"validate_workflow_step",
"validate_workflow_step_linked",
"ToolState",
"TestCaseToolState",
"ToolParameterT",
"to_json_schema_string",
"RequestToolState",
"RequestInternalToolState",
"flat_state_path",
"keys_starting_with",
"visit_input_values",
"repeat_inputs_to_array",
"decode",
"encode",
"WorkflowStepToolState",
"WorkflowStepLinkedToolState",
)
7 changes: 6 additions & 1 deletion lib/galaxy/tool_util/parameters/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
)


def optional(type: Type) -> Type:
return_type: Type = Optional[type] # type: ignore[assignment]
return return_type


def optional_if_needed(type: Type, is_optional: bool) -> Type:
return_type: Type = type
if is_optional:
return_type = Optional[type] # type: ignore[assignment]
return_type = optional(type)
return return_type


Expand Down
12 changes: 2 additions & 10 deletions lib/galaxy/tool_util/parameters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BaseUrlParameterModel,
BooleanParameterModel,
ColorParameterModel,
cond_test_parameter_default_value,
ConditionalParameterModel,
ConditionalWhen,
CwlBooleanParameterModel,
Expand Down Expand Up @@ -211,16 +212,7 @@ def _from_input_source_galaxy(input_source: InputSource) -> ToolParameterT:
Union[BooleanParameterModel, SelectParameterModel], _from_input_source_galaxy(test_param_input_source)
)
whens = []
default_value = object()
if isinstance(test_parameter, BooleanParameterModel):
default_value = test_parameter.value
elif isinstance(test_parameter, SelectParameterModel):
select_parameter = cast(SelectParameterModel, test_parameter)
select_default_value = select_parameter.default_value
if select_default_value is not None:
default_value = select_default_value

# TODO: handle select parameter model...
default_value = cond_test_parameter_default_value(test_parameter)
for value, case_inputs_sources in input_source.parse_when_input_sources():
if isinstance(test_parameter, BooleanParameterModel):
# TODO: investigate truevalue/falsevalue when...
Expand Down
153 changes: 136 additions & 17 deletions lib/galaxy/tool_util/parameters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
cast_as_type,
is_optional,
list_type,
optional,
optional_if_needed,
union_type,
)
Expand All @@ -55,7 +56,9 @@
# + request: Return info needed to build request pydantic model at runtime.
# + request_internal: This is a pydantic model to validate what Galaxy expects to find in the database,
# in particular dataset and collection references should be decoded integers.
StateRepresentationT = Literal["request", "request_internal", "job_internal", "test_case"]
StateRepresentationT = Literal[
"request", "request_internal", "job_internal", "test_case", "workflow_step", "workflow_step_linked"
]


# could be made more specific - validators need to be classmethod
Expand All @@ -72,6 +75,14 @@ class StrictModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class ConnectedValue(BaseModel):
discriminator: Literal["ConnectedValue"] = Field(alias="__class__")


def allow_connected_value(type: Type):
return union_type([type, ConnectedValue])


def allow_batching(job_template: DynamicModelInformation, batch_type: Optional[Type] = None) -> DynamicModelInformation:
job_py_type: Type = job_template.definition[0]
default_value = job_template.definition[1]
Expand Down Expand Up @@ -107,11 +118,15 @@ def request_requires_value(self) -> bool:
...


def dynamic_model_information_from_py_type(param_model: ParamModel, py_type: Type):
def dynamic_model_information_from_py_type(
param_model: ParamModel, py_type: Type, requires_value: Optional[bool] = None
):
name = param_model.name
initialize = ... if param_model.request_requires_value else None
if requires_value is None:
requires_value = param_model.request_requires_value
initialize = ... if requires_value else None
py_type_is_optional = is_optional(py_type)
if not py_type_is_optional and not param_model.request_requires_value:
if not py_type_is_optional and not requires_value:
validators = {"not_null": field_validator(name)(Validators.validate_not_none)}
else:
validators = {}
Expand Down Expand Up @@ -161,7 +176,10 @@ def py_type(self) -> Type:
return optional_if_needed(StrictStr, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
py_type = self.py_type
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
return dynamic_model_information_from_py_type(self, py_type)

@property
def request_requires_value(self) -> bool:
Expand All @@ -180,7 +198,10 @@ def py_type(self) -> Type:
return optional_if_needed(StrictInt, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
py_type = self.py_type
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
return dynamic_model_information_from_py_type(self, py_type)

@property
def request_requires_value(self) -> bool:
Expand All @@ -198,7 +219,10 @@ def py_type(self) -> Type:
return optional_if_needed(union_type([StrictInt, StrictFloat]), self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
py_type = self.py_type
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
return dynamic_model_information_from_py_type(self, py_type)

@property
def request_requires_value(self) -> bool:
Expand Down Expand Up @@ -302,6 +326,10 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam
return dynamic_model_information_from_py_type(self, self.py_type_internal)
elif state_representation == "test_case":
return dynamic_model_information_from_py_type(self, self.py_type_test_case)
elif state_representation == "workflow_step":
return dynamic_model_information_from_py_type(self, type(None), requires_value=False)
elif state_representation == "workflow_step_linked":
return dynamic_model_information_from_py_type(self, ConnectedValue)

@property
def request_requires_value(self) -> bool:
Expand Down Expand Up @@ -337,8 +365,14 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam
return allow_batching(dynamic_model_information_from_py_type(self, self.py_type))
elif state_representation == "request_internal":
return allow_batching(dynamic_model_information_from_py_type(self, self.py_type_internal))
elif state_representation == "workflow_step":
return dynamic_model_information_from_py_type(self, type(None), requires_value=False)
elif state_representation == "workflow_step_linked":
return dynamic_model_information_from_py_type(self, ConnectedValue)
else:
raise NotImplementedError("...")
raise NotImplementedError(
f"Have not implemented data collection parameter models for state representation {state_representation}"
)

@property
def request_requires_value(self) -> bool:
Expand All @@ -354,7 +388,15 @@ def py_type(self) -> Type:
return optional_if_needed(StrictStr, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
py_type = self.py_type
requires_value = not self.optional and self.value is None
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
elif state_representation == "workflow_step" and not self.optional:
# allow it to be linked in so force allow optional...
py_type = optional(py_type)
requires_value = False
return dynamic_model_information_from_py_type(self, py_type, requires_value=requires_value)

@property
def request_requires_value(self) -> bool:
Expand Down Expand Up @@ -390,11 +432,34 @@ def validate_color_str(value) -> str:
ensure_color_valid(value)
return value

@staticmethod
def validate_color_str_if_value(value) -> str:
if value:
ensure_color_valid(value)
return value

@staticmethod
def validate_color_str_or_connected_value(value) -> str:
if not isinstance(value, ConnectedValue):
ensure_color_valid(value)
return value

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str)}
py_type = self.py_type
initialize: Any = ...
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
validators = {
"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str_or_connected_value)
}
elif state_representation == "workflow_step":
initialize = None
validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str_if_value)}
else:
validators = {"color_format": field_validator(self.name)(ColorParameterModel.validate_color_str)}
return DynamicModelInformation(
self.name,
(self.py_type, ...),
(py_type, initialize),
validators,
)

Expand All @@ -414,7 +479,10 @@ def py_type(self) -> Type:
return optional_if_needed(StrictBool, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
py_type = self.py_type
if state_representation == "workflow_step_linked":
py_type = allow_connected_value(py_type)
return dynamic_model_information_from_py_type(self, py_type)

@property
def request_requires_value(self) -> bool:
Expand Down Expand Up @@ -468,8 +536,7 @@ class SelectParameterModel(BaseGalaxyToolParameterModelDefinition):
options: Optional[List[LabelValue]] = None
multiple: bool

@property
def py_type(self) -> Type:
def py_type_if_required(self, allow_connections=False) -> Type:
if self.options is not None:
if len(self.options) > 0:
literal_options: List[Type] = [cast_as_type(Literal[o.value]) for o in self.options]
Expand All @@ -479,11 +546,31 @@ def py_type(self) -> Type:
else:
py_type = StrictStr
if self.multiple:
py_type = list_type(py_type)
return optional_if_needed(py_type, self.optional)
if allow_connections:
py_type = list_type(allow_connected_value(py_type))
else:
py_type = list_type(py_type)
elif allow_connections:
py_type = allow_connected_value(py_type)
return py_type

@property
def py_type(self) -> Type:
return optional_if_needed(self.py_type_if_required(), self.optional)

@property
def py_type_workflow_step(self) -> Type:
# this is always optional in this context
return optional(self.py_type_if_required())

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
return dynamic_model_information_from_py_type(self, self.py_type)
if state_representation == "workflow_step":
return dynamic_model_information_from_py_type(self, self.py_type_workflow_step, requires_value=False)
elif state_representation == "workflow_step_linked":
py_type = self.py_type_if_required(allow_connections=True)
return dynamic_model_information_from_py_type(self, optional_if_needed(py_type, self.optional))
else:
return dynamic_model_information_from_py_type(self, self.py_type)

@property
def has_selected_static_option(self):
Expand Down Expand Up @@ -658,6 +745,20 @@ def request_requires_value(self) -> bool:
DiscriminatorType = Union[bool, str]


def cond_test_parameter_default_value(
test_parameter: Union["BooleanParameterModel", "SelectParameterModel"]
) -> Optional[DiscriminatorType]:
default_value: Optional[DiscriminatorType] = None
if isinstance(test_parameter, BooleanParameterModel):
default_value = test_parameter.value
elif isinstance(test_parameter, SelectParameterModel):
select_parameter = cast(SelectParameterModel, test_parameter)
select_default_value = select_parameter.default_value
if select_default_value is not None:
default_value = select_default_value
return default_value


class ConditionalWhen(StrictModel):
discriminator: DiscriminatorType
parameters: List["ToolParameterT"]
Expand Down Expand Up @@ -1068,6 +1169,14 @@ def create_test_case_model(tool: ToolParameterBundle, name: str = "DynamicModelF
return create_field_model(tool.input_models, name, "test_case")


def create_workflow_step_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step")


def create_workflow_step_linked_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step_linked")


def create_field_model(
tool_parameter_models: Union[List[ToolParameterModel], List[ToolParameterT]],
name: str,
Expand Down Expand Up @@ -1120,3 +1229,13 @@ def validate_internal_job(tool: ToolParameterBundle, request: Dict[str, Any]) ->
def validate_test_case(tool: ToolParameterBundle, request: Dict[str, Any]) -> None:
pydantic_model = create_test_case_model(tool)
validate_against_model(pydantic_model, request)


def validate_workflow_step(tool: ToolParameterBundle, request: Dict[str, Any]) -> None:
pydantic_model = create_workflow_step_model(tool)
validate_against_model(pydantic_model, request)


def validate_workflow_step_linked(tool: ToolParameterBundle, request: Dict[str, Any]) -> None:
pydantic_model = create_workflow_step_linked_model(tool)
validate_against_model(pydantic_model, request)
Loading

0 comments on commit 5efb22b

Please sign in to comment.