Skip to content

Commit

Permalink
Implement async steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
maafy6 committed Apr 16, 2023
1 parent 8fd0fd5 commit c292b9a
Show file tree
Hide file tree
Showing 5 changed files with 770 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/pytest_bdd/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytest_bdd.steps import async_given, async_then, async_when

__all__ = ["async_given", "async_when", "async_then"]
80 changes: 78 additions & 2 deletions src/pytest_bdd/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
"""
from __future__ import annotations

import asyncio
import contextlib
import functools
import inspect
import logging
import os
import re
Expand All @@ -34,7 +37,6 @@

from .parser import Feature, Scenario, ScenarioTemplate, Step


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -156,7 +158,14 @@ def _execute_step_function(

request.config.hook.pytest_bdd_before_step_call(**kw)
# Execute the step as if it was a pytest fixture, so that we can allow "yield" statements in it
return_value = call_fixture_func(fixturefunc=context.step_func, request=request, kwargs=kwargs)
step_func = context.step_func
if context.is_async:
if inspect.isasyncgenfunction(context.step_func):
step_func = _wrap_asyncgen(request, context.step_func)
elif inspect.iscoroutinefunction(context.step_func):
step_func = _wrap_coroutine(context.step_func)

return_value = call_fixture_func(fixturefunc=step_func, request=request, kwargs=kwargs)
except Exception as exception:
request.config.hook.pytest_bdd_step_error(exception=exception, **kw)
raise
Expand All @@ -167,6 +176,73 @@ def _execute_step_function(
request.config.hook.pytest_bdd_after_step(**kw)


def _wrap_asyncgen(request: FixtureRequest, func: Callable) -> Callable:
"""Wrapper for an async_generator function.
This will wrap the function in a synchronized method to return the first
yielded value from the generator. A finalizer will be added to the fixture
to ensure that no other values are yielded and that the loop is closed.
:param request: The fixture request.
:param func: The function to wrap.
:returns: The wrapped function.
"""

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
loop, created = asyncio.get_running_loop(), False
except RuntimeError:
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True

async_obj = func(*args, **kwargs)

def _finalizer() -> None:
"""Ensure no more values are yielded and close the loop."""
try:
loop.run_until_complete(async_obj.__anext__())
except StopAsyncIteration:
pass
else:
raise ValueError("Async generator must only yield once.")

if created:
loop.close()

value = loop.run_until_complete(async_obj.__anext__())
request.addfinalizer(_finalizer)

return value

return _wrapper


def _wrap_coroutine(func: Callable) -> Callable:
"""Wrapper for a coroutine function.
:param func: The function to wrap.
:returns: The wrapped function.
"""

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
loop, created = asyncio.get_running_loop(), False
except RuntimeError:
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True

try:
async_obj = func(*args, **kwargs)
return loop.run_until_complete(async_obj)
finally:
if created:
loop.close()

return _wrapper


def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequest) -> None:
"""Execute the scenario.
Expand Down
80 changes: 77 additions & 3 deletions src/pytest_bdd/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class StepFunctionContext:
parser: StepParser
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
target_fixture: str | None = None
is_async: bool = False


def get_step_fixture_name(step: Step) -> str:
Expand All @@ -78,6 +79,7 @@ def given(
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""Given step decorator.
Expand All @@ -86,17 +88,62 @@ def given(
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)
:return: Decorator function for the step.
"""
return step(name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(
name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_given(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""Async Given step decorator.
:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:return: Decorator function for the step.
"""
return given(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""When step decorator.
:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)
:return: Decorator function for the step.
"""
return step(
name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""When step decorator.
Expand All @@ -108,14 +155,15 @@ def when(
:return: Decorator function for the step.
"""
return step(name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return when(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""Then step decorator.
Expand All @@ -124,10 +172,32 @@ def then(
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)
:return: Decorator function for the step.
"""
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(
name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""Then step decorator.
:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:return: Decorator function for the step.
"""
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def step(
Expand All @@ -136,6 +206,7 @@ def step(
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable[[TCallable], TCallable]:
"""Generic step decorator.
Expand All @@ -144,6 +215,7 @@ def step(
:param converters: Optional step arguments converters mapping.
:param target_fixture: Optional fixture name to replace by step definition.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)
:return: Decorator function for the step.
Expand All @@ -165,6 +237,7 @@ def decorator(func: TCallable) -> TCallable:
parser=parser,
converters=converters,
target_fixture=target_fixture,
is_async=is_async,
)

def step_function_marker() -> StepFunctionContext:
Expand All @@ -177,6 +250,7 @@ def step_function_marker() -> StepFunctionContext:
f"{StepNamePrefix.step_def.value}_{type_ or '*'}_{parser.name}", seen=caller_locals.keys()
)
caller_locals[fixture_step_name] = pytest.fixture(name=fixture_step_name)(step_function_marker)

return func

return decorator
Expand Down
Loading

0 comments on commit c292b9a

Please sign in to comment.