Skip to content

Commit

Permalink
adjust async methods
Browse files Browse the repository at this point in the history
  • Loading branch information
EDsCODE committed Nov 22, 2023
1 parent 4236775 commit dd1098d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
8 changes: 3 additions & 5 deletions posthog/temporal/data_imports/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# TODO: remove dependency
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.heartbeat import HeartbeatDetails
from posthog.temporal.common.heartbeat import HeartbeatDetails
from temporalio import activity, workflow, exceptions
from temporalio.common import RetryPolicy
from asgiref.sync import sync_to_async
Expand Down Expand Up @@ -84,7 +84,7 @@ class MoveDraftToProductionExternalDataJobInputs:

@activity.defn
async def move_draft_to_production_activity(inputs: MoveDraftToProductionExternalDataJobInputs) -> None:
await sync_to_async(move_draft_to_production)(
await move_draft_to_production(
team_id=inputs.team_id,
external_data_source_id=inputs.external_data_source_id,
)
Expand All @@ -108,10 +108,8 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> List[SourceSch
)
job_fn = PIPELINE_TYPE_RUN_MAPPING[model.source_type]

async_job_fn = sync_to_async(job_fn)

heartbeat_details = HeartbeatDetails()
func = heartbeat_details.make_activity_heartbeat_while_running(async_job_fn, dt.timedelta(seconds=10))
func = heartbeat_details.make_activity_heartbeat_while_running(job_fn, dt.timedelta(seconds=10))

return await func(job_inputs)

Expand Down
16 changes: 10 additions & 6 deletions posthog/warehouse/data_load/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from posthog.warehouse.models import ExternalDataSource, DataWarehouseTable
import s3fs
from asgiref.sync import sync_to_async


@dataclass
Expand Down Expand Up @@ -56,7 +57,8 @@ class StripeJobInputs(PipelineInputs):
}


# TODO: add heartbeat
# Run pipeline on separate thread. No db clients used
@sync_to_async(thread_sensitive=False)
def run_stripe_pipeline(inputs: StripeJobInputs) -> List[SourceSchema]:
pipeline = create_pipeline(inputs)

Expand Down Expand Up @@ -102,21 +104,23 @@ def get_schema(pipeline: dlt.pipeline) -> List[SourceSchema]:


def get_s3fs():
return s3fs.S3FileSystem(key=settings.AIRBYTE_BUCKET_KEY, secret=settings.AIRBYTE_BUCKET_SECRET)
return s3fs.S3FileSystem(key=settings.AIRBYTE_BUCKET_KEY, secret=settings.AIRBYTE_BUCKET_SECRET, asynchronous=True)


def move_draft_to_production(team_id: int, external_data_source_id: str):
async def move_draft_to_production(team_id: int, external_data_source_id: str):
model = ExternalDataSource.objects.get(team_id=team_id, id=external_data_source_id)
bucket_name = settings.BUCKET_URL
s3 = get_s3fs()
s3.copy(
await s3._copy(
f"{bucket_name}/{model.draft_folder_path}", f"{bucket_name}/{model.draft_folder_path}_success", recursive=True
)
try:
s3.delete(f"{bucket_name}/{model.folder_path}", recursive=True)
except:
pass
s3.copy(f"{bucket_name}/{model.draft_folder_path}_success", f"{bucket_name}/{model.folder_path}", recursive=True)
await s3._copy(
f"{bucket_name}/{model.draft_folder_path}_success", f"{bucket_name}/{model.folder_path}", recursive=True
)
s3.delete(f"{bucket_name}/{model.draft_folder_path}_success", recursive=True)
s3.delete(f"{bucket_name}/{model.draft_folder_path}", recursive=True)

Expand All @@ -129,6 +133,6 @@ def move_draft_to_production(team_id: int, external_data_source_id: str):
f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{model.draft_folder_path}/{schema_name.lower()}/*.parquet"
)

DataWarehouseTable.objects.filter(
await sync_to_async(DataWarehouseTable.objects.filter)(
name=table_name, team_id=model.team_id, url_pattern=url_pattern, format="Parquet"
).update(url_pattern=url_pattern.replace("_draft", ""))
1 change: 1 addition & 0 deletions posthog/warehouse/data_load/sync_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
super().__init__(f"Schema validation failed")


# TODO: make async
def is_schema_valid(source_schemas: List[SourceSchema], external_data_source_id: str, create: bool = False) -> bool:
resource = ExternalDataSource.objects.get(pk=external_data_source_id)
credential, _ = DataWarehouseCredential.objects.get_or_create(
Expand Down

0 comments on commit dd1098d

Please sign in to comment.