diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index caaa628e7eefc..cbb80f4978ca8 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -15,7 +15,6 @@ from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs from posthog.warehouse.external_data_source.jobs import ( create_external_data_job, - get_external_data_job, update_external_job_status, ) from posthog.warehouse.models import ( @@ -23,6 +22,7 @@ get_active_schemas_for_source_id, sync_old_schemas_with_new_schemas, ExternalDataSource, + get_external_data_job, ) from posthog.warehouse.models.external_data_schema import get_postgres_schemas from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -116,7 +116,7 @@ class ValidateSchemaInputs: @activity.defn async def validate_schema_activity(inputs: ValidateSchemaInputs) -> None: - await sync_to_async(validate_schema_and_update_table)( # type: ignore + await validate_schema_and_update_table( run_id=inputs.run_id, team_id=inputs.team_id, schemas=inputs.schemas, @@ -144,9 +144,8 @@ class ExternalDataJobInputs: @activity.defn async def run_external_data_job(inputs: ExternalDataJobInputs) -> None: - model: ExternalDataJob = await sync_to_async(get_external_data_job)( # type: ignore - team_id=inputs.team_id, - run_id=inputs.run_id, + model: ExternalDataJob = await get_external_data_job( + job_id=inputs.run_id, ) logger = await bind_temporal_worker_logger(team_id=inputs.team_id) diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index a4574502ff9d4..835ee8898a7b6 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -26,6 +26,7 @@ ExternalDataJob, ExternalDataSource, ExternalDataSchema, + DataWarehouseCredential, ) from posthog.temporal.data_imports.pipelines.schemas import ( @@ -361,6 +362,122 @@ async def test_validate_schema_and_update_table_activity(activity_environment, t ) assert mock_get_columns.call_count == 10 + assert ( + await sync_to_async(DataWarehouseTable.objects.filter(external_data_source_id=new_source.pk).count)() == 5 # type: ignore + ) + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_validate_schema_and_update_table_activity_with_existing(activity_environment, team, **kwargs): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key"}, + ) # type: ignore + + old_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.COMPLETED, + rows_synced=0, + ) + + old_credential = await sync_to_async(DataWarehouseCredential.objects.create)( # type: ignore + team=team, + access_key=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + access_secret=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ) + + url_pattern = await sync_to_async(old_job.url_pattern_by_schema)("test-1") # type: ignore + + await sync_to_async(DataWarehouseTable.objects.create)( # type: ignore + credential=old_credential, + name="stripe_test-1", + format="Parquet", + url_pattern=url_pattern, + team_id=team.pk, + external_data_source_id=new_source.pk, + ) + + new_job = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + ) + + with mock.patch( + "posthog.warehouse.models.table.DataWarehouseTable.get_columns" + ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + mock_get_columns.return_value = {"id": "string"} + await activity_environment.run( + validate_schema_activity, + ValidateSchemaInputs( + run_id=new_job.pk, team_id=team.id, schemas=["test-1", "test-2", "test-3", "test-4", "test-5"] + ), + ) + + assert mock_get_columns.call_count == 10 + assert ( + await sync_to_async(DataWarehouseTable.objects.filter(external_data_source_id=new_source.pk).count)() == 5 # type: ignore + ) + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_validate_schema_and_update_table_activity_half_run(activity_environment, team, **kwargs): + new_source = await sync_to_async(ExternalDataSource.objects.create)( # type: ignore + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key"}, + ) + + new_job = await sync_to_async(ExternalDataJob.objects.create)( # type: ignore + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + ) + + with mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, mock.patch( + "posthog.warehouse.data_load.validate_schema.validate_schema", + ) as mock_validate, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + mock_get_columns.return_value = {"id": "string"} + credential = await sync_to_async(DataWarehouseCredential.objects.create)( # type: ignore + team=team, + access_key=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + access_secret=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ) + + mock_validate.side_effect = [ + Exception, + { + "credential": credential, + "format": "Parquet", + "name": "test_schema", + "url_pattern": "test_url_pattern", + "team_id": team.pk, + }, + ] + + await activity_environment.run( + validate_schema_activity, + ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=["broken_schema", "test_schema"]), + ) + + assert mock_get_columns.call_count == 1 + assert ( + await sync_to_async(DataWarehouseTable.objects.filter(external_data_source_id=new_source.pk).count)() == 1 # type: ignore + ) @pytest.mark.django_db(transaction=True) @@ -446,7 +563,7 @@ async def test_external_data_job_workflow_blank(team, **kwargs): retry_policy=RetryPolicy(maximum_attempts=1), ) - run = await sync_to_async(get_latest_run_if_exists)(team_id=team.pk, pipeline_id=new_source.pk) # type: ignore + run = await get_latest_run_if_exists(team_id=team.pk, pipeline_id=new_source.pk) assert run is not None assert run.status == ExternalDataJob.Status.COMPLETED @@ -509,7 +626,7 @@ async def mock_async_func(inputs): retry_policy=RetryPolicy(maximum_attempts=1), ) - run = await sync_to_async(get_latest_run_if_exists)(team_id=team.pk, pipeline_id=new_source.pk) # type: ignore + run = await get_latest_run_if_exists(team_id=team.pk, pipeline_id=new_source.pk) assert run is not None assert run.status == ExternalDataJob.Status.COMPLETED diff --git a/posthog/warehouse/data_load/test/__init__.py b/posthog/warehouse/data_load/test/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/posthog/warehouse/data_load/test/test_validate_schema.py b/posthog/warehouse/data_load/test/test_validate_schema.py deleted file mode 100644 index 0f139d899de33..0000000000000 --- a/posthog/warehouse/data_load/test/test_validate_schema.py +++ /dev/null @@ -1,116 +0,0 @@ -from posthog.test.base import BaseTest, ClickhouseTestMixin -from posthog.warehouse.models import ExternalDataJob, ExternalDataSource, DataWarehouseCredential, DataWarehouseTable -from posthog.warehouse.data_load.validate_schema import validate_schema_and_update_table -import uuid -from unittest.mock import patch - - -class TestValidateSchema(ClickhouseTestMixin, BaseTest): - def _create_external_data_source(self) -> ExternalDataSource: - return ExternalDataSource.objects.create( - source_id=str(uuid.uuid4()), - connection_id=str(uuid.uuid4()), - destination_id=str(uuid.uuid4()), - team=self.team, - status="Running", - source_type="Stripe", - ) - - def _create_external_data_job(self, source_id, status) -> ExternalDataJob: - return ExternalDataJob.objects.create( - pipeline_id=source_id, - status=status, - team_id=self.team.pk, - ) - - def _create_datawarehouse_credential(self): - return DataWarehouseCredential.objects.create( - team=self.team, - access_key="test-key", - access_secret="test-secret", - ) - - @patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"id": "String", "a_column": "String"}, - ) - def test_validate_schema(self, mock_get_columns): - pass - - @patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"id": "String", "a_column": "String"}, - ) - def test_validate_schema_and_update_table_no_existing_table(self, mock_get_columns): - source = self._create_external_data_source() - job = self._create_external_data_job(source.pk, "Running") - - with self.settings(AIRBYTE_BUCKET_KEY="key", AIRBYTE_BUCKET_SECRET="secret"): - validate_schema_and_update_table( - run_id=job.pk, - team_id=self.team.pk, - schemas=["test_schema"], - ) - - self.assertEqual(DataWarehouseTable.objects.filter(external_data_source_id=source.pk).count(), 1) - - @patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"id": "String", "a_column": "String"}, - ) - def test_validate_schema_and_update_table_existing_table(self, mock_get_columns): - source = self._create_external_data_source() - old_job = self._create_external_data_job(source.pk, "Completed") - job = self._create_external_data_job(source.pk, "Running") - DataWarehouseTable.objects.create( - credential=self._create_datawarehouse_credential(), - name="test_table", - format="Parquet", - url_pattern=old_job.url_pattern_by_schema("test_schema"), - team_id=self.team.pk, - external_data_source_id=source.pk, - ) - - with self.settings(AIRBYTE_BUCKET_KEY="key", AIRBYTE_BUCKET_SECRET="secret"): - validate_schema_and_update_table( - run_id=job.pk, - team_id=self.team.pk, - schemas=["test_schema"], - ) - - tables = DataWarehouseTable.objects.filter(external_data_source_id=source.pk).all() - self.assertEqual(len(tables), 1) - # url got updated - self.assertEqual(tables[0].url_pattern, job.url_pattern_by_schema("test_schema")) - - @patch( - "posthog.warehouse.data_load.validate_schema.validate_schema", - ) - @patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"id": "String", "a_column": "String"}, - ) - def test_validate_schema_and_update_table_half_broken(self, mock_get_columns, mock_validate): - credential = self._create_datawarehouse_credential() - mock_validate.side_effect = [ - Exception, - { - "credential": credential, - "format": "Parquet", - "name": "test_schema", - "url_pattern": "test_url_pattern", - "team_id": self.team.pk, - }, - ] - - source = self._create_external_data_source() - job = self._create_external_data_job(source.pk, "Running") - - with self.settings(AIRBYTE_BUCKET_KEY="test-key", AIRBYTE_BUCKET_SECRET="test-secret"): - validate_schema_and_update_table( - run_id=job.pk, - team_id=self.team.pk, - schemas=["broken_schema", "test_schema"], - ) - - self.assertEqual(DataWarehouseTable.objects.filter(external_data_source_id=source.pk).count(), 1) diff --git a/posthog/warehouse/data_load/validate_schema.py b/posthog/warehouse/data_load/validate_schema.py index 8d4ed729b3713..1c9d43d0f7c9d 100644 --- a/posthog/warehouse/data_load/validate_schema.py +++ b/posthog/warehouse/data_load/validate_schema.py @@ -6,16 +6,22 @@ get_table_by_url_pattern_and_source, DataWarehouseTable, DataWarehouseCredential, - get_schema_if_exists, + aget_schema_if_exists, + get_external_data_job, + asave_datawarehousetable, + acreate_datawarehousetable, + asave_external_data_schema, ) from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.temporal.common.logger import bind_temporal_worker_logger -from asgiref.sync import async_to_sync from clickhouse_driver.errors import ServerException +from asgiref.sync import sync_to_async from typing import Dict -def validate_schema(credential: DataWarehouseCredential, table_name: str, new_url_pattern: str, team_id: int) -> Dict: +async def validate_schema( + credential: DataWarehouseCredential, table_name: str, new_url_pattern: str, team_id: int +) -> Dict: params = { "credential": credential, "name": table_name, @@ -25,7 +31,7 @@ def validate_schema(credential: DataWarehouseCredential, table_name: str, new_ur } table = DataWarehouseTable(**params) - table.columns = table.get_columns(safe_expose_ch_error=False) + table.columns = await sync_to_async(table.get_columns)(safe_expose_ch_error=False) # type: ignore return { "credential": credential, @@ -36,8 +42,7 @@ def validate_schema(credential: DataWarehouseCredential, table_name: str, new_ur } -# TODO: make async -def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[str]) -> None: +async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[str]) -> None: """ Validates the schemas of data that has been synced by external data job. @@ -49,12 +54,12 @@ def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[st schemas: The list of schemas that have been synced by the external data job """ - logger = async_to_sync(bind_temporal_worker_logger)(team_id=team_id) + logger = await bind_temporal_worker_logger(team_id=team_id) - job = ExternalDataJob.objects.get(pk=run_id) - last_successful_job = get_latest_run_if_exists(job.team_id, job.pipeline_id) + job: ExternalDataJob = await get_external_data_job(job_id=run_id) + last_successful_job: ExternalDataJob | None = await get_latest_run_if_exists(job.team_id, job.pipeline_id) - credential = get_or_create_datawarehouse_credential( + credential: DataWarehouseCredential = await get_or_create_datawarehouse_credential( team_id=job.team_id, access_key=settings.AIRBYTE_BUCKET_KEY, access_secret=settings.AIRBYTE_BUCKET_SECRET, @@ -66,7 +71,7 @@ def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[st # Check try: - data = validate_schema( + data = await validate_schema( credential=credential, table_name=table_name, new_url_pattern=new_url_pattern, team_id=team_id ) except ServerException as err: @@ -89,28 +94,31 @@ def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[st if last_successful_job: old_url_pattern = last_successful_job.url_pattern_by_schema(_schema_name) try: - table_created = get_table_by_url_pattern_and_source( + table_created = await get_table_by_url_pattern_and_source( team_id=job.team_id, source_id=job.pipeline.id, url_pattern=old_url_pattern ) except Exception: table_created = None else: table_created.url_pattern = new_url_pattern - table_created.save() + await asave_datawarehousetable(table_created) if not table_created: - table_created = DataWarehouseTable.objects.create(external_data_source_id=job.pipeline.id, **data) + table_created = await acreate_datawarehousetable(external_data_source_id=job.pipeline.id, **data) - table_created.columns = table_created.get_columns() - table_created.save() + # TODO: this should be async too + table_created.columns = await sync_to_async(table_created.get_columns)() # type: ignore + await asave_datawarehousetable(table_created) # schema could have been deleted by this point - schema_model = get_schema_if_exists(schema_name=_schema_name, team_id=job.team_id, source_id=job.pipeline.id) + schema_model = await aget_schema_if_exists( + schema_name=_schema_name, team_id=job.team_id, source_id=job.pipeline.id + ) if schema_model: schema_model.table = table_created schema_model.last_synced_at = job.created_at - schema_model.save() + await asave_external_data_schema(schema_model) if last_successful_job: try: diff --git a/posthog/warehouse/external_data_source/jobs.py b/posthog/warehouse/external_data_source/jobs.py index cb494bc42f0af..7370615e9e3e7 100644 --- a/posthog/warehouse/external_data_source/jobs.py +++ b/posthog/warehouse/external_data_source/jobs.py @@ -8,10 +8,6 @@ def get_external_data_source(team_id: str, external_data_source_id: str) -> Exte return ExternalDataSource.objects.get(team_id=team_id, id=external_data_source_id) -def get_external_data_job(team_id: str, run_id: str) -> ExternalDataJob: - return ExternalDataJob.objects.select_related("pipeline").get(id=run_id, team_id=team_id) - - def create_external_data_job( external_data_source_id: UUID, workflow_id: str, diff --git a/posthog/warehouse/models/credential.py b/posthog/warehouse/models/credential.py index 82e4c47660499..0be74816459d8 100644 --- a/posthog/warehouse/models/credential.py +++ b/posthog/warehouse/models/credential.py @@ -3,6 +3,7 @@ from posthog.models.team import Team from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr +from posthog.warehouse.util import database_sync_to_async class DataWarehouseCredential(CreatedMetaFields, UUIDModel): @@ -13,6 +14,7 @@ class DataWarehouseCredential(CreatedMetaFields, UUIDModel): __repr__ = sane_repr("access_key") +@database_sync_to_async def get_or_create_datawarehouse_credential(team_id, access_key, access_secret) -> DataWarehouseCredential: credential, _ = DataWarehouseCredential.objects.get_or_create( team_id=team_id, access_key=access_key, access_secret=access_secret diff --git a/posthog/warehouse/models/external_data_job.py b/posthog/warehouse/models/external_data_job.py index b9150ee3a64c7..bb357d3ef7211 100644 --- a/posthog/warehouse/models/external_data_job.py +++ b/posthog/warehouse/models/external_data_job.py @@ -4,6 +4,7 @@ from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr from posthog.warehouse.s3 import get_s3_client from uuid import UUID +from posthog.warehouse.util import database_sync_to_async class ExternalDataJob(CreatedMetaFields, UUIDModel): @@ -37,11 +38,18 @@ def delete_data_in_bucket(self) -> None: s3.delete(f"{settings.BUCKET_URL}/{self.folder_path}", recursive=True) +@database_sync_to_async +def get_external_data_job(job_id: UUID) -> ExternalDataJob: + return ExternalDataJob.objects.prefetch_related("pipeline").get(pk=job_id) + + +@database_sync_to_async def get_latest_run_if_exists(team_id: int, pipeline_id: UUID) -> ExternalDataJob | None: job = ( ExternalDataJob.objects.filter( team_id=team_id, pipeline_id=pipeline_id, status=ExternalDataJob.Status.COMPLETED ) + .prefetch_related("pipeline") .order_by("-created_at") .first() ) diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index 8a4ac00e81416..3d4423b24778e 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -5,6 +5,7 @@ import uuid import psycopg from django.conf import settings +from posthog.warehouse.util import database_sync_to_async class ExternalDataSchema(CreatedMetaFields, UUIDModel): @@ -25,11 +26,21 @@ class ExternalDataSchema(CreatedMetaFields, UUIDModel): __repr__ = sane_repr("name") +@database_sync_to_async +def asave_external_data_schema(schema: ExternalDataSchema) -> None: + schema.save() + + def get_schema_if_exists(schema_name: str, team_id: int, source_id: uuid.UUID) -> ExternalDataSchema | None: schema = ExternalDataSchema.objects.filter(team_id=team_id, source_id=source_id, name=schema_name).first() return schema +@database_sync_to_async +def aget_schema_if_exists(schema_name: str, team_id: int, source_id: uuid.UUID) -> ExternalDataSchema | None: + return get_schema_if_exists(schema_name=schema_name, team_id=team_id, source_id=source_id) + + def get_active_schemas_for_source_id(source_id: uuid.UUID, team_id: int): schemas = ExternalDataSchema.objects.filter(team_id=team_id, source_id=source_id, should_sync=True).values().all() return [val["name"] for val in schemas] diff --git a/posthog/warehouse/models/external_data_source.py b/posthog/warehouse/models/external_data_source.py index 667ba244aca99..df668c5abfc54 100644 --- a/posthog/warehouse/models/external_data_source.py +++ b/posthog/warehouse/models/external_data_source.py @@ -3,6 +3,8 @@ from posthog.models.team import Team from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr +from posthog.warehouse.util import database_sync_to_async +from uuid import UUID class ExternalDataSource(CreatedMetaFields, UUIDModel): @@ -31,3 +33,8 @@ class Status(models.TextChoices): prefix: models.CharField = models.CharField(max_length=100, null=True, blank=True) __repr__ = sane_repr("id") + + +@database_sync_to_async +def get_external_data_source(source_id: UUID) -> ExternalDataSource: + return ExternalDataSource.objects.get(pk=source_id) diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index f0d2640fd61a9..2f9b090f6c71a 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -24,6 +24,7 @@ from .credential import DataWarehouseCredential from uuid import UUID from sentry_sdk import capture_exception +from posthog.warehouse.util import database_sync_to_async CLICKHOUSE_HOGQL_MAPPING = { "UUID": StringDatabaseField, @@ -145,7 +146,18 @@ def _safe_expose_ch_error(self, err): raise Exception("Could not get columns") +@database_sync_to_async def get_table_by_url_pattern_and_source(url_pattern: str, source_id: UUID, team_id: int) -> DataWarehouseTable: return DataWarehouseTable.objects.filter(Q(deleted=False) | Q(deleted__isnull=True)).get( team_id=team_id, external_data_source_id=source_id, url_pattern=url_pattern ) + + +@database_sync_to_async +def acreate_datawarehousetable(**kwargs): + return DataWarehouseTable.objects.create(**kwargs) + + +@database_sync_to_async +def asave_datawarehousetable(table: DataWarehouseTable) -> None: + table.save() diff --git a/posthog/warehouse/util.py b/posthog/warehouse/util.py new file mode 100644 index 0000000000000..8ef5cf00cbbcf --- /dev/null +++ b/posthog/warehouse/util.py @@ -0,0 +1,20 @@ +# From django channels https://github.com/django/channels/blob/b6dc8c127d7bda3f5e5ae205332b1388818540c5/channels/db.py#L16 + +from asgiref.sync import SyncToAsync +from django.db import close_old_connections + + +class DatabaseSyncToAsync(SyncToAsync): + """ + SyncToAsync version that cleans up old database connections when it exits. + """ + + def thread_handler(self, loop, *args, **kwargs): + close_old_connections() + try: + return super().thread_handler(loop, *args, **kwargs) + finally: + close_old_connections() + + +database_sync_to_async = DatabaseSyncToAsync