Skip to content

Commit

Permalink
refactor(s3-batch-exports): Swap to asyncio s3 client (#17673)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias authored Sep 29, 2023
1 parent 1162f7d commit 5e192a0
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 123 deletions.
163 changes: 88 additions & 75 deletions posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import datetime as dt
import functools
import gzip
Expand All @@ -9,7 +10,7 @@
from unittest import mock
from uuid import uuid4

import boto3
import aioboto3
import botocore.exceptions
import brotli
import pytest
Expand Down Expand Up @@ -55,18 +56,20 @@
TEST_ROOT_BUCKET = "test-batch-exports"


def check_valid_credentials() -> bool:
async def check_valid_credentials() -> bool:
"""Check if there are valid AWS credentials in the environment."""
sts = boto3.client("sts")
session = aioboto3.Session()
sts = await session.client("sts")
try:
sts.get_caller_identity()
await sts.get_caller_identity()
except botocore.exceptions.ClientError:
return False
else:
return True


create_test_client = functools.partial(boto3.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)
SESSION = aioboto3.Session()
create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)


@pytest.fixture
Expand All @@ -75,48 +78,47 @@ def bucket_name() -> str:
return f"{TEST_ROOT_BUCKET}-{str(uuid4())}"


@pytest.fixture
def s3_client(bucket_name):
@pytest_asyncio.fixture
async def s3_client(bucket_name):
"""Manage a testing S3 client to interact with a testing S3 bucket.
Yields the test S3 client after creating a testing S3 bucket. Upon resuming, we delete
the contents and the bucket itself.
"""
s3_client = create_test_client(
async with create_test_client(
"s3",
aws_access_key_id="object_storage_root_user",
aws_secret_access_key="object_storage_root_password",
)
) as s3_client:
await s3_client.create_bucket(Bucket=bucket_name)

s3_client.create_bucket(Bucket=bucket_name)
yield s3_client

yield s3_client
response = await s3_client.list_objects_v2(Bucket=bucket_name)

response = s3_client.list_objects_v2(Bucket=bucket_name)
if "Contents" in response:
for obj in response["Contents"]:
if "Key" in obj:
await s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])

if "Contents" in response:
for obj in response["Contents"]:
if "Key" in obj:
s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])
await s3_client.delete_bucket(Bucket=bucket_name)

s3_client.delete_bucket(Bucket=bucket_name)


def assert_events_in_s3(
async def assert_events_in_s3(
s3_client, bucket_name, key_prefix, events, compression: str | None = None, exclude_events: list[str] | None = None
):
"""Assert provided events written to JSON in key_prefix in S3 bucket_name."""
# List the objects in the bucket with the prefix.
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)

# Check that there is only one object.
assert len(objects.get("Contents", [])) == 1

# Get the object.
key = objects["Contents"][0].get("Key")
assert key
object = s3_client.get_object(Bucket=bucket_name, Key=key)
data = object["Body"].read()
s3_object = await s3_client.get_object(Bucket=bucket_name, Key=key)
data = await s3_object["Body"].read()

# Check that the data is correct.
match compression:
Expand Down Expand Up @@ -306,10 +308,12 @@ async def test_insert_into_s3_activity_puts_data_into_s3(
with override_settings(
BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2
): # 5MB, the minimum for Multipart uploads
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.run(insert_into_s3_activity, insert_inputs)

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.django_db
Expand Down Expand Up @@ -436,7 +440,9 @@ async def test_s3_export_workflow_with_minio_bucket(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -452,7 +458,7 @@ async def test_s3_export_workflow_with_minio_bucket(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.skipif(
Expand Down Expand Up @@ -581,45 +587,46 @@ async def test_s3_export_workflow_with_s3_bucket(interval, compression, encrypti
**batch_export.destination.config,
)

s3_client = boto3.client("s3")
async with aioboto3.Session().client("s3") as s3_client:

def create_s3_client(*args, **kwargs):
"""Mock function to return an already initialized S3 client."""
return s3_client
@contextlib.asynccontextmanager
async def create_s3_client(*args, **kwargs):
"""Mock function to return an already initialized S3 client."""
yield s3_client

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_s3_client):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_s3_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)

runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip"])
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
The workflow should update the batch export run status to completed and produce the expected
Expand Down Expand Up @@ -700,7 +707,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -716,15 +725,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression)
await assert_events_in_s3(
s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression
)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
In this scenario we assert that when inserted_at is NULL, we default to _timestamp.
Expand Down Expand Up @@ -818,7 +827,9 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -834,15 +845,13 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(s3_client, bucket_name, compression):
"""Test the S3BatchExport Workflow utilizing a custom key prefix.
We will be asserting that exported events land in the appropiate S3 key according to the prefix.
Expand Down Expand Up @@ -921,7 +930,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -940,20 +951,18 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
expected_key_prefix = prefix.format(
table="events", year="2023", month="04", day="25", hour="14", minute="30", second="00"
)
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
key = objects["Contents"][0].get("Key")
assert len(objects.get("Contents", [])) == 1
assert key.startswith(expected_key_prefix)

assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(s3_client, bucket_name, compression):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular instance of the test, we assert no duplicates are exported to S3.
Expand Down Expand Up @@ -1065,7 +1074,9 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -1080,7 +1091,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)


@pytest_asyncio.fixture
Expand Down Expand Up @@ -1537,9 +1548,11 @@ def assert_heartbeat_details(*details):
)

with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.run(insert_into_s3_activity, insert_inputs)

# This checks that the assert_heartbeat_details function was actually called
assert current_part_number > 1
assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
3 changes: 3 additions & 0 deletions posthog/temporal/workflows/batch_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ def __exit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)

def __iter__(self):
yield from self._file

@property
def brotli_compressor(self):
if self._brotli_compressor is None:
Expand Down
Loading

0 comments on commit 5e192a0

Please sign in to comment.