Skip to content

Commit

Permalink
Merge pull request #485 from verotel/task_all_params_fix
Browse files Browse the repository at this point in the history
Pass all variables to task when `TaskConfig.variables_to_fetch==[]`
  • Loading branch information
dimastbk authored Sep 26, 2024
2 parents 1b3e48d + e3a88da commit 614559d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
21 changes: 19 additions & 2 deletions pyzeebe/task/task_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import functools
import inspect
import logging
from typing import Any, Sequence, Tuple, TypeVar
from inspect import Parameter
from typing import Any, Callable, Dict, Sequence, Tuple, TypeVar

from typing_extensions import ParamSpec

Expand Down Expand Up @@ -66,13 +68,19 @@ async def run_original_task_function(
) -> Tuple[Variables, bool]:
try:
if task_config.variables_to_fetch is None:
variables = {}
variables: Dict[str, Any] = {}
elif task_wants_all_variables(task_config):
if only_job_is_required_in_task_function(task_function):
variables = {}
else:
variables = {**job.variables}
else:
variables = {
k: v
for k, v in job.variables.items()
if k in task_config.variables_to_fetch or k == task_config.job_parameter_name
}

if task_config.job_parameter_name:
variables[task_config.job_parameter_name] = job

Expand All @@ -89,6 +97,15 @@ async def run_original_task_function(
return job.variables, False


def only_job_is_required_in_task_function(task_function: DictFunction[...]) -> bool:
function_signature = inspect.signature(task_function)
return all(param.annotation == Job for param in function_signature.parameters.values())


def task_wants_all_variables(task_config: TaskConfig) -> bool:
return task_config.variables_to_fetch == []


def create_decorator_runner(decorators: Sequence[AsyncTaskDecorator]) -> DecoratorRunner:
async def decorator_runner(job: Job) -> Job:
for decorator in decorators:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def first_active_job(task, job_from_task, grpc_servicer) -> str:


@pytest.fixture
def task_config(task_type):
def task_config(task_type, variables_to_fetch=None):
return TaskConfig(
type=task_type,
exception_handler=AsyncMock(),
timeout_ms=10000,
max_jobs_to_activate=32,
max_running_jobs=32,
variables_to_fetch=[],
variables_to_fetch=variables_to_fetch or [],
single_value=False,
variable_name="",
before=[],
Expand Down
52 changes: 44 additions & 8 deletions tests/unit/task/task_builder_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
from typing import Callable
from unittest.mock import MagicMock

import pytest

from pyzeebe import Job, JobController, TaskDecorator
from pyzeebe.function_tools.parameter_tools import get_parameters_from_function
from pyzeebe.job.job_status import JobStatus
from pyzeebe.task import task_builder
from pyzeebe.task.task import Task
Expand Down Expand Up @@ -108,29 +110,41 @@ async def test_parameters_are_provided_to_task(
async def test_parameters_are_provided_to_task_with_only_job(
self, original_task_function: Callable, task_config: TaskConfig, mocked_job_controller: JobController
):
task_with_only_job_called = False

def task_with_only_job(job: Job):
nonlocal task_with_only_job_called
task_with_only_job_called = True

job = random_job(variables={"x": 1})
task_config.job_parameter_name = "job"
task_config.variables_to_fetch = []

job_handler = task_builder.build_job_handler(original_task_function, task_config)
job_handler = task_builder.build_job_handler(task_with_only_job, task_config)

await job_handler(job, mocked_job_controller)

original_task_function.assert_called_with(job=job)
assert task_with_only_job_called, "Task was called ok"

@pytest.mark.asyncio
async def test_parameters_are_provided_to_task_with_arg_and_job(
self, original_task_function: Callable, task_config: TaskConfig, mocked_job_controller: JobController
self, task_config: TaskConfig, mocked_job_controller: JobController
):
call_params = None

def task_with_arg_and_job(job: Job, x):
nonlocal call_params
call_params = {"job": job, "x": x}

job = random_job(variables={"x": 1})
task_config.job_parameter_name = "job"
task_config.variables_to_fetch = ["x"]

job_handler = task_builder.build_job_handler(original_task_function, task_config)
job_handler = task_builder.build_job_handler(task_with_arg_and_job, task_config)

await job_handler(job, mocked_job_controller)

original_task_function.assert_called_with(job=job, x=1)
assert call_params == {"job": job, "x": 1}

@pytest.mark.asyncio
async def test_variables_are_added_to_result(
Expand Down Expand Up @@ -175,6 +189,27 @@ async def test_returned_task_runs_original_function(

original_task_function.assert_called_once()

@pytest.mark.asyncio
async def test_empty_variables_to_fetch_results_in_all_vars_passed_to_task(
self, mocked_job_controller: JobController, task_config: TaskConfig
):
task_config.variables_to_fetch = []
task_config.job_parameter_name = "job"

call_params = None

def task_with_params(job: Job, *args, **kwargs):
nonlocal call_params
call_params = {"job": job, "args": args, "kwargs": kwargs}

job = random_job(variables={"a": 1, "b": 2})

job_handler = task_builder.build_job_handler(task_with_params, task_config)

await job_handler(job, mocked_job_controller)

assert call_params == {"job": job, "args": (), "kwargs": {"a": 1, "b": 2}}

@pytest.mark.asyncio
async def test_before_decorator_called(
self,
Expand Down Expand Up @@ -217,18 +252,19 @@ async def test_after_decorator_can_access_task_result(
async def task_function():
return {"result": 1}

self.task_result = dict()
task_result = dict()

async def after_decorator(job: Job):
self.task_result = job.task_result
nonlocal task_result
task_result = job.task_result
return job

task_config.after.append(after_decorator)
job_handler = task_builder.build_job_handler(task_function, task_config)

await job_handler(job, mocked_job_controller)

assert self.task_result == {"result": 1}
assert task_result == {"result": 1}

@pytest.mark.asyncio
async def test_failing_decorator_continues(
Expand Down

0 comments on commit 614559d

Please sign in to comment.