-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement encryption codec for Temporal (#15566)
Co-authored-by: Harry Waye <[email protected]>
- Loading branch information
1 parent
81d8c21
commit ebe7388
Showing
6 changed files
with
147 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import base64 | ||
from typing import Iterable | ||
|
||
from cryptography.fernet import Fernet | ||
from temporalio.api.common.v1 import Payload | ||
from temporalio.converter import PayloadCodec | ||
|
||
|
||
class EncryptionCodec(PayloadCodec): | ||
"""A PayloadCodec that encrypts/decrypts all Payloads. | ||
Args: | ||
settings: Django settings to obtain the SECRET_KEY to use for encryption. | ||
""" | ||
|
||
def __init__(self, settings) -> None: | ||
super().__init__() | ||
|
||
# Fernet requires a URL safe, base64 encoded, 32 byte key. So, we pad the SECRET_KEY | ||
# if it's not long enough (like in TEST environments) or we truncate it if it's too long. | ||
padded_key = b"\0" * max(32 - len(settings.SECRET_KEY), 0) + settings.SECRET_KEY.encode() | ||
encoded_key = base64.urlsafe_b64encode(padded_key[:32]) | ||
self.fernet = Fernet(encoded_key) | ||
|
||
async def encode(self, payloads: Iterable[Payload]) -> list[Payload]: | ||
"""Encrypt all payloads during encoding.""" | ||
return [ | ||
Payload( | ||
metadata={ | ||
"encoding": b"binary/encrypted", | ||
}, | ||
data=self.encrypt(p.SerializeToString()), | ||
) | ||
for p in payloads | ||
] | ||
|
||
async def decode(self, payloads: Iterable[Payload]) -> list[Payload]: | ||
"""Decode all payloads decrypting those with expected encoding.""" | ||
ret: list[Payload] = [] | ||
for p in payloads: | ||
# Ignore ones without our expected encoding | ||
if p.metadata.get("encoding", b"").decode() != "binary/encrypted": | ||
ret.append(p) | ||
continue | ||
|
||
ret.append(Payload.FromString(self.decrypt(p.data))) | ||
return ret | ||
|
||
def encrypt(self, data: bytes) -> bytes: | ||
"""Return data encrypted.""" | ||
return self.fernet.encrypt(data) | ||
|
||
def decrypt(self, data: bytes) -> bytes: | ||
"""Return data decrypted.""" | ||
return self.fernet.decrypt(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import dataclasses | ||
import uuid | ||
|
||
import pytest | ||
import temporalio.converter | ||
from django.conf import settings | ||
from temporalio.api.enums.v1 import EventType | ||
from temporalio.client import Client | ||
from temporalio.worker import UnsandboxedWorkflowRunner, Worker | ||
|
||
from posthog.temporal.codec import EncryptionCodec | ||
from posthog.temporal.workflows.noop import NoOpWorkflow, noop_activity | ||
|
||
|
||
def get_history_event_payloads(event): | ||
"""Return a history event's payloads if it has any. | ||
Depending on the event_type, each event has a different attribute to store the payloads (ugh). | ||
""" | ||
match event.event_type: | ||
case EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED: | ||
return event.workflow_execution_started_event_attributes.input.payloads | ||
case EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED: | ||
return event.workflow_execution_completed_event_attributes.result.payloads | ||
case EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED: | ||
return event.activity_task_scheduled_event_attributes.input.payloads | ||
case EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED: | ||
return event.activity_task_completed_event_attributes.result.payloads | ||
case _: | ||
return None | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_payloads_are_encrypted(): | ||
"""Test the payloads of a Workflow are encrypted when running with EncryptionCodec.""" | ||
codec = EncryptionCodec(settings=settings) | ||
client = await Client.connect( | ||
f"{settings.TEMPORAL_HOST}:{settings.TEMPORAL_PORT}", | ||
namespace=settings.TEMPORAL_NAMESPACE, | ||
data_converter=dataclasses.replace(temporalio.converter.default(), payload_codec=codec), | ||
) | ||
|
||
workflow_id = uuid.uuid4() | ||
input_str = str(uuid.uuid4()) | ||
no_op_result_str = f"OK - {input_str}" | ||
no_op_activity_input_str = f'{{"time":"{input_str}"}}' | ||
# The no-op Workflow can only produce a limited set of results, so we'll check if the events match any of these. | ||
# Either it's the final result (no_op_result_str), the input to an activity (no_op_activity_input_str), or the | ||
# input to the workflow (input_str). In all cases, data is encoded. | ||
expected_results = (f'"{no_op_result_str}"'.encode(), f'"{input_str}"'.encode(), no_op_activity_input_str.encode()) | ||
|
||
async with Worker( | ||
client, | ||
task_queue=settings.TEMPORAL_TASK_QUEUE, | ||
workflows=[NoOpWorkflow], | ||
activities=[noop_activity], | ||
workflow_runner=UnsandboxedWorkflowRunner(), | ||
) as worker: | ||
handle = await client.start_workflow( | ||
NoOpWorkflow.run, | ||
input_str, | ||
id=f"workflow-{workflow_id}", | ||
task_queue=worker.task_queue, | ||
) | ||
|
||
result = await handle.result() | ||
assert result == no_op_result_str | ||
|
||
async for event in handle.fetch_history_events(): | ||
payloads = get_history_event_payloads(event) | ||
|
||
if not payloads: | ||
continue | ||
|
||
payload = payloads[0] | ||
assert payload.metadata["encoding"] == b"binary/encrypted" | ||
|
||
decoded_payloads = await codec.decode([payload]) | ||
assert decoded_payloads[0].data in expected_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters