Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support S3 batch export encryption #17401

Merged
merged 9 commits into from
Sep 14, 2023
34 changes: 31 additions & 3 deletions frontend/src/scenes/batch_exports/BatchExportEditForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ export function BatchExportsEditForm(props: BatchExportsEditLogicProps): JSX.Ele
]}
/>
</Field>
</div>
<Field name="prefix" label="Key prefix">
<LemonInput placeholder="e.g. posthog-events/" />
</Field>

<div className="flex gap-4">
<Field name="compression" label="Compression" className="flex-1">
<LemonSelect
options={[
Expand All @@ -188,16 +194,25 @@ export function BatchExportsEditForm(props: BatchExportsEditLogicProps): JSX.Ele
]}
/>
</Field>

<Field name="encryption" label="Encryption" className="flex-1">
<LemonSelect
options={[
{ value: 'AES256', label: 'AES256' },
{ value: 'aws:kms', label: 'aws:kms' },
{ value: null, label: 'No encryption' },
]}
/>
</Field>
</div>
<Field name="prefix" label="Key prefix">
<LemonInput placeholder="e.g. posthog-events/" />
</Field>

<div className="flex gap-4">
<Field name="aws_access_key_id" label="AWS Access Key ID" className="flex-1">
<LemonInput
placeholder={isNew ? 'e.g. AKIAIOSFODNN7EXAMPLE' : 'leave unchanged'}
/>
</Field>

<Field
name="aws_secret_access_key"
label="AWS Secret Access Key"
Expand All @@ -208,7 +223,20 @@ export function BatchExportsEditForm(props: BatchExportsEditLogicProps): JSX.Ele
type="password"
/>
</Field>

{batchExportConfigForm.encryption == 'aws:kms' && (
<Field name="kms_key_id" label="AWS KMS Key ID" className="flex-1">
<LemonInput
placeholder={
isNew
? 'e.g. 1234abcd-12ab-34cd-56ef-1234567890ab'
: 'leave unchanged'
}
/>
</Field>
)}
</div>

<Field name="exclude_events" label="Events to exclude" className="flex-1">
<LemonSelectMultiple
mode="multiple-custom"
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/scenes/batch_exports/BatchExports.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ export default {
aws_secret_access_key: '',
compression: null,
exclude_events: [],
encryption: null,
kms_key_id: null,
},
},
start_at: null,
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/scenes/batch_exports/batchExportEditLogic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ const formFields = (
aws_access_key_id: isNew ? (!config.aws_access_key_id ? 'This field is required' : '') : '',
aws_secret_access_key: isNew ? (!config.aws_secret_access_key ? 'This field is required' : '') : '',
compression: '',
encryption: '',
kms_key_id: !config.kms_key_id && config.encryption == 'aws:kms' ? 'This field is required' : '',
exclude_events: '',
}
: destination === 'BigQuery'
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,8 @@ export type BatchExportDestinationS3 = {
aws_secret_access_key: string
exclude_events: string[]
compression: string | null
encryption: string | null
kms_key_id: string | null
}
}

Expand Down
2 changes: 2 additions & 0 deletions posthog/batch_exports/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class S3BatchExportInputs:
data_interval_end: str | None = None
compression: str | None = None
exclude_events: list[str] | None = None
encryption: str | None = None
kms_key_id: str | None = None


@dataclass
Expand Down
171 changes: 171 additions & 0 deletions posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import gzip
import itertools
import json
import os
from random import randint
from unittest import mock
from uuid import uuid4

import boto3
import botocore.exceptions
import brotli
import pytest
from django.conf import settings
Expand Down Expand Up @@ -40,6 +42,18 @@

TEST_ROOT_BUCKET = "test-batch-exports"


def check_valid_credentials() -> bool:
"""Check if there are valid AWS credentials in the environment."""
sts = boto3.client("sts")
try:
sts.get_caller_identity()
except botocore.exceptions.ClientError:
return False
else:
return True


create_test_client = functools.partial(boto3.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)


Expand Down Expand Up @@ -422,6 +436,163 @@ async def test_s3_export_workflow_with_minio_bucket(
assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.skipif(
"S3_TEST_BUCKET" not in os.environ or not check_valid_credentials(),
reason="AWS credentials not set in environment or missing S3_TEST_BUCKET variable",
)
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize(
"interval,compression,encryption,exclude_events",
itertools.product(["hour", "day"], [None, "gzip", "brotli"], [None, "AES256"], [None, ["test-exclude"]]),
)
async def test_s3_export_workflow_with_s3_bucket(interval, compression, encryption, exclude_events):
"""Test S3 Export Workflow end-to-end by using an S3 bucket.

The S3_TEST_BUCKET environment variable is used to set the name of the bucket for this test.
This test will be skipped if no valid AWS credentials exist, or if the S3_TEST_BUCKET environment
variable is not set.

The workflow should update the batch export run status to completed and produce the expected
records to the S3 bucket.
"""
bucket_name = os.getenv("S3_TEST_BUCKET")
prefix = f"posthog-events-{str(uuid4())}"
destination_data = {
"type": "S3",
"config": {
"bucket_name": bucket_name,
"region": "us-east-1",
"prefix": prefix,
"aws_access_key_id": "object_storage_root_user",
"aws_secret_access_key": "object_storage_root_password",
"compression": compression,
"exclude_events": exclude_events,
"encryption": encryption,
},
}

batch_export_data = {
"name": "my-production-s3-bucket-destination",
"destination": destination_data,
"interval": interval,
}

organization = await acreate_organization("test")
team = await acreate_team(organization=organization)
batch_export = await acreate_batch_export(
team_id=team.pk,
name=batch_export_data["name"],
destination_data=batch_export_data["destination"],
interval=batch_export_data["interval"],
)

events: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": "test",
"timestamp": "2023-04-25 13:30:00.000000",
"created_at": "2023-04-25 13:30:00.000000",
"inserted_at": "2023-04-25 13:30:00.000000",
"_timestamp": "2023-04-25 13:30:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
},
{
"uuid": str(uuid4()),
"event": "test-exclude",
"timestamp": "2023-04-25 14:29:00.000000",
"created_at": "2023-04-25 14:29:00.000000",
"inserted_at": "2023-04-25 14:29:00.000000",
"_timestamp": "2023-04-25 14:29:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
},
]

if interval == "day":
# Add an event outside the hour range but within the day range to ensure it's exported too.
events_outside_hour: list[EventValues] = [
{
"uuid": str(uuid4()),
"event": "test",
"timestamp": "2023-04-25 00:30:00.000000",
"created_at": "2023-04-25 00:30:00.000000",
"inserted_at": "2023-04-25 00:30:00.000000",
"_timestamp": "2023-04-25 00:30:00",
"person_id": str(uuid4()),
"person_properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"team_id": team.pk,
"properties": {"$browser": "Chrome", "$os": "Mac OS X"},
"distinct_id": str(uuid4()),
"elements_chain": "this is a comman, separated, list, of css selectors(?)",
}
]
events += events_outside_hour

ch_client = ClickHouseClient(
url=settings.CLICKHOUSE_HTTP_URL,
user=settings.CLICKHOUSE_USER,
password=settings.CLICKHOUSE_PASSWORD,
database=settings.CLICKHOUSE_DATABASE,
)

# Insert some data into the `sharded_events` table.
await insert_events(
client=ch_client,
events=events,
)

workflow_id = str(uuid4())
inputs = S3BatchExportInputs(
team_id=team.pk,
batch_export_id=str(batch_export.id),
data_interval_end="2023-04-25 14:30:00.000000",
interval=interval,
**batch_export.destination.config,
)

s3_client = boto3.client("s3")

def create_s3_client(*args, **kwargs):
"""Mock function to return an already initialized S3 client."""
return s3_client

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_s3_client):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)

runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip"])
Expand Down
34 changes: 27 additions & 7 deletions posthog/temporal/workflows/s3_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class S3MultiPartUploadState(typing.NamedTuple):
parts: list[dict[str, str | int]]


Part = dict[str, str | int]


class S3MultiPartUpload:
"""An S3 multi-part upload."""

def __init__(self, s3_client, bucket_name, key):
def __init__(self, s3_client, bucket_name: str, key: str, encryption: str | None, kms_key_id: str | None):
self.s3_client = s3_client
self.bucket_name = bucket_name
self.key = key
self.upload_id = None
self.parts = []
self.encryption = encryption
self.kms_key_id = kms_key_id
self.upload_id: str | None = None
self.parts: list[Part] = []

def to_state(self) -> S3MultiPartUploadState:
"""Produce state tuple that can be used to resume this S3MultiPartUpload."""
Expand All @@ -119,10 +124,21 @@ def start(self) -> str:
if self.is_upload_in_progress() is True:
raise UploadAlreadyInProgressError(self.upload_id)

multipart_response = self.s3_client.create_multipart_upload(Bucket=self.bucket_name, Key=self.key)
self.upload_id = multipart_response["UploadId"]
optional_kwargs = {}
if self.encryption:
optional_kwargs["ServerSideEncryption"] = self.encryption
if self.kms_key_id:
optional_kwargs["SSEKMSKeyId"] = self.kms_key_id

return self.upload_id
multipart_response = self.s3_client.create_multipart_upload(
Bucket=self.bucket_name,
Key=self.key,
**optional_kwargs,
)
upload_id: str = multipart_response["UploadId"]
self.upload_id = upload_id
Comment on lines +138 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not

Suggested change
upload_id: str = multipart_response["UploadId"]
self.upload_id = upload_id
self.upload_id = multipart_response["UploadId"]

Copy link
Contributor Author

@tomasfarias tomasfarias Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type checking fails as self.upload_id is str | None and this function returns str, not None, so we need the extra upload_id: str to return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't that work just fine - if it's str | None you can assign it str, I'm confused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can assign it str but then you cannot return it, so you need the extra variable. Not the only way to solve this though, you could have some isinstance checks, adding an extra variable seemed like the easiest


return upload_id

def continue_from_state(self, state: S3MultiPartUploadState):
"""Continue this S3MultiPartUpload from a previous state."""
Expand Down Expand Up @@ -230,6 +246,8 @@ class S3InsertInputs:
aws_secret_access_key: str | None = None
compression: str | None = None
exclude_events: list[str] | None = None
encryption: str | None = None
kms_key_id: str | None = None


def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]:
Expand All @@ -241,7 +259,7 @@ def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3Mu
aws_access_key_id=inputs.aws_access_key_id,
aws_secret_access_key=inputs.aws_secret_access_key,
)
s3_upload = S3MultiPartUpload(s3_client, inputs.bucket_name, key)
s3_upload = S3MultiPartUpload(s3_client, inputs.bucket_name, key, inputs.encryption, inputs.kms_key_id)

details = activity.info().heartbeat_details

Expand Down Expand Up @@ -442,6 +460,8 @@ async def run(self, inputs: S3BatchExportInputs):
data_interval_end=data_interval_end.isoformat(),
compression=inputs.compression,
exclude_events=inputs.exclude_events,
encryption=inputs.encryption,
kms_key_id=inputs.kms_key_id,
)
try:
await workflow.execute_activity(
Expand Down