Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API key client option #486

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 282 additions & 171 deletions temporalio/bridge/Cargo.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions temporalio/bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ crate-type = ["cdylib"]
[dependencies]
futures = "0.3"
log = "0.4"
once_cell = "1.16.0"
parking_lot = "0.12"
prost = "0.11"
prost-types = "0.11"
once_cell = "1.16"
prost = "0.12"
prost-types = "0.12"
pyo3 = { version = "0.19", features = ["extension-module", "abi3-py38"] }
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
pythonize = "0.19"
Expand All @@ -23,7 +22,7 @@ temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
tokio = "1.26"
tokio-stream = "0.1"
tonic = "0.9"
tonic = "0.11"
tracing = "0.1"
url = "2.2"

Expand Down
5 changes: 5 additions & 0 deletions temporalio/bridge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ClientConfig:

target_url: str
metadata: Mapping[str, str]
api_key: Optional[str]
identity: str
tls_config: Optional[ClientTlsConfig]
retry_config: Optional[ClientRetryConfig]
Expand Down Expand Up @@ -102,6 +103,10 @@ def update_metadata(self, metadata: Mapping[str, str]) -> None:
"""Update underlying metadata on Core client."""
self._ref.update_metadata(metadata)

def update_api_key(self, api_key: Optional[str]) -> None:
"""Update underlying API key on Core client."""
self._ref.update_api_key(api_key)

async def call(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion temporalio/bridge/sdk-core
Submodule sdk-core updated 36 files
+7 −3 Cargo.toml
+8 −5 client/Cargo.toml
+100 −31 client/src/lib.rs
+1 −1 client/src/raw.rs
+3 −3 core-api/Cargo.toml
+20 −20 core/Cargo.toml
+3 −3 core/src/core_tests/mod.rs
+1 −5 core/src/ephemeral_server/mod.rs
+1 −1 core/src/protosext/mod.rs
+21 −251 core/src/telemetry/metrics.rs
+9 −37 core/src/telemetry/mod.rs
+276 −0 core/src/telemetry/otel.rs
+34 −21 core/src/telemetry/prometheus_server.rs
+1 −1 core/src/worker/activities/local_activities.rs
+3 −4 core/src/worker/client/mocks.rs
+6 −6 core/src/worker/workflow/history_update.rs
+2 −2 core/src/worker/workflow/machines/activity_state_machine.rs
+2 −2 core/src/worker/workflow/machines/cancel_workflow_state_machine.rs
+14 −18 core/src/worker/workflow/machines/child_workflow_state_machine.rs
+6 −7 core/src/worker/workflow/machines/transition_coverage.rs
+5 −5 core/src/worker/workflow/machines/workflow_machines.rs
+2 −2 core/src/worker/workflow/machines/workflow_task_state_machine.rs
+4 −4 sdk-core-protos/Cargo.toml
+4 −4 sdk-core-protos/src/lib.rs
+3 −4 sdk/Cargo.toml
+1 −1 sdk/src/lib.rs
+3 −3 test-utils/Cargo.toml
+1 −1 test-utils/src/histfetch.rs
+1 −1 test-utils/src/lib.rs
+2 −2 tests/integ_tests/client_tests.rs
+1 −1 tests/integ_tests/ephemeral_server_tests.rs
+3 −3 tests/integ_tests/metrics_tests.rs
+2 −4 tests/integ_tests/update_tests.rs
+2 −2 tests/integ_tests/visibility_tests.rs
+1 −1 tests/integ_tests/workflow_tests/eager.rs
+2 −2 tests/main.rs
21 changes: 9 additions & 12 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use parking_lot::RwLock;
use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use temporal_client::{
ClientKeepAliveConfig as CoreClientKeepAliveConfig, ClientOptions, ClientOptionsBuilder,
Expand Down Expand Up @@ -31,6 +29,7 @@ pub struct ClientConfig {
client_name: String,
client_version: String,
metadata: HashMap<String, String>,
api_key: Option<String>,
identity: String,
tls_config: Option<ClientTlsConfig>,
retry_config: Option<ClientRetryConfig>,
Expand Down Expand Up @@ -75,20 +74,12 @@ pub fn connect_client<'a>(
runtime_ref: &runtime::RuntimeRef,
config: ClientConfig,
) -> PyResult<&'a PyAny> {
let headers = if config.metadata.is_empty() {
None
} else {
Some(Arc::new(RwLock::new(config.metadata.clone())))
};
let opts: ClientOptions = config.try_into()?;
let runtime = runtime_ref.runtime.clone();
runtime_ref.runtime.future_into_py(py, async move {
Ok(ClientRef {
retry_client: opts
.connect_no_namespace(
runtime.core.telemetry().get_temporal_metric_meter(),
headers,
)
.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter())
.await
.map_err(|err| {
PyRuntimeError::new_err(format!("Failed client connect: {}", err))
Expand All @@ -114,6 +105,10 @@ impl ClientRef {
self.retry_client.get_client().set_headers(headers);
}

fn update_api_key(&self, api_key: Option<String>) {
self.retry_client.get_client().set_api_key(api_key);
}

fn call_workflow_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<&'p PyAny> {
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -396,7 +391,9 @@ impl TryFrom<ClientConfig> for ClientOptions {
opts.retry_config
.map_or(RetryConfig::default(), |c| c.into()),
)
.keep_alive(opts.keep_alive_config.map(Into::into));
.keep_alive(opts.keep_alive_config.map(Into::into))
.headers(Some(opts.metadata))
.api_key(opts.api_key);
// Builder does not allow us to set option here, so we have to make
// a conditional to even call it
if let Some(tls_config) = opts.tls_config {
Expand Down
21 changes: 21 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def connect(
target_host: str,
*,
namespace: str = "default",
api_key: Optional[str] = None,
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
interceptors: Sequence[Interceptor] = [],
default_workflow_query_reject_condition: Optional[
Expand All @@ -116,6 +117,9 @@ async def connect(
target_host: ``host:port`` for the Temporal server. For local
development, this is often "localhost:7233".
namespace: Namespace to use for client calls.
api_key: API key for Temporal. This becomes the "Authorization"
HTTP header with "Bearer " prepended. This is only set if RPC
metadata doesn't already have an "authorization" key.
data_converter: Data converter to use for all data conversions
to/from payloads.
interceptors: Set of interceptors that are chained together to allow
Expand Down Expand Up @@ -152,6 +156,7 @@ async def connect(
"""
connect_config = temporalio.service.ConnectConfig(
target_host=target_host,
api_key=api_key,
tls=tls,
retry_config=retry_config,
keep_alive_config=keep_alive_config,
Expand Down Expand Up @@ -261,6 +266,22 @@ def rpc_metadata(self, value: Mapping[str, str]) -> None:
self.service_client.config.rpc_metadata = value
self.service_client.update_rpc_metadata(value)

@property
def api_key(self) -> Optional[str]:
"""API key for every call made by this client."""
return self.service_client.config.api_key

@api_key.setter
def api_key(self, value: Optional[str]) -> None:
"""Update the API key for this client.

This is only set if RPCmetadata doesn't already have an "authorization"
key.
"""
# Update config and perform update
self.service_client.config.api_key = value
self.service_client.update_api_key(value)

# Overload for no-param workflow
@overload
async def start_workflow(
Expand Down
15 changes: 15 additions & 0 deletions temporalio/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ConnectConfig:
"""Config for connecting to the server."""

target_host: str
api_key: Optional[str] = None
tls: Union[bool, TLSConfig] = False
retry_config: Optional[RetryConfig] = None
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
Expand Down Expand Up @@ -161,6 +162,7 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig:

return temporalio.bridge.client.ClientConfig(
target_url=target_url,
api_key=self.api_key,
tls_config=tls_config,
retry_config=self.retry_config._to_bridge_config()
if self.retry_config
Expand Down Expand Up @@ -238,6 +240,11 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
"""Update service client's RPC metadata."""
raise NotImplementedError

@abstractmethod
def update_api_key(self, api_key: Optional[str]) -> None:
"""Update service client's API key."""
raise NotImplementedError

@abstractmethod
async def _rpc_call(
self,
Expand Down Expand Up @@ -740,6 +747,14 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
if self._bridge_client:
self._bridge_client.update_metadata(metadata)

def update_api_key(self, api_key: Optional[str]) -> None:
"""Update Core client API key."""
# Mutate the bridge config and then only mutate the running client
# metadata if already connected
self._bridge_config.api_key = api_key
if self._bridge_client:
self._bridge_client.update_api_key(api_key)

async def _rpc_call(
self,
rpc: str,
Expand Down
105 changes: 77 additions & 28 deletions tests/api/test_grpc_stub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import timedelta
from typing import Mapping

from google.protobuf.empty_pb2 import Empty
from google.protobuf.timestamp_pb2 import Timestamp
Expand Down Expand Up @@ -27,12 +28,6 @@
from temporalio.client import Client


def assert_metadata(context: ServicerContext, **kwargs) -> None:
metadata = dict(context.invocation_metadata())
for k, v in kwargs.items():
assert metadata.get(k) == v


def assert_time_remaining(context: ServicerContext, expected: int) -> None:
# Give or take 5 seconds
assert expected - 5 <= context.time_remaining() <= expected + 5
Expand All @@ -41,24 +36,26 @@ def assert_time_remaining(context: ServicerContext, expected: int) -> None:
class SimpleWorkflowServer(WorkflowServiceServicer):
def __init__(self) -> None:
super().__init__()
self.expected_client_key_value = "client_value"
self.last_metadata: Mapping[str, str] = {}

def assert_last_metadata(self, expected: Mapping[str, str]) -> None:
for k, v in expected.items():
assert self.last_metadata.get(k) == v

async def GetSystemInfo( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
self,
request: GetSystemInfoRequest,
context: ServicerContext,
) -> GetSystemInfoResponse:
assert_metadata(context, client_key=self.expected_client_key_value)
self.last_metadata = dict(context.invocation_metadata())
return GetSystemInfoResponse()

async def CountWorkflowExecutions( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
self,
request: CountWorkflowExecutionsRequest,
context: ServicerContext,
) -> CountWorkflowExecutionsResponse:
assert_metadata(
context, client_key=self.expected_client_key_value, rpc_key="rpc_value"
)
self.last_metadata = dict(context.invocation_metadata())
assert_time_remaining(context, 123)
assert request.namespace == "my namespace"
assert request.query == "my query"
Expand All @@ -71,7 +68,6 @@ async def DeleteNamespace( # type: ignore # https://github.com/nipunn1313/mypy-
request: DeleteNamespaceRequest,
context: ServicerContext,
) -> DeleteNamespaceResponse:
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
assert_time_remaining(context, 123)
assert request.namespace == "my namespace"
return DeleteNamespaceResponse(deleted_namespace="my namespace response")
Expand All @@ -83,7 +79,6 @@ async def GetCurrentTime( # type: ignore # https://github.com/nipunn1313/mypy-p
request: Empty,
context: ServicerContext,
) -> GetCurrentTimeResponse:
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
assert_time_remaining(context, 123)
return GetCurrentTimeResponse(time=Timestamp(seconds=123))

Expand All @@ -101,34 +96,88 @@ async def test_python_grpc_stub():
await server.start()

# Use our client to make a call to each service
client = await Client.connect(
f"localhost:{port}", rpc_metadata={"client_key": "client_value"}
)
metadata = {"rpc_key": "rpc_value"}
client = await Client.connect(f"localhost:{port}")
timeout = timedelta(seconds=123)
count_resp = await client.workflow_service.count_workflow_executions(
CountWorkflowExecutionsRequest(namespace="my namespace", query="my query"),
metadata=metadata,
timeout=timeout,
)
assert count_resp.count == 123
del_resp = await client.operator_service.delete_namespace(
DeleteNamespaceRequest(namespace="my namespace"),
metadata=metadata,
timeout=timeout,
)
assert del_resp.deleted_namespace == "my namespace response"
time_resp = await client.test_service.get_current_time(
Empty(), metadata=metadata, timeout=timeout
)
time_resp = await client.test_service.get_current_time(Empty(), timeout=timeout)
assert time_resp.time.seconds == 123

# Make another call to get system info after changing the client-level
# header
new_metadata = dict(client.rpc_metadata)
new_metadata["client_key"] = "changed_value"
client.rpc_metadata = new_metadata
workflow_server.expected_client_key_value = "changed_value"
await server.stop(grace=None)


async def test_grpc_metadata():
# Start server
server = grpc_server()
workflow_server = SimpleWorkflowServer() # type: ignore[abstract]
add_WorkflowServiceServicer_to_server(workflow_server, server)
port = server.add_insecure_port("[::]:0")
await server.start()

# Connect and confirm metadata of get system info call
client = await Client.connect(
f"localhost:{port}",
api_key="my-api-key",
rpc_metadata={"my-meta-key": "my-meta-val"},
)
workflow_server.assert_last_metadata(
{
"authorization": "Bearer my-api-key",
"my-meta-key": "my-meta-val",
}
)

# Overwrite API key via client RPC metadata, confirm there
client.rpc_metadata = {
"authorization": "my-auth-val1",
"my-meta-key": "my-meta-val",
}
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"authorization": "my-auth-val1",
"my-meta-key": "my-meta-val",
}
)
client.rpc_metadata = {"my-meta-key": "my-meta-val"}

# Overwrite API key via call RPC metadata, confirm there
await client.workflow_service.get_system_info(
GetSystemInfoRequest(), metadata={"authorization": "my-auth-val2"}
)
workflow_server.assert_last_metadata(
{
"authorization": "my-auth-val2",
"my-meta-key": "my-meta-val",
}
)

# Update API key, confirm updated
client.api_key = "my-new-api-key"
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"authorization": "Bearer my-new-api-key",
"my-meta-key": "my-meta-val",
}
)

# Remove API key, confirm removed
client.api_key = None
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"my-meta-key": "my-meta-val",
}
)
assert "authorization" not in workflow_server.last_metadata

await server.stop(grace=None)
Loading