diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 93573c831c0fe..a2ab36ff3afea 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -617,13 +617,12 @@ posthog/warehouse/api/external_data_schema.py:0: note: def [_T] get(self, Type, posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore] -posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument 1 has incompatible type "str"; expected "Type" [arg-type] posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: No overload variant of "get" of "dict" matches argument types "str", "tuple[()]" [call-overload] posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: Possible overload variants: posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def get(self, Type, /) -> Sequence[str] | None posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def get(self, Type, Sequence[str], /) -> Sequence[str] posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def [_T] get(self, Type, _T, /) -> Sequence[str] | _T -posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument "source_id" has incompatible type "str"; expected "UUID" [arg-type] +posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument "source_id" to "sync_old_schemas_with_new_schemas" has incompatible type "str"; expected "UUID" [arg-type] posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a return type annotation [no-untyped-def] posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a type annotation [no-untyped-def] posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a type annotation for one or more arguments [no-untyped-def] @@ -796,6 +795,11 @@ posthog/temporal/tests/batch_exports/test_batch_exports.py:0: error: TypedDict k posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 20 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 21 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 22 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "FilesystemDestinationClientConfiguration" has no attribute "delta_jobs_per_write" [attr-defined] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "type[FilesystemDestinationClientConfiguration]" has no attribute "delta_jobs_per_write" [attr-defined] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "DataWarehouseCredential | Combinable | None") [assignment] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "str | int | Combinable") [assignment] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "dict[str, dict[str, str | bool]] | dict[str, str]", variable has type "dict[str, dict[str, str]]") [assignment] posthog/session_recordings/session_recording_api.py:0: error: Argument "team_id" to "get_realtime_snapshots" has incompatible type "int"; expected "str" [arg-type] posthog/session_recordings/session_recording_api.py:0: error: Value of type variable "SupportsRichComparisonT" of "sorted" cannot be "str | None" [type-var] posthog/session_recordings/session_recording_api.py:0: error: Argument 1 to "get" of "dict" has incompatible type "str | None"; expected "str" [arg-type] @@ -826,12 +830,6 @@ posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: List item 0 has incompatible type "tuple[str, str, int, int, int, int, str, int]"; expected "tuple[str, str, int, int, str, str, str, str]" [list-item] posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "last_uploaded_part_timestamp" [attr-defined] posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "upload_state" [attr-defined] -posthog/temporal/data_imports/workflow_activities/import_data.py:0: error: Argument "job_type" to "PipelineInputs" has incompatible type "str"; expected "Type" [arg-type] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "FilesystemDestinationClientConfiguration" has no attribute "delta_jobs_per_write" [attr-defined] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "type[FilesystemDestinationClientConfiguration]" has no attribute "delta_jobs_per_write" [attr-defined] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "DataWarehouseCredential | Combinable | None") [assignment] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "str | int | Combinable") [assignment] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "dict[str, dict[str, str | bool]] | dict[str, str]", variable has type "dict[str, dict[str, str]]") [assignment] posthog/migrations/0237_remove_timezone_from_teams.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] posthog/migrations/0228_fix_tile_layouts.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] @@ -839,6 +837,7 @@ posthog/api/plugin_log_entry.py:0: error: Module "django.utils.timezone" does no posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/api/plugin_log_entry.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py:0: error: Incompatible types in assignment (expression has type "str | int", variable has type "int") [assignment] +posthog/temporal/data_imports/external_data_job.py:0: error: Argument "status" to "update_external_job_status" has incompatible type "str"; expected "Status" [arg-type] posthog/api/sharing.py:0: error: Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr] posthog/api/test/batch_exports/conftest.py:0: error: Signature of "run" incompatible with supertype "Worker" [override] posthog/api/test/batch_exports/conftest.py:0: note: Superclass: @@ -850,10 +849,10 @@ posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] +posthog/temporal/tests/data_imports/test_end_to_end.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/api/test/test_team.py:0: error: "HttpResponse" has no attribute "json" [attr-defined] posthog/api/test/test_team.py:0: error: "HttpResponse" has no attribute "json" [attr-defined] posthog/test/test_middleware.py:0: error: Incompatible types in assignment (expression has type "_MonkeyPatchedWSGIResponse", variable has type "_MonkeyPatchedResponse") [assignment] -posthog/temporal/tests/data_imports/test_end_to_end.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/management/commands/test/test_create_batch_export_from_app.py:0: error: Incompatible return value type (got "dict[str, Collection[str]]", expected "dict[str, str]") [return-value] posthog/management/commands/test/test_create_batch_export_from_app.py:0: error: Incompatible types in assignment (expression has type "dict[str, Collection[str]]", variable has type "dict[str, str]") [assignment] posthog/management/commands/test/test_create_batch_export_from_app.py:0: error: Unpacked dict entry 1 has incompatible type "str"; expected "SupportsKeysAndGetItem[str, str]" [dict-item] diff --git a/posthog/temporal/common/heartbeat_sync.py b/posthog/temporal/common/heartbeat_sync.py index 35ac79515b9f4..cf775c3bf5cb0 100644 --- a/posthog/temporal/common/heartbeat_sync.py +++ b/posthog/temporal/common/heartbeat_sync.py @@ -11,6 +11,8 @@ def __init__(self, details: tuple[Any, ...] = (), factor: int = 12, logger: Opti self.details: tuple[Any, ...] = details self.factor = factor self.logger = logger + self.stop_event: Optional[threading.Event] = None + self.heartbeat_thread: Optional[threading.Thread] = None def log_debug(self, message: str, exc_info: Optional[Any] = None) -> None: if self.logger: diff --git a/posthog/temporal/data_imports/__init__.py b/posthog/temporal/data_imports/__init__.py index cabeaf433d4e1..c59f20b05d8cf 100644 --- a/posthog/temporal/data_imports/__init__.py +++ b/posthog/temporal/data_imports/__init__.py @@ -2,10 +2,8 @@ ExternalDataJobWorkflow, create_external_data_job_model_activity, create_source_templates, - import_data_activity, import_data_activity_sync, update_external_data_job_model, - check_schedule_activity, check_billing_limits_activity, sync_new_schemas_activity, ) @@ -15,10 +13,8 @@ ACTIVITIES = [ create_external_data_job_model_activity, update_external_data_job_model, - import_data_activity, import_data_activity_sync, create_source_templates, - check_schedule_activity, check_billing_limits_activity, sync_new_schemas_activity, ] diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 1820f462093ca..0bccbf9b95fa9 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -8,7 +8,6 @@ # TODO: remove dependency from posthog.temporal.batch_exports.base import PostHogWorkflow -from posthog.temporal.data_imports.util import is_posthog_team from posthog.temporal.data_imports.workflow_activities.check_billing_limits import ( CheckBillingLimitsActivityInputs, check_billing_limits_activity, @@ -23,28 +22,19 @@ CreateExternalDataJobModelActivityInputs, create_external_data_job_model_activity, ) -from posthog.temporal.data_imports.workflow_activities.import_data import ImportDataActivityInputs, import_data_activity +from posthog.temporal.data_imports.workflow_activities.import_data_sync import ImportDataActivityInputs from posthog.utils import get_machine_id -from posthog.warehouse.data_load.service import ( - a_delete_external_data_schedule, - a_external_data_workflow_exists, - a_sync_external_data_job_workflow, - a_trigger_external_data_workflow, -) from posthog.warehouse.data_load.source_templates import create_warehouse_templates_for_source from posthog.warehouse.external_data_source.jobs import ( - aget_running_job_for_schema, - aupdate_external_job_status, + update_external_job_status, ) from posthog.warehouse.models import ( ExternalDataJob, - get_active_schemas_for_source_id, ExternalDataSource, - get_external_data_source, ) -from posthog.temporal.common.logger import bind_temporal_worker_logger -from posthog.warehouse.models.external_data_schema import aupdate_should_sync +from posthog.temporal.common.logger import bind_temporal_worker_logger_sync +from posthog.warehouse.models.external_data_schema import update_should_sync Non_Retryable_Schema_Errors: dict[ExternalDataSource.Type, list[str]] = { @@ -76,11 +66,15 @@ class UpdateExternalDataJobStatusInputs: @activity.defn -async def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInputs) -> None: - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) +def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInputs) -> None: + logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) if inputs.job_id is None: - job: ExternalDataJob | None = await aget_running_job_for_schema(inputs.schema_id) + job: ExternalDataJob | None = ( + ExternalDataJob.objects.filter(schema_id=inputs.schema_id, status=ExternalDataJob.Status.RUNNING) + .order_by("-created_at") + .first() + ) if job is None: logger.info("No job to update status on") return @@ -94,7 +88,7 @@ async def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInpu f"External data job failed for external data schema {inputs.schema_id} with error: {inputs.internal_error}" ) - source: ExternalDataSource = await get_external_data_source(inputs.source_id) + source: ExternalDataSource = ExternalDataSource.objects.get(pk=inputs.source_id) non_retryable_errors = Non_Retryable_Schema_Errors.get(ExternalDataSource.Type(source.source_type)) if non_retryable_errors is not None: @@ -113,9 +107,9 @@ async def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInpu "error": inputs.internal_error, }, ) - await aupdate_should_sync(schema_id=inputs.schema_id, team_id=inputs.team_id, should_sync=False) + update_should_sync(schema_id=inputs.schema_id, team_id=inputs.team_id, should_sync=False) - await aupdate_external_job_status( + update_external_job_status( job_id=job_id, status=inputs.status, latest_error=inputs.latest_error, @@ -134,34 +128,8 @@ class CreateSourceTemplateInputs: @activity.defn -async def create_source_templates(inputs: CreateSourceTemplateInputs) -> None: - await create_warehouse_templates_for_source(team_id=inputs.team_id, run_id=inputs.run_id) - - -@activity.defn -async def check_schedule_activity(inputs: ExternalDataWorkflowInputs) -> bool: - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) - - # Creates schedules for all schemas if they don't exist yet, and then remove itself as a source schedule - if inputs.external_data_schema_id is None: - logger.info("Schema ID is none, creating schedules for schemas...") - schemas = await get_active_schemas_for_source_id( - team_id=inputs.team_id, source_id=inputs.external_data_source_id - ) - for schema in schemas: - if await a_external_data_workflow_exists(schema.id): - await a_trigger_external_data_workflow(schema) - logger.info(f"Schedule exists for schema {schema.id}. Triggered schedule") - else: - await a_sync_external_data_job_workflow(schema, create=True) - logger.info(f"Created schedule for schema {schema.id}") - # Delete the source schedule in favour of the schema schedules - await a_delete_external_data_schedule(ExternalDataSource(id=inputs.external_data_source_id)) - logger.info(f"Deleted schedule for source {inputs.external_data_source_id}") - return True - - logger.info("Schema ID is set. Continuing...") - return False +def create_source_templates(inputs: CreateSourceTemplateInputs) -> None: + create_warehouse_templates_for_source(team_id=inputs.team_id, run_id=inputs.run_id) # TODO: update retry policies @@ -174,21 +142,6 @@ def parse_inputs(inputs: list[str]) -> ExternalDataWorkflowInputs: @workflow.run async def run(self, inputs: ExternalDataWorkflowInputs): - should_exit = await workflow.execute_activity( - check_schedule_activity, - inputs, - start_to_close_timeout=dt.timedelta(minutes=1), - retry_policy=RetryPolicy( - initial_interval=dt.timedelta(seconds=10), - maximum_interval=dt.timedelta(seconds=60), - maximum_attempts=0, - non_retryable_error_types=["NotNullViolation", "IntegrityError"], - ), - ) - - if should_exit: - return - assert inputs.external_data_schema_id is not None update_inputs = UpdateExternalDataJobStatusInputs( @@ -262,24 +215,12 @@ async def run(self, inputs: ExternalDataWorkflowInputs): else {"start_to_close_timeout": dt.timedelta(hours=12), "retry_policy": RetryPolicy(maximum_attempts=3)} ) - if is_posthog_team(inputs.team_id) and ( - source_type == ExternalDataSource.Type.POSTGRES or source_type == ExternalDataSource.Type.BIGQUERY - ): - # Sync activity for testing - await workflow.execute_activity( - import_data_activity_sync, - job_inputs, - heartbeat_timeout=dt.timedelta(minutes=5), - **timeout_params, - ) # type: ignore - else: - # Async activity for everyone else - await workflow.execute_activity( - import_data_activity, - job_inputs, - heartbeat_timeout=dt.timedelta(minutes=5), - **timeout_params, - ) # type: ignore + await workflow.execute_activity( + import_data_activity_sync, + job_inputs, + heartbeat_timeout=dt.timedelta(minutes=5), + **timeout_params, + ) # type: ignore # Create source templates await workflow.execute_activity( diff --git a/posthog/temporal/data_imports/pipelines/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline.py deleted file mode 100644 index 24099e698fb7c..0000000000000 --- a/posthog/temporal/data_imports/pipelines/pipeline.py +++ /dev/null @@ -1,266 +0,0 @@ -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from typing import Literal -from uuid import UUID - -import dlt -from django.conf import settings -from dlt.pipeline.exceptions import PipelineStepFailed - -from asgiref.sync import async_to_sync -import asyncio -from posthog.settings.base_variables import TEST -from structlog.typing import FilteringBoundLogger -from dlt.common.libs.deltalake import get_delta_tables -from dlt.load.exceptions import LoadClientJobRetry -from dlt.sources import DltSource -from deltalake.exceptions import DeltaError -from collections import Counter - -from posthog.warehouse.data_load.validate_schema import update_last_synced_at, validate_schema_and_update_table -from posthog.warehouse.models.external_data_job import ExternalDataJob, get_external_data_job -from posthog.warehouse.models.external_data_schema import ExternalDataSchema, aget_schema_by_id -from posthog.warehouse.models.external_data_source import ExternalDataSource -from posthog.warehouse.models.table import DataWarehouseTable -from posthog.temporal.data_imports.util import prepare_s3_files_for_querying - - -@dataclass -class PipelineInputs: - source_id: UUID - run_id: str - schema_id: UUID - dataset_name: str - job_type: ExternalDataSource.Type - team_id: int - - -class DataImportPipeline: - loader_file_format: Literal["parquet"] = "parquet" - - def __init__( - self, - inputs: PipelineInputs, - source: DltSource, - logger: FilteringBoundLogger, - reset_pipeline: bool, - incremental: bool = False, - ): - self.inputs = inputs - self.logger = logger - - self._incremental = incremental - self.refresh_dlt = reset_pipeline - self.should_chunk_pipeline = ( - incremental - and inputs.job_type != ExternalDataSource.Type.POSTGRES - and inputs.job_type != ExternalDataSource.Type.MYSQL - and inputs.job_type != ExternalDataSource.Type.MSSQL - and inputs.job_type != ExternalDataSource.Type.SNOWFLAKE - and inputs.job_type != ExternalDataSource.Type.BIGQUERY - ) - - if self.should_chunk_pipeline: - # Incremental syncs: Assuming each page is 100 items for now so bound each run at 50_000 items - self.source = source.add_limit(500) - else: - self.source = source - - def _get_pipeline_name(self): - return f"{self.inputs.job_type}_pipeline_{self.inputs.team_id}_run_{self.inputs.schema_id}" - - def _get_destination(self): - if TEST: - credentials = { - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, - "region_name": settings.AIRBYTE_BUCKET_REGION, - "AWS_ALLOW_HTTP": "true", - "AWS_S3_ALLOW_UNSAFE_RENAME": "true", - } - else: - credentials = { - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - "region_name": settings.AIRBYTE_BUCKET_REGION, - "AWS_S3_ALLOW_UNSAFE_RENAME": "true", - } - - return dlt.destinations.filesystem( - credentials=credentials, - bucket_url=settings.BUCKET_URL, # type: ignore - ) - - def _create_pipeline(self): - pipeline_name = self._get_pipeline_name() - destination = self._get_destination() - - dlt.config["normalize.parquet_normalizer.add_dlt_load_id"] = True - dlt.config["normalize.parquet_normalizer.add_dlt_id"] = True - - return dlt.pipeline( - pipeline_name=pipeline_name, destination=destination, dataset_name=self.inputs.dataset_name, progress="log" - ) - - async def _prepare_s3_files_for_querying(self, file_uris: list[str]): - job: ExternalDataJob = await get_external_data_job(job_id=self.inputs.run_id) - schema: ExternalDataSchema = await aget_schema_by_id(self.inputs.schema_id, self.inputs.team_id) - - prepare_s3_files_for_querying(job.folder_path(), schema.name, file_uris) - - def _run(self) -> dict[str, int]: - if self.refresh_dlt: - self.logger.info("Pipeline getting a full refresh due to reset_pipeline being set") - - pipeline = self._create_pipeline() - - total_counts: Counter[str] = Counter({}) - - # Do chunking for incremental syncing on API based endpoints (e.g. not sql databases) - if self.should_chunk_pipeline: - # will get overwritten - counts: Counter[str] = Counter({"start": 1}) - pipeline_runs = 0 - - while counts: - self.logger.info(f"Running incremental (non-sql) pipeline, run ${pipeline_runs}") - - try: - pipeline.run( - self.source, - loader_file_format=self.loader_file_format, - refresh="drop_sources" if self.refresh_dlt and pipeline_runs == 0 else None, - ) - except PipelineStepFailed as e: - # Remove once DLT support writing empty Delta files - if isinstance(e.exception, LoadClientJobRetry): - if "Generic S3 error" not in e.exception.retry_message: - raise - elif isinstance(e.exception, DeltaError): - if e.exception.args[0] != "Generic error: No data source supplied to write command.": - raise - else: - raise - - if pipeline.last_trace.last_normalize_info is not None: - row_counts = pipeline.last_trace.last_normalize_info.row_counts - else: - row_counts = {} - # Remove any DLT tables from the counts - filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items())) - counts = Counter(filtered_rows) - total_counts = counts + total_counts - - if total_counts.total() > 0: - delta_tables = get_delta_tables(pipeline) - - table_format = DataWarehouseTable.TableFormat.DeltaS3Wrapper - - # Workaround while we fix msising table_format on DLT resource - if len(delta_tables.values()) == 0: - table_format = DataWarehouseTable.TableFormat.Delta - - # There should only ever be one table here - for table in delta_tables.values(): - self.logger.info("Compacting delta table") - table.optimize.compact() - table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False) - - file_uris = table.file_uris() - self.logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") - async_to_sync(self._prepare_s3_files_for_querying)(file_uris) - - self.logger.info(f"Table format: {table_format}") - - async_to_sync(validate_schema_and_update_table)( - run_id=self.inputs.run_id, - team_id=self.inputs.team_id, - schema_id=self.inputs.schema_id, - table_schema=self.source.schema.tables, - row_count=total_counts.total(), - table_format=table_format, - ) - else: - self.logger.info("No table_counts, skipping validate_schema_and_update_table") - - pipeline_runs = pipeline_runs + 1 - else: - self.logger.info("Running standard pipeline") - try: - pipeline.run( - self.source, - loader_file_format=self.loader_file_format, - refresh="drop_sources" if self.refresh_dlt else None, - ) - except PipelineStepFailed as e: - # Remove once DLT support writing empty Delta files - if isinstance(e.exception, LoadClientJobRetry): - if "Generic S3 error" not in e.exception.retry_message: - raise - elif isinstance(e.exception, DeltaError): - if e.exception.args[0] != "Generic error: No data source supplied to write command.": - raise - else: - raise - - if pipeline.last_trace.last_normalize_info is not None: - row_counts = pipeline.last_trace.last_normalize_info.row_counts - else: - row_counts = {} - - filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items())) - counts = Counter(filtered_rows) - total_counts = total_counts + counts - - if total_counts.total() > 0: - delta_tables = get_delta_tables(pipeline) - - table_format = DataWarehouseTable.TableFormat.DeltaS3Wrapper - - # Workaround while we fix msising table_format on DLT resource - if len(delta_tables.values()) == 0: - table_format = DataWarehouseTable.TableFormat.Delta - - # There should only ever be one table here - for table in delta_tables.values(): - self.logger.info("Compacting delta table") - table.optimize.compact() - table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False) - - file_uris = table.file_uris() - self.logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") - async_to_sync(self._prepare_s3_files_for_querying)(file_uris) - - self.logger.info(f"Table format: {table_format}") - - async_to_sync(validate_schema_and_update_table)( - run_id=self.inputs.run_id, - team_id=self.inputs.team_id, - schema_id=self.inputs.schema_id, - table_schema=self.source.schema.tables, - row_count=total_counts.total(), - table_format=table_format, - ) - else: - self.logger.info("No table_counts, skipping validate_schema_and_update_table") - - # Update last_synced_at on schema - async_to_sync(update_last_synced_at)( - job_id=self.inputs.run_id, schema_id=str(self.inputs.schema_id), team_id=self.inputs.team_id - ) - - # Cleanup: delete local state from the file system - pipeline.drop() - - return dict(total_counts) - - async def run(self) -> dict[str, int]: - try: - # Use a dedicated thread pool to not interfere with the heartbeater thread - with ThreadPoolExecutor(max_workers=5) as pipeline_executor: - loop = asyncio.get_event_loop() - return await loop.run_in_executor(pipeline_executor, self._run) - except PipelineStepFailed as e: - self.logger.exception(f"Data import failed for endpoint with exception {e}", exc_info=e) - raise diff --git a/posthog/temporal/data_imports/pipelines/pipeline_sync.py b/posthog/temporal/data_imports/pipelines/pipeline_sync.py index 1dd1269a72e0b..e3ca8a4ecbdaa 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline_sync.py +++ b/posthog/temporal/data_imports/pipelines/pipeline_sync.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Literal, Optional from collections.abc import Iterator, Sequence import uuid @@ -34,7 +35,6 @@ from clickhouse_driver.errors import ServerException from posthog.temporal.common.logger import bind_temporal_worker_logger_sync -from posthog.temporal.data_imports.pipelines.pipeline import PipelineInputs from posthog.warehouse.data_load.validate_schema import dlt_to_hogql_type from posthog.warehouse.models.credential import get_or_create_datawarehouse_credential from posthog.warehouse.models.external_data_job import ExternalDataJob @@ -44,6 +44,16 @@ from posthog.temporal.data_imports.util import prepare_s3_files_for_querying +@dataclass +class PipelineInputs: + source_id: uuid.UUID + run_id: str + schema_id: uuid.UUID + dataset_name: str + job_type: ExternalDataSource.Type + team_id: int + + class DataImportPipelineSync: loader_file_format: Literal["parquet"] = "parquet" diff --git a/posthog/temporal/data_imports/pipelines/test/test_pipeline.py b/posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py similarity index 73% rename from posthog/temporal/data_imports/pipelines/test/test_pipeline.py rename to posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py index 965b77ca5f9ae..3b265f54a352a 100644 --- a/posthog/temporal/data_imports/pipelines/test/test_pipeline.py +++ b/posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py @@ -4,8 +4,7 @@ import pytest import structlog -from asgiref.sync import sync_to_async -from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs +from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync, PipelineInputs from posthog.temporal.data_imports.pipelines.stripe import stripe_source from posthog.test.base import APIBaseTest from posthog.warehouse.models.external_data_job import ExternalDataJob @@ -14,8 +13,8 @@ class TestDataImportPipeline(APIBaseTest): - async def _create_pipeline(self, schema_name: str, incremental: bool): - source = await sync_to_async(ExternalDataSource.objects.create)( + def _create_pipeline(self, schema_name: str, incremental: bool): + source = ExternalDataSource.objects.create( source_id=str(uuid.uuid4()), connection_id=str(uuid.uuid4()), destination_id=str(uuid.uuid4()), @@ -23,13 +22,13 @@ async def _create_pipeline(self, schema_name: str, incremental: bool): status="running", source_type="Stripe", ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( name=schema_name, team_id=self.team.pk, source_id=source.pk, source=source, ) - job = await sync_to_async(ExternalDataJob.objects.create)( + job = ExternalDataJob.objects.create( team_id=self.team.pk, pipeline_id=source.pk, pipeline=source, @@ -40,7 +39,7 @@ async def _create_pipeline(self, schema_name: str, incremental: bool): workflow_id=str(uuid.uuid4()), ) - pipeline = DataImportPipeline( + pipeline = DataImportPipelineSync( inputs=PipelineInputs( source_id=source.pk, run_id=str(job.pk), @@ -65,45 +64,43 @@ async def _create_pipeline(self, schema_name: str, incremental: bool): return pipeline @pytest.mark.django_db(transaction=True) - @pytest.mark.asyncio - async def test_pipeline_non_incremental(self): + def test_pipeline_non_incremental(self): def mock_create_pipeline(local_self: Any): mock = MagicMock() mock.last_trace.last_normalize_info.row_counts = {"customer": 1} return mock with ( - patch.object(DataImportPipeline, "_create_pipeline", mock_create_pipeline), + patch.object(DataImportPipelineSync, "_create_pipeline", mock_create_pipeline), patch( - "posthog.temporal.data_imports.pipelines.pipeline.validate_schema_and_update_table" + "posthog.temporal.data_imports.pipelines.pipeline_sync.validate_schema_and_update_table_sync" ) as mock_validate_schema_and_update_table, - patch("posthog.temporal.data_imports.pipelines.pipeline.get_delta_tables"), - patch("posthog.temporal.data_imports.pipelines.pipeline.update_last_synced_at"), + patch("posthog.temporal.data_imports.pipelines.pipeline_sync.get_delta_tables"), + patch("posthog.temporal.data_imports.pipelines.pipeline_sync.update_last_synced_at_sync"), ): - pipeline = await self._create_pipeline("Customer", False) - res = await pipeline.run() + pipeline = self._create_pipeline("Customer", False) + res = pipeline.run() assert res.get("customer") == 1 assert mock_validate_schema_and_update_table.call_count == 1 @pytest.mark.django_db(transaction=True) - @pytest.mark.asyncio - async def test_pipeline_incremental(self): + def test_pipeline_incremental(self): def mock_create_pipeline(local_self: Any): mock = MagicMock() type(mock.last_trace.last_normalize_info).row_counts = PropertyMock(side_effect=[{"customer": 1}, {}]) return mock with ( - patch.object(DataImportPipeline, "_create_pipeline", mock_create_pipeline), + patch.object(DataImportPipelineSync, "_create_pipeline", mock_create_pipeline), patch( - "posthog.temporal.data_imports.pipelines.pipeline.validate_schema_and_update_table" + "posthog.temporal.data_imports.pipelines.pipeline_sync.validate_schema_and_update_table_sync" ) as mock_validate_schema_and_update_table, - patch("posthog.temporal.data_imports.pipelines.pipeline.get_delta_tables"), - patch("posthog.temporal.data_imports.pipelines.pipeline.update_last_synced_at"), + patch("posthog.temporal.data_imports.pipelines.pipeline_sync.get_delta_tables"), + patch("posthog.temporal.data_imports.pipelines.pipeline_sync.update_last_synced_at_sync"), ): - pipeline = await self._create_pipeline("Customer", True) - res = await pipeline.run() + pipeline = self._create_pipeline("Customer", True) + res = pipeline.run() assert res.get("customer") == 1 assert mock_validate_schema_and_update_table.call_count == 2 diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index 8d3577cf1ff23..02eb6aee7d52a 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -1,19 +1,15 @@ import dataclasses import uuid -from asgiref.sync import sync_to_async from temporalio import activity # TODO: remove dependency -from posthog.warehouse.external_data_source.jobs import ( - acreate_external_data_job, -) -from posthog.warehouse.models import ExternalDataSource +from posthog.warehouse.models import ExternalDataJob, ExternalDataSource from posthog.warehouse.models.external_data_schema import ( ExternalDataSchema, ) -from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.common.logger import bind_temporal_worker_logger_sync @dataclasses.dataclass @@ -24,25 +20,27 @@ class CreateExternalDataJobModelActivityInputs: @activity.defn -async def create_external_data_job_model_activity( +def create_external_data_job_model_activity( inputs: CreateExternalDataJobModelActivityInputs, ) -> tuple[str, bool, str]: - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) + logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) try: - job = await acreate_external_data_job( + job = ExternalDataJob.objects.create( team_id=inputs.team_id, - external_data_source_id=inputs.source_id, - external_data_schema_id=inputs.schema_id, + pipeline_id=inputs.source_id, + schema_id=inputs.schema_id, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, workflow_id=activity.info().workflow_id, workflow_run_id=activity.info().workflow_run_id, ) - schema = await sync_to_async(ExternalDataSchema.objects.get)(team_id=inputs.team_id, id=inputs.schema_id) + schema = ExternalDataSchema.objects.get(team_id=inputs.team_id, id=inputs.schema_id) schema.status = ExternalDataSchema.Status.RUNNING - await sync_to_async(schema.save)() + schema.save() - source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=schema.source_id) + source: ExternalDataSource = schema.source logger.info( f"Created external data job for external data source {inputs.source_id}", diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py deleted file mode 100644 index 26ce621f99a3d..0000000000000 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ /dev/null @@ -1,434 +0,0 @@ -import dataclasses -import uuid -from datetime import datetime -from typing import Any - -from structlog.typing import FilteringBoundLogger -from temporalio import activity - -from posthog.temporal.common.heartbeat import Heartbeater -from posthog.temporal.common.logger import bind_temporal_worker_logger -from posthog.temporal.data_imports.pipelines.bigquery import delete_table -from posthog.temporal.data_imports.pipelines.helpers import aremove_reset_pipeline, aupdate_job_count - -from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs -from posthog.temporal.data_imports.util import is_posthog_team -from posthog.warehouse.models import ( - ExternalDataJob, - ExternalDataSource, - get_external_data_job, -) -from posthog.warehouse.models.external_data_schema import ( - ExternalDataSchema, - aget_schema_by_id, -) -from posthog.warehouse.models.ssh_tunnel import SSHTunnel - - -@dataclasses.dataclass -class ImportDataActivityInputs: - team_id: int - schema_id: uuid.UUID - source_id: uuid.UUID - run_id: str - - -@activity.defn -async def import_data_activity(inputs: ImportDataActivityInputs): - async with Heartbeater(factor=30): # Every 10 secs - model: ExternalDataJob = await get_external_data_job( - job_id=inputs.run_id, - ) - - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) - - logger.debug("Running *ASYNC* import_data") - - job_inputs = PipelineInputs( - source_id=inputs.source_id, - schema_id=inputs.schema_id, - run_id=inputs.run_id, - team_id=inputs.team_id, - job_type=model.pipeline.source_type, - dataset_name=model.folder_path(), - ) - - reset_pipeline = model.pipeline.job_inputs.get("reset_pipeline", "False") == "True" - - schema: ExternalDataSchema = await aget_schema_by_id(inputs.schema_id, inputs.team_id) - - endpoints = [schema.name] - - source = None - if model.pipeline.source_type == ExternalDataSource.Type.STRIPE: - from posthog.temporal.data_imports.pipelines.stripe import stripe_source - - stripe_secret_key = model.pipeline.job_inputs.get("stripe_secret_key", None) - account_id = model.pipeline.job_inputs.get("stripe_account_id", None) - if not stripe_secret_key: - raise ValueError(f"Stripe secret key not found for job {model.id}") - - source = stripe_source( - api_key=stripe_secret_key, - account_id=account_id, - endpoint=schema.name, - team_id=inputs.team_id, - job_id=inputs.run_id, - is_incremental=schema.is_incremental, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type == ExternalDataSource.Type.HUBSPOT: - from posthog.temporal.data_imports.pipelines.hubspot import hubspot - from posthog.temporal.data_imports.pipelines.hubspot.auth import ( - hubspot_refresh_access_token, - ) - - hubspot_access_code = model.pipeline.job_inputs.get("hubspot_secret_key", None) - refresh_token = model.pipeline.job_inputs.get("hubspot_refresh_token", None) - if not refresh_token: - raise ValueError(f"Hubspot refresh token not found for job {model.id}") - - if not hubspot_access_code: - hubspot_access_code = hubspot_refresh_access_token(refresh_token) - - source = hubspot( - api_key=hubspot_access_code, - refresh_token=refresh_token, - endpoints=tuple(endpoints), - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type in [ - ExternalDataSource.Type.POSTGRES, - ExternalDataSource.Type.MYSQL, - ExternalDataSource.Type.MSSQL, - ]: - if is_posthog_team(inputs.team_id): - from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( - sql_source_for_type, - ) - else: - from posthog.temporal.data_imports.pipelines.sql_database import ( - sql_source_for_type, - ) - - host = model.pipeline.job_inputs.get("host") - port = model.pipeline.job_inputs.get("port") - user = model.pipeline.job_inputs.get("user") - password = model.pipeline.job_inputs.get("password") - database = model.pipeline.job_inputs.get("database") - pg_schema = model.pipeline.job_inputs.get("schema") - - using_ssh_tunnel = str(model.pipeline.job_inputs.get("ssh_tunnel_enabled", False)) == "True" - ssh_tunnel_host = model.pipeline.job_inputs.get("ssh_tunnel_host") - ssh_tunnel_port = model.pipeline.job_inputs.get("ssh_tunnel_port") - ssh_tunnel_auth_type = model.pipeline.job_inputs.get("ssh_tunnel_auth_type") - ssh_tunnel_auth_type_username = model.pipeline.job_inputs.get("ssh_tunnel_auth_type_username") - ssh_tunnel_auth_type_password = model.pipeline.job_inputs.get("ssh_tunnel_auth_type_password") - ssh_tunnel_auth_type_passphrase = model.pipeline.job_inputs.get("ssh_tunnel_auth_type_passphrase") - ssh_tunnel_auth_type_private_key = model.pipeline.job_inputs.get("ssh_tunnel_auth_type_private_key") - - ssh_tunnel = SSHTunnel( - enabled=using_ssh_tunnel, - host=ssh_tunnel_host, - port=ssh_tunnel_port, - auth_type=ssh_tunnel_auth_type, - username=ssh_tunnel_auth_type_username, - password=ssh_tunnel_auth_type_password, - passphrase=ssh_tunnel_auth_type_passphrase, - private_key=ssh_tunnel_auth_type_private_key, - ) - - if ssh_tunnel.enabled: - with ssh_tunnel.get_tunnel(host, int(port)) as tunnel: - if tunnel is None: - raise Exception("Can't open tunnel to SSH server") - - source = sql_source_for_type( - source_type=ExternalDataSource.Type(model.pipeline.source_type), - host=tunnel.local_bind_host, - port=tunnel.local_bind_port, - user=user, - password=password, - database=database, - sslmode="prefer", - schema=pg_schema, - table_names=endpoints, - incremental_field=schema.sync_type_config.get("incremental_field") - if schema.is_incremental - else None, - incremental_field_type=schema.sync_type_config.get("incremental_field_type") - if schema.is_incremental - else None, - team_id=inputs.team_id, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - - source = sql_source_for_type( - source_type=ExternalDataSource.Type(model.pipeline.source_type), - host=host, - port=port, - user=user, - password=password, - database=database, - sslmode="prefer", - schema=pg_schema, - table_names=endpoints, - incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None, - incremental_field_type=schema.sync_type_config.get("incremental_field_type") - if schema.is_incremental - else None, - team_id=inputs.team_id, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type == ExternalDataSource.Type.SNOWFLAKE: - if is_posthog_team(inputs.team_id): - from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( - snowflake_source, - ) - else: - from posthog.temporal.data_imports.pipelines.sql_database import ( - snowflake_source, - ) - - account_id = model.pipeline.job_inputs.get("account_id") - user = model.pipeline.job_inputs.get("user") - password = model.pipeline.job_inputs.get("password") - database = model.pipeline.job_inputs.get("database") - warehouse = model.pipeline.job_inputs.get("warehouse") - sf_schema = model.pipeline.job_inputs.get("schema") - role = model.pipeline.job_inputs.get("role") - - source = snowflake_source( - account_id=account_id, - user=user, - password=password, - database=database, - schema=sf_schema, - warehouse=warehouse, - role=role, - table_names=endpoints, - incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None, - incremental_field_type=schema.sync_type_config.get("incremental_field_type") - if schema.is_incremental - else None, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type == ExternalDataSource.Type.SALESFORCE: - from posthog.models.integration import aget_integration_by_id - from posthog.temporal.data_imports.pipelines.salesforce import ( - salesforce_source, - ) - from posthog.temporal.data_imports.pipelines.salesforce.auth import ( - salesforce_refresh_access_token, - ) - - salesforce_integration_id = model.pipeline.job_inputs.get("salesforce_integration_id", None) - - if not salesforce_integration_id: - raise ValueError(f"Salesforce integration not found for job {model.id}") - - integration = await aget_integration_by_id(integration_id=salesforce_integration_id, team_id=inputs.team_id) - salesforce_refresh_token = integration.refresh_token - - if not salesforce_refresh_token: - raise ValueError(f"Salesforce refresh token not found for job {model.id}") - - salesforce_access_token = integration.access_token - - if not salesforce_access_token: - salesforce_access_token = salesforce_refresh_access_token(salesforce_refresh_token) - - salesforce_instance_url = integration.config.get("instance_url") - - source = salesforce_source( - instance_url=salesforce_instance_url, - access_token=salesforce_access_token, - refresh_token=salesforce_refresh_token, - endpoint=schema.name, - team_id=inputs.team_id, - job_id=inputs.run_id, - is_incremental=schema.is_incremental, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - - elif model.pipeline.source_type == ExternalDataSource.Type.ZENDESK: - from posthog.temporal.data_imports.pipelines.zendesk import zendesk_source - - source = zendesk_source( - subdomain=model.pipeline.job_inputs.get("zendesk_subdomain"), - api_key=model.pipeline.job_inputs.get("zendesk_api_key"), - email_address=model.pipeline.job_inputs.get("zendesk_email_address"), - endpoint=schema.name, - team_id=inputs.team_id, - job_id=inputs.run_id, - is_incremental=schema.is_incremental, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type == ExternalDataSource.Type.VITALLY: - from posthog.temporal.data_imports.pipelines.vitally import vitally_source - - source = vitally_source( - secret_token=model.pipeline.job_inputs.get("secret_token"), - region=model.pipeline.job_inputs.get("region"), - subdomain=model.pipeline.job_inputs.get("subdomain"), - endpoint=schema.name, - team_id=inputs.team_id, - job_id=inputs.run_id, - is_incremental=schema.is_incremental, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - elif model.pipeline.source_type == ExternalDataSource.Type.BIGQUERY: - from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( - bigquery_source, - ) - - dataset_id = model.pipeline.job_inputs.get("dataset_id") - project_id = model.pipeline.job_inputs.get("project_id") - private_key = model.pipeline.job_inputs.get("private_key") - private_key_id = model.pipeline.job_inputs.get("private_key_id") - client_email = model.pipeline.job_inputs.get("client_email") - token_uri = model.pipeline.job_inputs.get("token_uri") - - destination_table = f"{project_id}.{dataset_id}.__posthog_import_{inputs.run_id}_{str(datetime.now().timestamp()).replace('.', '')}" - try: - source = bigquery_source( - dataset_id=dataset_id, - project_id=project_id, - private_key=private_key, - private_key_id=private_key_id, - client_email=client_email, - token_uri=token_uri, - table_name=schema.name, - bq_destination_table_id=destination_table, - incremental_field=schema.sync_type_config.get("incremental_field") - if schema.is_incremental - else None, - incremental_field_type=schema.sync_type_config.get("incremental_field_type") - if schema.is_incremental - else None, - ) - - await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - except: - raise - finally: - # Delete the destination table (if it exists) after we're done with it - delete_table( - table_id=destination_table, - project_id=project_id, - private_key=private_key, - private_key_id=private_key_id, - client_email=client_email, - token_uri=token_uri, - ) - logger.info(f"Deleting bigquery temp destination table: {destination_table}") - elif model.pipeline.source_type == ExternalDataSource.Type.CHARGEBEE: - from posthog.temporal.data_imports.pipelines.chargebee import ( - chargebee_source, - ) - - source = chargebee_source( - api_key=model.pipeline.job_inputs.get("api_key"), - site_name=model.pipeline.job_inputs.get("site_name"), - endpoint=schema.name, - team_id=inputs.team_id, - job_id=inputs.run_id, - is_incremental=schema.is_incremental, - ) - - return await _run( - job_inputs=job_inputs, - source=source, - logger=logger, - inputs=inputs, - schema=schema, - reset_pipeline=reset_pipeline, - ) - else: - raise ValueError(f"Source type {model.pipeline.source_type} not supported") - - -async def _run( - job_inputs: PipelineInputs, - source: Any, - logger: FilteringBoundLogger, - inputs: ImportDataActivityInputs, - schema: ExternalDataSchema, - reset_pipeline: bool, -): - table_row_counts = await DataImportPipeline(job_inputs, source, logger, reset_pipeline, schema.is_incremental).run() - total_rows_synced = sum(table_row_counts.values()) - - await aupdate_job_count(inputs.run_id, inputs.team_id, total_rows_synced) - await aremove_reset_pipeline(inputs.source_id) diff --git a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py index 9fc9489fabc94..ddb242483ab31 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py @@ -1,3 +1,5 @@ +import dataclasses +import uuid from datetime import datetime from typing import Any @@ -5,13 +7,12 @@ from temporalio import activity +from posthog.models.integration import Integration from posthog.temporal.common.heartbeat_sync import HeartbeaterSync from posthog.temporal.data_imports.pipelines.bigquery import delete_table -from posthog.temporal.data_imports.pipelines.pipeline import PipelineInputs -from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync +from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync, PipelineInputs from posthog.temporal.data_imports.util import is_posthog_team -from posthog.temporal.data_imports.workflow_activities.import_data import ImportDataActivityInputs from posthog.warehouse.models import ( ExternalDataJob, ExternalDataSource, @@ -22,6 +23,14 @@ from posthog.warehouse.models.ssh_tunnel import SSHTunnel +@dataclasses.dataclass +class ImportDataActivityInputs: + team_id: int + schema_id: uuid.UUID + source_id: uuid.UUID + run_id: str + + @activity.defn def import_data_activity_sync(inputs: ImportDataActivityInputs): logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) @@ -53,7 +62,60 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): endpoints = [schema.name] source = None - if model.pipeline.source_type in [ + if model.pipeline.source_type == ExternalDataSource.Type.STRIPE: + from posthog.temporal.data_imports.pipelines.stripe import stripe_source + + stripe_secret_key = model.pipeline.job_inputs.get("stripe_secret_key", None) + account_id = model.pipeline.job_inputs.get("stripe_account_id", None) + if not stripe_secret_key: + raise ValueError(f"Stripe secret key not found for job {model.id}") + + source = stripe_source( + api_key=stripe_secret_key, + account_id=account_id, + endpoint=schema.name, + team_id=inputs.team_id, + job_id=inputs.run_id, + is_incremental=schema.is_incremental, + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + elif model.pipeline.source_type == ExternalDataSource.Type.HUBSPOT: + from posthog.temporal.data_imports.pipelines.hubspot import hubspot + from posthog.temporal.data_imports.pipelines.hubspot.auth import ( + hubspot_refresh_access_token, + ) + + hubspot_access_code = model.pipeline.job_inputs.get("hubspot_secret_key", None) + refresh_token = model.pipeline.job_inputs.get("hubspot_refresh_token", None) + if not refresh_token: + raise ValueError(f"Hubspot refresh token not found for job {model.id}") + + if not hubspot_access_code: + hubspot_access_code = hubspot_refresh_access_token(refresh_token) + + source = hubspot( + api_key=hubspot_access_code, + refresh_token=refresh_token, + endpoints=tuple(endpoints), + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + elif model.pipeline.source_type in [ ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL, ExternalDataSource.Type.MSSQL, @@ -140,6 +202,134 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, ) + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + elif model.pipeline.source_type == ExternalDataSource.Type.SNOWFLAKE: + if is_posthog_team(inputs.team_id): + from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( + snowflake_source, + ) + else: + from posthog.temporal.data_imports.pipelines.sql_database import ( + snowflake_source, + ) + + account_id = model.pipeline.job_inputs.get("account_id") + user = model.pipeline.job_inputs.get("user") + password = model.pipeline.job_inputs.get("password") + database = model.pipeline.job_inputs.get("database") + warehouse = model.pipeline.job_inputs.get("warehouse") + sf_schema = model.pipeline.job_inputs.get("schema") + role = model.pipeline.job_inputs.get("role") + + source = snowflake_source( + account_id=account_id, + user=user, + password=password, + database=database, + schema=sf_schema, + warehouse=warehouse, + role=role, + table_names=endpoints, + incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None, + incremental_field_type=schema.sync_type_config.get("incremental_field_type") + if schema.is_incremental + else None, + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + elif model.pipeline.source_type == ExternalDataSource.Type.SALESFORCE: + from posthog.temporal.data_imports.pipelines.salesforce import ( + salesforce_source, + ) + from posthog.temporal.data_imports.pipelines.salesforce.auth import ( + salesforce_refresh_access_token, + ) + + salesforce_integration_id = model.pipeline.job_inputs.get("salesforce_integration_id", None) + + if not salesforce_integration_id: + raise ValueError(f"Salesforce integration not found for job {model.id}") + + integration = Integration.objects.get(id=salesforce_integration_id, team_id=inputs.team_id) + salesforce_refresh_token = integration.refresh_token + + if not salesforce_refresh_token: + raise ValueError(f"Salesforce refresh token not found for job {model.id}") + + salesforce_access_token = integration.access_token + + if not salesforce_access_token: + salesforce_access_token = salesforce_refresh_access_token(salesforce_refresh_token) + + salesforce_instance_url = integration.config.get("instance_url") + + source = salesforce_source( + instance_url=salesforce_instance_url, + access_token=salesforce_access_token, + refresh_token=salesforce_refresh_token, + endpoint=schema.name, + team_id=inputs.team_id, + job_id=inputs.run_id, + is_incremental=schema.is_incremental, + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + + elif model.pipeline.source_type == ExternalDataSource.Type.ZENDESK: + from posthog.temporal.data_imports.pipelines.zendesk import zendesk_source + + source = zendesk_source( + subdomain=model.pipeline.job_inputs.get("zendesk_subdomain"), + api_key=model.pipeline.job_inputs.get("zendesk_api_key"), + email_address=model.pipeline.job_inputs.get("zendesk_email_address"), + endpoint=schema.name, + team_id=inputs.team_id, + job_id=inputs.run_id, + is_incremental=schema.is_incremental, + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) + elif model.pipeline.source_type == ExternalDataSource.Type.VITALLY: + from posthog.temporal.data_imports.pipelines.vitally import vitally_source + + source = vitally_source( + secret_token=model.pipeline.job_inputs.get("secret_token"), + region=model.pipeline.job_inputs.get("region"), + subdomain=model.pipeline.job_inputs.get("subdomain"), + endpoint=schema.name, + team_id=inputs.team_id, + job_id=inputs.run_id, + is_incremental=schema.is_incremental, + ) + return _run( job_inputs=job_inputs, source=source, @@ -198,6 +388,28 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): token_uri=token_uri, ) logger.info(f"Deleting bigquery temp destination table: {destination_table}") + elif model.pipeline.source_type == ExternalDataSource.Type.CHARGEBEE: + from posthog.temporal.data_imports.pipelines.chargebee import ( + chargebee_source, + ) + + source = chargebee_source( + api_key=model.pipeline.job_inputs.get("api_key"), + site_name=model.pipeline.job_inputs.get("site_name"), + endpoint=schema.name, + team_id=inputs.team_id, + job_id=inputs.run_id, + is_incremental=schema.is_incremental, + ) + + return _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) else: raise ValueError(f"Source type {model.pipeline.source_type} not supported") diff --git a/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py b/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py index 34e27b0cd49ff..2bc916d3ec9d4 100644 --- a/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py +++ b/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py @@ -1,9 +1,8 @@ import dataclasses -from asgiref.sync import sync_to_async from temporalio import activity -from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.common.logger import bind_temporal_worker_logger_sync from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING from posthog.warehouse.models import sync_old_schemas_with_new_schemas, ExternalDataSource @@ -21,12 +20,12 @@ class SyncNewSchemasActivityInputs: @activity.defn -async def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None: - logger = await bind_temporal_worker_logger(team_id=inputs.team_id) +def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None: + logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) logger.info("Syncing new -> old schemas") - source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=inputs.source_id) + source = ExternalDataSource.objects.get(team_id=inputs.team_id, id=inputs.source_id) schemas_to_sync: list[str] = [] @@ -65,8 +64,8 @@ async def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> Non private_key=ssh_tunnel_auth_type_private_key, ) - sql_schemas = await sync_to_async(get_sql_schemas_for_source_type)( - source.source_type, host, port, database, user, password, db_schema, ssh_tunnel + sql_schemas = get_sql_schemas_for_source_type( + ExternalDataSource.Type(source.source_type), host, port, database, user, password, db_schema, ssh_tunnel ) schemas_to_sync = list(sql_schemas.keys()) @@ -82,9 +81,7 @@ async def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> Non sf_schema = source.job_inputs.get("schema") role = source.job_inputs.get("role") - sql_schemas = await sync_to_async(get_snowflake_schemas)( - account_id, database, warehouse, user, password, sf_schema, role - ) + sql_schemas = get_snowflake_schemas(account_id, database, warehouse, user, password, sf_schema, role) schemas_to_sync = list(sql_schemas.keys()) else: @@ -92,7 +89,7 @@ async def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> Non # TODO: this could cause a race condition where each schema worker creates the missing schema - schemas_created = await sync_to_async(sync_old_schemas_with_new_schemas)( + schemas_created = sync_old_schemas_with_new_schemas( schemas_to_sync, source_id=inputs.source_id, team_id=inputs.team_id, diff --git a/posthog/temporal/tests/batch_exports/test_import_data.py b/posthog/temporal/tests/batch_exports/test_import_data.py index 229f063cc9b43..93d20fbd44b23 100644 --- a/posthog/temporal/tests/batch_exports/test_import_data.py +++ b/posthog/temporal/tests/batch_exports/test_import_data.py @@ -1,9 +1,9 @@ from typing import Any from unittest import mock import pytest -from asgiref.sync import sync_to_async from posthog.models.team.team import Team -from posthog.temporal.data_imports.workflow_activities.import_data import ImportDataActivityInputs, import_data_activity +from posthog.temporal.data_imports import import_data_activity_sync +from posthog.temporal.data_imports.workflow_activities.import_data_sync import ImportDataActivityInputs from posthog.warehouse.models.credential import DataWarehouseCredential from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.warehouse.models.external_data_schema import ExternalDataSchema @@ -12,8 +12,8 @@ from posthog.warehouse.models.table import DataWarehouseTable -async def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityInputs: - source = await sync_to_async(ExternalDataSource.objects.create)( +def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityInputs: + source = ExternalDataSource.objects.create( team=team, source_id="source_id", connection_id="connection_id", @@ -21,10 +21,8 @@ async def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityIn source_type=ExternalDataSource.Type.POSTGRES, job_inputs=job_inputs, ) - credentials = await sync_to_async(DataWarehouseCredential.objects.create)( - access_key="blah", access_secret="blah", team=team - ) - warehouse_table = await sync_to_async(DataWarehouseTable.objects.create)( + credentials = DataWarehouseCredential.objects.create(access_key="blah", access_secret="blah", team=team) + warehouse_table = DataWarehouseTable.objects.create( name="table_1", format="Parquet", team=team, @@ -34,7 +32,7 @@ async def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityIn url_pattern="https://bucket.s3/data/*", columns={"id": {"hogql": "StringDatabaseField", "clickhouse": "Nullable(String)", "schema_valid": True}}, ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( team=team, name="table_1", source=source, @@ -43,7 +41,7 @@ async def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityIn status=ExternalDataSchema.Status.COMPLETED, last_synced_at="2024-01-01", ) - job = await sync_to_async(ExternalDataJob.objects.create)( + job = ExternalDataJob.objects.create( team=team, pipeline=source, schema=schema, @@ -56,8 +54,7 @@ async def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityIn @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_postgres_source_without_ssh_tunnel(activity_environment, team, **kwargs): +def test_postgres_source_without_ssh_tunnel(activity_environment, team, **kwargs): job_inputs = { "host": "host.com", "port": 5432, @@ -67,15 +64,15 @@ async def test_postgres_source_without_ssh_tunnel(activity_environment, team, ** "schema": "schema", } - activity_inputs = await _setup(team, job_inputs) + activity_inputs = _setup(team, job_inputs) with ( mock.patch( "posthog.temporal.data_imports.pipelines.sql_database_v2.sql_source_for_type" ) as sql_source_for_type, - mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), + mock.patch("posthog.temporal.data_imports.workflow_activities.import_data_sync._run"), ): - await activity_environment.run(import_data_activity, activity_inputs) + activity_environment.run(import_data_activity_sync, activity_inputs) sql_source_for_type.assert_called_once_with( source_type=ExternalDataSource.Type.POSTGRES, @@ -94,8 +91,7 @@ async def test_postgres_source_without_ssh_tunnel(activity_environment, team, ** @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, team, **kwargs): +def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, team, **kwargs): job_inputs = { "host": "host.com", "port": "5432", @@ -108,15 +104,15 @@ async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, te "ssh_tunnel_port": "", } - activity_inputs = await _setup(team, job_inputs) + activity_inputs = _setup(team, job_inputs) with ( mock.patch( "posthog.temporal.data_imports.pipelines.sql_database_v2.sql_source_for_type" ) as sql_source_for_type, - mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), + mock.patch("posthog.temporal.data_imports.workflow_activities.import_data_sync._run"), ): - await activity_environment.run(import_data_activity, activity_inputs) + activity_environment.run(import_data_activity_sync, activity_inputs) sql_source_for_type.assert_called_once_with( source_type=ExternalDataSource.Type.POSTGRES, @@ -136,7 +132,7 @@ async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, te @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_postgres_source_with_ssh_tunnel_enabled(activity_environment, team, **kwargs): +def test_postgres_source_with_ssh_tunnel_enabled(activity_environment, team, **kwargs): job_inputs = { "host": "host.com", "port": "5432", @@ -152,7 +148,7 @@ async def test_postgres_source_with_ssh_tunnel_enabled(activity_environment, tea "ssh_tunnel_auth_type_password": "password", } - activity_inputs = await _setup(team, job_inputs) + activity_inputs = _setup(team, job_inputs) def mock_get_tunnel(self_class, host, port): class MockedTunnel: @@ -171,10 +167,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): mock.patch( "posthog.temporal.data_imports.pipelines.sql_database_v2.sql_source_for_type" ) as sql_source_for_type_v2, - mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), + mock.patch("posthog.temporal.data_imports.workflow_activities.import_data_sync._run"), mock.patch.object(SSHTunnel, "get_tunnel", mock_get_tunnel), ): - await activity_environment.run(import_data_activity, activity_inputs) + activity_environment.run(import_data_activity_sync, activity_inputs) sql_source_for_type_v2.assert_called_once_with( source_type=ExternalDataSource.Type.POSTGRES, diff --git a/posthog/temporal/tests/data_imports/test_end_to_end.py b/posthog/temporal/tests/data_imports/test_end_to_end.py index 786d6fdd56596..cb29cbafa5d78 100644 --- a/posthog/temporal/tests/data_imports/test_end_to_end.py +++ b/posthog/temporal/tests/data_imports/test_end_to_end.py @@ -870,10 +870,11 @@ def get_jobs(): return list(jobs) - with mock.patch( - "posthog.temporal.data_imports.workflow_activities.create_job_model.acreate_external_data_job", - ) as acreate_external_data_job: - acreate_external_data_job.side_effect = Exception("Ruhoh!") + with mock.patch.object( + ExternalDataJob.objects, + "create", + ) as create_external_data_job: + create_external_data_job.side_effect = Exception("Ruhoh!") with pytest.raises(Exception): await _execute_run(workflow_id, inputs, stripe_customer["data"]) diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index d554fe81fc5e1..f931c97f93943 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -6,9 +6,9 @@ from asgiref.sync import sync_to_async from django.test import override_settings +from posthog.temporal.data_imports import import_data_activity_sync from posthog.temporal.data_imports.external_data_job import ( UpdateExternalDataJobStatusInputs, - check_schedule_activity, create_source_templates, update_external_data_job_model, ) @@ -16,58 +16,55 @@ ExternalDataJobWorkflow, ExternalDataWorkflowInputs, ) +from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync from posthog.temporal.data_imports.workflow_activities.check_billing_limits import check_billing_limits_activity from posthog.temporal.data_imports.workflow_activities.create_job_model import ( CreateExternalDataJobModelActivityInputs, create_external_data_job_model_activity, ) -from posthog.temporal.data_imports.workflow_activities.import_data import ImportDataActivityInputs, import_data_activity +from posthog.temporal.data_imports.workflow_activities.import_data_sync import ImportDataActivityInputs from posthog.temporal.data_imports.workflow_activities.sync_new_schemas import ( SyncNewSchemasActivityInputs, sync_new_schemas_activity, ) -from posthog.warehouse.external_data_source.jobs import acreate_external_data_job from posthog.warehouse.models import ( get_latest_run_if_exists, ExternalDataJob, ExternalDataSource, ExternalDataSchema, - get_external_data_job, ) from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, ) from posthog.models import Team -from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline from temporalio.testing import WorkflowEnvironment from temporalio.common import RetryPolicy from temporalio.worker import UnsandboxedWorkflowRunner, Worker from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE import pytest_asyncio -import aioboto3 +import boto3 import functools from django.conf import settings from dlt.sources.helpers.rest_client.client import RESTClient from dlt.common.configuration.specs.aws_credentials import AwsCredentials -import asyncio import psycopg from posthog.warehouse.models.external_data_schema import get_all_schemas_for_source_id BUCKET_NAME = "test-pipeline" -SESSION = aioboto3.Session() +SESSION = boto3.Session() create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT) -async def delete_all_from_s3(minio_client, bucket_name: str, key_prefix: str): +def delete_all_from_s3(minio_client, bucket_name: str, key_prefix: str): """Delete all objects in bucket_name under key_prefix.""" - response = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) + response = minio_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) if "Contents" in response: for obj in response["Contents"]: if "Key" in obj: - await minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) @pytest.fixture @@ -76,28 +73,29 @@ def bucket_name(request) -> str: return BUCKET_NAME -@pytest_asyncio.fixture -async def minio_client(bucket_name): +@pytest.fixture +def minio_client(bucket_name): """Manage an S3 client to interact with a MinIO bucket. Yields the client after creating a bucket. Upon resuming, we delete the contents and the bucket itself. """ - async with create_test_client( + minio_client = create_test_client( "s3", aws_access_key_id=settings.OBJECT_STORAGE_ACCESS_KEY_ID, aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, - ) as minio_client: - try: - await minio_client.head_bucket(Bucket=bucket_name) - except: - await minio_client.create_bucket(Bucket=bucket_name) + ) - yield minio_client + try: + minio_client.head_bucket(Bucket=bucket_name) + except: + minio_client.create_bucket(Bucket=bucket_name) - await delete_all_from_s3(minio_client, bucket_name, key_prefix="/") + yield minio_client - await minio_client.delete_bucket(Bucket=bucket_name) + delete_all_from_s3(minio_client, bucket_name, key_prefix="/") + + minio_client.delete_bucket(Bucket=bucket_name) @pytest.fixture @@ -127,8 +125,8 @@ async def postgres_connection(postgres_config, setup_postgres_test_db): await connection.close() -async def _create_schema(schema_name: str, source: ExternalDataSource, team: Team, table_id: Optional[str] = None): - return await sync_to_async(ExternalDataSchema.objects.create)( +def _create_schema(schema_name: str, source: ExternalDataSource, team: Team, table_id: Optional[str] = None): + return ExternalDataSchema.objects.create( name=schema_name, team_id=team.pk, source_id=source.pk, @@ -136,46 +134,64 @@ async def _create_schema(schema_name: str, source: ExternalDataSource, team: Tea ) +def _create_external_data_job( + external_data_source_id: uuid.UUID, + external_data_schema_id: uuid.UUID, + workflow_id: str, + workflow_run_id: str, + team_id: int, +) -> ExternalDataJob: + job = ExternalDataJob.objects.create( + team_id=team_id, + pipeline_id=external_data_source_id, + schema_id=external_data_schema_id, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + ) + + return job + + @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_create_external_job_activity(activity_environment, team, **kwargs): +def test_create_external_job_activity(activity_environment, team, **kwargs): """ Test that the create external job activity creates a new job """ - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", ) - test_1_schema = await _create_schema("test-1", new_source, team) + test_1_schema = _create_schema("test-1", new_source, team) inputs = CreateExternalDataJobModelActivityInputs( team_id=team.id, source_id=new_source.pk, schema_id=test_1_schema.id ) - run_id, _, __ = await activity_environment.run(create_external_data_job_model_activity, inputs) + run_id, _, __ = activity_environment.run(create_external_data_job_model_activity, inputs) runs = ExternalDataJob.objects.filter(id=run_id) - assert await sync_to_async(runs.exists)() + assert runs.exists() @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_create_external_job_activity_schemas_exist(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_create_external_job_activity_schemas_exist(activity_environment, team, **kwargs): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], team_id=team.id, source_id=new_source.pk, @@ -183,25 +199,24 @@ async def test_create_external_job_activity_schemas_exist(activity_environment, inputs = CreateExternalDataJobModelActivityInputs(team_id=team.id, source_id=new_source.pk, schema_id=schema.id) - run_id, _, __ = await activity_environment.run(create_external_data_job_model_activity, inputs) + run_id, _, __ = activity_environment.run(create_external_data_job_model_activity, inputs) runs = ExternalDataJob.objects.filter(id=run_id) - assert await sync_to_async(runs.exists)() + assert runs.exists() @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_create_external_job_activity_update_schemas(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_create_external_job_activity_update_schemas(activity_environment, team, **kwargs): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", ) - await sync_to_async(ExternalDataSchema.objects.create)( + ExternalDataSchema.objects.create( name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], team_id=team.id, source_id=new_source.pk, @@ -210,36 +225,35 @@ async def test_create_external_job_activity_update_schemas(activity_environment, inputs = SyncNewSchemasActivityInputs(source_id=str(new_source.pk), team_id=team.id) - await activity_environment.run(sync_new_schemas_activity, inputs) + activity_environment.run(sync_new_schemas_activity, inputs) - all_schemas = await sync_to_async(get_all_schemas_for_source_id)(new_source.pk, team.id) + all_schemas = get_all_schemas_for_source_id(new_source.pk, team.id) assert len(all_schemas) == len(PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[ExternalDataSource.Type.STRIPE]) @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_update_external_job_activity(activity_environment, team, **kwargs): +def test_update_external_job_activity(activity_environment, team, **kwargs): """ Test that the update external job activity updates the job status """ - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], team_id=team.id, source_id=new_source.pk, should_sync=True, ) - new_job = await acreate_external_data_job( + new_job = _create_external_data_job( team_id=team.id, external_data_source_id=new_source.pk, workflow_id=activity_environment.info.workflow_id, @@ -257,34 +271,33 @@ async def test_update_external_job_activity(activity_environment, team, **kwargs team_id=team.id, ) - await activity_environment.run(update_external_data_job_model, inputs) - await sync_to_async(new_job.refresh_from_db)() - await sync_to_async(schema.refresh_from_db)() + activity_environment.run(update_external_data_job_model, inputs) + new_job.refresh_from_db() + schema.refresh_from_db() assert new_job.status == ExternalDataJob.Status.COMPLETED assert schema.status == ExternalDataJob.Status.COMPLETED @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_update_external_job_activity_with_retryable_error(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_update_external_job_activity_with_retryable_error(activity_environment, team, **kwargs): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], team_id=team.id, source_id=new_source.pk, should_sync=True, ) - new_job = await acreate_external_data_job( + new_job = _create_external_data_job( team_id=team.id, external_data_source_id=new_source.pk, workflow_id=activity_environment.info.workflow_id, @@ -302,9 +315,9 @@ async def test_update_external_job_activity_with_retryable_error(activity_enviro team_id=team.id, ) - await activity_environment.run(update_external_data_job_model, inputs) - await sync_to_async(new_job.refresh_from_db)() - await sync_to_async(schema.refresh_from_db)() + activity_environment.run(update_external_data_job_model, inputs) + new_job.refresh_from_db() + schema.refresh_from_db() assert new_job.status == ExternalDataJob.Status.COMPLETED assert schema.status == ExternalDataJob.Status.COMPLETED @@ -312,25 +325,24 @@ async def test_update_external_job_activity_with_retryable_error(activity_enviro @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_update_external_job_activity_with_non_retryable_error(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_update_external_job_activity_with_non_retryable_error(activity_environment, team, **kwargs): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Postgres", ) - schema = await sync_to_async(ExternalDataSchema.objects.create)( + schema = ExternalDataSchema.objects.create( name="test_123", team_id=team.id, source_id=new_source.pk, should_sync=True, ) - new_job = await acreate_external_data_job( + new_job = _create_external_data_job( team_id=team.id, external_data_source_id=new_source.pk, workflow_id=activity_environment.info.workflow_id, @@ -348,10 +360,10 @@ async def test_update_external_job_activity_with_non_retryable_error(activity_en team_id=team.id, ) with mock.patch("posthog.warehouse.models.external_data_schema.external_data_workflow_exists", return_value=False): - await activity_environment.run(update_external_data_job_model, inputs) + activity_environment.run(update_external_data_job_model, inputs) - await sync_to_async(new_job.refresh_from_db)() - await sync_to_async(schema.refresh_from_db)() + new_job.refresh_from_db() + schema.refresh_from_db() assert new_job.status == ExternalDataJob.Status.COMPLETED assert schema.status == ExternalDataJob.Status.COMPLETED @@ -359,22 +371,21 @@ async def test_update_external_job_activity_with_non_retryable_error(activity_en @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_run_stripe_job(activity_environment, team, minio_client, **kwargs): - async def setup_job_1(): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_run_stripe_job(activity_environment, team, minio_client, **kwargs): + def setup_job_1(): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, ) - customer_schema = await _create_schema("Customer", new_source, team) + customer_schema = _create_schema("Customer", new_source, team) - new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( + new_job: ExternalDataJob = ExternalDataJob.objects.create( team_id=team.id, pipeline_id=new_source.pk, status=ExternalDataJob.Status.RUNNING, @@ -382,7 +393,7 @@ async def setup_job_1(): schema=customer_schema, ) - new_job = await get_external_data_job(new_job.id) + new_job = ExternalDataJob.objects.get(id=new_job.id) inputs = ImportDataActivityInputs( team_id=team.id, @@ -393,20 +404,20 @@ async def setup_job_1(): return new_job, inputs - async def setup_job_2(): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), + def setup_job_2(): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, ) - charge_schema = await _create_schema("Charge", new_source, team) + charge_schema = _create_schema("Charge", new_source, team) - new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( + new_job: ExternalDataJob = ExternalDataJob.objects.create( team_id=team.id, pipeline_id=new_source.pk, status=ExternalDataJob.Status.RUNNING, @@ -414,7 +425,7 @@ async def setup_job_2(): schema=charge_schema, ) - new_job = await get_external_data_job(new_job.id) + new_job = ExternalDataJob.objects.get(id=new_job.id) inputs = ImportDataActivityInputs( team_id=team.id, @@ -425,8 +436,8 @@ async def setup_job_2(): return new_job, inputs - job_1, job_1_inputs = await setup_job_1() - job_2, job_2_inputs = await setup_job_2() + job_1, job_1_inputs = setup_job_1() + job_2, job_2_inputs = setup_job_2() def mock_customers_paginate( class_self, @@ -504,14 +515,10 @@ def mock_to_object_store_rs_credentials(class_self): mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), ): - await asyncio.gather( - activity_environment.run(import_data_activity, job_1_inputs), - ) + activity_environment.run(import_data_activity_sync, job_1_inputs) - folder_path = await sync_to_async(job_1.folder_path)() - job_1_customer_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/" - ) + folder_path = job_1.folder_path() + job_1_customer_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/") assert len(job_1_customer_objects["Contents"]) == 2 @@ -531,33 +538,28 @@ def mock_to_object_store_rs_credentials(class_self): mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), ): - await asyncio.gather( - activity_environment.run(import_data_activity, job_2_inputs), - ) + activity_environment.run(import_data_activity_sync, job_2_inputs) - job_2_charge_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path()}/charge/" - ) + job_2_charge_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path()}/charge/") assert len(job_2_charge_objects["Contents"]) == 2 @pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_run_stripe_job_row_count_update(activity_environment, team, minio_client, **kwargs): - async def setup_job_1(): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), +def test_run_stripe_job_row_count_update(activity_environment, team, minio_client, **kwargs): + def setup_job_1(): + new_source = ExternalDataSource.objects.create( + source_id=str(uuid.uuid4()), + connection_id=str(uuid.uuid4()), + destination_id=str(uuid.uuid4()), team=team, status="running", source_type="Stripe", job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, ) - customer_schema = await _create_schema("Customer", new_source, team) + customer_schema = _create_schema("Customer", new_source, team) - new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( + new_job: ExternalDataJob = ExternalDataJob.objects.create( team_id=team.id, pipeline_id=new_source.pk, status=ExternalDataJob.Status.RUNNING, @@ -565,9 +567,9 @@ async def setup_job_1(): schema=customer_schema, ) - new_job = await sync_to_async( - ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").prefetch_related("schema").get - )() + new_job = ( + ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").prefetch_related("schema").get() + ) inputs = ImportDataActivityInputs( team_id=team.id, @@ -578,7 +580,7 @@ async def setup_job_1(): return new_job, inputs - job_1, job_1_inputs = await setup_job_1() + job_1, job_1_inputs = setup_job_1() def mock_customers_paginate( class_self, @@ -636,18 +638,14 @@ def mock_to_object_store_rs_credentials(class_self): mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), ): - await asyncio.gather( - activity_environment.run(import_data_activity, job_1_inputs), - ) + activity_environment.run(import_data_activity_sync, job_1_inputs) - folder_path = await sync_to_async(job_1.folder_path)() - job_1_customer_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/" - ) + folder_path = job_1.folder_path() + job_1_customer_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/") assert len(job_1_customer_objects["Contents"]) == 2 - await sync_to_async(job_1.refresh_from_db)() + job_1.refresh_from_db() assert job_1.rows_synced == 1 @@ -680,24 +678,30 @@ async def test_external_data_job_workflow_with_schema(team, **kwargs): external_data_schema_id=schema.id, ) - async def mock_async_func(inputs): + def mock_func(inputs): return {} with ( mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"}), - mock.patch.object(DataImportPipeline, "run", mock_async_func), + mock.patch.object(DataImportPipelineSync, "run", mock_func), ): - with override_settings(AIRBYTE_BUCKET_KEY="test-key", AIRBYTE_BUCKET_SECRET="test-secret"): + with override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + AIRBYTE_BUCKET_REGION="us-east-1", + AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", + BUCKET_NAME=BUCKET_NAME, + ): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( activity_environment.client, task_queue=DATA_WAREHOUSE_TASK_QUEUE, workflows=[ExternalDataJobWorkflow], activities=[ - check_schedule_activity, create_external_data_job_model_activity, update_external_data_job_model, - import_data_activity, + import_data_activity_sync, create_source_templates, check_billing_limits_activity, sync_new_schemas_activity, @@ -752,7 +756,7 @@ async def setup_job_1(): }, ) - posthog_test_schema = await _create_schema("posthog_test", new_source, team) + posthog_test_schema = await sync_to_async(_create_schema)("posthog_test", new_source, team) new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( team_id=team.id, @@ -806,127 +810,8 @@ def mock_to_object_store_rs_credentials(class_self): mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), ): - await asyncio.gather( - activity_environment.run(import_data_activity, job_1_inputs), - ) + await sync_to_async(activity_environment.run)(import_data_activity_sync, job_1_inputs) folder_path = await sync_to_async(job_1.folder_path)() - job_1_team_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/" - ) + job_1_team_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/") assert len(job_1_team_objects["Contents"]) == 2 - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_check_schedule_activity_with_schema_id(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), - team=team, - status="running", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, - ) - - test_1_schema = await _create_schema("test-1", new_source, team) - - should_exit = await activity_environment.run( - check_schedule_activity, - ExternalDataWorkflowInputs( - team_id=team.id, - external_data_source_id=new_source.id, - external_data_schema_id=test_1_schema.id, - ), - ) - - assert should_exit is False - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_check_schedule_activity_with_missing_schema_id_but_with_schedule(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), - team=team, - status="running", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, - ) - - await sync_to_async(ExternalDataSchema.objects.create)( - name="test-1", - team_id=team.id, - source_id=new_source.pk, - should_sync=True, - ) - - with ( - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=True - ), - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_trigger_external_data_workflow" - ) as mock_a_trigger_external_data_workflow, - ): - should_exit = await activity_environment.run( - check_schedule_activity, - ExternalDataWorkflowInputs( - team_id=team.id, - external_data_source_id=new_source.id, - external_data_schema_id=None, - ), - ) - - assert should_exit is True - assert mock_a_trigger_external_data_workflow.call_count == 1 - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_check_schedule_activity_with_missing_schema_id_and_no_schedule(activity_environment, team, **kwargs): - new_source = await sync_to_async(ExternalDataSource.objects.create)( - source_id=uuid.uuid4(), - connection_id=uuid.uuid4(), - destination_id=uuid.uuid4(), - team=team, - status="running", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, - ) - - await sync_to_async(ExternalDataSchema.objects.create)( - name="test-1", - team_id=team.id, - source_id=new_source.pk, - should_sync=True, - ) - - with ( - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=False - ), - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), - mock.patch( - "posthog.temporal.data_imports.external_data_job.a_sync_external_data_job_workflow" - ) as mock_a_sync_external_data_job_workflow, - ): - should_exit = await activity_environment.run( - check_schedule_activity, - ExternalDataWorkflowInputs( - team_id=team.id, - external_data_source_id=new_source.id, - external_data_schema_id=None, - ), - ) - - assert should_exit is True - assert mock_a_sync_external_data_job_workflow.call_count == 1 diff --git a/posthog/warehouse/data_load/source_templates.py b/posthog/warehouse/data_load/source_templates.py index 5a7d515bc8536..6b993e00d3d97 100644 --- a/posthog/warehouse/data_load/source_templates.py +++ b/posthog/warehouse/data_load/source_templates.py @@ -1,11 +1,9 @@ -from posthog.temporal.common.logger import bind_temporal_worker_logger -from posthog.warehouse.models.external_data_job import ExternalDataJob, get_external_data_job, get_latest_run_if_exists +from posthog.temporal.common.logger import bind_temporal_worker_logger_sync +from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.models.join import DataWarehouseJoin -from posthog.warehouse.util import database_sync_to_async -@database_sync_to_async def database_operations(team_id: int, table_prefix: str) -> None: customer_join_exists = ( DataWarehouseJoin.objects.filter( @@ -54,11 +52,18 @@ def database_operations(team_id: int, table_prefix: str) -> None: ) -async def create_warehouse_templates_for_source(team_id: int, run_id: str) -> None: - logger = await bind_temporal_worker_logger(team_id=team_id) +def create_warehouse_templates_for_source(team_id: int, run_id: str) -> None: + logger = bind_temporal_worker_logger_sync(team_id=team_id) - job: ExternalDataJob = await get_external_data_job(job_id=run_id) - last_successful_job: ExternalDataJob | None = await get_latest_run_if_exists(job.team_id, job.pipeline_id) + job: ExternalDataJob = ExternalDataJob.objects.get(pk=run_id) + last_successful_job: ExternalDataJob | None = ( + ExternalDataJob.objects.filter( + team_id=job.team_id, pipeline_id=job.pipeline_id, status=ExternalDataJob.Status.COMPLETED + ) + .prefetch_related("pipeline") + .order_by("-created_at") + .first() + ) source: ExternalDataSource.Type = job.pipeline.source_type @@ -71,7 +76,7 @@ async def create_warehouse_templates_for_source(team_id: int, run_id: str) -> No table_prefix = job.pipeline.prefix or "" - await database_operations(team_id, table_prefix) + database_operations(team_id, table_prefix) logger.info( f"Created warehouse template for job {run_id}", diff --git a/posthog/warehouse/external_data_source/jobs.py b/posthog/warehouse/external_data_source/jobs.py index d21210f2ec097..b7d37eb746270 100644 --- a/posthog/warehouse/external_data_source/jobs.py +++ b/posthog/warehouse/external_data_source/jobs.py @@ -1,4 +1,3 @@ -from uuid import UUID from posthog.warehouse.util import database_sync_to_async from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.warehouse.models.external_data_schema import ExternalDataSchema @@ -9,27 +8,6 @@ def get_external_data_source(team_id: str, external_data_source_id: str) -> Exte return ExternalDataSource.objects.get(team_id=team_id, id=external_data_source_id) -@database_sync_to_async -def acreate_external_data_job( - external_data_source_id: UUID, - external_data_schema_id: UUID, - workflow_id: str, - workflow_run_id: str, - team_id: int, -) -> ExternalDataJob: - job = ExternalDataJob.objects.create( - team_id=team_id, - pipeline_id=external_data_source_id, - schema_id=external_data_schema_id, - status=ExternalDataJob.Status.RUNNING, - rows_synced=0, - workflow_id=workflow_id, - workflow_run_id=workflow_run_id, - ) - - return job - - @database_sync_to_async def aget_running_job_for_schema(schema_id: str) -> ExternalDataJob | None: return ( @@ -39,8 +17,7 @@ def aget_running_job_for_schema(schema_id: str) -> ExternalDataJob | None: ) -@database_sync_to_async -def aupdate_external_job_status( +def update_external_job_status( job_id: str, team_id: int, status: ExternalDataJob.Status, latest_error: str | None ) -> ExternalDataJob: model = ExternalDataJob.objects.get(id=job_id, team_id=team_id) diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index c90a5c2e472bb..3bcbc6c658f7f 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -99,8 +99,7 @@ def aget_schema_by_id(schema_id: str, team_id: int) -> ExternalDataSchema | None ) -@database_sync_to_async -def aupdate_should_sync(schema_id: str, team_id: int, should_sync: bool) -> ExternalDataSchema | None: +def update_should_sync(schema_id: str, team_id: int, should_sync: bool) -> ExternalDataSchema | None: schema = ExternalDataSchema.objects.get(id=schema_id, team_id=team_id) schema.should_sync = should_sync schema.save() @@ -119,15 +118,6 @@ def aupdate_should_sync(schema_id: str, team_id: int, should_sync: bool) -> Exte return schema -@database_sync_to_async -def get_active_schemas_for_source_id(source_id: uuid.UUID, team_id: int): - return list( - ExternalDataSchema.objects.exclude(deleted=True) - .filter(team_id=team_id, source_id=source_id, should_sync=True) - .all() - ) - - def get_all_schemas_for_source_id(source_id: uuid.UUID, team_id: int): return list(ExternalDataSchema.objects.exclude(deleted=True).filter(team_id=team_id, source_id=source_id).all())