diff --git a/frontend/src/lib/integrations/IntegrationScopesWarning.tsx b/frontend/src/lib/integrations/IntegrationScopesWarning.tsx new file mode 100644 index 0000000000000..c9e6c7a61d764 --- /dev/null +++ b/frontend/src/lib/integrations/IntegrationScopesWarning.tsx @@ -0,0 +1,65 @@ +import api from 'lib/api' +import { LemonBanner } from 'lib/lemon-ui/LemonBanner' +import { Link } from 'lib/lemon-ui/Link' +import { useMemo } from 'react' + +import { HogFunctionInputSchemaType, IntegrationType } from '~/types' + +export function IntegrationScopesWarning({ + integration, + schema, +}: { + integration: IntegrationType + schema?: HogFunctionInputSchemaType +}): JSX.Element { + const getScopes = useMemo((): string[] => { + const scopes: any[] = [] + const possibleScopeLocation = [integration.config.scope, integration.config.scopes] + + possibleScopeLocation.map((scope) => { + if (typeof scope === 'string') { + scopes.push(scope.split(' ')) + scopes.push(scope.split(',')) + } + if (typeof scope === 'object') { + scopes.push(scope) + } + }) + return scopes + .filter((scope: any) => typeof scope === 'object') + .reduce((a, b) => (a.length > b.length ? a : b), []) + }, [integration.config]) + + const requiredScopes = schema?.requiredScopes?.split(' ') || [] + const missingScopes = requiredScopes.filter((scope: string) => !getScopes.includes(scope)) + + if (missingScopes.length === 0 || getScopes.length === 0) { + return <> + } + return ( +
+ + Required scopes are missing: [{missingScopes.join(', ')}]. + {integration.kind === 'hubspot' ? ( + + Note that some features may not be available on your current HubSpot plan. Check out{' '} + + this page + {' '} + for more details. + + ) : null} + +
+ ) +} diff --git a/frontend/src/lib/integrations/IntegrationView.tsx b/frontend/src/lib/integrations/IntegrationView.tsx index 31cd12e82eb40..80590299bda4d 100644 --- a/frontend/src/lib/integrations/IntegrationView.tsx +++ b/frontend/src/lib/integrations/IntegrationView.tsx @@ -1,15 +1,18 @@ import { LemonBanner } from '@posthog/lemon-ui' import api from 'lib/api' import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator' +import { IntegrationScopesWarning } from 'lib/integrations/IntegrationScopesWarning' -import { IntegrationType } from '~/types' +import { HogFunctionInputSchemaType, IntegrationType } from '~/types' export function IntegrationView({ integration, suffix, + schema, }: { integration: IntegrationType suffix?: JSX.Element + schema?: HogFunctionInputSchemaType }): JSX.Element { const errors = (integration.errors && integration.errors?.split(',')) || [] @@ -36,7 +39,7 @@ export function IntegrationView({ {suffix} - {errors.length > 0 && ( + {errors.length > 0 ? (
+ ) : ( + )} ) diff --git a/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx b/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx index f92b2f9123deb..e73b679afcd40 100644 --- a/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx +++ b/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx @@ -16,6 +16,7 @@ export function HogFunctionInputIntegration({ schema, ...props }: HogFunctionInp <> persistForUnload()} diff --git a/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx b/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx index cee61f7c80c88..334c17ee3d859 100644 --- a/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx +++ b/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx @@ -7,10 +7,13 @@ import { IntegrationView } from 'lib/integrations/IntegrationView' import { capitalizeFirstLetter } from 'lib/utils' import { urls } from 'scenes/urls' +import { HogFunctionInputSchemaType } from '~/types' + export type IntegrationConfigureProps = { value?: number onChange?: (value: number | null) => void redirectUrl?: string + schema?: HogFunctionInputSchemaType integration?: string beforeRedirect?: () => void } @@ -18,6 +21,7 @@ export type IntegrationConfigureProps = { export function IntegrationChoice({ onChange, value, + schema, integration, redirectUrl, beforeRedirect, @@ -124,5 +128,13 @@ export function IntegrationChoice({ ) - return <>{integrationKind ? : button} + return ( + <> + {integrationKind ? ( + + ) : ( + button + )} + + ) } diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 4747aff1f17c5..e2ac6193659cf 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -4605,6 +4605,7 @@ export type HogFunctionInputSchemaType = { integration?: string integration_key?: string integration_field?: 'slack_channel' + requiredScopes?: string } export type HogFunctionInputType = { diff --git a/plugin-server/src/cdp/types.ts b/plugin-server/src/cdp/types.ts index 8a675e605e017..e9d506a7a7823 100644 --- a/plugin-server/src/cdp/types.ts +++ b/plugin-server/src/cdp/types.ts @@ -272,6 +272,7 @@ export type HogFunctionInputSchemaType = { integration?: string integration_key?: string integration_field?: 'slack_channel' + requiredScopes?: string } export type HogFunctionTypeType = 'destination' | 'email' | 'sms' | 'push' | 'activity' | 'alert' | 'broadcast' diff --git a/posthog/api/survey.py b/posthog/api/survey.py index df100f8717b32..835860bb00906 100644 --- a/posthog/api/survey.py +++ b/posthog/api/survey.py @@ -325,7 +325,7 @@ def validate(self, data): if response_sampling_start_date < today_utc: raise serializers.ValidationError( { - "response_sampling_start_date": "Response sampling start date must be today or a future date in UTC." + "response_sampling_start_date": f"Response sampling start date must be today or a future date in UTC. Got {response_sampling_start_date} when current time is {today_utc}" } ) diff --git a/posthog/api/test/test_survey.py b/posthog/api/test/test_survey.py index c874d88abcfbc..35b0bb1cdc553 100644 --- a/posthog/api/test/test_survey.py +++ b/posthog/api/test/test_survey.py @@ -2378,6 +2378,7 @@ def test_can_clear_associated_actions(self): assert len(survey.actions.all()) == 0 +@freeze_time("2024-12-12 00:00:00") class TestSurveyResponseSampling(APIBaseTest): def _create_survey_with_sampling_limits( self, @@ -2407,6 +2408,7 @@ def _create_survey_with_sampling_limits( ) response_data = response.json() + assert response.status_code == status.HTTP_201_CREATED, response_data survey = Survey.objects.get(id=response_data["id"]) return survey diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index 59b217b4fc8f2..d17bb3b1b69c3 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -794,3 +794,19 @@ async def aupdate_batch_export_backfill_status(backfill_id: UUID, status: str) - raise ValueError(f"BatchExportBackfill with id {backfill_id} not found.") return await model.aget() + + +async def aupdate_records_total_count( + batch_export_id: UUID, interval_start: dt.datetime, interval_end: dt.datetime, count: int +) -> int: + """Update the expected records count for a set of batch export runs. + + Typically, there is one batch export run per batch export interval, however + there could be multiple if data has been backfilled. + """ + rows_updated = await BatchExportRun.objects.filter( + batch_export_id=batch_export_id, + data_interval_start=interval_start, + data_interval_end=interval_end, + ).aupdate(records_total_count=count) + return rows_updated diff --git a/posthog/batch_exports/sql.py b/posthog/batch_exports/sql.py index baa0216afdbbc..9a7fd0cea95aa 100644 --- a/posthog/batch_exports/sql.py +++ b/posthog/batch_exports/sql.py @@ -318,3 +318,22 @@ SETTINGS optimize_aggregation_in_order=1 ) """ + +# TODO: is this the best query to use? +EVENT_COUNT_BY_INTERVAL = """ +SELECT + toStartOfInterval(_inserted_at, INTERVAL {interval}) AS interval_start, + interval_start + INTERVAL {interval} AS interval_end, + COUNT(*) as total_count +FROM + events_batch_export_recent( + team_id={team_id}, + interval_start={overall_interval_start}, + interval_end={overall_interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) AS events +GROUP BY interval_start +ORDER BY interval_start desc +SETTINGS max_replica_delay_for_distributed_queries=1 +""" diff --git a/posthog/cdp/templates/google_ads/template_google_ads.py b/posthog/cdp/templates/google_ads/template_google_ads.py index 3743ca93db541..9cc61d507fe56 100644 --- a/posthog/cdp/templates/google_ads/template_google_ads.py +++ b/posthog/cdp/templates/google_ads/template_google_ads.py @@ -55,6 +55,7 @@ "type": "integration", "integration": "google-ads", "label": "Google Ads account", + "requiredScopes": "https://www.googleapis.com/auth/adwords https://www.googleapis.com/auth/userinfo.email", "secret": False, "required": True, }, diff --git a/posthog/cdp/templates/hubspot/template_hubspot.py b/posthog/cdp/templates/hubspot/template_hubspot.py index f8f6c9cf06a72..a36c850725972 100644 --- a/posthog/cdp/templates/hubspot/template_hubspot.py +++ b/posthog/cdp/templates/hubspot/template_hubspot.py @@ -61,6 +61,7 @@ "type": "integration", "integration": "hubspot", "label": "Hubspot connection", + "requiredScopes": "crm.objects.contacts.write crm.objects.contacts.read", "secret": False, "required": True, }, @@ -307,6 +308,7 @@ "type": "integration", "integration": "hubspot", "label": "Hubspot connection", + "requiredScopes": "analytics.behavioral_events.send behavioral_events.event_definitions.read_write", "secret": False, "required": True, }, diff --git a/posthog/cdp/templates/salesforce/template_salesforce.py b/posthog/cdp/templates/salesforce/template_salesforce.py index eedfd9980efb1..844c86ad15803 100644 --- a/posthog/cdp/templates/salesforce/template_salesforce.py +++ b/posthog/cdp/templates/salesforce/template_salesforce.py @@ -15,6 +15,7 @@ "type": "integration", "integration": "salesforce", "label": "Salesforce account", + "requiredScopes": "refresh_token full", "secret": False, "required": True, } diff --git a/posthog/cdp/templates/slack/template_slack.py b/posthog/cdp/templates/slack/template_slack.py index 16bb0383c1c0b..8cfb5a84101de 100644 --- a/posthog/cdp/templates/slack/template_slack.py +++ b/posthog/cdp/templates/slack/template_slack.py @@ -34,6 +34,7 @@ "type": "integration", "integration": "slack", "label": "Slack workspace", + "requiredScopes": "channels:read groups:read chat:write chat:write.customize", "secret": False, "required": True, }, diff --git a/posthog/cdp/validation.py b/posthog/cdp/validation.py index 0ca2fa353dc26..a0466d8128dab 100644 --- a/posthog/cdp/validation.py +++ b/posthog/cdp/validation.py @@ -65,6 +65,7 @@ class InputsSchemaItemSerializer(serializers.Serializer): integration = serializers.CharField(required=False) integration_key = serializers.CharField(required=False) integration_field = serializers.ChoiceField(choices=["slack_channel"], required=False) + requiredScopes = serializers.CharField(required=False) # TODO Validate choices if type=choice diff --git a/posthog/models/integration.py b/posthog/models/integration.py index d8e49cc5d67aa..f42f70da332d9 100644 --- a/posthog/models/integration.py +++ b/posthog/models/integration.py @@ -163,7 +163,7 @@ def oauth_config_for_kind(cls, kind: str) -> OauthConfig: authorize_url="https://app.hubspot.com/oauth/authorize", token_url="https://api.hubapi.com/oauth/v1/token", token_info_url="https://api.hubapi.com/oauth/v1/access-tokens/:access_token", - token_info_config_fields=["hub_id", "hub_domain", "user", "user_id"], + token_info_config_fields=["hub_id", "hub_domain", "user", "user_id", "scopes"], client_id=settings.HUBSPOT_APP_CLIENT_ID, client_secret=settings.HUBSPOT_APP_CLIENT_SECRET, scope="tickets crm.objects.contacts.write sales-email-read crm.objects.companies.read crm.objects.deals.read crm.objects.contacts.read crm.objects.quotes.read crm.objects.companies.write", @@ -187,7 +187,7 @@ def oauth_config_for_kind(cls, kind: str) -> OauthConfig: token_url="https://oauth2.googleapis.com/token", client_id=settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY, client_secret=settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET, - scope="https://www.googleapis.com/auth/adwords email", + scope="https://www.googleapis.com/auth/adwords https://www.googleapis.com/auth/userinfo.email", id_path="sub", name_path="email", ) diff --git a/posthog/models/test/test_integration_model.py b/posthog/models/test/test_integration_model.py index 456f085d9c2e9..d4184d9e0265a 100644 --- a/posthog/models/test/test_integration_model.py +++ b/posthog/models/test/test_integration_model.py @@ -120,7 +120,7 @@ def test_authorize_url_with_additional_authorize_params(self): url = OauthIntegration.authorize_url("google-ads", next="/projects/test") assert ( url - == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fadwords+email&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fgoogle-ads%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest&access_type=offline&prompt=consent" + == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fadwords+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fgoogle-ads%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest&access_type=offline&prompt=consent" ) @patch("posthog.models.integration.requests.post") @@ -199,6 +199,10 @@ def test_integration_fetches_info_from_token_info_url(self, mock_get, mock_post) "user": "user", "user_id": "user_id", "should_not": "be_saved", + "scopes": [ + "crm.objects.contacts.read", + "crm.objects.contacts.write", + ], } with freeze_time("2024-01-01T12:00:00Z"): @@ -219,6 +223,10 @@ def test_integration_fetches_info_from_token_info_url(self, mock_get, mock_post) "user": "user", "user_id": "user_id", "refreshed_at": 1704110400, + "scopes": [ + "crm.objects.contacts.read", + "crm.objects.contacts.write", + ], } assert integration.sensitive_config == { "access_token": "FAKES_ACCESS_TOKEN", diff --git a/posthog/temporal/batch_exports/__init__.py b/posthog/temporal/batch_exports/__init__.py index 33c1b200e6a97..a3616f1107c5b 100644 --- a/posthog/temporal/batch_exports/__init__.py +++ b/posthog/temporal/batch_exports/__init__.py @@ -17,6 +17,12 @@ HttpBatchExportWorkflow, insert_into_http_activity, ) +from posthog.temporal.batch_exports.monitoring import ( + BatchExportMonitoringWorkflow, + get_batch_export, + get_event_counts, + update_batch_export_runs, +) from posthog.temporal.batch_exports.noop import NoOpWorkflow, noop_activity from posthog.temporal.batch_exports.postgres_batch_export import ( PostgresBatchExportWorkflow, @@ -54,6 +60,7 @@ SnowflakeBatchExportWorkflow, HttpBatchExportWorkflow, SquashPersonOverridesWorkflow, + BatchExportMonitoringWorkflow, ] ACTIVITIES = [ @@ -76,4 +83,7 @@ update_batch_export_backfill_model_status, wait_for_mutation, wait_for_table, + get_batch_export, + get_event_counts, + update_batch_export_runs, ] diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index edaf13d1888af..30c600d210802 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -391,7 +391,17 @@ async def amerge_person_tables( merge_query = f""" MERGE `{final_table.full_table_id.replace(":", ".", 1)}` final - USING `{stage_table.full_table_id.replace(":", ".", 1)}` stage + USING ( + SELECT * FROM + ( + SELECT + *, + ROW_NUMBER() OVER (PARTITION BY {",".join(field.name for field in merge_key)}) row_num + FROM + `{stage_table.full_table_id.replace(":", ".", 1)}` + ) + WHERE row_num = 1 + ) stage {merge_condition} WHEN MATCHED AND (stage.`{person_version_key}` > final.`{person_version_key}` OR stage.`{person_distinct_id_version_key}` > final.`{person_distinct_id_version_key}`) THEN diff --git a/posthog/temporal/batch_exports/monitoring.py b/posthog/temporal/batch_exports/monitoring.py new file mode 100644 index 0000000000000..97eaf6c2430d9 --- /dev/null +++ b/posthog/temporal/batch_exports/monitoring.py @@ -0,0 +1,227 @@ +import datetime as dt +import json +from dataclasses import dataclass +from uuid import UUID + +from temporalio import activity, workflow +from temporalio.common import RetryPolicy + +from posthog.batch_exports.models import BatchExport +from posthog.batch_exports.service import aupdate_records_total_count +from posthog.batch_exports.sql import EVENT_COUNT_BY_INTERVAL +from posthog.temporal.batch_exports.base import PostHogWorkflow +from posthog.temporal.common.clickhouse import get_client +from posthog.temporal.common.heartbeat import Heartbeater + + +class BatchExportNotFoundError(Exception): + """Exception raised when batch export is not found.""" + + def __init__(self, batch_export_id: UUID): + super().__init__(f"Batch export with id {batch_export_id} not found") + + +class NoValidBatchExportsFoundError(Exception): + """Exception raised when no valid batch export is found.""" + + def __init__(self, message: str = "No valid batch exports found"): + super().__init__(message) + + +@dataclass +class BatchExportMonitoringInputs: + """Inputs for the BatchExportMonitoringWorkflow. + + Attributes: + batch_export_id: The batch export id to monitor. + """ + + batch_export_id: UUID + + +@dataclass +class BatchExportDetails: + id: UUID + team_id: int + interval: str + exclude_events: list[str] + include_events: list[str] + + +@activity.defn +async def get_batch_export(batch_export_id: UUID) -> BatchExportDetails: + """Fetch a batch export from the database and return its details.""" + batch_export = ( + await BatchExport.objects.filter(id=batch_export_id, model="events", paused=False, deleted=False) + .prefetch_related("destination") + .afirst() + ) + if batch_export is None: + raise BatchExportNotFoundError(batch_export_id) + if batch_export.deleted is True: + raise NoValidBatchExportsFoundError("Batch export has been deleted") + if batch_export.paused is True: + raise NoValidBatchExportsFoundError("Batch export is paused") + if batch_export.model != "events": + raise NoValidBatchExportsFoundError("Batch export model is not 'events'") + if batch_export.interval_time_delta != dt.timedelta(minutes=5): + raise NoValidBatchExportsFoundError( + "Only batch exports with interval of 5 minutes are supported for monitoring at this time." + ) + config = batch_export.destination.config + return BatchExportDetails( + id=batch_export.id, + team_id=batch_export.team_id, + interval=batch_export.interval, + exclude_events=config.get("exclude_events", []), + include_events=config.get("include_events", []), + ) + + +@dataclass +class GetEventCountsInputs: + team_id: int + interval: str + overall_interval_start: str + overall_interval_end: str + exclude_events: list[str] + include_events: list[str] + + +@dataclass +class EventCountsOutput: + interval_start: str + interval_end: str + count: int + + +@dataclass +class GetEventCountsOutputs: + results: list[EventCountsOutput] + + +@activity.defn +async def get_event_counts(inputs: GetEventCountsInputs) -> GetEventCountsOutputs: + """Get the total number of events for a given team over a set of time intervals.""" + + query = EVENT_COUNT_BY_INTERVAL + + interval = inputs.interval + # we check interval is "every 5 minutes" above but double check here + if not interval.startswith("every 5 minutes"): + raise NoValidBatchExportsFoundError( + "Only intervals of 'every 5 minutes' are supported for monitoring at this time." + ) + _, value, unit = interval.split(" ") + interval = f"{value} {unit}" + + query_params = { + "team_id": inputs.team_id, + "interval": interval, + "overall_interval_start": inputs.overall_interval_start, + "overall_interval_end": inputs.overall_interval_end, + "include_events": inputs.include_events, + "exclude_events": inputs.exclude_events, + } + async with Heartbeater(), get_client() as client: + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + + response = await client.read_query(query, query_params) + results = [] + for line in response.decode("utf-8").splitlines(): + interval_start, interval_end, count = line.strip().split("\t") + results.append( + EventCountsOutput(interval_start=interval_start, interval_end=interval_end, count=int(count)) + ) + + return GetEventCountsOutputs(results=results) + + +@dataclass +class UpdateBatchExportRunsInputs: + batch_export_id: UUID + results: list[EventCountsOutput] + + +@activity.defn +async def update_batch_export_runs(inputs: UpdateBatchExportRunsInputs) -> int: + """Update BatchExportRuns with the expected number of events.""" + + total_rows_updated = 0 + async with Heartbeater(): + for result in inputs.results: + total_rows_updated += await aupdate_records_total_count( + batch_export_id=inputs.batch_export_id, + interval_start=dt.datetime.strptime(result.interval_start, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC), + interval_end=dt.datetime.strptime(result.interval_end, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC), + count=result.count, + ) + activity.logger.info(f"Updated {total_rows_updated} BatchExportRuns") + return total_rows_updated + + +@workflow.defn(name="batch-export-monitoring") +class BatchExportMonitoringWorkflow(PostHogWorkflow): + """Workflow to monitor batch exports. + + We have had some issues with batch exports in the past, where some events + have been missing. The purpose of this workflow is to monitor the status of + batch exports for a given customer by reconciling the number of exported + events with the number of events in ClickHouse for a given interval. + """ + + @staticmethod + def parse_inputs(inputs: list[str]) -> BatchExportMonitoringInputs: + """Parse inputs from the management command CLI.""" + loaded = json.loads(inputs[0]) + return BatchExportMonitoringInputs(**loaded) + + @workflow.run + async def run(self, inputs: BatchExportMonitoringInputs): + """Workflow implementation to monitor batch exports for a given team.""" + # TODO - check if this is the right way to do logging since there seems to be a few different ways + workflow.logger.info( + "Starting batch exports monitoring workflow for batch export id %s", inputs.batch_export_id + ) + + batch_export_details = await workflow.execute_activity( + get_batch_export, + inputs.batch_export_id, + start_to_close_timeout=dt.timedelta(minutes=1), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=20), + non_retryable_error_types=["BatchExportNotFoundError", "NoValidBatchExportsFoundError"], + ), + ) + + # time interval to check is not the previous hour but the hour before that + # (just to ensure all recent batch exports have run successfully) + now = dt.datetime.now(tz=dt.UTC) + interval_end = now.replace(minute=0, second=0, microsecond=0) - dt.timedelta(hours=1) + interval_start = interval_end - dt.timedelta(hours=1) + interval_end_str = interval_end.strftime("%Y-%m-%d %H:%M:%S") + interval_start_str = interval_start.strftime("%Y-%m-%d %H:%M:%S") + + total_events = await workflow.execute_activity( + get_event_counts, + GetEventCountsInputs( + team_id=batch_export_details.team_id, + interval=batch_export_details.interval, + overall_interval_start=interval_start_str, + overall_interval_end=interval_end_str, + exclude_events=batch_export_details.exclude_events, + include_events=batch_export_details.include_events, + ), + start_to_close_timeout=dt.timedelta(hours=1), + retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)), + heartbeat_timeout=dt.timedelta(minutes=1), + ) + + return await workflow.execute_activity( + update_batch_export_runs, + UpdateBatchExportRunsInputs(batch_export_id=batch_export_details.id, results=total_events.results), + start_to_close_timeout=dt.timedelta(hours=1), + retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)), + heartbeat_timeout=dt.timedelta(minutes=1), + ) diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py index 67c321205a14f..7044d8fe96868 100644 --- a/posthog/temporal/tests/batch_exports/conftest.py +++ b/posthog/temporal/tests/batch_exports/conftest.py @@ -152,8 +152,8 @@ async def create_clickhouse_tables_and_views(clickhouse_client, django_db_setup) from posthog.batch_exports.sql import ( CREATE_EVENTS_BATCH_EXPORT_VIEW, CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, - CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, CREATE_EVENTS_BATCH_EXPORT_VIEW_RECENT, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, CREATE_PERSONS_BATCH_EXPORT_VIEW, CREATE_PERSONS_BATCH_EXPORT_VIEW_BACKFILL, ) @@ -211,8 +211,12 @@ def data_interval_start(request, data_interval_end, interval): @pytest.fixture -def data_interval_end(interval): +def data_interval_end(request, interval): """Set a test data interval end.""" + try: + return request.param + except AttributeError: + pass return dt.datetime(2023, 4, 25, 15, 0, 0, tzinfo=dt.UTC) diff --git a/posthog/temporal/tests/batch_exports/test_monitoring.py b/posthog/temporal/tests/batch_exports/test_monitoring.py new file mode 100644 index 0000000000000..cab50c25d3177 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_monitoring.py @@ -0,0 +1,201 @@ +import datetime as dt +import uuid + +import pytest +import pytest_asyncio +from temporalio.common import RetryPolicy +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import UnsandboxedWorkflowRunner, Worker + +from posthog import constants +from posthog.batch_exports.models import BatchExportRun +from posthog.temporal.batch_exports.monitoring import ( + BatchExportMonitoringInputs, + BatchExportMonitoringWorkflow, + get_batch_export, + get_event_counts, + update_batch_export_runs, +) +from posthog.temporal.tests.utils.models import ( + acreate_batch_export, + adelete_batch_export, + afetch_batch_export_runs, +) + +pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] + +GENERATE_TEST_DATA_END = dt.datetime.now(tz=dt.UTC).replace( + minute=0, second=0, microsecond=0, tzinfo=dt.UTC +) - dt.timedelta(hours=1) +GENERATE_TEST_DATA_START = GENERATE_TEST_DATA_END - dt.timedelta(hours=1) + + +@pytest_asyncio.fixture +async def batch_export(ateam, temporal_client): + """Provide a batch export for tests, not intended to be used.""" + destination_data = { + "type": "S3", + "config": { + "bucket_name": "a-bucket", + "region": "us-east-1", + "prefix": "a-key", + "aws_access_key_id": "object_storage_root_user", + "aws_secret_access_key": "object_storage_root_password", + }, + } + + batch_export_data = { + "name": "my-production-s3-bucket-destination", + "destination": destination_data, + "interval": "every 5 minutes", + } + + batch_export = await acreate_batch_export( + team_id=ateam.pk, + name=batch_export_data["name"], # type: ignore + destination_data=batch_export_data["destination"], # type: ignore + interval=batch_export_data["interval"], # type: ignore + ) + + yield batch_export + + await adelete_batch_export(batch_export, temporal_client) + + +@pytest_asyncio.fixture +async def generate_batch_export_runs( + generate_test_data, + data_interval_start: dt.datetime, + data_interval_end: dt.datetime, + interval: str, + batch_export, +): + # to keep things simple for now, we assume 5 min interval + if interval != "every 5 minutes": + raise NotImplementedError("Only 5 minute intervals are supported for now. Please update the test.") + + events_created, _ = generate_test_data + + batch_export_runs: list[BatchExportRun] = [] + interval_start = data_interval_start + interval_end = interval_start + dt.timedelta(minutes=5) + while interval_end <= data_interval_end: + run = BatchExportRun( + batch_export_id=batch_export.id, + data_interval_start=interval_start, + data_interval_end=interval_end, + status="completed", + records_completed=len( + [ + e + for e in events_created + if interval_start + <= dt.datetime.fromisoformat(e["inserted_at"]).replace(tzinfo=dt.UTC) + < interval_end + ] + ), + ) + await run.asave() + batch_export_runs.append(run) + interval_start = interval_end + interval_end += dt.timedelta(minutes=5) + + yield + + for run in batch_export_runs: + await run.adelete() + + +async def test_monitoring_workflow_when_no_event_data(batch_export): + workflow_id = str(uuid.uuid4()) + inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id) + async with await WorkflowEnvironment.start_time_skipping() as activity_environment: + async with Worker( + activity_environment.client, + # TODO - not sure if this is the right task queue + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + workflows=[BatchExportMonitoringWorkflow], + activities=[ + get_batch_export, + get_event_counts, + update_batch_export_runs, + ], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + batch_export_runs_updated = await activity_environment.client.execute_workflow( + BatchExportMonitoringWorkflow.run, + inputs, + id=workflow_id, + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=30), + ) + assert batch_export_runs_updated == 0 + + +@pytest.mark.parametrize( + "data_interval_start", + [GENERATE_TEST_DATA_START], + indirect=True, +) +@pytest.mark.parametrize( + "data_interval_end", + [GENERATE_TEST_DATA_END], + indirect=True, +) +@pytest.mark.parametrize( + "interval", + ["every 5 minutes"], + indirect=True, +) +async def test_monitoring_workflow( + batch_export, + generate_test_data, + data_interval_start, + data_interval_end, + interval, + generate_batch_export_runs, +): + """Test the monitoring workflow with a batch export that has data. + + We generate 2 hours of data between 13:00 and 15:00, and then run the + monitoring workflow at 15:30. The monitoring workflow should check the data + between 14:00 and 15:00, and update the batch export runs. + + We generate some dummy batch export runs based on the event data we + generated and assert that the expected records count matches the records + completed. + """ + workflow_id = str(uuid.uuid4()) + inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id) + async with await WorkflowEnvironment.start_time_skipping() as activity_environment: + async with Worker( + activity_environment.client, + # TODO - not sure if this is the right task queue + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + workflows=[BatchExportMonitoringWorkflow], + activities=[ + get_batch_export, + get_event_counts, + update_batch_export_runs, + ], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + await activity_environment.client.execute_workflow( + BatchExportMonitoringWorkflow.run, + inputs, + id=workflow_id, + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=30), + ) + + batch_export_runs = await afetch_batch_export_runs(batch_export_id=batch_export.id) + + for run in batch_export_runs: + if run.records_completed == 0: + # TODO: in the actual monitoring activity it would be better to + # update the actual count to 0 rather than None + assert run.records_total_count is None + else: + assert run.records_completed == run.records_total_count