+
+ Required scopes are missing: [{missingScopes.join(', ')}].
+ {integration.kind === 'hubspot' ? (
+
+ Note that some features may not be available on your current HubSpot plan. Check out{' '}
+
+ this page
+ {' '}
+ for more details.
+
+ ) : null}
+
+
+ )
+}
diff --git a/frontend/src/lib/integrations/IntegrationView.tsx b/frontend/src/lib/integrations/IntegrationView.tsx
index 31cd12e82eb40..80590299bda4d 100644
--- a/frontend/src/lib/integrations/IntegrationView.tsx
+++ b/frontend/src/lib/integrations/IntegrationView.tsx
@@ -1,15 +1,18 @@
import { LemonBanner } from '@posthog/lemon-ui'
import api from 'lib/api'
import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator'
+import { IntegrationScopesWarning } from 'lib/integrations/IntegrationScopesWarning'
-import { IntegrationType } from '~/types'
+import { HogFunctionInputSchemaType, IntegrationType } from '~/types'
export function IntegrationView({
integration,
suffix,
+ schema,
}: {
integration: IntegrationType
suffix?: JSX.Element
+ schema?: HogFunctionInputSchemaType
}): JSX.Element {
const errors = (integration.errors && integration.errors?.split(',')) || []
@@ -36,7 +39,7 @@ export function IntegrationView({
{suffix}
- {errors.length > 0 && (
+ {errors.length > 0 ? (
+ ) : (
+
)}
)
diff --git a/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx b/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx
index f92b2f9123deb..e73b679afcd40 100644
--- a/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx
+++ b/frontend/src/scenes/pipeline/hogfunctions/integrations/HogFunctionInputIntegration.tsx
@@ -16,6 +16,7 @@ export function HogFunctionInputIntegration({ schema, ...props }: HogFunctionInp
<>
persistForUnload()}
diff --git a/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx b/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx
index cee61f7c80c88..334c17ee3d859 100644
--- a/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx
+++ b/frontend/src/scenes/pipeline/hogfunctions/integrations/IntegrationChoice.tsx
@@ -7,10 +7,13 @@ import { IntegrationView } from 'lib/integrations/IntegrationView'
import { capitalizeFirstLetter } from 'lib/utils'
import { urls } from 'scenes/urls'
+import { HogFunctionInputSchemaType } from '~/types'
+
export type IntegrationConfigureProps = {
value?: number
onChange?: (value: number | null) => void
redirectUrl?: string
+ schema?: HogFunctionInputSchemaType
integration?: string
beforeRedirect?: () => void
}
@@ -18,6 +21,7 @@ export type IntegrationConfigureProps = {
export function IntegrationChoice({
onChange,
value,
+ schema,
integration,
redirectUrl,
beforeRedirect,
@@ -124,5 +128,13 @@ export function IntegrationChoice({
)
- return <>{integrationKind ? : button}>
+ return (
+ <>
+ {integrationKind ? (
+
+ ) : (
+ button
+ )}
+ >
+ )
}
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index 4747aff1f17c5..e2ac6193659cf 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -4605,6 +4605,7 @@ export type HogFunctionInputSchemaType = {
integration?: string
integration_key?: string
integration_field?: 'slack_channel'
+ requiredScopes?: string
}
export type HogFunctionInputType = {
diff --git a/plugin-server/src/cdp/types.ts b/plugin-server/src/cdp/types.ts
index 8a675e605e017..e9d506a7a7823 100644
--- a/plugin-server/src/cdp/types.ts
+++ b/plugin-server/src/cdp/types.ts
@@ -272,6 +272,7 @@ export type HogFunctionInputSchemaType = {
integration?: string
integration_key?: string
integration_field?: 'slack_channel'
+ requiredScopes?: string
}
export type HogFunctionTypeType = 'destination' | 'email' | 'sms' | 'push' | 'activity' | 'alert' | 'broadcast'
diff --git a/posthog/api/survey.py b/posthog/api/survey.py
index df100f8717b32..835860bb00906 100644
--- a/posthog/api/survey.py
+++ b/posthog/api/survey.py
@@ -325,7 +325,7 @@ def validate(self, data):
if response_sampling_start_date < today_utc:
raise serializers.ValidationError(
{
- "response_sampling_start_date": "Response sampling start date must be today or a future date in UTC."
+ "response_sampling_start_date": f"Response sampling start date must be today or a future date in UTC. Got {response_sampling_start_date} when current time is {today_utc}"
}
)
diff --git a/posthog/api/test/test_survey.py b/posthog/api/test/test_survey.py
index c874d88abcfbc..35b0bb1cdc553 100644
--- a/posthog/api/test/test_survey.py
+++ b/posthog/api/test/test_survey.py
@@ -2378,6 +2378,7 @@ def test_can_clear_associated_actions(self):
assert len(survey.actions.all()) == 0
+@freeze_time("2024-12-12 00:00:00")
class TestSurveyResponseSampling(APIBaseTest):
def _create_survey_with_sampling_limits(
self,
@@ -2407,6 +2408,7 @@ def _create_survey_with_sampling_limits(
)
response_data = response.json()
+ assert response.status_code == status.HTTP_201_CREATED, response_data
survey = Survey.objects.get(id=response_data["id"])
return survey
diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py
index 59b217b4fc8f2..d17bb3b1b69c3 100644
--- a/posthog/batch_exports/service.py
+++ b/posthog/batch_exports/service.py
@@ -794,3 +794,19 @@ async def aupdate_batch_export_backfill_status(backfill_id: UUID, status: str) -
raise ValueError(f"BatchExportBackfill with id {backfill_id} not found.")
return await model.aget()
+
+
+async def aupdate_records_total_count(
+ batch_export_id: UUID, interval_start: dt.datetime, interval_end: dt.datetime, count: int
+) -> int:
+ """Update the expected records count for a set of batch export runs.
+
+ Typically, there is one batch export run per batch export interval, however
+ there could be multiple if data has been backfilled.
+ """
+ rows_updated = await BatchExportRun.objects.filter(
+ batch_export_id=batch_export_id,
+ data_interval_start=interval_start,
+ data_interval_end=interval_end,
+ ).aupdate(records_total_count=count)
+ return rows_updated
diff --git a/posthog/batch_exports/sql.py b/posthog/batch_exports/sql.py
index baa0216afdbbc..9a7fd0cea95aa 100644
--- a/posthog/batch_exports/sql.py
+++ b/posthog/batch_exports/sql.py
@@ -318,3 +318,22 @@
SETTINGS optimize_aggregation_in_order=1
)
"""
+
+# TODO: is this the best query to use?
+EVENT_COUNT_BY_INTERVAL = """
+SELECT
+ toStartOfInterval(_inserted_at, INTERVAL {interval}) AS interval_start,
+ interval_start + INTERVAL {interval} AS interval_end,
+ COUNT(*) as total_count
+FROM
+ events_batch_export_recent(
+ team_id={team_id},
+ interval_start={overall_interval_start},
+ interval_end={overall_interval_end},
+ include_events={include_events}::Array(String),
+ exclude_events={exclude_events}::Array(String)
+ ) AS events
+GROUP BY interval_start
+ORDER BY interval_start desc
+SETTINGS max_replica_delay_for_distributed_queries=1
+"""
diff --git a/posthog/cdp/templates/google_ads/template_google_ads.py b/posthog/cdp/templates/google_ads/template_google_ads.py
index 3743ca93db541..9cc61d507fe56 100644
--- a/posthog/cdp/templates/google_ads/template_google_ads.py
+++ b/posthog/cdp/templates/google_ads/template_google_ads.py
@@ -55,6 +55,7 @@
"type": "integration",
"integration": "google-ads",
"label": "Google Ads account",
+ "requiredScopes": "https://www.googleapis.com/auth/adwords https://www.googleapis.com/auth/userinfo.email",
"secret": False,
"required": True,
},
diff --git a/posthog/cdp/templates/hubspot/template_hubspot.py b/posthog/cdp/templates/hubspot/template_hubspot.py
index f8f6c9cf06a72..a36c850725972 100644
--- a/posthog/cdp/templates/hubspot/template_hubspot.py
+++ b/posthog/cdp/templates/hubspot/template_hubspot.py
@@ -61,6 +61,7 @@
"type": "integration",
"integration": "hubspot",
"label": "Hubspot connection",
+ "requiredScopes": "crm.objects.contacts.write crm.objects.contacts.read",
"secret": False,
"required": True,
},
@@ -307,6 +308,7 @@
"type": "integration",
"integration": "hubspot",
"label": "Hubspot connection",
+ "requiredScopes": "analytics.behavioral_events.send behavioral_events.event_definitions.read_write",
"secret": False,
"required": True,
},
diff --git a/posthog/cdp/templates/salesforce/template_salesforce.py b/posthog/cdp/templates/salesforce/template_salesforce.py
index eedfd9980efb1..844c86ad15803 100644
--- a/posthog/cdp/templates/salesforce/template_salesforce.py
+++ b/posthog/cdp/templates/salesforce/template_salesforce.py
@@ -15,6 +15,7 @@
"type": "integration",
"integration": "salesforce",
"label": "Salesforce account",
+ "requiredScopes": "refresh_token full",
"secret": False,
"required": True,
}
diff --git a/posthog/cdp/templates/slack/template_slack.py b/posthog/cdp/templates/slack/template_slack.py
index 16bb0383c1c0b..8cfb5a84101de 100644
--- a/posthog/cdp/templates/slack/template_slack.py
+++ b/posthog/cdp/templates/slack/template_slack.py
@@ -34,6 +34,7 @@
"type": "integration",
"integration": "slack",
"label": "Slack workspace",
+ "requiredScopes": "channels:read groups:read chat:write chat:write.customize",
"secret": False,
"required": True,
},
diff --git a/posthog/cdp/validation.py b/posthog/cdp/validation.py
index 0ca2fa353dc26..a0466d8128dab 100644
--- a/posthog/cdp/validation.py
+++ b/posthog/cdp/validation.py
@@ -65,6 +65,7 @@ class InputsSchemaItemSerializer(serializers.Serializer):
integration = serializers.CharField(required=False)
integration_key = serializers.CharField(required=False)
integration_field = serializers.ChoiceField(choices=["slack_channel"], required=False)
+ requiredScopes = serializers.CharField(required=False)
# TODO Validate choices if type=choice
diff --git a/posthog/models/integration.py b/posthog/models/integration.py
index d8e49cc5d67aa..f42f70da332d9 100644
--- a/posthog/models/integration.py
+++ b/posthog/models/integration.py
@@ -163,7 +163,7 @@ def oauth_config_for_kind(cls, kind: str) -> OauthConfig:
authorize_url="https://app.hubspot.com/oauth/authorize",
token_url="https://api.hubapi.com/oauth/v1/token",
token_info_url="https://api.hubapi.com/oauth/v1/access-tokens/:access_token",
- token_info_config_fields=["hub_id", "hub_domain", "user", "user_id"],
+ token_info_config_fields=["hub_id", "hub_domain", "user", "user_id", "scopes"],
client_id=settings.HUBSPOT_APP_CLIENT_ID,
client_secret=settings.HUBSPOT_APP_CLIENT_SECRET,
scope="tickets crm.objects.contacts.write sales-email-read crm.objects.companies.read crm.objects.deals.read crm.objects.contacts.read crm.objects.quotes.read crm.objects.companies.write",
@@ -187,7 +187,7 @@ def oauth_config_for_kind(cls, kind: str) -> OauthConfig:
token_url="https://oauth2.googleapis.com/token",
client_id=settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY,
client_secret=settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET,
- scope="https://www.googleapis.com/auth/adwords email",
+ scope="https://www.googleapis.com/auth/adwords https://www.googleapis.com/auth/userinfo.email",
id_path="sub",
name_path="email",
)
diff --git a/posthog/models/test/test_integration_model.py b/posthog/models/test/test_integration_model.py
index 456f085d9c2e9..d4184d9e0265a 100644
--- a/posthog/models/test/test_integration_model.py
+++ b/posthog/models/test/test_integration_model.py
@@ -120,7 +120,7 @@ def test_authorize_url_with_additional_authorize_params(self):
url = OauthIntegration.authorize_url("google-ads", next="/projects/test")
assert (
url
- == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fadwords+email&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fgoogle-ads%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest&access_type=offline&prompt=consent"
+ == "https://accounts.google.com/o/oauth2/v2/auth?client_id=google-client-id&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fadwords+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fgoogle-ads%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest&access_type=offline&prompt=consent"
)
@patch("posthog.models.integration.requests.post")
@@ -199,6 +199,10 @@ def test_integration_fetches_info_from_token_info_url(self, mock_get, mock_post)
"user": "user",
"user_id": "user_id",
"should_not": "be_saved",
+ "scopes": [
+ "crm.objects.contacts.read",
+ "crm.objects.contacts.write",
+ ],
}
with freeze_time("2024-01-01T12:00:00Z"):
@@ -219,6 +223,10 @@ def test_integration_fetches_info_from_token_info_url(self, mock_get, mock_post)
"user": "user",
"user_id": "user_id",
"refreshed_at": 1704110400,
+ "scopes": [
+ "crm.objects.contacts.read",
+ "crm.objects.contacts.write",
+ ],
}
assert integration.sensitive_config == {
"access_token": "FAKES_ACCESS_TOKEN",
diff --git a/posthog/temporal/batch_exports/__init__.py b/posthog/temporal/batch_exports/__init__.py
index 33c1b200e6a97..a3616f1107c5b 100644
--- a/posthog/temporal/batch_exports/__init__.py
+++ b/posthog/temporal/batch_exports/__init__.py
@@ -17,6 +17,12 @@
HttpBatchExportWorkflow,
insert_into_http_activity,
)
+from posthog.temporal.batch_exports.monitoring import (
+ BatchExportMonitoringWorkflow,
+ get_batch_export,
+ get_event_counts,
+ update_batch_export_runs,
+)
from posthog.temporal.batch_exports.noop import NoOpWorkflow, noop_activity
from posthog.temporal.batch_exports.postgres_batch_export import (
PostgresBatchExportWorkflow,
@@ -54,6 +60,7 @@
SnowflakeBatchExportWorkflow,
HttpBatchExportWorkflow,
SquashPersonOverridesWorkflow,
+ BatchExportMonitoringWorkflow,
]
ACTIVITIES = [
@@ -76,4 +83,7 @@
update_batch_export_backfill_model_status,
wait_for_mutation,
wait_for_table,
+ get_batch_export,
+ get_event_counts,
+ update_batch_export_runs,
]
diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py
index edaf13d1888af..30c600d210802 100644
--- a/posthog/temporal/batch_exports/bigquery_batch_export.py
+++ b/posthog/temporal/batch_exports/bigquery_batch_export.py
@@ -391,7 +391,17 @@ async def amerge_person_tables(
merge_query = f"""
MERGE `{final_table.full_table_id.replace(":", ".", 1)}` final
- USING `{stage_table.full_table_id.replace(":", ".", 1)}` stage
+ USING (
+ SELECT * FROM
+ (
+ SELECT
+ *,
+ ROW_NUMBER() OVER (PARTITION BY {",".join(field.name for field in merge_key)}) row_num
+ FROM
+ `{stage_table.full_table_id.replace(":", ".", 1)}`
+ )
+ WHERE row_num = 1
+ ) stage
{merge_condition}
WHEN MATCHED AND (stage.`{person_version_key}` > final.`{person_version_key}` OR stage.`{person_distinct_id_version_key}` > final.`{person_distinct_id_version_key}`) THEN
diff --git a/posthog/temporal/batch_exports/monitoring.py b/posthog/temporal/batch_exports/monitoring.py
new file mode 100644
index 0000000000000..97eaf6c2430d9
--- /dev/null
+++ b/posthog/temporal/batch_exports/monitoring.py
@@ -0,0 +1,227 @@
+import datetime as dt
+import json
+from dataclasses import dataclass
+from uuid import UUID
+
+from temporalio import activity, workflow
+from temporalio.common import RetryPolicy
+
+from posthog.batch_exports.models import BatchExport
+from posthog.batch_exports.service import aupdate_records_total_count
+from posthog.batch_exports.sql import EVENT_COUNT_BY_INTERVAL
+from posthog.temporal.batch_exports.base import PostHogWorkflow
+from posthog.temporal.common.clickhouse import get_client
+from posthog.temporal.common.heartbeat import Heartbeater
+
+
+class BatchExportNotFoundError(Exception):
+ """Exception raised when batch export is not found."""
+
+ def __init__(self, batch_export_id: UUID):
+ super().__init__(f"Batch export with id {batch_export_id} not found")
+
+
+class NoValidBatchExportsFoundError(Exception):
+ """Exception raised when no valid batch export is found."""
+
+ def __init__(self, message: str = "No valid batch exports found"):
+ super().__init__(message)
+
+
+@dataclass
+class BatchExportMonitoringInputs:
+ """Inputs for the BatchExportMonitoringWorkflow.
+
+ Attributes:
+ batch_export_id: The batch export id to monitor.
+ """
+
+ batch_export_id: UUID
+
+
+@dataclass
+class BatchExportDetails:
+ id: UUID
+ team_id: int
+ interval: str
+ exclude_events: list[str]
+ include_events: list[str]
+
+
+@activity.defn
+async def get_batch_export(batch_export_id: UUID) -> BatchExportDetails:
+ """Fetch a batch export from the database and return its details."""
+ batch_export = (
+ await BatchExport.objects.filter(id=batch_export_id, model="events", paused=False, deleted=False)
+ .prefetch_related("destination")
+ .afirst()
+ )
+ if batch_export is None:
+ raise BatchExportNotFoundError(batch_export_id)
+ if batch_export.deleted is True:
+ raise NoValidBatchExportsFoundError("Batch export has been deleted")
+ if batch_export.paused is True:
+ raise NoValidBatchExportsFoundError("Batch export is paused")
+ if batch_export.model != "events":
+ raise NoValidBatchExportsFoundError("Batch export model is not 'events'")
+ if batch_export.interval_time_delta != dt.timedelta(minutes=5):
+ raise NoValidBatchExportsFoundError(
+ "Only batch exports with interval of 5 minutes are supported for monitoring at this time."
+ )
+ config = batch_export.destination.config
+ return BatchExportDetails(
+ id=batch_export.id,
+ team_id=batch_export.team_id,
+ interval=batch_export.interval,
+ exclude_events=config.get("exclude_events", []),
+ include_events=config.get("include_events", []),
+ )
+
+
+@dataclass
+class GetEventCountsInputs:
+ team_id: int
+ interval: str
+ overall_interval_start: str
+ overall_interval_end: str
+ exclude_events: list[str]
+ include_events: list[str]
+
+
+@dataclass
+class EventCountsOutput:
+ interval_start: str
+ interval_end: str
+ count: int
+
+
+@dataclass
+class GetEventCountsOutputs:
+ results: list[EventCountsOutput]
+
+
+@activity.defn
+async def get_event_counts(inputs: GetEventCountsInputs) -> GetEventCountsOutputs:
+ """Get the total number of events for a given team over a set of time intervals."""
+
+ query = EVENT_COUNT_BY_INTERVAL
+
+ interval = inputs.interval
+ # we check interval is "every 5 minutes" above but double check here
+ if not interval.startswith("every 5 minutes"):
+ raise NoValidBatchExportsFoundError(
+ "Only intervals of 'every 5 minutes' are supported for monitoring at this time."
+ )
+ _, value, unit = interval.split(" ")
+ interval = f"{value} {unit}"
+
+ query_params = {
+ "team_id": inputs.team_id,
+ "interval": interval,
+ "overall_interval_start": inputs.overall_interval_start,
+ "overall_interval_end": inputs.overall_interval_end,
+ "include_events": inputs.include_events,
+ "exclude_events": inputs.exclude_events,
+ }
+ async with Heartbeater(), get_client() as client:
+ if not await client.is_alive():
+ raise ConnectionError("Cannot establish connection to ClickHouse")
+
+ response = await client.read_query(query, query_params)
+ results = []
+ for line in response.decode("utf-8").splitlines():
+ interval_start, interval_end, count = line.strip().split("\t")
+ results.append(
+ EventCountsOutput(interval_start=interval_start, interval_end=interval_end, count=int(count))
+ )
+
+ return GetEventCountsOutputs(results=results)
+
+
+@dataclass
+class UpdateBatchExportRunsInputs:
+ batch_export_id: UUID
+ results: list[EventCountsOutput]
+
+
+@activity.defn
+async def update_batch_export_runs(inputs: UpdateBatchExportRunsInputs) -> int:
+ """Update BatchExportRuns with the expected number of events."""
+
+ total_rows_updated = 0
+ async with Heartbeater():
+ for result in inputs.results:
+ total_rows_updated += await aupdate_records_total_count(
+ batch_export_id=inputs.batch_export_id,
+ interval_start=dt.datetime.strptime(result.interval_start, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC),
+ interval_end=dt.datetime.strptime(result.interval_end, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC),
+ count=result.count,
+ )
+ activity.logger.info(f"Updated {total_rows_updated} BatchExportRuns")
+ return total_rows_updated
+
+
+@workflow.defn(name="batch-export-monitoring")
+class BatchExportMonitoringWorkflow(PostHogWorkflow):
+ """Workflow to monitor batch exports.
+
+ We have had some issues with batch exports in the past, where some events
+ have been missing. The purpose of this workflow is to monitor the status of
+ batch exports for a given customer by reconciling the number of exported
+ events with the number of events in ClickHouse for a given interval.
+ """
+
+ @staticmethod
+ def parse_inputs(inputs: list[str]) -> BatchExportMonitoringInputs:
+ """Parse inputs from the management command CLI."""
+ loaded = json.loads(inputs[0])
+ return BatchExportMonitoringInputs(**loaded)
+
+ @workflow.run
+ async def run(self, inputs: BatchExportMonitoringInputs):
+ """Workflow implementation to monitor batch exports for a given team."""
+ # TODO - check if this is the right way to do logging since there seems to be a few different ways
+ workflow.logger.info(
+ "Starting batch exports monitoring workflow for batch export id %s", inputs.batch_export_id
+ )
+
+ batch_export_details = await workflow.execute_activity(
+ get_batch_export,
+ inputs.batch_export_id,
+ start_to_close_timeout=dt.timedelta(minutes=1),
+ retry_policy=RetryPolicy(
+ initial_interval=dt.timedelta(seconds=20),
+ non_retryable_error_types=["BatchExportNotFoundError", "NoValidBatchExportsFoundError"],
+ ),
+ )
+
+ # time interval to check is not the previous hour but the hour before that
+ # (just to ensure all recent batch exports have run successfully)
+ now = dt.datetime.now(tz=dt.UTC)
+ interval_end = now.replace(minute=0, second=0, microsecond=0) - dt.timedelta(hours=1)
+ interval_start = interval_end - dt.timedelta(hours=1)
+ interval_end_str = interval_end.strftime("%Y-%m-%d %H:%M:%S")
+ interval_start_str = interval_start.strftime("%Y-%m-%d %H:%M:%S")
+
+ total_events = await workflow.execute_activity(
+ get_event_counts,
+ GetEventCountsInputs(
+ team_id=batch_export_details.team_id,
+ interval=batch_export_details.interval,
+ overall_interval_start=interval_start_str,
+ overall_interval_end=interval_end_str,
+ exclude_events=batch_export_details.exclude_events,
+ include_events=batch_export_details.include_events,
+ ),
+ start_to_close_timeout=dt.timedelta(hours=1),
+ retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)),
+ heartbeat_timeout=dt.timedelta(minutes=1),
+ )
+
+ return await workflow.execute_activity(
+ update_batch_export_runs,
+ UpdateBatchExportRunsInputs(batch_export_id=batch_export_details.id, results=total_events.results),
+ start_to_close_timeout=dt.timedelta(hours=1),
+ retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)),
+ heartbeat_timeout=dt.timedelta(minutes=1),
+ )
diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py
index 67c321205a14f..7044d8fe96868 100644
--- a/posthog/temporal/tests/batch_exports/conftest.py
+++ b/posthog/temporal/tests/batch_exports/conftest.py
@@ -152,8 +152,8 @@ async def create_clickhouse_tables_and_views(clickhouse_client, django_db_setup)
from posthog.batch_exports.sql import (
CREATE_EVENTS_BATCH_EXPORT_VIEW,
CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL,
- CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED,
CREATE_EVENTS_BATCH_EXPORT_VIEW_RECENT,
+ CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED,
CREATE_PERSONS_BATCH_EXPORT_VIEW,
CREATE_PERSONS_BATCH_EXPORT_VIEW_BACKFILL,
)
@@ -211,8 +211,12 @@ def data_interval_start(request, data_interval_end, interval):
@pytest.fixture
-def data_interval_end(interval):
+def data_interval_end(request, interval):
"""Set a test data interval end."""
+ try:
+ return request.param
+ except AttributeError:
+ pass
return dt.datetime(2023, 4, 25, 15, 0, 0, tzinfo=dt.UTC)
diff --git a/posthog/temporal/tests/batch_exports/test_monitoring.py b/posthog/temporal/tests/batch_exports/test_monitoring.py
new file mode 100644
index 0000000000000..cab50c25d3177
--- /dev/null
+++ b/posthog/temporal/tests/batch_exports/test_monitoring.py
@@ -0,0 +1,201 @@
+import datetime as dt
+import uuid
+
+import pytest
+import pytest_asyncio
+from temporalio.common import RetryPolicy
+from temporalio.testing import WorkflowEnvironment
+from temporalio.worker import UnsandboxedWorkflowRunner, Worker
+
+from posthog import constants
+from posthog.batch_exports.models import BatchExportRun
+from posthog.temporal.batch_exports.monitoring import (
+ BatchExportMonitoringInputs,
+ BatchExportMonitoringWorkflow,
+ get_batch_export,
+ get_event_counts,
+ update_batch_export_runs,
+)
+from posthog.temporal.tests.utils.models import (
+ acreate_batch_export,
+ adelete_batch_export,
+ afetch_batch_export_runs,
+)
+
+pytestmark = [pytest.mark.asyncio, pytest.mark.django_db]
+
+GENERATE_TEST_DATA_END = dt.datetime.now(tz=dt.UTC).replace(
+ minute=0, second=0, microsecond=0, tzinfo=dt.UTC
+) - dt.timedelta(hours=1)
+GENERATE_TEST_DATA_START = GENERATE_TEST_DATA_END - dt.timedelta(hours=1)
+
+
+@pytest_asyncio.fixture
+async def batch_export(ateam, temporal_client):
+ """Provide a batch export for tests, not intended to be used."""
+ destination_data = {
+ "type": "S3",
+ "config": {
+ "bucket_name": "a-bucket",
+ "region": "us-east-1",
+ "prefix": "a-key",
+ "aws_access_key_id": "object_storage_root_user",
+ "aws_secret_access_key": "object_storage_root_password",
+ },
+ }
+
+ batch_export_data = {
+ "name": "my-production-s3-bucket-destination",
+ "destination": destination_data,
+ "interval": "every 5 minutes",
+ }
+
+ batch_export = await acreate_batch_export(
+ team_id=ateam.pk,
+ name=batch_export_data["name"], # type: ignore
+ destination_data=batch_export_data["destination"], # type: ignore
+ interval=batch_export_data["interval"], # type: ignore
+ )
+
+ yield batch_export
+
+ await adelete_batch_export(batch_export, temporal_client)
+
+
+@pytest_asyncio.fixture
+async def generate_batch_export_runs(
+ generate_test_data,
+ data_interval_start: dt.datetime,
+ data_interval_end: dt.datetime,
+ interval: str,
+ batch_export,
+):
+ # to keep things simple for now, we assume 5 min interval
+ if interval != "every 5 minutes":
+ raise NotImplementedError("Only 5 minute intervals are supported for now. Please update the test.")
+
+ events_created, _ = generate_test_data
+
+ batch_export_runs: list[BatchExportRun] = []
+ interval_start = data_interval_start
+ interval_end = interval_start + dt.timedelta(minutes=5)
+ while interval_end <= data_interval_end:
+ run = BatchExportRun(
+ batch_export_id=batch_export.id,
+ data_interval_start=interval_start,
+ data_interval_end=interval_end,
+ status="completed",
+ records_completed=len(
+ [
+ e
+ for e in events_created
+ if interval_start
+ <= dt.datetime.fromisoformat(e["inserted_at"]).replace(tzinfo=dt.UTC)
+ < interval_end
+ ]
+ ),
+ )
+ await run.asave()
+ batch_export_runs.append(run)
+ interval_start = interval_end
+ interval_end += dt.timedelta(minutes=5)
+
+ yield
+
+ for run in batch_export_runs:
+ await run.adelete()
+
+
+async def test_monitoring_workflow_when_no_event_data(batch_export):
+ workflow_id = str(uuid.uuid4())
+ inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id)
+ async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
+ async with Worker(
+ activity_environment.client,
+ # TODO - not sure if this is the right task queue
+ task_queue=constants.BATCH_EXPORTS_TASK_QUEUE,
+ workflows=[BatchExportMonitoringWorkflow],
+ activities=[
+ get_batch_export,
+ get_event_counts,
+ update_batch_export_runs,
+ ],
+ workflow_runner=UnsandboxedWorkflowRunner(),
+ ):
+ batch_export_runs_updated = await activity_environment.client.execute_workflow(
+ BatchExportMonitoringWorkflow.run,
+ inputs,
+ id=workflow_id,
+ task_queue=constants.BATCH_EXPORTS_TASK_QUEUE,
+ retry_policy=RetryPolicy(maximum_attempts=1),
+ execution_timeout=dt.timedelta(seconds=30),
+ )
+ assert batch_export_runs_updated == 0
+
+
+@pytest.mark.parametrize(
+ "data_interval_start",
+ [GENERATE_TEST_DATA_START],
+ indirect=True,
+)
+@pytest.mark.parametrize(
+ "data_interval_end",
+ [GENERATE_TEST_DATA_END],
+ indirect=True,
+)
+@pytest.mark.parametrize(
+ "interval",
+ ["every 5 minutes"],
+ indirect=True,
+)
+async def test_monitoring_workflow(
+ batch_export,
+ generate_test_data,
+ data_interval_start,
+ data_interval_end,
+ interval,
+ generate_batch_export_runs,
+):
+ """Test the monitoring workflow with a batch export that has data.
+
+ We generate 2 hours of data between 13:00 and 15:00, and then run the
+ monitoring workflow at 15:30. The monitoring workflow should check the data
+ between 14:00 and 15:00, and update the batch export runs.
+
+ We generate some dummy batch export runs based on the event data we
+ generated and assert that the expected records count matches the records
+ completed.
+ """
+ workflow_id = str(uuid.uuid4())
+ inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id)
+ async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
+ async with Worker(
+ activity_environment.client,
+ # TODO - not sure if this is the right task queue
+ task_queue=constants.BATCH_EXPORTS_TASK_QUEUE,
+ workflows=[BatchExportMonitoringWorkflow],
+ activities=[
+ get_batch_export,
+ get_event_counts,
+ update_batch_export_runs,
+ ],
+ workflow_runner=UnsandboxedWorkflowRunner(),
+ ):
+ await activity_environment.client.execute_workflow(
+ BatchExportMonitoringWorkflow.run,
+ inputs,
+ id=workflow_id,
+ task_queue=constants.BATCH_EXPORTS_TASK_QUEUE,
+ retry_policy=RetryPolicy(maximum_attempts=1),
+ execution_timeout=dt.timedelta(seconds=30),
+ )
+
+ batch_export_runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
+
+ for run in batch_export_runs:
+ if run.records_completed == 0:
+ # TODO: in the actual monitoring activity it would be better to
+ # update the actual count to 0 rather than None
+ assert run.records_total_count is None
+ else:
+ assert run.records_completed == run.records_total_count