Skip to content

Commit

Permalink
(TODO: fix client code, only halfway done) Work on schedule action se…
Browse files Browse the repository at this point in the history
…arch attributes
  • Loading branch information
cretz committed Aug 24, 2023
1 parent 100753d commit 14d9f71
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 98 deletions.
93 changes: 61 additions & 32 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,11 +2757,27 @@ class ScheduleActionStartWorkflow(ScheduleAction):
"""Search attributes for the schedule.
.. deprecated::
Use :py:attr:`typed_search_attributes` instead.
Use :py:attr:`typed_search_attributes` instead when creating a new
schedule.
.. warning::
During update, if :py:attr:`typed_search_attributes` is present, it will
be used and this will be ignored. This means once users begin using
:py:attr:`typed_search_attributes` they should not go back to untyped.
"""

typed_search_attributes: temporalio.common.TypedSearchAttributes
"""Search attributes for the schedule."""
typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes]
"""Search attributes for the schedule.
This will only be set if there is at least one search attribute and this was
previously used on create/update before.
.. warning::
If :py:attr:`search_attributes` was used before, this value may not have
all search attributes. Users should either update and replace with this
:py:attr:`typed_search_attributes` or continue to use
:py:attr:`search_attributes`.
"""

headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None

Expand All @@ -2782,9 +2798,8 @@ def __init__(
task_timeout: Optional[timedelta] = None,
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Union[
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
] = temporalio.common.TypedSearchAttributes.empty,
search_attributes: temporalio.common.SearchAttributes = {},
typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None,
) -> None:
...

Expand All @@ -2802,9 +2817,8 @@ def __init__(
task_timeout: Optional[timedelta] = None,
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Union[
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
] = temporalio.common.TypedSearchAttributes.empty,
search_attributes: temporalio.common.SearchAttributes = {},
typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None,
) -> None:
...

Expand All @@ -2824,9 +2838,8 @@ def __init__(
task_timeout: Optional[timedelta] = None,
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Union[
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
] = temporalio.common.TypedSearchAttributes.empty,
search_attributes: temporalio.common.SearchAttributes = {},
typed_search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None,
) -> None:
...

Expand All @@ -2845,9 +2858,8 @@ def __init__(
task_timeout: Optional[timedelta] = None,
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Union[
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
] = temporalio.common.TypedSearchAttributes.empty,
search_attributes: temporalio.common.SearchAttributes = {},
typed_search_attributes: temporalio.common.TypedSearchAttributes = temporalio.common.TypedSearchAttributes.empty,
) -> None:
...

Expand All @@ -2874,9 +2886,9 @@ def __init__(
task_timeout: Optional[timedelta] = None,
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Union[
temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes
] = temporalio.common.TypedSearchAttributes.empty,
search_attributes: temporalio.common.SearchAttributes = {},
typed_search_attributes: temporalio.common.TypedSearchAttributes = temporalio.common.TypedSearchAttributes.empty,
headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]] = None,
raw_info: Optional[temporalio.api.workflow.v1.NewWorkflowExecutionInfo] = None,
) -> None:
"""Create a start-workflow action.
Expand Down Expand Up @@ -2920,6 +2932,7 @@ def __init__(
raw_info.search_attributes
)
)
self.headers = raw_info.header.fields
else:
if not id:
raise ValueError("ID required")
Expand All @@ -2942,15 +2955,9 @@ def __init__(
self.task_timeout = task_timeout
self.retry_policy = retry_policy
self.memo = memo
temporalio.common._warn_on_deprecated_search_attributes(search_attributes)
if isinstance(search_attributes, temporalio.common.TypedSearchAttributes):
self.search_attributes = {}
self.typed_search_attributes = search_attributes
else:
self.search_attributes = search_attributes
self.typed_search_attributes = (
temporalio.common.TypedSearchAttributes.empty
)
self.search_attributes = search_attributes
self.typed_search_attributes = typed_search_attributes
self.headers = headers

async def _to_proto(
self, client: Client
Expand Down Expand Up @@ -3005,11 +3012,33 @@ async def _to_proto(
else temporalio.api.common.v1.Header(fields=self.headers),
),
)
temporalio.converter._encode_maybe_typed_search_attributes(
self.search_attributes,
self.typed_search_attributes,
action.start_workflow.search_attributes,
)

# If typed is present, we have to use it


# Must unset unmodified attr on untyped SA if present
if "__temporal_search_attributes_unmodified" in self.search_attributes:
del self.search_attributes["__temporal_search_attributes_unmodified"]

# If this is for update, if typed search attributes are modified, use
# them

# If typed is unmodified, it should not be included in the update. But
# if it is present with values, it needs to overwrite search attributes
# blindly. We have no way to check whether untyped was unmodified
# because we can't store an arbitrary attr on a dict that's not a key.
# So we have to accept that both present means typed wins since typed
# may have been updated and there is no check for accidentally
# explicitly setting both at this time. This tradeoff is preferred over
# the complexity of diffing. This also means that typed search
# attributes cannot be used to unset all attributes.
if getattr(self.typed_search_attributes, "__temporal_search_attributes_unmodified", False) or not self.typed_search_attributes or not self.typed_search_attributes.search_attributes:
print("!!! PERSISTING: ", self.search_attributes)
temporalio.converter.encode_search_attributes(self.search_attributes, action.start_workflow.search_attributes)
else:
temporalio.converter.encode_search_attributes(self.typed_search_attributes, action.start_workflow.search_attributes)
if self.headers:
temporalio.common._apply_headers(self.headers, action.start_workflow.header.fields)
return action


Expand Down
26 changes: 20 additions & 6 deletions temporalio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,17 @@ def _guess_from_untyped_values(
return None
elif len(vals) > 1:
if isinstance(vals[0], str):
return temporalio.common.SearchAttributeKey.for_keyword_list(name)
return SearchAttributeKey.for_keyword_list(name)
elif isinstance(vals[0], str):
return temporalio.common.SearchAttributeKey.for_keyword(name)
return SearchAttributeKey.for_keyword(name)
elif isinstance(vals[0], int):
return temporalio.common.SearchAttributeKey.for_int(name)
return SearchAttributeKey.for_int(name)
elif isinstance(vals[0], float):
return temporalio.common.SearchAttributeKey.for_float(name)
return SearchAttributeKey.for_float(name)
elif isinstance(vals[0], bool):
return temporalio.common.SearchAttributeKey.for_bool(name)
return SearchAttributeKey.for_bool(name)
elif isinstance(vals[0], datetime):
return temporalio.common.SearchAttributeKey.for_datetime(name)
return SearchAttributeKey.for_datetime(name)
return None


Expand Down Expand Up @@ -496,6 +496,20 @@ def get(
return self.__getitem__(key)
except KeyError:
return default

def updated(self, *search_attributes: SearchAttributePair) -> TypedSearchAttributes:
"""Copy this collection, replacing attributes with matching key names or
adding if key name not present.
"""
attrs = list(self.search_attributes)
# Go over each update, replacing matching keys by index or adding
for attr in search_attributes:
existing_index = next((i for i, attr in enumerate(attrs) if attr.key.name == attr.key.name), None)
if existing_index is None:
attrs.append(attr)
else:
attrs[existing_index] = attr
return TypedSearchAttributes(attrs)


TypedSearchAttributes.empty = TypedSearchAttributes(search_attributes=[])
Expand Down
52 changes: 28 additions & 24 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def decode_search_attributes(
api: API message with search attribute values to convert.
Returns:
Converted search attribute values.
Converted search attribute values (new mapping every time).
"""
conv = default().payload_converter
ret = {}
Expand All @@ -1288,38 +1288,42 @@ def decode_typed_search_attributes(
api: API message with search attribute values to convert.
Returns:
Typed search attribute collection.
Typed search attribute collection (new object every time).
"""
conv = default().payload_converter
pairs: List[temporalio.common.SearchAttributePair] = []
for k, v in api.indexed_fields.items():
# We want the "type" metadata, but if it is not present or an unknown
# type, we will just ignore
metadata_type = v.metadata.get("type")
if not metadata_type:
continue
key = temporalio.common.SearchAttributeKey._from_metadata_type(
k, v.metadata.get("type").decode()
)
if key:
val = conv.from_payload(v)
# If the value is a list but the type is not keyword list, pull out
# single item or consider this an invalid value and ignore
if (
key.indexed_value_type
!= temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST
and isinstance(val, list)
):
if len(val) != 1:
continue
val = val[0]
if (
key.indexed_value_type
== temporalio.common.SearchAttributeIndexedValueType.DATETIME
):
parser = _get_iso_datetime_parser()
# We will let this throw
val = parser(val)
# If the value isn't the right type, we need to ignore
if isinstance(val, key.origin_value_type):
pairs.append(temporalio.common.SearchAttributePair(key, val))
if not key:
continue
val = conv.from_payload(v)
# If the value is a list but the type is not keyword list, pull out
# single item or consider this an invalid value and ignore
if (
key.indexed_value_type
!= temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST
and isinstance(val, list)
):
if len(val) != 1:
continue
val = val[0]
if (
key.indexed_value_type
== temporalio.common.SearchAttributeIndexedValueType.DATETIME
):
parser = _get_iso_datetime_parser()
# We will let this throw
val = parser(val)
# If the value isn't the right type, we need to ignore
if isinstance(val, key.origin_value_type):
pairs.append(temporalio.common.SearchAttributePair(key, val))
return temporalio.common.TypedSearchAttributes(pairs)


Expand Down
25 changes: 25 additions & 0 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from temporalio.client import BuildIdOpAddNewDefault, Client
from temporalio.service import RPCError, RPCStatusCode
from temporalio.common import SearchAttributeKey
from temporalio.worker import Worker, WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
from temporalio.api.enums.v1 import IndexedValueType
from temporalio.api.operatorservice.v1 import ListSearchAttributesRequest, AddSearchAttributesRequest


def new_worker(
Expand Down Expand Up @@ -63,3 +66,25 @@ async def worker_versioning_enabled(client: Client) -> bool:
if e.status in [RPCStatusCode.PERMISSION_DENIED, RPCStatusCode.UNIMPLEMENTED]:
return False
raise

async def ensure_search_attributes_present(client: Client, *keys: SearchAttributeKey) -> None:
"""Ensure all search attributes are present or attempt to add all."""
async def search_attributes_present() -> bool:
resp = await client.operator_service.list_search_attributes(
ListSearchAttributesRequest(namespace=client.namespace)
)
return sorted(resp.custom_attributes.keys()) == sorted([key.name for key in keys])

# Add search attributes if not already present
if not await search_attributes_present():
await client.operator_service.add_search_attributes(
AddSearchAttributesRequest(
namespace=client.namespace,
search_attributes={
key.name: IndexedValueType.ValueType(key.indexed_value_type)
for key in keys
},
),
)
# Confirm now present
assert await search_attributes_present()
Loading

0 comments on commit 14d9f71

Please sign in to comment.