Skip to content

Commit

Permalink
refactor(s3-batch-exports): Swap to asyncio s3 client
Browse files Browse the repository at this point in the history
Temporal activities are not heartbeating. I believe this is because the heartbeat
is spawning an async task, but since our code is not doing any await, we never
yield the main thread for it to run the async task.

By swapping to aioboto3 we now await on each upload part which should allow us to
heartbeat while the part is uploaded.
  • Loading branch information
tomasfarias committed Sep 28, 2023
1 parent 1832f4d commit 16e5434
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 144 deletions.
163 changes: 88 additions & 75 deletions posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import datetime as dt
import functools
import gzip
Expand All @@ -9,7 +10,7 @@
from unittest import mock
from uuid import uuid4

import boto3
import aioboto3
import botocore.exceptions
import brotli
import pytest
Expand Down Expand Up @@ -55,18 +56,20 @@
TEST_ROOT_BUCKET = "test-batch-exports"


def check_valid_credentials() -> bool:
async def check_valid_credentials() -> bool:
"""Check if there are valid AWS credentials in the environment."""
sts = boto3.client("sts")
session = aioboto3.Session()
sts = await session.client("sts")
try:
sts.get_caller_identity()
await sts.get_caller_identity()
except botocore.exceptions.ClientError:
return False
else:
return True


create_test_client = functools.partial(boto3.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)
SESSION = aioboto3.Session()
create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)


@pytest.fixture
Expand All @@ -75,48 +78,47 @@ def bucket_name() -> str:
return f"{TEST_ROOT_BUCKET}-{str(uuid4())}"


@pytest.fixture
def s3_client(bucket_name):
@pytest_asyncio.fixture
async def s3_client(bucket_name):
"""Manage a testing S3 client to interact with a testing S3 bucket.
Yields the test S3 client after creating a testing S3 bucket. Upon resuming, we delete
the contents and the bucket itself.
"""
s3_client = create_test_client(
async with create_test_client(
"s3",
aws_access_key_id="object_storage_root_user",
aws_secret_access_key="object_storage_root_password",
)
) as s3_client:
await s3_client.create_bucket(Bucket=bucket_name)

s3_client.create_bucket(Bucket=bucket_name)
yield s3_client

yield s3_client
response = await s3_client.list_objects_v2(Bucket=bucket_name)

response = s3_client.list_objects_v2(Bucket=bucket_name)
if "Contents" in response:
for obj in response["Contents"]:
if "Key" in obj:
await s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])

if "Contents" in response:
for obj in response["Contents"]:
if "Key" in obj:
s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])
await s3_client.delete_bucket(Bucket=bucket_name)

s3_client.delete_bucket(Bucket=bucket_name)


def assert_events_in_s3(
async def assert_events_in_s3(
s3_client, bucket_name, key_prefix, events, compression: str | None = None, exclude_events: list[str] | None = None
):
"""Assert provided events written to JSON in key_prefix in S3 bucket_name."""
# List the objects in the bucket with the prefix.
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)

# Check that there is only one object.
assert len(objects.get("Contents", [])) == 1

# Get the object.
key = objects["Contents"][0].get("Key")
assert key
object = s3_client.get_object(Bucket=bucket_name, Key=key)
data = object["Body"].read()
s3_object = await s3_client.get_object(Bucket=bucket_name, Key=key)
data = await s3_object["Body"].read()

# Check that the data is correct.
match compression:
Expand Down Expand Up @@ -306,10 +308,12 @@ async def test_insert_into_s3_activity_puts_data_into_s3(
with override_settings(
BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2
): # 5MB, the minimum for Multipart uploads
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.run(insert_into_s3_activity, insert_inputs)

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.django_db
Expand Down Expand Up @@ -436,7 +440,9 @@ async def test_s3_export_workflow_with_minio_bucket(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -452,7 +458,7 @@ async def test_s3_export_workflow_with_minio_bucket(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.skipif(
Expand Down Expand Up @@ -581,45 +587,46 @@ async def test_s3_export_workflow_with_s3_bucket(interval, compression, encrypti
**batch_export.destination.config,
)

s3_client = boto3.client("s3")
async with aioboto3.Session().client("s3") as s3_client:

def create_s3_client(*args, **kwargs):
"""Mock function to return an already initialized S3 client."""
return s3_client
@contextlib.asynccontextmanager
async def create_s3_client(*args, **kwargs):
"""Mock function to return an already initialized S3 client."""
yield s3_client

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_s3_client):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[S3BatchExportWorkflow],
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_s3_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)

runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1
runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip"])
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
The workflow should update the batch export run status to completed and produce the expected
Expand Down Expand Up @@ -700,7 +707,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -716,15 +725,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression)
await assert_events_in_s3(
s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression
)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
In this scenario we assert that when inserted_at is NULL, we default to _timestamp.
Expand Down Expand Up @@ -818,7 +827,9 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -834,15 +845,13 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(s3_client, bucket_name, compression):
"""Test the S3BatchExport Workflow utilizing a custom key prefix.
We will be asserting that exported events land in the appropiate S3 key according to the prefix.
Expand Down Expand Up @@ -921,7 +930,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -940,20 +951,18 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
expected_key_prefix = prefix.format(
table="events", year="2023", month="04", day="25", hour="14", minute="30", second="00"
)
objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
key = objects["Contents"][0].get("Key")
assert len(objects.get("Contents", [])) == 1
assert key.startswith(expected_key_prefix)

assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
client: HttpClient, s3_client, bucket_name, compression
):
async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(s3_client, bucket_name, compression):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular instance of the test, we assert no duplicates are exported to S3.
Expand Down Expand Up @@ -1065,7 +1074,9 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
Expand All @@ -1080,7 +1091,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
run = runs[0]
assert run.status == "Completed"

assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)


@pytest_asyncio.fixture
Expand Down Expand Up @@ -1537,9 +1548,11 @@ def assert_heartbeat_details(*details):
)

with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
with mock.patch(
"posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
):
await activity_environment.run(insert_into_s3_activity, insert_inputs)

# This checks that the assert_heartbeat_details function was actually called
assert current_part_number > 1
assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
await assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
19 changes: 3 additions & 16 deletions posthog/temporal/workflows/batch_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import brotli
from asgiref.sync import sync_to_async
from django.conf import settings
from temporalio import activity, workflow

from posthog.batch_exports.service import (
Expand All @@ -24,7 +23,6 @@
)
from posthog.kafka_client.client import KafkaProducer
from posthog.kafka_client.topics import KAFKA_LOG_ENTRIES
from posthog.temporal.client import connect

SELECT_QUERY_TEMPLATE = Template(
"""
Expand Down Expand Up @@ -297,6 +295,9 @@ def __exit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)

def __iter__(self):
yield from self._file

@property
def brotli_compressor(self):
if self._brotli_compressor is None:
Expand Down Expand Up @@ -609,17 +610,3 @@ class UpdateBatchExportRunStatusInputs:
async def update_export_run_status(inputs: UpdateBatchExportRunStatusInputs):
"""Activity that updates the status of an BatchExportRun."""
await sync_to_async(update_batch_export_run_status)(run_id=uuid.UUID(inputs.id), status=inputs.status, latest_error=inputs.latest_error) # type: ignore


async def heartbeat(task_token: bytes, *details):
"""Async heartbeat function for batch export activities."""
client = await connect(
settings.TEMPORAL_HOST,
settings.TEMPORAL_PORT,
settings.TEMPORAL_NAMESPACE,
settings.TEMPORAL_CLIENT_ROOT_CA,
settings.TEMPORAL_CLIENT_CERT,
settings.TEMPORAL_CLIENT_KEY,
)
handle = client.get_async_activity_handle(task_token=task_token)
await handle.heartbeat(*details)
Loading

0 comments on commit 16e5434

Please sign in to comment.