Skip to content

Commit

Permalink
feat: Implement encryption codec for Temporal (#15566)
Browse files Browse the repository at this point in the history
Co-authored-by: Harry Waye <[email protected]>
  • Loading branch information
tomasfarias and Harry Waye authored Jun 8, 2023
1 parent 81d8c21 commit ebe7388
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 4 deletions.
8 changes: 8 additions & 0 deletions posthog/temporal/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import dataclasses

import temporalio.converter
from asgiref.sync import async_to_sync
from django.conf import settings
from temporalio.client import Client, TLSConfig

from posthog.temporal.codec import EncryptionCodec


async def connect(host, port, namespace, server_root_ca_cert=None, client_cert=None, client_key=None):
tls: TLSConfig | bool = False
Expand All @@ -15,6 +20,9 @@ async def connect(host, port, namespace, server_root_ca_cert=None, client_cert=N
f"{host}:{port}",
namespace=namespace,
tls=tls,
data_converter=dataclasses.replace(
temporalio.converter.default(), payload_codec=EncryptionCodec(settings=settings)
),
)
return client

Expand Down
55 changes: 55 additions & 0 deletions posthog/temporal/codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import base64
from typing import Iterable

from cryptography.fernet import Fernet
from temporalio.api.common.v1 import Payload
from temporalio.converter import PayloadCodec


class EncryptionCodec(PayloadCodec):
"""A PayloadCodec that encrypts/decrypts all Payloads.
Args:
settings: Django settings to obtain the SECRET_KEY to use for encryption.
"""

def __init__(self, settings) -> None:
super().__init__()

# Fernet requires a URL safe, base64 encoded, 32 byte key. So, we pad the SECRET_KEY
# if it's not long enough (like in TEST environments) or we truncate it if it's too long.
padded_key = b"\0" * max(32 - len(settings.SECRET_KEY), 0) + settings.SECRET_KEY.encode()
encoded_key = base64.urlsafe_b64encode(padded_key[:32])
self.fernet = Fernet(encoded_key)

async def encode(self, payloads: Iterable[Payload]) -> list[Payload]:
"""Encrypt all payloads during encoding."""
return [
Payload(
metadata={
"encoding": b"binary/encrypted",
},
data=self.encrypt(p.SerializeToString()),
)
for p in payloads
]

async def decode(self, payloads: Iterable[Payload]) -> list[Payload]:
"""Decode all payloads decrypting those with expected encoding."""
ret: list[Payload] = []
for p in payloads:
# Ignore ones without our expected encoding
if p.metadata.get("encoding", b"").decode() != "binary/encrypted":
ret.append(p)
continue

ret.append(Payload.FromString(self.decrypt(p.data)))
return ret

def encrypt(self, data: bytes) -> bytes:
"""Return data encrypted."""
return self.fernet.encrypt(data)

def decrypt(self, data: bytes) -> bytes:
"""Return data decrypted."""
return self.fernet.decrypt(data)
79 changes: 79 additions & 0 deletions posthog/temporal/tests/test_encryption_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import dataclasses
import uuid

import pytest
import temporalio.converter
from django.conf import settings
from temporalio.api.enums.v1 import EventType
from temporalio.client import Client
from temporalio.worker import UnsandboxedWorkflowRunner, Worker

from posthog.temporal.codec import EncryptionCodec
from posthog.temporal.workflows.noop import NoOpWorkflow, noop_activity


def get_history_event_payloads(event):
"""Return a history event's payloads if it has any.
Depending on the event_type, each event has a different attribute to store the payloads (ugh).
"""
match event.event_type:
case EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED:
return event.workflow_execution_started_event_attributes.input.payloads
case EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED:
return event.workflow_execution_completed_event_attributes.result.payloads
case EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED:
return event.activity_task_scheduled_event_attributes.input.payloads
case EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED:
return event.activity_task_completed_event_attributes.result.payloads
case _:
return None


@pytest.mark.asyncio
async def test_payloads_are_encrypted():
"""Test the payloads of a Workflow are encrypted when running with EncryptionCodec."""
codec = EncryptionCodec(settings=settings)
client = await Client.connect(
f"{settings.TEMPORAL_HOST}:{settings.TEMPORAL_PORT}",
namespace=settings.TEMPORAL_NAMESPACE,
data_converter=dataclasses.replace(temporalio.converter.default(), payload_codec=codec),
)

workflow_id = uuid.uuid4()
input_str = str(uuid.uuid4())
no_op_result_str = f"OK - {input_str}"
no_op_activity_input_str = f'{{"time":"{input_str}"}}'
# The no-op Workflow can only produce a limited set of results, so we'll check if the events match any of these.
# Either it's the final result (no_op_result_str), the input to an activity (no_op_activity_input_str), or the
# input to the workflow (input_str). In all cases, data is encoded.
expected_results = (f'"{no_op_result_str}"'.encode(), f'"{input_str}"'.encode(), no_op_activity_input_str.encode())

async with Worker(
client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[NoOpWorkflow],
activities=[noop_activity],
workflow_runner=UnsandboxedWorkflowRunner(),
) as worker:
handle = await client.start_workflow(
NoOpWorkflow.run,
input_str,
id=f"workflow-{workflow_id}",
task_queue=worker.task_queue,
)

result = await handle.result()
assert result == no_op_result_str

async for event in handle.fetch_history_events():
payloads = get_history_event_payloads(event)

if not payloads:
continue

payload = payloads[0]
assert payload.metadata["encoding"] == b"binary/encrypted"

decoded_payloads = await codec.decode([payload])
assert decoded_payloads[0].data in expected_results
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.black]
line-length = 120
target-version = ['py38']
target-version = ['py310']

[tool.isort]
multi_line_output = 3
Expand Down Expand Up @@ -95,4 +95,3 @@ max-complexity = 10
"./posthog/management/commands/test_migrations_are_safe.py" = ["T201"]
"./posthog/management/commands/api_keys.py" = ["T201"]
"./posthog/demo/matrix/manager.py" = ["T201"]

1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ celery==4.4.7
celery-redbeat==2.0.0
clickhouse-driver==0.2.4
clickhouse-pool==0.5.3
cryptography==37.0.2
defusedxml==0.6.0
dj-database-url==0.5.0
Django==3.2.18
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
# This file is autogenerated by pip-compile with python 3.10
# To update, run:
#
# pip-compile requirements.in
#
Expand Down Expand Up @@ -75,6 +75,7 @@ clickhouse-pool==0.5.3
# via -r requirements.in
cryptography==37.0.2
# via
# -r requirements.in
# kafka-helper
# pyopenssl
# social-auth-core
Expand Down

0 comments on commit ebe7388

Please sign in to comment.