From 51dd5d78c93e393ef87a996d9b572423f3e8adc3 Mon Sep 17 00:00:00 2001 From: Eric Duong Date: Thu, 14 Dec 2023 22:28:09 -0500 Subject: [PATCH] chore(data-warehouse): make data pipeline async (#19338) * working async * cleanup * update test * remove * add good test * adjust env var handling * typing --- .../data_imports/external_data_job.py | 1 - .../data_imports/pipelines/stripe/helpers.py | 115 +++++----- .../pipelines/stripe/stripe_pipeline.py | 38 ++-- .../temporal/tests/test_external_data_job.py | 200 ++++++++++++++---- requirements.in | 2 +- requirements.txt | 8 +- 6 files changed, 238 insertions(+), 126 deletions(-) diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 49b94f268279a..0648ed01df59e 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -190,7 +190,6 @@ async def run(self, inputs: ExternalDataWorkflowInputs): schemas=schemas, ) - # TODO: can make this a child workflow for separate worker pool await workflow.execute_activity( run_external_data_job, job_inputs, diff --git a/posthog/temporal/data_imports/pipelines/stripe/helpers.py b/posthog/temporal/data_imports/pipelines/stripe/helpers.py index 2e9bba272d5fc..9f71a490dcbdd 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/helpers.py +++ b/posthog/temporal/data_imports/pipelines/stripe/helpers.py @@ -1,6 +1,6 @@ """Stripe analytics source helpers""" -from typing import Any, Dict, Optional, Union, Iterable +from typing import Any, Dict, Optional, Union, Iterable, Tuple import stripe import dlt @@ -20,6 +20,34 @@ def transform_date(date: Union[str, DateTime, int]) -> int: return date +def stripe_get_data( + api_key: str, + resource: str, + start_date: Optional[Any] = None, + end_date: Optional[Any] = None, + **kwargs: Any, +) -> Dict[Any, Any]: + if start_date: + start_date = transform_date(start_date) + if end_date: + end_date = transform_date(end_date) + + if resource == "Subscription": + kwargs.update({"status": "all"}) + + _resource = getattr(stripe, resource) + + resource_dict = _resource.list( + api_key=api_key, + created={"gte": start_date, "lt": end_date}, + limit=100, + **kwargs, + ) + response = dict(resource_dict) + + return response + + def stripe_pagination( api_key: str, endpoint: str, @@ -39,45 +67,8 @@ def stripe_pagination( Iterable[TDataItem]: Data items retrieved from the endpoint. """ - should_continue = True - - def stripe_get_data( - api_key: str, - resource: str, - start_date: Optional[Any] = None, - end_date: Optional[Any] = None, - **kwargs: Any, - ) -> Dict[Any, Any]: - nonlocal should_continue - nonlocal starting_after - - if start_date: - start_date = transform_date(start_date) - if end_date: - end_date = transform_date(end_date) - - if resource == "Subscription": - kwargs.update({"status": "all"}) - - _resource = getattr(stripe, resource) - resource_dict = _resource.list( - api_key=api_key, - created={"gte": start_date, "lt": end_date}, - limit=100, - **kwargs, - ) - response = dict(resource_dict) - - if not response["has_more"]: - should_continue = False - - if len(response["data"]) > 0: - starting_after = response["data"][-1]["id"] - - return response["data"] - - while should_continue: - yield stripe_get_data( + while True: + response = stripe_get_data( api_key, endpoint, start_date=start_date, @@ -85,29 +76,37 @@ def stripe_get_data( starting_after=starting_after, ) + if len(response["data"]) > 0: + starting_after = response["data"][-1]["id"] + yield response["data"] + + if not response["has_more"]: + break + @dlt.source def stripe_source( api_key: str, - endpoint: str, + endpoints: Tuple[str, ...], start_date: Optional[Any] = None, end_date: Optional[Any] = None, starting_after: Optional[str] = None, ) -> Iterable[DltResource]: - return dlt.resource( - stripe_pagination, - name=endpoint, - write_disposition="append", - columns={ - "metadata": { - "data_type": "complex", - "nullable": True, - } - }, - )( - api_key=api_key, - endpoint=endpoint, - start_date=start_date, - end_date=end_date, - starting_after=starting_after, - ) + for endpoint in endpoints: + yield dlt.resource( + stripe_pagination, + name=endpoint, + write_disposition="append", + columns={ + "metadata": { + "data_type": "complex", + "nullable": True, + } + }, + )( + api_key=api_key, + endpoint=endpoint, + start_date=start_date, + end_date=end_date, + starting_after=starting_after, + ) diff --git a/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py b/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py index b99a14f438bce..a1138c74aa10e 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py +++ b/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py @@ -10,8 +10,9 @@ from posthog.temporal.data_imports.pipelines.stripe.helpers import stripe_source from posthog.temporal.data_imports.pipelines.stripe.settings import ENDPOINTS from posthog.temporal.common.logger import bind_temporal_worker_logger - +import asyncio import os +from posthog.settings.base_variables import TEST @dataclass @@ -47,18 +48,28 @@ class StripeJobInputs(PipelineInputs): def create_pipeline(inputs: PipelineInputs): pipeline_name = f"{inputs.job_type}_pipeline_{inputs.team_id}_run_{inputs.run_id}" pipelines_dir = f"{os.getcwd()}/.dlt/{inputs.team_id}/{inputs.run_id}/{inputs.job_type}" + return dlt.pipeline( pipeline_name=pipeline_name, - pipelines_dir=pipelines_dir, # workers can be created and destroyed so it doesn't matter where the metadata gets put temporarily - destination="filesystem", + pipelines_dir=pipelines_dir, + destination=dlt.destinations.filesystem( + credentials={ + "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, + "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, + "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT if TEST else None, + }, + bucket_url=settings.BUCKET_URL, # type: ignore + ), dataset_name=inputs.dataset_name, - credentials={ - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - }, ) +def _run_pipeline(inputs: StripeJobInputs): + pipeline = create_pipeline(inputs) + source = stripe_source(inputs.stripe_secret_key, tuple(inputs.schemas)) + pipeline.run(source, loader_file_format="parquet") + + # a temporal activity async def run_stripe_pipeline(inputs: StripeJobInputs) -> None: logger = await bind_temporal_worker_logger(team_id=inputs.team_id) @@ -67,14 +78,11 @@ async def run_stripe_pipeline(inputs: StripeJobInputs) -> None: logger.info(f"No schemas found for source id {inputs.source_id}") return - for endpoint in schemas: - pipeline = create_pipeline(inputs) - try: - source = stripe_source(inputs.stripe_secret_key, endpoint) - pipeline.run(source, table_name=endpoint.lower(), loader_file_format="parquet") - except PipelineStepFailed: - logger.error(f"Data import failed for endpoint {endpoint}") - raise + try: + await asyncio.to_thread(_run_pipeline, inputs) + except PipelineStepFailed: + logger.error(f"Data import failed for endpoint") + raise PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING = {ExternalDataSource.Type.STRIPE: ENDPOINTS} diff --git a/posthog/temporal/tests/test_external_data_job.py b/posthog/temporal/tests/test_external_data_job.py index 44d153967cd7d..e519b334693f0 100644 --- a/posthog/temporal/tests/test_external_data_job.py +++ b/posthog/temporal/tests/test_external_data_job.py @@ -15,9 +15,6 @@ update_external_data_job_model, validate_schema_activity, ) -from posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline import ( - StripeJobInputs, -) from posthog.temporal.data_imports.external_data_job import ( ExternalDataJobWorkflow, ExternalDataJobInputs, @@ -39,13 +36,59 @@ from temporalio.common import RetryPolicy from temporalio.worker import UnsandboxedWorkflowRunner, Worker from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE +import pytest_asyncio +import aioboto3 +import functools +from django.conf import settings +import asyncio + +BUCKET_NAME = "test-external-data-jobs" +SESSION = aioboto3.Session() +create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT) AWS_BUCKET_MOCK_SETTINGS = { - "AIRBYTE_BUCKET_KEY": "test-key", - "AIRBYTE_BUCKET_SECRET": "test-secret", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, } +async def delete_all_from_s3(minio_client, bucket_name: str, key_prefix: str): + """Delete all objects in bucket_name under key_prefix.""" + response = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) + + if "Contents" in response: + for obj in response["Contents"]: + if "Key" in obj: + await minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + + +@pytest.fixture +def bucket_name(request) -> str: + """Name for a test S3 bucket.""" + return BUCKET_NAME + + +@pytest_asyncio.fixture +async def minio_client(bucket_name): + """Manage an S3 client to interact with a MinIO bucket. + + Yields the client after creating a bucket. Upon resuming, we delete + the contents and the bucket itself. + """ + async with create_test_client( + "s3", + aws_access_key_id=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ) as minio_client: + await minio_client.create_bucket(Bucket=bucket_name) + + yield minio_client + + await delete_all_from_s3(minio_client, bucket_name, key_prefix="/") + + await minio_client.delete_bucket(Bucket=bucket_name) + + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_create_external_job_activity(activity_environment, team, **kwargs): @@ -143,51 +186,118 @@ async def test_update_external_job_activity(activity_environment, team, **kwargs @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_run_stripe_job(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), - team=team, - status="running", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key"}, - ) # type: ignore +async def test_run_stripe_job(activity_environment, team, minio_client, **kwargs): + async def setup_job_1(): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key"}, + ) # type: ignore + + new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + ) - new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore - team_id=team.id, - pipeline_id=new_source.pk, - status=ExternalDataJob.Status.RUNNING, - rows_synced=0, - ) + new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)() # type: ignore - new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)() # type: ignore + schemas = ["Customer"] + inputs = ExternalDataJobInputs( + team_id=team.id, + run_id=new_job.pk, + source_id=new_source.pk, + schemas=schemas, + ) - inputs = ExternalDataJobInputs( - team_id=team.id, - run_id=new_job.pk, - source_id=new_source.pk, - schemas=["test-1", "test-2", "test-3", "test-4", "test-5"], - ) + return new_job, inputs + + async def setup_job_2(): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key"}, + ) # type: ignore + + new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + ) - with mock.patch( - "posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline.create_pipeline", - ) as mock_create_pipeline: # noqa: B015 - await activity_environment.run(run_external_data_job, inputs) - - assert mock_create_pipeline.call_count == 5 - - mock_create_pipeline.assert_called_with( - StripeJobInputs( - source_id=new_source.pk, - run_id=new_job.pk, - job_type="Stripe", - team_id=team.id, - stripe_secret_key="test-key", - dataset_name=new_job.folder_path, - schemas=["test-1", "test-2", "test-3", "test-4", "test-5"], - ) + new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)() # type: ignore + + schemas = ["Customer", "Invoice"] + inputs = ExternalDataJobInputs( + team_id=team.id, + run_id=new_job.pk, + source_id=new_source.pk, + schemas=schemas, + ) + + return new_job, inputs + + job_1, job_1_inputs = await setup_job_1() + job_2, job_2_inputs = await setup_job_2() + + with mock.patch("stripe.Customer.list") as mock_customer_list, mock.patch( + "stripe.Invoice.list" + ) as mock_invoice_list, override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ): + mock_customer_list.return_value = { + "data": [ + { + "id": "cus_123", + "name": "John Doe", + } + ], + "has_more": False, + } + + mock_invoice_list.return_value = { + "data": [ + { + "id": "inv_123", + "customer": "cus_1", + } + ], + "has_more": False, + } + await asyncio.gather( + activity_environment.run(run_external_data_job, job_1_inputs), + activity_environment.run(run_external_data_job, job_2_inputs), + ) + + job_1_customer_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/customer/" + ) + job_1_invoice_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/invoice/" + ) + assert len(job_1_customer_objects["Contents"]) == 1 + assert job_1_invoice_objects.get("Contents", None) is None + + job_2_customer_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path}/customer/" + ) + job_2_invoice_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path}/invoice/" ) + assert len(job_2_customer_objects["Contents"]) == 1 + assert len(job_2_invoice_objects["Contents"]) == 1 @pytest.mark.django_db(transaction=True) diff --git a/requirements.in b/requirements.in index 655539e4497e2..c45e32b576ef3 100644 --- a/requirements.in +++ b/requirements.in @@ -36,7 +36,7 @@ djangorestframework==3.14.0 djangorestframework-csv==2.1.1 djangorestframework-dataclasses==1.2.0 django-fernet-encrypted-fields==0.1.3 -dlt==0.3.24 +dlt==0.4.1a2 dnspython==2.2.1 drf-exceptions-hog==0.4.0 drf-extensions==0.7.0 diff --git a/requirements.txt b/requirements.txt index 388b8362d75bf..f0459a383be9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -136,8 +136,6 @@ defusedxml==0.6.0 # -r requirements.in # python3-openid # social-auth-core -deprecated==1.2.14 - # via dlt dj-database-url==0.5.0 # via -r requirements.in django==3.2.19 @@ -210,7 +208,7 @@ djangorestframework-csv==2.1.1 # via -r requirements.in djangorestframework-dataclasses==1.2.0 # via -r requirements.in -dlt==0.3.24 +dlt==0.4.1a2 # via -r requirements.in dnspython==2.2.1 # via -r requirements.in @@ -640,9 +638,7 @@ wheel==0.42.0 whitenoise==6.5.0 # via -r requirements.in wrapt==1.15.0 - # via - # aiobotocore - # deprecated + # via aiobotocore wsproto==1.1.0 # via trio-websocket xmlsec==1.3.13