Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflow init #634

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
self._payload_converter = det.payload_converter_class()
self._failure_converter = det.failure_converter_class()
self._defn = det.defn
self._workflow_input: Optional[ExecuteWorkflowInput] = None
self._info = det.info
self._extern_functions = det.extern_functions
self._disable_eager_activity_execution = det.disable_eager_activity_execution
Expand Down Expand Up @@ -318,8 +319,9 @@ def get_thread_id(self) -> Optional[int]:
return self._current_thread_id

#### Activation functions ####
# These are in alphabetical order and besides "activate", all other calls
# are "_apply_" + the job field name.
# These are in alphabetical order and besides "activate", and
# "_make_workflow_input", all other calls are "_apply_" + the job field
# name.

def activate(
self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation
Expand All @@ -342,6 +344,7 @@ def activate(
try:
# Split into job sets with patches, then signals + updates, then
# non-queries, then queries
start_job = None
job_sets: List[
List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob]
] = [[], [], [], []]
Expand All @@ -351,10 +354,15 @@ def activate(
elif job.HasField("signal_workflow") or job.HasField("do_update"):
job_sets[1].append(job)
elif not job.HasField("query_workflow"):
if job.HasField("start_workflow"):
start_job = job.start_workflow
job_sets[2].append(job)
else:
job_sets[3].append(job)

if start_job:
self._workflow_input = self._make_workflow_input(start_job)

# Apply every job set, running after each set
for index, job_set in enumerate(job_sets):
if not job_set:
Expand Down Expand Up @@ -863,34 +871,41 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
return
raise

if not self._workflow_input:
raise RuntimeError(
"Expected workflow input to be set. This is an SDK Python bug."
)
self._primary_task = self.create_task(
self._run_top_level_workflow_function(run_workflow(self._workflow_input)),
name="run",
)

def _apply_update_random_seed(
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
) -> None:
self._random.seed(job.randomness_seed)

def _make_workflow_input(
self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow
) -> ExecuteWorkflowInput:
# Set arg types, using raw values for dynamic
arg_types = self._defn.arg_types
if not self._defn.name:
# Dynamic is just the raw value for each input value
arg_types = [temporalio.common.RawValue] * len(job.arguments)
args = self._convert_payloads(job.arguments, arg_types)
arg_types = [temporalio.common.RawValue] * len(start_job.arguments)
args = self._convert_payloads(start_job.arguments, arg_types)
# Put args in a list if dynamic
if not self._defn.name:
args = [args]

# Schedule it
input = ExecuteWorkflowInput(
return ExecuteWorkflowInput(
type=self._defn.cls,
# TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed
run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn),
args=args,
headers=job.headers,
)
self._primary_task = self.create_task(
self._run_top_level_workflow_function(run_workflow(input)),
name="run",
headers=start_job.headers,
)

def _apply_update_random_seed(
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
) -> None:
self._random.seed(job.randomness_seed)

#### _Runtime direct workflow call overrides ####
# These are in alphabetical order and all start with "workflow_".

Expand Down Expand Up @@ -1617,6 +1632,14 @@ def _convert_payloads(
except Exception as err:
raise RuntimeError("Failed decoding arguments") from err

def _instantiate_workflow_object(self) -> Any:
if not self._workflow_input:
raise RuntimeError("Expected workflow input. This is a Python SDK bug.")
if hasattr(self._defn.cls.__init__, "__temporal_workflow_init"):
return self._defn.cls(*self._workflow_input.args)
else:
return self._defn.cls()

def _is_workflow_failure_exception(self, err: BaseException) -> bool:
# An exception is a failure instead of a task fail if it's already a
# failure error or if it is an instance of any of the failure types in
Expand Down Expand Up @@ -1752,7 +1775,7 @@ def _run_once(self, *, check_conditions: bool) -> None:
# We instantiate the workflow class _inside_ here because __init__
# needs to run with this event loop set
if not self._object:
self._object = self._defn.cls()
self._object = self._instantiate_workflow_object()

# Run while there is anything ready
while self._ready:
Expand Down
45 changes: 36 additions & 9 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,38 @@ def decorator(cls: ClassType) -> ClassType:
return decorator


def init(
init_fn: CallableType,
) -> CallableType:
"""Decorator for the workflow init method.

This may be used on the __init__ method of the workflow class to specify
that it accepts the same workflow input arguments as the ``@workflow.run``
method. It may not be used on any other method.

If used, the workflow will be instantiated as
``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be
instantiated as ``MyWorkflow()``.

Note that the ``@workflow.run`` method is always called as
``my_workflow.my_run_method(**workflow_input_args)``. If you use the
``@workflow.init`` decorator, your __init__ method and your
``@workflow.run`` method will typically have exactly the same parameters.

Args:
init_fn: The __init__function to decorate.
"""
if init_fn.__name__ != "__init__":
raise ValueError("@workflow.init may only be used on the __init__ method")

setattr(init_fn, "__temporal_workflow_init", True)
return init_fn


def run(fn: CallableAsyncType) -> CallableAsyncType:
"""Decorator for the workflow run method.

This must be set on one and only one async method defined on the same class
This must be used on one and only one async method defined on the same class
as ``@workflow.defn``. This can be defined on a base class method but must
then be explicitly overridden and defined on the workflow class.

Expand Down Expand Up @@ -238,7 +266,7 @@ def signal(
):
"""Decorator for a workflow signal method.

This is set on any async or non-async method that you wish to be called upon
This is used on any async or non-async method that you wish to be called upon
receiving a signal. If a function overrides one with this decorator, it too
must be decorated.

Expand Down Expand Up @@ -309,7 +337,7 @@ def query(
):
"""Decorator for a workflow query method.

This is set on any non-async method that expects to handle a query. If a
This is used on any non-async method that expects to handle a query. If a
function overrides one with this decorator, it too must be decorated.

Query methods can only have positional parameters. Best practice for
Expand Down Expand Up @@ -983,7 +1011,7 @@ def update(
):
"""Decorator for a workflow update handler method.

This is set on any async or non-async method that you wish to be called upon
This is used on any async or non-async method that you wish to be called upon
receiving an update. If a function overrides one with this decorator, it too
must be decorated.

Expand Down Expand Up @@ -1307,13 +1335,12 @@ def _apply_to_class(
issues: List[str] = []

# Collect run fn and all signal/query/update fns
members = inspect.getmembers(cls)
run_fn: Optional[Callable[..., Awaitable[Any]]] = None
seen_run_attr = False
signals: Dict[Optional[str], _SignalDefinition] = {}
queries: Dict[Optional[str], _QueryDefinition] = {}
updates: Dict[Optional[str], _UpdateDefinition] = {}
for name, member in members:
for name, member in inspect.getmembers(cls):
if hasattr(member, "__temporal_workflow_run"):
seen_run_attr = True
if not _is_unbound_method_on_cls(member, cls):
Expand Down Expand Up @@ -1406,9 +1433,9 @@ def _apply_to_class(

if not seen_run_attr:
issues.append("Missing @workflow.run method")
if len(issues) == 1:
raise ValueError(f"Invalid workflow class: {issues[0]}")
elif issues:
if issues:
if len(issues) == 1:
raise ValueError(f"Invalid workflow class: {issues[0]}")
raise ValueError(
f"Invalid workflow class for {len(issues)} reasons: {', '.join(issues)}"
)
Expand Down
Loading
Loading