diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 7234ed570699a..3782211273b97 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -21,7 +21,7 @@ # TODO: remove dependency from posthog.temporal.batch_exports.base import PostHogWorkflow -from posthog.temporal.heartbeat import HeartbeatDetails +from posthog.temporal.common.heartbeat import HeartbeatDetails from temporalio import activity, workflow, exceptions from temporalio.common import RetryPolicy from asgiref.sync import sync_to_async @@ -84,7 +84,7 @@ class MoveDraftToProductionExternalDataJobInputs: @activity.defn async def move_draft_to_production_activity(inputs: MoveDraftToProductionExternalDataJobInputs) -> None: - await sync_to_async(move_draft_to_production)( + await move_draft_to_production( team_id=inputs.team_id, external_data_source_id=inputs.external_data_source_id, ) @@ -108,10 +108,8 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> List[SourceSch ) job_fn = PIPELINE_TYPE_RUN_MAPPING[model.source_type] - async_job_fn = sync_to_async(job_fn) - heartbeat_details = HeartbeatDetails() - func = heartbeat_details.make_activity_heartbeat_while_running(async_job_fn, dt.timedelta(seconds=10)) + func = heartbeat_details.make_activity_heartbeat_while_running(job_fn, dt.timedelta(seconds=10)) return await func(job_inputs) diff --git a/posthog/warehouse/data_load/pipeline.py b/posthog/warehouse/data_load/pipeline.py index 4b18b27d05cea..89a0ae686a749 100644 --- a/posthog/warehouse/data_load/pipeline.py +++ b/posthog/warehouse/data_load/pipeline.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from posthog.warehouse.models import ExternalDataSource, DataWarehouseTable import s3fs +from asgiref.sync import sync_to_async @dataclass @@ -56,7 +57,8 @@ class StripeJobInputs(PipelineInputs): } -# TODO: add heartbeat +# Run pipeline on separate thread. No db clients used +@sync_to_async(thread_sensitive=False) def run_stripe_pipeline(inputs: StripeJobInputs) -> List[SourceSchema]: pipeline = create_pipeline(inputs) @@ -102,21 +104,23 @@ def get_schema(pipeline: dlt.pipeline) -> List[SourceSchema]: def get_s3fs(): - return s3fs.S3FileSystem(key=settings.AIRBYTE_BUCKET_KEY, secret=settings.AIRBYTE_BUCKET_SECRET) + return s3fs.S3FileSystem(key=settings.AIRBYTE_BUCKET_KEY, secret=settings.AIRBYTE_BUCKET_SECRET, asynchronous=True) -def move_draft_to_production(team_id: int, external_data_source_id: str): +async def move_draft_to_production(team_id: int, external_data_source_id: str): model = ExternalDataSource.objects.get(team_id=team_id, id=external_data_source_id) bucket_name = settings.BUCKET_URL s3 = get_s3fs() - s3.copy( + await s3._copy( f"{bucket_name}/{model.draft_folder_path}", f"{bucket_name}/{model.draft_folder_path}_success", recursive=True ) try: s3.delete(f"{bucket_name}/{model.folder_path}", recursive=True) except: pass - s3.copy(f"{bucket_name}/{model.draft_folder_path}_success", f"{bucket_name}/{model.folder_path}", recursive=True) + await s3._copy( + f"{bucket_name}/{model.draft_folder_path}_success", f"{bucket_name}/{model.folder_path}", recursive=True + ) s3.delete(f"{bucket_name}/{model.draft_folder_path}_success", recursive=True) s3.delete(f"{bucket_name}/{model.draft_folder_path}", recursive=True) @@ -129,6 +133,6 @@ def move_draft_to_production(team_id: int, external_data_source_id: str): f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{model.draft_folder_path}/{schema_name.lower()}/*.parquet" ) - DataWarehouseTable.objects.filter( + await sync_to_async(DataWarehouseTable.objects.filter)( name=table_name, team_id=model.team_id, url_pattern=url_pattern, format="Parquet" ).update(url_pattern=url_pattern.replace("_draft", "")) diff --git a/posthog/warehouse/data_load/sync_table.py b/posthog/warehouse/data_load/sync_table.py index 20b875f02b2f7..f874c10f8ad06 100644 --- a/posthog/warehouse/data_load/sync_table.py +++ b/posthog/warehouse/data_load/sync_table.py @@ -14,6 +14,7 @@ def __init__(self): super().__init__(f"Schema validation failed") +# TODO: make async def is_schema_valid(source_schemas: List[SourceSchema], external_data_source_id: str, create: bool = False) -> bool: resource = ExternalDataSource.objects.get(pk=external_data_source_id) credential, _ = DataWarehouseCredential.objects.get_or_create(