From 14d9f716ce8905ce01308ea740d68f1176959226 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Thu, 24 Aug 2023 10:17:28 -0500 Subject: [PATCH] (TODO: fix client code, only halfway done) Work on schedule action search attributes --- temporalio/client.py | 93 +++++++++++++++++--------- temporalio/common.py | 26 ++++++-- temporalio/converter.py | 52 ++++++++------- tests/helpers/__init__.py | 25 +++++++ tests/test_client.py | 122 +++++++++++++++++++++++++++++++++- tests/worker/test_workflow.py | 44 +++--------- 6 files changed, 264 insertions(+), 98 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 780e0d69..4d06d41b 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -2757,11 +2757,27 @@ class ScheduleActionStartWorkflow(ScheduleAction): """Search attributes for the schedule. .. deprecated:: - Use :py:attr:`typed_search_attributes` instead. + Use :py:attr:`typed_search_attributes` instead when creating a new + schedule. + + .. warning:: + During update, if :py:attr:`typed_search_attributes` is present, it will + be used and this will be ignored. This means once users begin using + :py:attr:`typed_search_attributes` they should not go back to untyped. """ - typed_search_attributes: temporalio.common.TypedSearchAttributes - """Search attributes for the schedule.""" + typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] + """Search attributes for the schedule. + + This will only be set if there is at least one search attribute and this was + previously used on create/update before. + + .. warning:: + If :py:attr:`search_attributes` was used before, this value may not have + all search attributes. Users should either update and replace with this + :py:attr:`typed_search_attributes` or continue to use + :py:attr:`search_attributes`. + """ headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None @@ -2782,9 +2798,8 @@ def __init__( task_timeout: Optional[timedelta] = None, retry_policy: Optional[temporalio.common.RetryPolicy] = None, memo: Optional[Mapping[str, Any]] = None, - search_attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ] = temporalio.common.TypedSearchAttributes.empty, + search_attributes: temporalio.common.SearchAttributes = {}, + typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None, ) -> None: ... @@ -2802,9 +2817,8 @@ def __init__( task_timeout: Optional[timedelta] = None, retry_policy: Optional[temporalio.common.RetryPolicy] = None, memo: Optional[Mapping[str, Any]] = None, - search_attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ] = temporalio.common.TypedSearchAttributes.empty, + search_attributes: temporalio.common.SearchAttributes = {}, + typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None, ) -> None: ... @@ -2824,9 +2838,8 @@ def __init__( task_timeout: Optional[timedelta] = None, retry_policy: Optional[temporalio.common.RetryPolicy] = None, memo: Optional[Mapping[str, Any]] = None, - search_attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ] = temporalio.common.TypedSearchAttributes.empty, + search_attributes: temporalio.common.SearchAttributes = {}, + typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None, ) -> None: ... @@ -2845,9 +2858,8 @@ def __init__( task_timeout: Optional[timedelta] = None, retry_policy: Optional[temporalio.common.RetryPolicy] = None, memo: Optional[Mapping[str, Any]] = None, - search_attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ] = temporalio.common.TypedSearchAttributes.empty, + search_attributes: temporalio.common.SearchAttributes = {}, + typed_search_attributes: temporalio.common.TypedSearchAttributes = temporalio.common.TypedSearchAttributes.empty, ) -> None: ... @@ -2874,9 +2886,9 @@ def __init__( task_timeout: Optional[timedelta] = None, retry_policy: Optional[temporalio.common.RetryPolicy] = None, memo: Optional[Mapping[str, Any]] = None, - search_attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ] = temporalio.common.TypedSearchAttributes.empty, + search_attributes: temporalio.common.SearchAttributes = {}, + typed_search_attributes: temporalio.common.TypedSearchAttributes = temporalio.common.TypedSearchAttributes.empty, + headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None, raw_info: Optional[temporalio.api.workflow.v1.NewWorkflowExecutionInfo] = None, ) -> None: """Create a start-workflow action. @@ -2920,6 +2932,7 @@ def __init__( raw_info.search_attributes ) ) + self.headers = raw_info.header.fields else: if not id: raise ValueError("ID required") @@ -2942,15 +2955,9 @@ def __init__( self.task_timeout = task_timeout self.retry_policy = retry_policy self.memo = memo - temporalio.common._warn_on_deprecated_search_attributes(search_attributes) - if isinstance(search_attributes, temporalio.common.TypedSearchAttributes): - self.search_attributes = {} - self.typed_search_attributes = search_attributes - else: - self.search_attributes = search_attributes - self.typed_search_attributes = ( - temporalio.common.TypedSearchAttributes.empty - ) + self.search_attributes = search_attributes + self.typed_search_attributes = typed_search_attributes + self.headers = headers async def _to_proto( self, client: Client @@ -3005,11 +3012,33 @@ async def _to_proto( else temporalio.api.common.v1.Header(fields=self.headers), ), ) - temporalio.converter._encode_maybe_typed_search_attributes( - self.search_attributes, - self.typed_search_attributes, - action.start_workflow.search_attributes, - ) + + # If typed is present, we have to use it + + + # Must unset unmodified attr on untyped SA if present + if "__temporal_search_attributes_unmodified" in self.search_attributes: + del self.search_attributes["__temporal_search_attributes_unmodified"] + + # If this is for update, if typed search attributes are modified, use + # them + + # If typed is unmodified, it should not be included in the update. But + # if it is present with values, it needs to overwrite search attributes + # blindly. We have no way to check whether untyped was unmodified + # because we can't store an arbitrary attr on a dict that's not a key. + # So we have to accept that both present means typed wins since typed + # may have been updated and there is no check for accidentally + # explicitly setting both at this time. This tradeoff is preferred over + # the complexity of diffing. This also means that typed search + # attributes cannot be used to unset all attributes. + if getattr(self.typed_search_attributes, "__temporal_search_attributes_unmodified", False) or not self.typed_search_attributes or not self.typed_search_attributes.search_attributes: + print("!!! PERSISTING: ", self.search_attributes) + temporalio.converter.encode_search_attributes(self.search_attributes, action.start_workflow.search_attributes) + else: + temporalio.converter.encode_search_attributes(self.typed_search_attributes, action.start_workflow.search_attributes) + if self.headers: + temporalio.common._apply_headers(self.headers, action.start_workflow.header.fields) return action diff --git a/temporalio/common.py b/temporalio/common.py index d92de25a..599ebaf9 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -345,17 +345,17 @@ def _guess_from_untyped_values( return None elif len(vals) > 1: if isinstance(vals[0], str): - return temporalio.common.SearchAttributeKey.for_keyword_list(name) + return SearchAttributeKey.for_keyword_list(name) elif isinstance(vals[0], str): - return temporalio.common.SearchAttributeKey.for_keyword(name) + return SearchAttributeKey.for_keyword(name) elif isinstance(vals[0], int): - return temporalio.common.SearchAttributeKey.for_int(name) + return SearchAttributeKey.for_int(name) elif isinstance(vals[0], float): - return temporalio.common.SearchAttributeKey.for_float(name) + return SearchAttributeKey.for_float(name) elif isinstance(vals[0], bool): - return temporalio.common.SearchAttributeKey.for_bool(name) + return SearchAttributeKey.for_bool(name) elif isinstance(vals[0], datetime): - return temporalio.common.SearchAttributeKey.for_datetime(name) + return SearchAttributeKey.for_datetime(name) return None @@ -496,6 +496,20 @@ def get( return self.__getitem__(key) except KeyError: return default + + def updated(self, *search_attributes: SearchAttributePair) -> TypedSearchAttributes: + """Copy this collection, replacing attributes with matching key names or + adding if key name not present. + """ + attrs = list(self.search_attributes) + # Go over each update, replacing matching keys by index or adding + for attr in search_attributes: + existing_index = next((i for i, attr in enumerate(attrs) if attr.key.name == attr.key.name), None) + if existing_index is None: + attrs.append(attr) + else: + attrs[existing_index] = attr + return TypedSearchAttributes(attrs) TypedSearchAttributes.empty = TypedSearchAttributes(search_attributes=[]) diff --git a/temporalio/converter.py b/temporalio/converter.py index 1ef5032f..27e787c6 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1262,7 +1262,7 @@ def decode_search_attributes( api: API message with search attribute values to convert. Returns: - Converted search attribute values. + Converted search attribute values (new mapping every time). """ conv = default().payload_converter ret = {} @@ -1288,38 +1288,42 @@ def decode_typed_search_attributes( api: API message with search attribute values to convert. Returns: - Typed search attribute collection. + Typed search attribute collection (new object every time). """ conv = default().payload_converter pairs: List[temporalio.common.SearchAttributePair] = [] for k, v in api.indexed_fields.items(): # We want the "type" metadata, but if it is not present or an unknown # type, we will just ignore + metadata_type = v.metadata.get("type") + if not metadata_type: + continue key = temporalio.common.SearchAttributeKey._from_metadata_type( k, v.metadata.get("type").decode() ) - if key: - val = conv.from_payload(v) - # If the value is a list but the type is not keyword list, pull out - # single item or consider this an invalid value and ignore - if ( - key.indexed_value_type - != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST - and isinstance(val, list) - ): - if len(val) != 1: - continue - val = val[0] - if ( - key.indexed_value_type - == temporalio.common.SearchAttributeIndexedValueType.DATETIME - ): - parser = _get_iso_datetime_parser() - # We will let this throw - val = parser(val) - # If the value isn't the right type, we need to ignore - if isinstance(val, key.origin_value_type): - pairs.append(temporalio.common.SearchAttributePair(key, val)) + if not key: + continue + val = conv.from_payload(v) + # If the value is a list but the type is not keyword list, pull out + # single item or consider this an invalid value and ignore + if ( + key.indexed_value_type + != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST + and isinstance(val, list) + ): + if len(val) != 1: + continue + val = val[0] + if ( + key.indexed_value_type + == temporalio.common.SearchAttributeIndexedValueType.DATETIME + ): + parser = _get_iso_datetime_parser() + # We will let this throw + val = parser(val) + # If the value isn't the right type, we need to ignore + if isinstance(val, key.origin_value_type): + pairs.append(temporalio.common.SearchAttributePair(key, val)) return temporalio.common.TypedSearchAttributes(pairs) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index d15dcf88..47e7c829 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -6,8 +6,11 @@ from temporalio.client import BuildIdOpAddNewDefault, Client from temporalio.service import RPCError, RPCStatusCode +from temporalio.common import SearchAttributeKey from temporalio.worker import Worker, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from temporalio.api.enums.v1 import IndexedValueType +from temporalio.api.operatorservice.v1 import ListSearchAttributesRequest, AddSearchAttributesRequest def new_worker( @@ -63,3 +66,25 @@ async def worker_versioning_enabled(client: Client) -> bool: if e.status in [RPCStatusCode.PERMISSION_DENIED, RPCStatusCode.UNIMPLEMENTED]: return False raise + +async def ensure_search_attributes_present(client: Client, *keys: SearchAttributeKey) -> None: + """Ensure all search attributes are present or attempt to add all.""" + async def search_attributes_present() -> bool: + resp = await client.operator_service.list_search_attributes( + ListSearchAttributesRequest(namespace=client.namespace) + ) + return sorted(resp.custom_attributes.keys()) == sorted([key.name for key in keys]) + + # Add search attributes if not already present + if not await search_attributes_present(): + await client.operator_service.add_search_attributes( + AddSearchAttributesRequest( + namespace=client.namespace, + search_attributes={ + key.name: IndexedValueType.ValueType(key.indexed_value_type) + for key in keys + }, + ), + ) + # Confirm now present + assert await search_attributes_present() diff --git a/tests/test_client.py b/tests/test_client.py index 913f3aef..1f746baa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,8 +1,9 @@ import json import os import uuid +import dataclasses from datetime import datetime, timedelta, timezone -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Optional import pytest from google.protobuf import json_format @@ -65,11 +66,11 @@ WorkflowQueryRejectedError, _history_from_json, ) -from temporalio.common import RetryPolicy +from temporalio.common import RetryPolicy, TypedSearchAttributes, SearchAttributePair, SearchAttributeKey from temporalio.converter import DataConverter from temporalio.exceptions import WorkflowAlreadyStartedError from temporalio.testing import WorkflowEnvironment -from tests.helpers import assert_eq_eventually, new_worker, worker_versioning_enabled +from tests.helpers import assert_eq_eventually, new_worker, worker_versioning_enabled, ensure_search_attributes_present from tests.helpers.worker import ( ExternalWorker, KSAction, @@ -983,6 +984,121 @@ async def test_schedule_create_limited_actions_validation( await client.create_schedule(f"schedule-{uuid.uuid4()}", sched) assert "are remaining actions set" in str(err.value) +async def test_schedule_search_attribute_update(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") + await assert_no_schedules(client) + + # Put search attribute on server + text_attr_key = SearchAttributeKey.for_text(f"python-test-schedule-text") + await ensure_search_attributes_present(client, text_attr_key) + + # Create a schedule with search attributes on the schedule and on the + # workflow + handle = await client.create_schedule( + f"schedule-{uuid.uuid4()}", + Schedule( + action=ScheduleActionStartWorkflow( + "some workflow", + [], + id=f"workflow-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + typed_search_attributes=TypedSearchAttributes([ + SearchAttributePair(text_attr_key, "some-workflow-attr1") + ]) + ), + spec=ScheduleSpec(), + ), + search_attributes=TypedSearchAttributes([ + SearchAttributePair(text_attr_key, "some-schedule-attr1") + ]) + ) + + # Do update of typed attrs + def update_schedule_typed_attrs(input: ScheduleUpdateInput) -> Optional[ScheduleUpdate]: + assert isinstance( + input.description.schedule.action, ScheduleActionStartWorkflow + ) + + # Make sure the search attributes are present in all forms + assert input.description.search_attributes[text_attr_key.name] == ["some-schedule-attr1"] + assert input.description.typed_search_attributes[text_attr_key] == "some-schedule-attr1" + assert input.description.schedule.action.search_attributes[text_attr_key.name] == ["some-workflow-attr1"] + assert input.description.schedule.action.typed_search_attributes[text_attr_key] == "some-workflow-attr1" + + # Update the workflow search attribute with a new value + return ScheduleUpdate(dataclasses.replace( + input.description.schedule, + action=dataclasses.replace( + input.description.schedule.action, + typed_search_attributes=input.description.schedule.action.typed_search_attributes.updated( + SearchAttributePair(text_attr_key, "some-workflow-attr2") + ) + ) + )) + await handle.update(update_schedule_typed_attrs) + + # Check that it changed + desc = await handle.describe() + assert isinstance(desc.schedule.action, ScheduleActionStartWorkflow) + assert desc.schedule.action.search_attributes[text_attr_key.name] == ["some-workflow-attr2"] + assert desc.schedule.action.typed_search_attributes[text_attr_key] == "some-workflow-attr2" + + # Do update of untyped attrs + def update_schedule_untyped_attrs(input: ScheduleUpdateInput) -> Optional[ScheduleUpdate]: + assert isinstance( + input.description.schedule.action, ScheduleActionStartWorkflow + ) + return ScheduleUpdate(dataclasses.replace( + input.description.schedule, + action=dataclasses.replace( + input.description.schedule.action, + search_attributes={text_attr_key.name: ["some-workflow-attr3"]}, + ) + )) + await handle.update(update_schedule_untyped_attrs) + + # Check that it changed + desc = await handle.describe() + assert isinstance(desc.schedule.action, ScheduleActionStartWorkflow) + assert desc.schedule.action.search_attributes[text_attr_key.name] == ["some-workflow-attr3"] + # XXX: Note how the attribute is not in the typed search attribute + # collection because it has no type + assert not desc.schedule.action.typed_search_attributes.get(text_attr_key) + + # Do update of typed attrs again + def update_schedule_typed_attrs_again(input: ScheduleUpdateInput) -> Optional[ScheduleUpdate]: + assert isinstance( + input.description.schedule.action, ScheduleActionStartWorkflow + ) + return ScheduleUpdate(dataclasses.replace( + input.description.schedule, + action=dataclasses.replace( + input.description.schedule.action, + typed_search_attributes=input.description.schedule.action.typed_search_attributes.updated( + SearchAttributePair(text_attr_key, "some-workflow-attr4") + ) + ) + )) + await handle.update(update_schedule_typed_attrs_again) + + # Check that it changed + desc = await handle.describe() + assert isinstance(desc.schedule.action, ScheduleActionStartWorkflow) + assert desc.schedule.action.search_attributes[text_attr_key.name] == ["some-workflow-attr4"] + assert desc.schedule.action.typed_search_attributes[text_attr_key] == "some-workflow-attr4" + + # Normal update with no attr change + def update_schedule_no_attr_change(input: ScheduleUpdateInput) -> Optional[ScheduleUpdate]: + return ScheduleUpdate(input.description.schedule) + await handle.update(update_schedule_no_attr_change) + + # Check that it did not change + desc = await handle.describe() + assert isinstance(desc.schedule.action, ScheduleActionStartWorkflow) + assert desc.schedule.action.search_attributes[text_attr_key.name] == ["some-workflow-attr4"] + assert desc.schedule.action.typed_search_attributes[text_attr_key] == "some-workflow-attr4" + async def assert_no_schedules(client: Client) -> None: # Listing appears eventually consistent diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index add5e4f0..18a0ce04 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -83,7 +83,7 @@ WorkflowInstanceDetails, WorkflowRunner, ) -from tests.helpers import assert_eq_eventually, new_worker +from tests.helpers import assert_eq_eventually, new_worker, ensure_search_attributes_present @workflow.defn @@ -1629,41 +1629,19 @@ def do_search_attribute_update_typed(self) -> None: ) -async def ensure_search_attributes_on_server(client: Client) -> None: - async def search_attributes_present() -> bool: - resp = await client.operator_service.list_search_attributes( - ListSearchAttributesRequest(namespace=client.namespace) - ) - return any(k for k in resp.custom_attributes.keys() if k.startswith(sa_prefix)) - - # Add search attributes if not already present - if not await search_attributes_present(): - attrs: List[SearchAttributeKey] = [ - SearchAttributeWorkflow.text_attribute, - SearchAttributeWorkflow.keyword_attribute, - SearchAttributeWorkflow.keyword_list_attribute, - SearchAttributeWorkflow.int_attribute, - SearchAttributeWorkflow.float_attribute, - SearchAttributeWorkflow.bool_attribute, - SearchAttributeWorkflow.datetime_attribute, - ] - await client.operator_service.add_search_attributes( - AddSearchAttributesRequest( - namespace=client.namespace, - search_attributes={ - attr.name: IndexedValueType.ValueType(attr.indexed_value_type) - for attr in attrs - }, - ), - ) - # Confirm now present - assert await search_attributes_present() - - async def test_workflow_search_attributes(client: Client, env_type: str): if env_type != "local": pytest.skip("Only testing search attributes on local which disables cache") - await ensure_search_attributes_on_server(client) + await ensure_search_attributes_present( + client, + SearchAttributeWorkflow.text_attribute, + SearchAttributeWorkflow.keyword_attribute, + SearchAttributeWorkflow.keyword_list_attribute, + SearchAttributeWorkflow.int_attribute, + SearchAttributeWorkflow.float_attribute, + SearchAttributeWorkflow.bool_attribute, + SearchAttributeWorkflow.datetime_attribute, + ) initial_attrs_untyped: SearchAttributes = { SearchAttributeWorkflow.text_attribute.name: ["text1"],