diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index b63e0b57..3015a86d 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -562,5 +562,5 @@ def _client_with_interceptors( config = client.config() config_interceptors = list(config["interceptors"]) config_interceptors.extend(interceptors) - config["interceptors"] = interceptors + config["interceptors"] = config_interceptors return temporalio.client.Client(**config) diff --git a/tests/testing/test_workflow.py b/tests/testing/test_workflow.py index 2dc0eb8c..e0c1117b 100644 --- a/tests/testing/test_workflow.py +++ b/tests/testing/test_workflow.py @@ -3,12 +3,19 @@ import uuid from datetime import datetime, timedelta, timezone from time import monotonic -from typing import Optional, Union +from typing import Any, List, Optional, Union import pytest from temporalio import activity, workflow -from temporalio.client import Client, WorkflowFailureError +from temporalio.client import ( + Client, + Interceptor, + OutboundInterceptor, + StartWorkflowInput, + WorkflowFailureError, + WorkflowHandle, +) from temporalio.common import RetryPolicy from temporalio.exceptions import ( ActivityError, @@ -176,7 +183,36 @@ def some_signal(self) -> None: assert "foo" == "bar" +class SimpleClientInterceptor(Interceptor): + def __init__(self) -> None: + self.events: List[str] = [] + + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return SimpleClientOutboundInterceptor(self, super().intercept_client(next)) + + +class SimpleClientOutboundInterceptor(OutboundInterceptor): + def __init__( + self, root: SimpleClientInterceptor, next: OutboundInterceptor + ) -> None: + super().__init__(next) + self.root = root + + async def start_workflow( + self, input: StartWorkflowInput + ) -> WorkflowHandle[Any, Any]: + self.root.events.append(f"start: {input.workflow}") + return await super().start_workflow(input) + + async def test_workflow_env_assert(client: Client): + # Set the interceptor on the client. This used to fail for being + # accidentally overridden. + client_config = client.config() + interceptor = SimpleClientInterceptor() + client_config["interceptors"] = [interceptor] + client = Client(**client_config) + def assert_proper_error(err: Optional[BaseException]) -> None: assert isinstance(err, ApplicationError) # In unsandboxed workflows, this message has extra diff info appended @@ -195,6 +231,7 @@ def assert_proper_error(err: Optional[BaseException]) -> None: task_queue=worker.task_queue, ) assert_proper_error(err.value.cause) + assert interceptor.events # Start a new one and check signal handle = await env.client.start_workflow(