Skip to content

Commit

Permalink
Add Client.count_workflows
Browse files Browse the repository at this point in the history
Fixes #294
  • Loading branch information
cretz committed Apr 15, 2024
1 parent b45447e commit 3d87944
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 1 deletion.
105 changes: 105 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,30 @@ def list_workflows(
)
)

async def count_workflows(
self,
query: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> WorkflowExecutionCount:
"""Count workflows.
Args:
query: A Temporal visibility filter. See Temporal documentation
concerning visibility list filters.
rpc_metadata: Headers used on each RPC call. Keys here override
client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for each RPC call.
Returns:
Count of workflows.
"""
return await self._impl.count_workflows(
CountWorkflowsInput(
query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout
)
)

@overload
def get_async_activity_handle(
self, *, workflow_id: str, run_id: Optional[str], activity_id: str
Expand Down Expand Up @@ -2310,6 +2334,57 @@ class WorkflowExecutionStatus(IntEnum):
)


@dataclass
class WorkflowExecutionCount:
"""Representation of a count from a count workflows call."""

count: int
"""Approximate number of workflows matching the original query.
If the query had a group-by clause, this is simply the sum of all the counts
in py:attr:`groups`.
"""

groups: Sequence[WorkflowExecutionCountAggregationGroup]
"""Groups if the query had a group-by clause, or empty if not."""

@staticmethod
def _from_raw(
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse,
) -> WorkflowExecutionCount:
return WorkflowExecutionCount(
count=raw.count,
groups=[
WorkflowExecutionCountAggregationGroup._from_raw(g) for g in raw.groups
],
)


@dataclass
class WorkflowExecutionCountAggregationGroup:
"""Aggregation group if the workflow count query had a group-by clause."""

count: int
"""Approximate number of workflows matching the original query for this
group.
"""

group_values: Sequence[temporalio.common.SearchAttributeValue]
"""Search attribute values for this group."""

@staticmethod
def _from_raw(
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup,
) -> WorkflowExecutionCountAggregationGroup:
return WorkflowExecutionCountAggregationGroup(
count=raw.count,
group_values=[
temporalio.converter._decode_search_attribute_value(v)
for v in raw.group_values
],
)


class WorkflowExecutionAsyncIterator:
"""Asynchronous iterator for :py:class:`WorkflowExecution` values.
Expand Down Expand Up @@ -4373,6 +4448,15 @@ class ListWorkflowsInput:
rpc_timeout: Optional[timedelta]


@dataclass
class CountWorkflowsInput:
"""Input for :py:meth:`OutboundInterceptor.count_workflows`."""

query: Optional[str]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]


@dataclass
class QueryWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.query_workflow`."""
Expand Down Expand Up @@ -4669,6 +4753,12 @@ def list_workflows(
"""Called for every :py:meth:`Client.list_workflows` call."""
return self.next.list_workflows(input)

async def count_workflows(
self, input: CountWorkflowsInput
) -> WorkflowExecutionCount:
"""Called for every :py:meth:`Client.count_workflows` call."""
return await self.next.count_workflows(input)

async def query_workflow(self, input: QueryWorkflowInput) -> Any:
"""Called for every :py:meth:`WorkflowHandle.query` call."""
return await self.next.query_workflow(input)
Expand Down Expand Up @@ -4928,6 +5018,21 @@ def list_workflows(
) -> WorkflowExecutionAsyncIterator:
return WorkflowExecutionAsyncIterator(self._client, input)

async def count_workflows(
self, input: CountWorkflowsInput
) -> WorkflowExecutionCount:
return WorkflowExecutionCount._from_raw(
await self._client.workflow_service.count_workflow_executions(
temporalio.api.workflowservice.v1.CountWorkflowExecutionsRequest(
namespace=self._client.namespace,
query=input.query or "",
),
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
)

async def query_workflow(self, input: QueryWorkflowInput) -> Any:
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
namespace=self._client.namespace,
Expand Down
9 changes: 9 additions & 0 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,15 @@ def decode_typed_search_attributes(
return temporalio.common.TypedSearchAttributes(pairs)


def _decode_search_attribute_value(
payload: temporalio.api.common.v1.Payload,
) -> temporalio.common.SearchAttributeValue:
val = default().payload_converter.from_payload(payload)
if isinstance(val, str) and payload.metadata.get("type") == b"Datetime":
val = _get_iso_datetime_parser()(val)
return val # type: ignore


def value_to_type(
hint: Type,
value: Any,
Expand Down
57 changes: 56 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, cast

import pytest
from google.protobuf import json_format
Expand Down Expand Up @@ -60,6 +60,8 @@
TaskReachabilityType,
TerminateWorkflowInput,
WorkflowContinuedAsNewError,
WorkflowExecutionCount,
WorkflowExecutionCountAggregationGroup,
WorkflowExecutionStatus,
WorkflowFailureError,
WorkflowHandle,
Expand Down Expand Up @@ -569,6 +571,59 @@ async def test_list_workflows_and_fetch_history(
assert actual_id_and_input == expected_id_and_input


@workflow.defn
class CountableWorkflow:
@workflow.run
async def run(self, wait_forever: bool) -> None:
await workflow.wait_condition(lambda: not wait_forever)


async def test_count_workflows(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
pytest.skip("Java test server doesn't support newer workflow listing")

# 3 workflows that complete, 2 that don't
async with new_worker(client, CountableWorkflow) as worker:
for _ in range(3):
await client.execute_workflow(
CountableWorkflow.run,
False,
id=f"id-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
for _ in range(2):
await client.start_workflow(
CountableWorkflow.run,
True,
id=f"id-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

async def fetch_count() -> WorkflowExecutionCount:
resp = await client.count_workflows(
f"TaskQueue = '{worker.task_queue}' GROUP BY ExecutionStatus"
)
cast(List[WorkflowExecutionCountAggregationGroup], resp.groups).sort(
key=lambda g: g.count
)
return resp

await assert_eq_eventually(
WorkflowExecutionCount(
count=5,
groups=[
WorkflowExecutionCountAggregationGroup(
count=2, group_values=["Running"]
),
WorkflowExecutionCountAggregationGroup(
count=3, group_values=["Completed"]
),
],
),
fetch_count,
)


def test_history_from_json():
# Take proto, make JSON, convert to dict, alter some enums, confirm that it
# alters the enums back and matches original history
Expand Down

0 comments on commit 3d87944

Please sign in to comment.