Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass all variables to task when TaskConfig.variables_to_fetch==[] #485

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions pyzeebe/task/task_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import functools
import inspect
import logging
from typing import Any, Sequence, Tuple, TypeVar
from inspect import Parameter
from typing import Any, Callable, Dict, Sequence, Tuple, TypeVar

from typing_extensions import ParamSpec

Expand Down Expand Up @@ -66,13 +68,19 @@ async def run_original_task_function(
) -> Tuple[Variables, bool]:
try:
if task_config.variables_to_fetch is None:
variables = {}
variables: Dict[str, Any] = {}
elif task_wants_all_variables(task_config):
if only_job_is_required_in_task_function(task_function):
variables = {}
else:
variables = {**job.variables}
else:
variables = {
k: v
for k, v in job.variables.items()
if k in task_config.variables_to_fetch or k == task_config.job_parameter_name
}

if task_config.job_parameter_name:
variables[task_config.job_parameter_name] = job

Expand All @@ -89,6 +97,15 @@ async def run_original_task_function(
return job.variables, False


def only_job_is_required_in_task_function(task_function: DictFunction[...]) -> bool:
function_signature = inspect.signature(task_function)
return all(param.annotation == Job for param in function_signature.parameters.values())


def task_wants_all_variables(task_config: TaskConfig) -> bool:
return task_config.variables_to_fetch == []


def create_decorator_runner(decorators: Sequence[AsyncTaskDecorator]) -> DecoratorRunner:
async def decorator_runner(job: Job) -> Job:
for decorator in decorators:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def first_active_job(task, job_from_task, grpc_servicer) -> str:


@pytest.fixture
def task_config(task_type):
def task_config(task_type, variables_to_fetch=None):
return TaskConfig(
type=task_type,
exception_handler=AsyncMock(),
timeout_ms=10000,
max_jobs_to_activate=32,
max_running_jobs=32,
variables_to_fetch=[],
variables_to_fetch=variables_to_fetch or [],
single_value=False,
variable_name="",
before=[],
Expand Down
52 changes: 44 additions & 8 deletions tests/unit/task/task_builder_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
from typing import Callable
from unittest.mock import MagicMock

import pytest

from pyzeebe import Job, JobController, TaskDecorator
from pyzeebe.function_tools.parameter_tools import get_parameters_from_function
from pyzeebe.job.job_status import JobStatus
from pyzeebe.task import task_builder
from pyzeebe.task.task import Task
Expand Down Expand Up @@ -108,29 +110,41 @@ async def test_parameters_are_provided_to_task(
async def test_parameters_are_provided_to_task_with_only_job(
self, original_task_function: Callable, task_config: TaskConfig, mocked_job_controller: JobController
):
task_with_only_job_called = False

def task_with_only_job(job: Job):
nonlocal task_with_only_job_called
task_with_only_job_called = True

job = random_job(variables={"x": 1})
task_config.job_parameter_name = "job"
task_config.variables_to_fetch = []

job_handler = task_builder.build_job_handler(original_task_function, task_config)
job_handler = task_builder.build_job_handler(task_with_only_job, task_config)

await job_handler(job, mocked_job_controller)

original_task_function.assert_called_with(job=job)
assert task_with_only_job_called, "Task was called ok"

@pytest.mark.asyncio
async def test_parameters_are_provided_to_task_with_arg_and_job(
self, original_task_function: Callable, task_config: TaskConfig, mocked_job_controller: JobController
self, task_config: TaskConfig, mocked_job_controller: JobController
):
call_params = None

def task_with_arg_and_job(job: Job, x):
nonlocal call_params
call_params = {"job": job, "x": x}

job = random_job(variables={"x": 1})
task_config.job_parameter_name = "job"
task_config.variables_to_fetch = ["x"]

job_handler = task_builder.build_job_handler(original_task_function, task_config)
job_handler = task_builder.build_job_handler(task_with_arg_and_job, task_config)

await job_handler(job, mocked_job_controller)

original_task_function.assert_called_with(job=job, x=1)
assert call_params == {"job": job, "x": 1}

@pytest.mark.asyncio
async def test_variables_are_added_to_result(
Expand Down Expand Up @@ -175,6 +189,27 @@ async def test_returned_task_runs_original_function(

original_task_function.assert_called_once()

@pytest.mark.asyncio
async def test_empty_variables_to_fetch_results_in_all_vars_passed_to_task(
self, mocked_job_controller: JobController, task_config: TaskConfig
):
task_config.variables_to_fetch = []
task_config.job_parameter_name = "job"

call_params = None

def task_with_params(job: Job, *args, **kwargs):
nonlocal call_params
call_params = {"job": job, "args": args, "kwargs": kwargs}

job = random_job(variables={"a": 1, "b": 2})

job_handler = task_builder.build_job_handler(task_with_params, task_config)

await job_handler(job, mocked_job_controller)

assert call_params == {"job": job, "args": (), "kwargs": {"a": 1, "b": 2}}

@pytest.mark.asyncio
async def test_before_decorator_called(
self,
Expand Down Expand Up @@ -217,18 +252,19 @@ async def test_after_decorator_can_access_task_result(
async def task_function():
return {"result": 1}

self.task_result = dict()
task_result = dict()

async def after_decorator(job: Job):
self.task_result = job.task_result
nonlocal task_result
task_result = job.task_result
return job

task_config.after.append(after_decorator)
job_handler = task_builder.build_job_handler(task_function, task_config)

await job_handler(job, mocked_job_controller)

assert self.task_result == {"result": 1}
assert task_result == {"result": 1}

@pytest.mark.asyncio
async def test_failing_decorator_continues(
Expand Down