Skip to content

Commit

Permalink
chore(data-warehouse): make data pipeline async (#19338)
Browse files Browse the repository at this point in the history
* working async

* cleanup

* update test

* remove

* add good test

* adjust env var handling

* typing
  • Loading branch information
EDsCODE authored Dec 15, 2023
1 parent 7cc6e81 commit 51dd5d7
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 126 deletions.
1 change: 0 additions & 1 deletion posthog/temporal/data_imports/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ async def run(self, inputs: ExternalDataWorkflowInputs):
schemas=schemas,
)

# TODO: can make this a child workflow for separate worker pool
await workflow.execute_activity(
run_external_data_job,
job_inputs,
Expand Down
115 changes: 57 additions & 58 deletions posthog/temporal/data_imports/pipelines/stripe/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Stripe analytics source helpers"""

from typing import Any, Dict, Optional, Union, Iterable
from typing import Any, Dict, Optional, Union, Iterable, Tuple

import stripe
import dlt
Expand All @@ -20,6 +20,34 @@ def transform_date(date: Union[str, DateTime, int]) -> int:
return date


def stripe_get_data(
api_key: str,
resource: str,
start_date: Optional[Any] = None,
end_date: Optional[Any] = None,
**kwargs: Any,
) -> Dict[Any, Any]:
if start_date:
start_date = transform_date(start_date)
if end_date:
end_date = transform_date(end_date)

if resource == "Subscription":
kwargs.update({"status": "all"})

_resource = getattr(stripe, resource)

resource_dict = _resource.list(
api_key=api_key,
created={"gte": start_date, "lt": end_date},
limit=100,
**kwargs,
)
response = dict(resource_dict)

return response


def stripe_pagination(
api_key: str,
endpoint: str,
Expand All @@ -39,75 +67,46 @@ def stripe_pagination(
Iterable[TDataItem]: Data items retrieved from the endpoint.
"""

should_continue = True

def stripe_get_data(
api_key: str,
resource: str,
start_date: Optional[Any] = None,
end_date: Optional[Any] = None,
**kwargs: Any,
) -> Dict[Any, Any]:
nonlocal should_continue
nonlocal starting_after

if start_date:
start_date = transform_date(start_date)
if end_date:
end_date = transform_date(end_date)

if resource == "Subscription":
kwargs.update({"status": "all"})

_resource = getattr(stripe, resource)
resource_dict = _resource.list(
api_key=api_key,
created={"gte": start_date, "lt": end_date},
limit=100,
**kwargs,
)
response = dict(resource_dict)

if not response["has_more"]:
should_continue = False

if len(response["data"]) > 0:
starting_after = response["data"][-1]["id"]

return response["data"]

while should_continue:
yield stripe_get_data(
while True:
response = stripe_get_data(
api_key,
endpoint,
start_date=start_date,
end_date=end_date,
starting_after=starting_after,
)

if len(response["data"]) > 0:
starting_after = response["data"][-1]["id"]
yield response["data"]

if not response["has_more"]:
break


@dlt.source
def stripe_source(
api_key: str,
endpoint: str,
endpoints: Tuple[str, ...],
start_date: Optional[Any] = None,
end_date: Optional[Any] = None,
starting_after: Optional[str] = None,
) -> Iterable[DltResource]:
return dlt.resource(
stripe_pagination,
name=endpoint,
write_disposition="append",
columns={
"metadata": {
"data_type": "complex",
"nullable": True,
}
},
)(
api_key=api_key,
endpoint=endpoint,
start_date=start_date,
end_date=end_date,
starting_after=starting_after,
)
for endpoint in endpoints:
yield dlt.resource(
stripe_pagination,
name=endpoint,
write_disposition="append",
columns={
"metadata": {
"data_type": "complex",
"nullable": True,
}
},
)(
api_key=api_key,
endpoint=endpoint,
start_date=start_date,
end_date=end_date,
starting_after=starting_after,
)
38 changes: 23 additions & 15 deletions posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from posthog.temporal.data_imports.pipelines.stripe.helpers import stripe_source
from posthog.temporal.data_imports.pipelines.stripe.settings import ENDPOINTS
from posthog.temporal.common.logger import bind_temporal_worker_logger

import asyncio
import os
from posthog.settings.base_variables import TEST


@dataclass
Expand Down Expand Up @@ -47,18 +48,28 @@ class StripeJobInputs(PipelineInputs):
def create_pipeline(inputs: PipelineInputs):
pipeline_name = f"{inputs.job_type}_pipeline_{inputs.team_id}_run_{inputs.run_id}"
pipelines_dir = f"{os.getcwd()}/.dlt/{inputs.team_id}/{inputs.run_id}/{inputs.job_type}"

return dlt.pipeline(
pipeline_name=pipeline_name,
pipelines_dir=pipelines_dir, # workers can be created and destroyed so it doesn't matter where the metadata gets put temporarily
destination="filesystem",
pipelines_dir=pipelines_dir,
destination=dlt.destinations.filesystem(
credentials={
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"endpoint_url": settings.OBJECT_STORAGE_ENDPOINT if TEST else None,
},
bucket_url=settings.BUCKET_URL, # type: ignore
),
dataset_name=inputs.dataset_name,
credentials={
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
},
)


def _run_pipeline(inputs: StripeJobInputs):
pipeline = create_pipeline(inputs)
source = stripe_source(inputs.stripe_secret_key, tuple(inputs.schemas))
pipeline.run(source, loader_file_format="parquet")


# a temporal activity
async def run_stripe_pipeline(inputs: StripeJobInputs) -> None:
logger = await bind_temporal_worker_logger(team_id=inputs.team_id)
Expand All @@ -67,14 +78,11 @@ async def run_stripe_pipeline(inputs: StripeJobInputs) -> None:
logger.info(f"No schemas found for source id {inputs.source_id}")
return

for endpoint in schemas:
pipeline = create_pipeline(inputs)
try:
source = stripe_source(inputs.stripe_secret_key, endpoint)
pipeline.run(source, table_name=endpoint.lower(), loader_file_format="parquet")
except PipelineStepFailed:
logger.error(f"Data import failed for endpoint {endpoint}")
raise
try:
await asyncio.to_thread(_run_pipeline, inputs)
except PipelineStepFailed:
logger.error(f"Data import failed for endpoint")
raise


PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING = {ExternalDataSource.Type.STRIPE: ENDPOINTS}
Expand Down
Loading

0 comments on commit 51dd5d7

Please sign in to comment.