diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 0648ed01df59e..fc75f6beed4de 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -9,12 +9,10 @@ # TODO: remove dependency from posthog.temporal.batch_exports.base import PostHogWorkflow -from posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline import ( - PIPELINE_TYPE_INPUTS_MAPPING, - PIPELINE_TYPE_RUN_MAPPING, - PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, -) + from posthog.warehouse.data_load.validate_schema import validate_schema_and_update_table +from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING +from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs from posthog.warehouse.external_data_source.jobs import ( create_external_data_job, get_external_data_job, @@ -47,6 +45,8 @@ async def create_external_data_job_model(inputs: CreateExternalDataJobInputs) -> source = await sync_to_async(ExternalDataSource.objects.get)( # type: ignore team_id=inputs.team_id, id=inputs.external_data_source_id ) + source.status = "Running" + await sync_to_async(source.save)() # type: ignore # Sync schemas if they have changed await sync_to_async(sync_old_schemas_with_new_schemas)( # type: ignore @@ -133,19 +133,29 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None: team_id=inputs.team_id, run_id=inputs.run_id, ) + logger = await bind_temporal_worker_logger(team_id=inputs.team_id) - job_inputs = PIPELINE_TYPE_INPUTS_MAPPING[model.pipeline.source_type]( + job_inputs = PipelineInputs( source_id=inputs.source_id, schemas=inputs.schemas, run_id=inputs.run_id, team_id=inputs.team_id, job_type=model.pipeline.source_type, dataset_name=model.folder_path, - **model.pipeline.job_inputs, ) - job_fn = PIPELINE_TYPE_RUN_MAPPING[model.pipeline.source_type] - await job_fn(job_inputs) + source = None + if model.pipeline.source_type == ExternalDataSource.Type.STRIPE: + from posthog.temporal.data_imports.pipelines.stripe.helpers import stripe_source + + stripe_secret_key = model.pipeline.job_inputs.get("stripe_secret_key", None) + if not stripe_secret_key: + raise ValueError(f"Stripe secret key not found for job {model.id}") + source = stripe_source(api_key=stripe_secret_key, endpoints=tuple(inputs.schemas)) + else: + raise ValueError(f"Source type {model.pipeline.source_type} not supported") + + await DataImportPipeline(job_inputs, source, logger).run() # TODO: update retry policies diff --git a/posthog/temporal/data_imports/pipelines/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline.py new file mode 100644 index 0000000000000..ad6d53aa3a9e6 --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/pipeline.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from uuid import UUID + +import dlt +from django.conf import settings +from dlt.pipeline.exceptions import PipelineStepFailed + +import asyncio +import os +from posthog.settings.base_variables import TEST +from structlog.typing import FilteringBoundLogger +from dlt.sources import DltResource + + +@dataclass +class PipelineInputs: + source_id: UUID + run_id: str + schemas: list[str] + dataset_name: str + job_type: str + team_id: int + + +class DataImportPipeline: + loader_file_format = "parquet" + + def __init__(self, inputs: PipelineInputs, source: DltResource, logger: FilteringBoundLogger): + self.inputs = inputs + self.logger = logger + self.source = source + + def _get_pipeline_name(self): + return f"{self.inputs.job_type}_pipeline_{self.inputs.team_id}_run_{self.inputs.run_id}" + + def _get_pipelines_dir(self): + return f"{os.getcwd()}/.dlt/{self.inputs.team_id}/{self.inputs.run_id}/{self.inputs.job_type}" + + def _get_destination(self): + if TEST: + credentials = { + "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, + "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, + "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, + } + else: + credentials = { + "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, + "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, + } + + return dlt.destinations.filesystem( + credentials=credentials, + bucket_url=settings.BUCKET_URL, # type: ignore + ) + + def _create_pipeline(self): + pipeline_name = self._get_pipeline_name() + pipelines_dir = self._get_pipelines_dir() + destination = self._get_destination() + + return dlt.pipeline( + pipeline_name=pipeline_name, + pipelines_dir=pipelines_dir, + destination=destination, + dataset_name=self.inputs.dataset_name, + ) + + def _get_schemas(self): + if not self.inputs.schemas: + self.logger.info(f"No schemas found for source id {self.inputs.source_id}") + return None + + return self.inputs.schemas + + def _run(self): + pipeline = self._create_pipeline() + pipeline.run(self.source, loader_file_format=self.loader_file_format) + + async def run(self) -> None: + schemas = self._get_schemas() + if not schemas: + return + + try: + await asyncio.to_thread(self._run) + except PipelineStepFailed: + self.logger.error(f"Data import failed for endpoint") + raise diff --git a/posthog/temporal/data_imports/pipelines/schemas.py b/posthog/temporal/data_imports/pipelines/schemas.py new file mode 100644 index 0000000000000..a62db7d664e40 --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/schemas.py @@ -0,0 +1,4 @@ +from posthog.warehouse.models import ExternalDataSource +from posthog.temporal.data_imports.pipelines.stripe.settings import ENDPOINTS + +PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING = {ExternalDataSource.Type.STRIPE: ENDPOINTS} diff --git a/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py b/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py deleted file mode 100644 index a1138c74aa10e..0000000000000 --- a/posthog/temporal/data_imports/pipelines/stripe/stripe_pipeline.py +++ /dev/null @@ -1,90 +0,0 @@ -from dataclasses import dataclass -from typing import Dict -from uuid import UUID - -import dlt -from django.conf import settings -from dlt.pipeline.exceptions import PipelineStepFailed - -from posthog.warehouse.models import ExternalDataSource -from posthog.temporal.data_imports.pipelines.stripe.helpers import stripe_source -from posthog.temporal.data_imports.pipelines.stripe.settings import ENDPOINTS -from posthog.temporal.common.logger import bind_temporal_worker_logger -import asyncio -import os -from posthog.settings.base_variables import TEST - - -@dataclass -class PipelineInputs: - source_id: UUID - run_id: str - schemas: list[str] - dataset_name: str - job_type: str - team_id: int - - -@dataclass -class SourceColumnType: - name: str - data_type: str - nullable: bool - - -@dataclass -class SourceSchema: - resource: str - name: str - columns: Dict[str, SourceColumnType] - write_disposition: str - - -@dataclass -class StripeJobInputs(PipelineInputs): - stripe_secret_key: str - - -def create_pipeline(inputs: PipelineInputs): - pipeline_name = f"{inputs.job_type}_pipeline_{inputs.team_id}_run_{inputs.run_id}" - pipelines_dir = f"{os.getcwd()}/.dlt/{inputs.team_id}/{inputs.run_id}/{inputs.job_type}" - - return dlt.pipeline( - pipeline_name=pipeline_name, - pipelines_dir=pipelines_dir, - destination=dlt.destinations.filesystem( - credentials={ - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT if TEST else None, - }, - bucket_url=settings.BUCKET_URL, # type: ignore - ), - dataset_name=inputs.dataset_name, - ) - - -def _run_pipeline(inputs: StripeJobInputs): - pipeline = create_pipeline(inputs) - source = stripe_source(inputs.stripe_secret_key, tuple(inputs.schemas)) - pipeline.run(source, loader_file_format="parquet") - - -# a temporal activity -async def run_stripe_pipeline(inputs: StripeJobInputs) -> None: - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) - schemas = inputs.schemas - if not schemas: - logger.info(f"No schemas found for source id {inputs.source_id}") - return - - try: - await asyncio.to_thread(_run_pipeline, inputs) - except PipelineStepFailed: - logger.error(f"Data import failed for endpoint") - raise - - -PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING = {ExternalDataSource.Type.STRIPE: ENDPOINTS} -PIPELINE_TYPE_INPUTS_MAPPING = {ExternalDataSource.Type.STRIPE: StripeJobInputs} -PIPELINE_TYPE_RUN_MAPPING = {ExternalDataSource.Type.STRIPE: run_stripe_pipeline} diff --git a/posthog/temporal/tests/test_external_data_job.py b/posthog/temporal/tests/test_external_data_job.py index e519b334693f0..1af196f368831 100644 --- a/posthog/temporal/tests/test_external_data_job.py +++ b/posthog/temporal/tests/test_external_data_job.py @@ -28,10 +28,10 @@ ExternalDataSchema, ) -from posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline import ( - PIPELINE_TYPE_RUN_MAPPING, +from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, ) +from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline from temporalio.testing import WorkflowEnvironment from temporalio.common import RetryPolicy from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -449,7 +449,7 @@ async def mock_async_func(inputs): with mock.patch( "posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"} - ), mock.patch.dict(PIPELINE_TYPE_RUN_MAPPING, {ExternalDataSource.Type.STRIPE: mock_async_func}): + ), mock.patch.object(DataImportPipeline, "run", mock_async_func): with override_settings(AIRBYTE_BUCKET_KEY="test-key", AIRBYTE_BUCKET_SECRET="test-secret"): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index 4dadbd33ab7fc..c72cbd14e6c9d 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -21,7 +21,7 @@ ) from posthog.warehouse.models import ExternalDataSource, ExternalDataSchema, ExternalDataJob from posthog.warehouse.api.external_data_schema import ExternalDataSchemaSerializer -from posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline import ( +from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, ) diff --git a/posthog/warehouse/api/test/test_external_data_source.py b/posthog/warehouse/api/test/test_external_data_source.py index f05ade40513c3..2ad741b453a29 100644 --- a/posthog/warehouse/api/test/test_external_data_source.py +++ b/posthog/warehouse/api/test/test_external_data_source.py @@ -2,7 +2,7 @@ from posthog.warehouse.models import ExternalDataSource, ExternalDataSchema import uuid from unittest.mock import patch -from posthog.temporal.data_imports.pipelines.stripe.stripe_pipeline import ( +from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, )