diff --git a/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts index c39b308145fe11..439cf8d14c7d38 100644 --- a/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts +++ b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts @@ -1,6 +1,7 @@ import { actions, afterMount, kea, listeners, path, reducers, selectors } from 'kea' import { loaders } from 'kea-loaders' import api, { PaginatedResponse } from 'lib/api' +import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast' import { Scene } from 'scenes/sceneTypes' import { urls } from 'scenes/urls' @@ -80,8 +81,16 @@ export const dataWarehouseSettingsLogic = kea([ actions.loadingFinished(source) }, reloadSource: async ({ source }) => { - await api.externalDataSources.reload(source.id) - actions.loadSources() + try { + await api.externalDataSources.reload(source.id) + actions.loadSources() + } catch (e: any) { + if (e.message) { + lemonToast.error(e.message) + } else { + lemonToast.error('Cant refresh source at this time') + } + } actions.loadingFinished(source) }, updateSchema: async ({ schema }) => { diff --git a/posthog/celery.py b/posthog/celery.py index 7980df823600b1..a76fa36510f30b 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -355,6 +355,12 @@ def setup_periodic_tasks(sender: Celery, **kwargs): name="calculate external data rows synced", ) + sender.add_periodic_task( + crontab(minute="23", hour="*"), + calculate_external_data_rows_synced.s(), + name="check external data rows synced", + ) + # Set up clickhouse query instrumentation @task_prerun.connect @@ -1118,3 +1124,13 @@ def sync_datawarehouse_sources(): pass else: sync_resources() + + +@app.task(ignore_result=True) +def check_data_import_row_limits(): + try: + from posthog.tasks.warehouse import check_synced_row_limits + except ImportError: + pass + else: + check_synced_row_limits() diff --git a/posthog/tasks/test/test_warehouse.py b/posthog/tasks/test/test_warehouse.py index 20b669b7549956..01b5ac561f5dd2 100644 --- a/posthog/tasks/test/test_warehouse.py +++ b/posthog/tasks/test/test_warehouse.py @@ -5,8 +5,9 @@ _traverse_jobs_by_field, capture_workspace_rows_synced_by_team, check_external_data_source_billing_limit_by_team, + check_synced_row_limits_of_team, ) -from posthog.warehouse.models import ExternalDataSource +from posthog.warehouse.models import ExternalDataSource, ExternalDataJob from freezegun import freeze_time @@ -165,3 +166,33 @@ def test_external_data_source_billing_limit_activate( external_source.refresh_from_db() self.assertEqual(external_source.status, "running") + + @patch("posthog.tasks.warehouse.MONTHLY_LIMIT", 100) + @patch("posthog.tasks.warehouse.cancel_external_data_workflow") + @patch("posthog.tasks.warehouse.pause_external_data_schedule") + def test_check_synced_row_limits_of_team( + self, pause_schedule_mock: MagicMock, cancel_workflow_mock: MagicMock + ) -> None: + source = ExternalDataSource.objects.create( + source_id="test_id", + connection_id="fake connectino_id", + destination_id="fake destination_id", + team=self.team, + status="Running", + source_type="Stripe", + ) + + job = ExternalDataJob.objects.create( + pipeline=source, workflow_id="fake_workflow_id", team=self.team, status="Running", rows_synced=100000 + ) + + check_synced_row_limits_of_team(self.team.pk) + + source.refresh_from_db() + self.assertEqual(source.status, ExternalDataSource.Status.PAUSED) + + job.refresh_from_db() + self.assertEqual(job.status, ExternalDataJob.Status.CANCELLED) + + self.assertEqual(pause_schedule_mock.call_count, 1) + self.assertEqual(cancel_workflow_mock.call_count, 1) diff --git a/posthog/tasks/warehouse.py b/posthog/tasks/warehouse.py index 2450251830c597..5ab889fcd54a10 100644 --- a/posthog/tasks/warehouse.py +++ b/posthog/tasks/warehouse.py @@ -2,8 +2,12 @@ import datetime from posthog.models import Team from posthog.warehouse.external_data_source.client import send_request -from posthog.warehouse.models.external_data_source import ExternalDataSource -from posthog.warehouse.models import DataWarehouseCredential, DataWarehouseTable +from posthog.warehouse.data_load.service import ( + cancel_external_data_workflow, + pause_external_data_schedule, + unpause_external_data_schedule, +) +from posthog.warehouse.models import DataWarehouseCredential, DataWarehouseTable, ExternalDataSource, ExternalDataJob from posthog.warehouse.external_data_source.connection import retrieve_sync from urllib.parse import urlencode from posthog.ph_client import get_ph_client @@ -165,3 +169,55 @@ def _traverse_jobs_by_field( return _traverse_jobs_by_field(ph_client, team, response_next, field, acc) return acc + + +MONTHLY_LIMIT = 1_000_000 + + +def check_synced_row_limits() -> None: + team_ids = ExternalDataSource.objects.values_list("team", flat=True) + for team_id in team_ids: + check_synced_row_limits_of_team.delay(team_id) + + +@app.task(ignore_result=True) +def check_synced_row_limits_of_team(team_id: int) -> None: + logger.info("Checking synced row limits of team", team_id=team_id) + start_of_month = datetime.datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0) + rows_synced_list = [ + x + for x in ExternalDataJob.objects.filter(team_id=team_id, created_at__gte=start_of_month).values_list( + "rows_synced", flat=True + ) + if x + ] + total_rows_synced = sum(rows_synced_list) + + if total_rows_synced > MONTHLY_LIMIT: + running_jobs = ExternalDataJob.objects.filter(team_id=team_id, status=ExternalDataJob.Status.RUNNING) + for job in running_jobs: + try: + cancel_external_data_workflow(job.workflow_id) + except Exception as e: + logger.exception("Could not cancel external data workflow", exc_info=e) + + try: + pause_external_data_schedule(job.pipeline) + except Exception as e: + logger.exception("Could not pause external data schedule", exc_info=e) + + job.status = ExternalDataJob.Status.CANCELLED + job.save() + + job.pipeline.status = ExternalDataSource.Status.PAUSED + job.pipeline.save() + else: + all_sources = ExternalDataSource.objects.filter(team_id=team_id) + for source in all_sources: + try: + unpause_external_data_schedule(source) + except Exception as e: + logger.exception("Could not unpause external data schedule", exc_info=e) + + source.status = ExternalDataSource.Status.COMPLETED + source.save() diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index fc75f6beed4de7..cdb218c0cce31f 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -26,6 +26,7 @@ ) from posthog.temporal.common.logger import bind_temporal_worker_logger from typing import Tuple +import asyncio @dataclasses.dataclass @@ -151,11 +152,25 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None: stripe_secret_key = model.pipeline.job_inputs.get("stripe_secret_key", None) if not stripe_secret_key: raise ValueError(f"Stripe secret key not found for job {model.id}") - source = stripe_source(api_key=stripe_secret_key, endpoints=tuple(inputs.schemas)) + source = stripe_source( + api_key=stripe_secret_key, endpoints=tuple(inputs.schemas), job_id=str(model.id), team_id=inputs.team_id + ) else: raise ValueError(f"Source type {model.pipeline.source_type} not supported") - await DataImportPipeline(job_inputs, source, logger).run() + # Temp background heartbeat for now + async def heartbeat() -> None: + while True: + await asyncio.sleep(10) + activity.heartbeat() + + heartbeat_task = asyncio.create_task(heartbeat()) + + try: + await DataImportPipeline(job_inputs, source, logger).run() + finally: + heartbeat_task.cancel() + await asyncio.wait([heartbeat_task]) # TODO: update retry policies @@ -205,6 +220,7 @@ async def run(self, inputs: ExternalDataWorkflowInputs): job_inputs, start_to_close_timeout=dt.timedelta(minutes=90), retry_policy=RetryPolicy(maximum_attempts=5), + heartbeat_timeout=dt.timedelta(minutes=1), ) # check schema first diff --git a/posthog/temporal/data_imports/pipelines/helpers.py b/posthog/temporal/data_imports/pipelines/helpers.py new file mode 100644 index 00000000000000..753cce2ea9cb44 --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/helpers.py @@ -0,0 +1,38 @@ +from posthog.warehouse.models import ExternalDataJob +from django.db.models import F + +CHUNK_SIZE = 10_000 + + +def limit_paginated_generator(f): + """ + Limits the number of items returned by a paginated generator. + + Must wrap a function with args: + team_id: int, + job_id (ExternalDataJob): str + """ + + def wrapped(**kwargs): + job_id = kwargs.pop("job_id") + team_id = kwargs.pop("team_id") + + model = ExternalDataJob.objects.get(id=job_id, team_id=team_id) + gen = f(**kwargs) + count = 0 + for item in gen: + if count >= CHUNK_SIZE: + ExternalDataJob.objects.filter(id=job_id, team_id=team_id).update(rows_synced=F("rows_synced") + count) + count = 0 + + model.refresh_from_db() + + if model.status == ExternalDataJob.Status.CANCELLED: + break + + yield item + count += len(item) + + ExternalDataJob.objects.filter(id=job_id, team_id=team_id).update(rows_synced=F("rows_synced") + count) + + return wrapped diff --git a/posthog/temporal/data_imports/pipelines/stripe/helpers.py b/posthog/temporal/data_imports/pipelines/stripe/helpers.py index a6d71ed809a53d..81140f1518442b 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/helpers.py +++ b/posthog/temporal/data_imports/pipelines/stripe/helpers.py @@ -7,6 +7,7 @@ from dlt.common import pendulum from dlt.sources import DltResource from pendulum import DateTime +from posthog.temporal.data_imports.pipelines.helpers import limit_paginated_generator stripe.api_version = "2022-11-15" @@ -48,11 +49,10 @@ def stripe_get_data( return response +@limit_paginated_generator def stripe_pagination( api_key: str, endpoint: str, - start_date: Optional[Any] = None, - end_date: Optional[Any] = None, starting_after: Optional[str] = None, ): """ @@ -71,8 +71,6 @@ def stripe_pagination( response = stripe_get_data( api_key, endpoint, - start_date=start_date, - end_date=end_date, starting_after=starting_after, ) @@ -86,11 +84,7 @@ def stripe_pagination( @dlt.source(max_table_nesting=0) def stripe_source( - api_key: str, - endpoints: Tuple[str, ...], - start_date: Optional[Any] = None, - end_date: Optional[Any] = None, - starting_after: Optional[str] = None, + api_key: str, endpoints: Tuple[str, ...], job_id: str, team_id: int, starting_after: Optional[str] = None ) -> Iterable[DltResource]: for endpoint in endpoints: yield dlt.resource( @@ -100,7 +94,7 @@ def stripe_source( )( api_key=api_key, endpoint=endpoint, - start_date=start_date, - end_date=end_date, + job_id=job_id, + team_id=team_id, starting_after=starting_after, ) diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index c72cbd14e6c9d0..48f8babed4a5a6 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -18,12 +18,14 @@ delete_external_data_schedule, cancel_external_data_workflow, delete_data_import_folder, + is_any_external_data_job_paused, ) from posthog.warehouse.models import ExternalDataSource, ExternalDataSchema, ExternalDataJob from posthog.warehouse.api.external_data_schema import ExternalDataSchemaSerializer from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, ) +import temporalio logger = structlog.get_logger(__name__) @@ -118,6 +120,12 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: elif self.prefix_exists(source_type, prefix): return Response(status=status.HTTP_400_BAD_REQUEST, data={"message": "Prefix already exists"}) + if is_any_external_data_job_paused(self.team_id): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Monthly sync limit reached. Please contact PostHog support to increase your limit."}, + ) + # TODO: remove dummy vars new_source_model = ExternalDataSource.objects.create( source_id=str(uuid.uuid4()), @@ -140,7 +148,11 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: source=new_source_model, ) - sync_external_data_job_workflow(new_source_model, create=True) + try: + sync_external_data_job_workflow(new_source_model, create=True) + except Exception as e: + # Log error but don't fail because the source model was already created + logger.exception("Could not trigger external data job", exc_info=e) return Response(status=status.HTTP_201_CREATED, data={"id": new_source_model.pk}) @@ -185,7 +197,23 @@ def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response: @action(methods=["POST"], detail=True) def reload(self, request: Request, *args: Any, **kwargs: Any): instance = self.get_object() - trigger_external_data_workflow(instance) + + if is_any_external_data_job_paused(self.team_id): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Monthly sync limit reached. Please contact PostHog support to increase your limit."}, + ) + + try: + trigger_external_data_workflow(instance) + + except temporalio.service.RPCError as e: + # schedule doesn't exist + if e.message == "sql: no rows in result set": + sync_external_data_job_workflow(instance, create=True) + except Exception as e: + logger.exception("Could not trigger external data job", exc_info=e) + raise instance.status = "Running" instance.save() diff --git a/posthog/warehouse/data_load/service.py b/posthog/warehouse/data_load/service.py index 7a614b127145ce..d88ccae59bb65b 100644 --- a/posthog/warehouse/data_load/service.py +++ b/posthog/warehouse/data_load/service.py @@ -19,6 +19,7 @@ trigger_schedule, update_schedule, delete_schedule, + unpause_schedule, ) from posthog.temporal.data_imports.external_data_job import ( ExternalDataWorkflowInputs, @@ -73,11 +74,16 @@ def trigger_external_data_workflow(external_data_source: ExternalDataSource): trigger_schedule(temporal, schedule_id=str(external_data_source.id)) -def pause_external_data_workflow(external_data_source: ExternalDataSource): +def pause_external_data_schedule(external_data_source: ExternalDataSource): temporal = sync_connect() pause_schedule(temporal, schedule_id=str(external_data_source.id)) +def unpause_external_data_schedule(external_data_source: ExternalDataSource): + temporal = sync_connect() + unpause_schedule(temporal, schedule_id=str(external_data_source.id)) + + def delete_external_data_schedule(external_data_source: ExternalDataSource): temporal = sync_connect() try: @@ -107,3 +113,7 @@ def delete_data_import_folder(folder_path: str): ) bucket_name = settings.BUCKET_URL s3.delete(f"{bucket_name}/{folder_path}", recursive=True) + + +def is_any_external_data_job_paused(team_id: int) -> bool: + return ExternalDataSource.objects.filter(team_id=team_id, status=ExternalDataSource.Status.PAUSED).exists() diff --git a/posthog/warehouse/models/external_data_source.py b/posthog/warehouse/models/external_data_source.py index 06c8d8dddf7717..287a4a3f2cd995 100644 --- a/posthog/warehouse/models/external_data_source.py +++ b/posthog/warehouse/models/external_data_source.py @@ -9,6 +9,13 @@ class ExternalDataSource(CreatedMetaFields, UUIDModel): class Type(models.TextChoices): STRIPE = "Stripe", "Stripe" + class Status(models.TextChoices): + RUNNING = "Running", "Running" + PAUSED = "Paused", "Paused" + ERROR = "Error", "Error" + COMPLETED = "Completed", "Completed" + CANCELLED = "Cancelled", "Cancelled" + source_id: models.CharField = models.CharField(max_length=400) connection_id: models.CharField = models.CharField(max_length=400) destination_id: models.CharField = models.CharField(max_length=400, null=True, blank=True)