diff --git a/cypress/e2e/alerts.cy.ts b/cypress/e2e/alerts.cy.ts index b55f1e09b9494..bd6ca01bcb734 100644 --- a/cypress/e2e/alerts.cy.ts +++ b/cypress/e2e/alerts.cy.ts @@ -16,7 +16,8 @@ describe('Alerts', () => { const createAlert = ( name: string = 'Alert name', lowerThreshold: string = '100', - upperThreshold: string = '200' + upperThreshold: string = '200', + condition?: string ): void => { cy.get('[data-attr=more-button]').click() cy.contains('Manage alerts').click() @@ -24,6 +25,13 @@ describe('Alerts', () => { cy.get('[data-attr=alertForm-name]').clear().type(name) cy.get('[data-attr=subscribed-users').click().type('{downarrow}{enter}') + + if (condition) { + cy.get('[data-attr=alertForm-condition').click() + cy.contains(condition).click() + cy.contains('%').click() + } + cy.get('[data-attr=alertForm-lower-threshold').clear().type(lowerThreshold) cy.get('[data-attr=alertForm-upper-threshold').clear().type(upperThreshold) cy.contains('Create alert').click() @@ -39,7 +47,6 @@ describe('Alerts', () => { cy.get('[data-attr=insight-edit-button]').click() cy.get('[data-attr=chart-filter]').click() cy.contains(displayType).click() - cy.get('.insight-empty-state').should('not.exist') cy.get('[data-attr=insight-save-button]').contains('Save').click() cy.url().should('not.include', '/edit') } @@ -69,7 +76,7 @@ describe('Alerts', () => { }) it('Should warn about an alert deletion', () => { - setInsightDisplayTypeAndSave('Number') + setInsightDisplayTypeAndSave('Area chart') createAlert('Alert to be deleted because of a changed insight') @@ -90,4 +97,28 @@ describe('Alerts', () => { cy.contains('Manage alerts').click() cy.contains('Alert to be deleted because of a changed insight').should('not.exist') }) + + it('Should allow create and delete a relative alert', () => { + cy.get('[data-attr=more-button]').click() + // Alerts should be disabled for trends represented with graphs + cy.get('[data-attr=manage-alerts-button]').should('have.attr', 'aria-disabled', 'true') + + setInsightDisplayTypeAndSave('Bar chart') + + createAlert('Alert name', '10', '20', 'increases by') + cy.reload() + + // Check the alert has the same values as when it was created + cy.get('[data-attr=more-button]').click() + cy.contains('Manage alerts').click() + cy.get('[data-attr=alert-list-item]').contains('Alert name').click() + cy.get('[data-attr=alertForm-name]').should('have.value', 'Alert name') + cy.get('[data-attr=alertForm-lower-threshold').should('have.value', '10') + cy.get('[data-attr=alertForm-upper-threshold').should('have.value', '20') + cy.contains('Delete alert').click() + cy.wait(2000) + + cy.reload() + cy.contains('Alert name').should('not.exist') + }) }) diff --git a/cypress/e2e/notebooks-insights.ts b/cypress/e2e/notebooks-insights.ts new file mode 100644 index 0000000000000..0b007744576c6 --- /dev/null +++ b/cypress/e2e/notebooks-insights.ts @@ -0,0 +1,18 @@ +import { insight, savedInsights } from '../productAnalytics' + +describe('Notebooks', () => { + beforeEach(() => { + cy.clickNavMenu('notebooks') + cy.location('pathname').should('include', '/notebooks') + }) + ;['SQL', 'TRENDS', 'FUNNELS', 'RETENTION', 'PATHS', 'STICKINESS', 'LIFECYCLE'].forEach((insightType) => { + it(`Can add a ${insightType} insight`, () => { + savedInsights.createNewInsightOfType(insightType) + insight.editName(`${insightType} Insight`) + insight.save() + cy.get('[data-attr="notebooks-add-button"]').click() + cy.get('[data-attr="notebooks-select-button-create"]').click() + cy.get('.ErrorBoundary').should('not.exist') + }) + }) +}) diff --git a/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--dark.png b/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--dark.png index 8ef8f708a388a..b8b549ad31a31 100644 Binary files a/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--dark.png and b/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--light.png b/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--light.png index eb2946a53bac1..7df6f432cbc78 100644 Binary files a/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--light.png and b/frontend/__snapshots__/scenes-app-pipeline--pipeline-node-new-hog-function--light.png differ diff --git a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png index ca494887f274e..b0d04d6d0bca7 100644 Binary files a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png and b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png differ diff --git a/frontend/src/lib/components/Alerts/SnoozeButton.tsx b/frontend/src/lib/components/Alerts/SnoozeButton.tsx new file mode 100644 index 0000000000000..28516638f209c --- /dev/null +++ b/frontend/src/lib/components/Alerts/SnoozeButton.tsx @@ -0,0 +1,43 @@ +import { dayjs } from 'lib/dayjs' +import { formatDate } from 'lib/utils' + +import { DateFilter } from '../DateFilter/DateFilter' + +const DATETIME_FORMAT = 'MMM D - HH:mm' + +interface SnoozeButtonProps { + onChange: (snoonzeUntil: string) => void + value?: string +} + +export function SnoozeButton({ onChange, value }: SnoozeButtonProps): JSX.Element { + return ( + { + snoozeUntil && onChange(snoozeUntil) + }} + placeholder="Snooze until" + max={31} + isFixedDateMode + showRollingRangePicker={false} + allowedRollingDateOptions={['days', 'weeks', 'months', 'years']} + showCustom + dateOptions={[ + { + key: 'Tomorrow', + values: ['+1d'], + getFormattedDate: (date: dayjs.Dayjs): string => formatDate(date.add(1, 'd'), DATETIME_FORMAT), + defaultInterval: 'day', + }, + { + key: 'One week from now', + values: ['+1w'], + getFormattedDate: (date: dayjs.Dayjs): string => formatDate(date.add(1, 'w'), DATETIME_FORMAT), + defaultInterval: 'day', + }, + ]} + size="medium" + /> + ) +} diff --git a/frontend/src/lib/components/Alerts/alertFormLogic.ts b/frontend/src/lib/components/Alerts/alertFormLogic.ts index 4230dc9238d01..3c0ab234a8ae1 100644 --- a/frontend/src/lib/components/Alerts/alertFormLogic.ts +++ b/frontend/src/lib/components/Alerts/alertFormLogic.ts @@ -3,7 +3,7 @@ import { forms } from 'kea-forms' import api from 'lib/api' import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast' -import { AlertCalculationInterval } from '~/queries/schema' +import { AlertCalculationInterval, AlertConditionType, InsightThresholdType } from '~/queries/schema' import { QueryBasedInsightModel } from '~/types' import type { alertFormLogicType } from './alertFormLogicType' @@ -11,7 +11,7 @@ import { AlertType, AlertTypeWrite } from './types' export type AlertFormType = Pick< AlertType, - 'name' | 'enabled' | 'created_at' | 'threshold' | 'subscribed_users' | 'checks' | 'config' + 'name' | 'enabled' | 'created_at' | 'threshold' | 'condition' | 'subscribed_users' | 'checks' | 'config' > & { id?: AlertType['id'] created_by?: AlertType['created_by'] | null @@ -31,6 +31,8 @@ export const alertFormLogic = kea([ actions({ deleteAlert: true, + snoozeAlert: (snoozeUntil: string) => ({ snoozeUntil }), + clearSnooze: true, }), forms(({ props }) => ({ @@ -47,10 +49,9 @@ export const alertFormLogic = kea([ type: 'TrendsAlertConfig', series_index: 0, }, - threshold: { - configuration: { - absoluteThreshold: {}, - }, + threshold: { configuration: { type: InsightThresholdType.ABSOLUTE, bounds: {} } }, + condition: { + type: AlertConditionType.ABSOLUTE_VALUE, }, subscribed_users: [], checks: [], @@ -61,12 +62,17 @@ export const alertFormLogic = kea([ name: !name ? 'You need to give your alert a name' : undefined, }), submit: async (alert) => { - const payload: Partial = { + const payload: AlertTypeWrite = { ...alert, subscribed_users: alert.subscribed_users?.map(({ id }) => id), insight: props.insightId, } + // absolute value alert can only have absolute threshold + if (payload.condition.type === AlertConditionType.ABSOLUTE_VALUE) { + payload.threshold.configuration.type = InsightThresholdType.ABSOLUTE + } + try { if (alert.id === undefined) { const updatedAlert: AlertType = await api.alerts.create(payload) @@ -101,5 +107,21 @@ export const alertFormLogic = kea([ await api.alerts.delete(values.alertForm.id) props.onEditSuccess() }, + snoozeAlert: async ({ snoozeUntil }) => { + // resolution only allowed on created alert (which will have alertId) + if (!values.alertForm.id) { + throw new Error("Cannot resolve alert that doesn't exist") + } + await api.alerts.update(values.alertForm.id, { snoozed_until: snoozeUntil }) + props.onEditSuccess() + }, + clearSnooze: async () => { + // resolution only allowed on created alert (which will have alertId) + if (!values.alertForm.id) { + throw new Error("Cannot resolve alert that doesn't exist") + } + await api.alerts.update(values.alertForm.id, { snoozed_until: null }) + props.onEditSuccess() + }, })), ]) diff --git a/frontend/src/lib/components/Alerts/insightAlertsLogic.ts b/frontend/src/lib/components/Alerts/insightAlertsLogic.ts index dd6a09a29d08c..6bca4dc317fa1 100644 --- a/frontend/src/lib/components/Alerts/insightAlertsLogic.ts +++ b/frontend/src/lib/components/Alerts/insightAlertsLogic.ts @@ -3,7 +3,7 @@ import { loaders } from 'kea-loaders' import api from 'lib/api' import { insightVizDataLogic } from 'scenes/insights/insightVizDataLogic' -import { GoalLine } from '~/queries/schema' +import { GoalLine, InsightThresholdType } from '~/queries/schema' import { getBreakdown, isInsightVizNode, isTrendsQuery } from '~/queries/utils' import { InsightLogicProps } from '~/types' @@ -65,21 +65,27 @@ export const insightAlertsLogic = kea([ (s) => [s.alerts], (alerts: AlertType[]): GoalLine[] => alerts.flatMap((alert) => { - const thresholds = [] + if ( + alert.threshold.configuration.type !== InsightThresholdType.ABSOLUTE || + !alert.threshold.configuration.bounds + ) { + return [] + } - const absoluteThreshold = alert.threshold.configuration.absoluteThreshold + const bounds = alert.threshold.configuration.bounds - if (absoluteThreshold?.upper !== undefined) { + const thresholds = [] + if (bounds?.upper !== undefined) { thresholds.push({ label: `${alert.name} Upper Threshold`, - value: absoluteThreshold?.upper, + value: bounds?.upper, }) } - if (absoluteThreshold?.lower !== undefined) { + if (bounds?.lower !== undefined) { thresholds.push({ label: `${alert.name} Lower Threshold`, - value: absoluteThreshold?.lower, + value: bounds?.lower, }) } diff --git a/frontend/src/lib/components/Alerts/types.ts b/frontend/src/lib/components/Alerts/types.ts index 864c2a2321909..4641d7fe0728f 100644 --- a/frontend/src/lib/components/Alerts/types.ts +++ b/frontend/src/lib/components/Alerts/types.ts @@ -12,6 +12,7 @@ export type AlertConfig = TrendsAlertConfig export interface AlertTypeBase { name: string condition: AlertCondition + threshold: { configuration: InsightThreshold } enabled: boolean insight: QueryBasedInsightModel config: AlertConfig @@ -20,6 +21,7 @@ export interface AlertTypeBase { export interface AlertTypeWrite extends Omit { subscribed_users: number[] insight: number + snoozed_until?: string | null } export interface AlertCheck { @@ -33,7 +35,7 @@ export interface AlertCheck { export interface AlertType extends AlertTypeBase { id: string subscribed_users: UserBasicType[] - threshold: { configuration: InsightThreshold } + condition: AlertCondition created_by: UserBasicType created_at: string state: AlertState @@ -41,4 +43,5 @@ export interface AlertType extends AlertTypeBase { last_checked_at: string checks: AlertCheck[] calculation_interval: AlertCalculationInterval + snoozed_until?: string } diff --git a/frontend/src/lib/components/Alerts/views/EditAlertModal.tsx b/frontend/src/lib/components/Alerts/views/EditAlertModal.tsx index 9a0c568bda465..b3c63ea6973e6 100644 --- a/frontend/src/lib/components/Alerts/views/EditAlertModal.tsx +++ b/frontend/src/lib/components/Alerts/views/EditAlertModal.tsx @@ -1,22 +1,24 @@ -import { LemonBanner, LemonCheckbox, LemonInput, LemonSelect, SpinnerOverlay } from '@posthog/lemon-ui' +import { LemonCheckbox, LemonInput, LemonSegmentedButton, LemonSelect, SpinnerOverlay } from '@posthog/lemon-ui' import { useActions, useValues } from 'kea' import { Form, Group } from 'kea-forms' import { AlertStateIndicator } from 'lib/components/Alerts/views/ManageAlertsModal' import { MemberSelectMultiple } from 'lib/components/MemberSelectMultiple' import { TZLabel } from 'lib/components/TZLabel' import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator' +import { dayjs } from 'lib/dayjs' import { IconChevronLeft } from 'lib/lemon-ui/icons' import { LemonButton } from 'lib/lemon-ui/LemonButton' import { LemonField } from 'lib/lemon-ui/LemonField' import { LemonModal } from 'lib/lemon-ui/LemonModal' -import { alphabet } from 'lib/utils' +import { alphabet, formatDate } from 'lib/utils' import { trendsDataLogic } from 'scenes/trends/trendsDataLogic' -import { AlertCalculationInterval } from '~/queries/schema' +import { AlertCalculationInterval, AlertConditionType, AlertState, InsightThresholdType } from '~/queries/schema' import { InsightShortId, QueryBasedInsightModel } from '~/types' import { alertFormLogic } from '../alertFormLogic' import { alertLogic } from '../alertLogic' +import { SnoozeButton } from '../SnoozeButton' import { AlertType } from '../types' export function AlertStateTable({ alert }: { alert: AlertType }): JSX.Element | null { @@ -27,7 +29,8 @@ export function AlertStateTable({ alert }: { alert: AlertType }): JSX.Element | return (

- Current status {alert.state} + Current status - {alert.state} + {alert.snoozed_until && ` until ${formatDate(dayjs(alert?.snoozed_until), 'MMM D, HH:mm')}`}{' '}

@@ -78,11 +81,11 @@ export function EditAlertModal({ const formLogicProps = { alert, insightId, onEditSuccess } const formLogic = alertFormLogic(formLogicProps) const { alertForm, isAlertFormSubmitting, alertFormChanged } = useValues(formLogic) - const { deleteAlert } = useActions(formLogic) + const { deleteAlert, snoozeAlert, clearSnooze } = useActions(formLogic) const { setAlertFormValue } = useActions(formLogic) const trendsLogic = trendsDataLogic({ dashboardItemId: insightShortId }) - const { alertSeries, breakdownFilter } = useValues(trendsLogic) + const { alertSeries, isNonTimeSeriesDisplay } = useValues(trendsLogic) const creatingNewAlert = alertForm.id === undefined @@ -107,112 +110,220 @@ export function EditAlertModal({ -
- {alert?.created_by ? ( - - ) : null} - - - - - - - - - - {breakdownFilter && ( - - - Alerts on insights with breakdowns alert when any of the breakdown values - breaches the threshold - - - )} - - - - ({ - label: `${alphabet[index]} - ${event}`, - value: index, - }))} - /> - - - - - ['hourly', 'daily'].includes(interval)) - .map((interval) => ({ - label: interval, - value: interval, - }))} - /> - - - - - - +
+
+
+ + - - + - - +
+ {alert?.created_by ? ( + + ) : null} +
- u.id) ?? []} - idKey="id" - onChange={(value) => setAlertFormValue('subscribed_users', value)} - /> +
+

Definition

+
+
+
When
+ + + ({ + label: `${alphabet[index]} - ${event}`, + value: index, + }))} + /> + + + + + + + +
+
+
less than
+ + + setAlertFormValue('threshold', { + configuration: { + type: alertForm.threshold.configuration.type, + bounds: { + ...alertForm.threshold.configuration.bounds, + lower: + value && + alertForm.threshold.configuration.type === + InsightThresholdType.PERCENTAGE + ? value / 100 + : value, + }, + }, + }) + } + /> + +
or more than
+ + + setAlertFormValue('threshold', { + configuration: { + type: alertForm.threshold.configuration.type, + bounds: { + ...alertForm.threshold.configuration.bounds, + upper: + value && + alertForm.threshold.configuration.type === + InsightThresholdType.PERCENTAGE + ? value / 100 + : value, + }, + }, + }) + } + /> + + {alertForm.condition.type !== AlertConditionType.ABSOLUTE_VALUE && ( + + + + + + )} +
+
+
+ {alertForm.condition.type === AlertConditionType.ABSOLUTE_VALUE + ? 'check' + : 'compare'} +
+ + ({ + label: interval, + value: interval, + }))} + /> + +
and notify
+
+ u.id) ?? []} + idKey="id" + onChange={(value) => setAlertFormValue('subscribed_users', value)} + /> +
+
+
+
{alert && }
- {!creatingNewAlert ? ( - - Delete alert - - ) : null} +
+ {!creatingNewAlert ? ( + + Delete alert + + ) : null} + {!creatingNewAlert && alert?.state === AlertState.FIRING ? ( + + ) : null} + {!creatingNewAlert && alert?.state === AlertState.SNOOZED ? ( + + Clear snooze + + ) : null} +
- - Cancel - - -
- ) : ( + return alert.state === AlertState.FIRING ? ( + ) : ( + + + ) } @@ -32,7 +32,9 @@ interface AlertListItemProps { } export function AlertListItem({ alert, onClick }: AlertListItemProps): JSX.Element { - const absoluteThreshold = alert.threshold?.configuration?.absoluteThreshold + const bounds = alert.threshold?.configuration?.bounds + const isPercentage = alert.threshold?.configuration.type === InsightThresholdType.PERCENTAGE + return (
@@ -42,9 +44,11 @@ export function AlertListItem({ alert, onClick }: AlertListItemProps): JSX.Eleme {alert.enabled ? (
- {absoluteThreshold?.lower && `Low ${absoluteThreshold.lower}`} - {absoluteThreshold?.lower && absoluteThreshold?.upper ? ' · ' : ''} - {absoluteThreshold?.upper && `High ${absoluteThreshold.upper}`} + {bounds?.lower && + `Low ${isPercentage ? bounds.lower * 100 : bounds.lower}${isPercentage ? '%' : ''}`} + {bounds?.lower && bounds?.upper ? ' · ' : ''} + {bounds?.upper && + `High ${isPercentage ? bounds.upper * 100 : bounds.upper}${isPercentage ? '%' : ''}`}
) : (
Disabled
diff --git a/frontend/src/lib/components/DateFilter/DateFilter.tsx b/frontend/src/lib/components/DateFilter/DateFilter.tsx index e8597357d3e58..d3a35d762b144 100644 --- a/frontend/src/lib/components/DateFilter/DateFilter.tsx +++ b/frontend/src/lib/components/DateFilter/DateFilter.tsx @@ -38,6 +38,7 @@ export interface DateFilterProps { dropdownPlacement?: Placement /* True when we're not dealing with ranges, but a single date / relative date */ isFixedDateMode?: boolean + placeholder?: string } interface RawDateFilterProps extends DateFilterProps { dateFrom?: string | null | dayjs.Dayjs @@ -62,6 +63,7 @@ export function DateFilter({ max, isFixedDateMode = false, allowedRollingDateOptions, + placeholder, }: RawDateFilterProps): JSX.Element { const key = useRef(uuid()).current const logicProps: DateFilterLogicProps = { @@ -72,6 +74,7 @@ export function DateFilter({ dateOptions, isDateFormatted, isFixedDateMode, + placeholder, } const { open, diff --git a/frontend/src/lib/components/DateFilter/dateFilterLogic.ts b/frontend/src/lib/components/DateFilter/dateFilterLogic.ts index 0a1f3680dbc1b..7d8593963d7b7 100644 --- a/frontend/src/lib/components/DateFilter/dateFilterLogic.ts +++ b/frontend/src/lib/components/DateFilter/dateFilterLogic.ts @@ -112,8 +112,9 @@ export const dateFilterLogic = kea([ s.isFixedDate, s.dateOptions, (_, p) => p.isFixedDateMode, + (_, p) => p.placeholder, ], - (dateFrom, dateTo, isFixedRange, isDateToNow, isFixedDate, dateOptions, isFixedDateMode) => + (dateFrom, dateTo, isFixedRange, isDateToNow, isFixedDate, dateOptions, isFixedDateMode, placeholder) => isFixedRange ? formatDateRange(dayjs(dateFrom), dayjs(dateTo)) : isDateToNow @@ -123,7 +124,9 @@ export const dateFilterLogic = kea([ : dateFilterToText( dateFrom, dateTo, - isFixedDateMode ? SELECT_FIXED_VALUE_PLACEHOLDER : NO_OVERRIDE_RANGE_PLACEHOLDER, + isFixedDateMode + ? placeholder ?? SELECT_FIXED_VALUE_PLACEHOLDER + : NO_OVERRIDE_RANGE_PLACEHOLDER, dateOptions, false ), diff --git a/frontend/src/lib/components/DateFilter/types.ts b/frontend/src/lib/components/DateFilter/types.ts index 3ebdb781b7c8c..2e95131e9cb34 100644 --- a/frontend/src/lib/components/DateFilter/types.ts +++ b/frontend/src/lib/components/DateFilter/types.ts @@ -17,6 +17,7 @@ export type DateFilterLogicProps = { dateOptions?: DateMappingOption[] isDateFormatted?: boolean isFixedDateMode?: boolean + placeholder?: string } export const CUSTOM_OPTION_KEY = 'Custom' diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 266ca1ec918be..3e3645da4a059 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -401,10 +401,20 @@ }, "AlertCondition": { "additionalProperties": false, + "properties": { + "type": { + "$ref": "#/definitions/AlertConditionType" + } + }, + "required": ["type"], "type": "object" }, + "AlertConditionType": { + "enum": ["absolute_value", "relative_increase", "relative_decrease"], + "type": "string" + }, "AlertState": { - "enum": ["Firing", "Not firing", "Errored"], + "enum": ["Firing", "Not firing", "Errored", "Snoozed"], "type": "string" }, "AnyDataNode": { @@ -6908,12 +6918,20 @@ "InsightThreshold": { "additionalProperties": false, "properties": { - "absoluteThreshold": { - "$ref": "#/definitions/InsightsThresholdAbsolute" + "bounds": { + "$ref": "#/definitions/InsightsThresholdBounds" + }, + "type": { + "$ref": "#/definitions/InsightThresholdType" } }, + "required": ["type"], "type": "object" }, + "InsightThresholdType": { + "enum": ["absolute", "percentage"], + "type": "string" + }, "InsightVizNode": { "additionalProperties": false, "properties": { @@ -7216,7 +7234,7 @@ "required": ["kind"], "type": "object" }, - "InsightsThresholdAbsolute": { + "InsightsThresholdBounds": { "additionalProperties": false, "properties": { "lower": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index e156f17d82b5d..17e84a783f3e9 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -1975,25 +1975,38 @@ export interface DashboardFilter { properties?: AnyPropertyFilter[] | null } -export interface InsightsThresholdAbsolute { +export interface InsightsThresholdBounds { lower?: number upper?: number } +export enum InsightThresholdType { + ABSOLUTE = 'absolute', + PERCENTAGE = 'percentage', +} + export interface InsightThreshold { - absoluteThreshold?: InsightsThresholdAbsolute - // More types of thresholds or conditions can be added here + type: InsightThresholdType + bounds?: InsightsThresholdBounds +} + +export enum AlertConditionType { + ABSOLUTE_VALUE = 'absolute_value', // default alert, checks absolute value of current interval + RELATIVE_INCREASE = 'relative_increase', // checks increase in value during current interval compared to previous interval + RELATIVE_DECREASE = 'relative_decrease', // checks decrease in value during current interval compared to previous interval } export interface AlertCondition { // Conditions in addition to the separate threshold // TODO: Think about things like relative thresholds, rate of change, etc. + type: AlertConditionType } export enum AlertState { FIRING = 'Firing', NOT_FIRING = 'Not firing', ERRORED = 'Errored', + SNOOZED = 'Snoozed', } export enum AlertCalculationInterval { diff --git a/frontend/src/scenes/notebooks/Nodes/NotebookNodeQuery.tsx b/frontend/src/scenes/notebooks/Nodes/NotebookNodeQuery.tsx index 38b48e7512036..be59069b7d665 100644 --- a/frontend/src/scenes/notebooks/Nodes/NotebookNodeQuery.tsx +++ b/frontend/src/scenes/notebooks/Nodes/NotebookNodeQuery.tsx @@ -2,7 +2,7 @@ import { Query } from '~/queries/Query/Query' import { DataTableNode, InsightQueryNode, InsightVizNode, NodeKind, QuerySchema } from '~/queries/schema' import { createPostHogWidgetNode } from 'scenes/notebooks/Nodes/NodeWrapper' import { InsightLogicProps, InsightShortId, NotebookNodeType } from '~/types' -import { useActions, useMountedLogic, useValues } from 'kea' +import { BindLogic, useActions, useMountedLogic, useValues } from 'kea' import { useEffect, useMemo } from 'react' import { notebookNodeLogic } from './notebookNodeLogic' import { NotebookNodeProps, NotebookNodeAttributeProperties } from '../Notebook/utils' @@ -35,9 +35,11 @@ const Component = ({ const { expanded } = useValues(nodeLogic) const { setTitlePlaceholder } = useActions(nodeLogic) const summarizeInsight = useSummarizeInsight() - const { insightName } = useValues( - insightLogic({ dashboardItemId: query.kind === NodeKind.SavedInsightNode ? query.shortId : 'new' }) - ) + + const insightLogicProps = { + dashboardItemId: query.kind === NodeKind.SavedInsightNode ? query.shortId : ('new' as const), + } + const { insightName } = useValues(insightLogic(insightLogicProps)) useEffect(() => { let title = 'Query' @@ -96,19 +98,21 @@ const Component = ({ return (
- { - updateAttributes({ - query: { - ...attributes.query, - source: (t as DataTableNode | InsightVizNode).source, - } as QuerySchema, - }) - }} - /> + + { + updateAttributes({ + query: { + ...attributes.query, + source: (t as DataTableNode | InsightVizNode).source, + } as QuerySchema, + }) + }} + /> +
) } diff --git a/frontend/src/scenes/pipeline/hogfunctions/HogFunctionConfiguration.tsx b/frontend/src/scenes/pipeline/hogfunctions/HogFunctionConfiguration.tsx index e0569c8157229..e16f5cadedc7a 100644 --- a/frontend/src/scenes/pipeline/hogfunctions/HogFunctionConfiguration.tsx +++ b/frontend/src/scenes/pipeline/hogfunctions/HogFunctionConfiguration.tsx @@ -13,6 +13,7 @@ import { Link, SpinnerOverlay, } from '@posthog/lemon-ui' +import clsx from 'clsx' import { BindLogic, useActions, useValues } from 'kea' import { Form } from 'kea-forms' import { NotFound } from 'lib/components/NotFound' @@ -340,89 +341,97 @@ export function HogFunctionConfiguration({ templateId, id }: { templateId?: stri
- {showSource ? ( - <> - } - size="small" - type="secondary" - className="my-4" - onClick={() => { - setConfigurationValue('inputs_schema', [ - ...(configuration.inputs_schema ?? []), - { - type: 'string', - key: `input_${ - (configuration.inputs_schema?.length ?? 0) + 1 - }`, - label: '', - required: false, - }, - ]) - }} - > - Add input variable - - - {({ value, onChange }) => ( - <> -
- Function source code - setShowSource(false)} - > - Hide source code - -
- - This is the underlying Hog code that will run whenever the - filters match.{' '} - See the docs{' '} - for more info - - onChange(v ?? '')} - globals={globalsWithInputs} - options={{ - minimap: { - enabled: false, - }, - wordWrap: 'on', - scrollBeyondLastLine: false, - automaticLayout: true, - fixedOverflowWidgets: true, - suggest: { - showInlineDetails: true, - }, - quickSuggestionsDelay: 300, - }} - /> - - )} -
- + } + size="small" + type="secondary" + className="my-4" + onClick={() => { + setConfigurationValue('inputs_schema', [ + ...(configuration.inputs_schema ?? []), + { + type: 'string', + key: `input_${(configuration.inputs_schema?.length ?? 0) + 1}`, + label: '', + required: false, + }, + ]) + }} + > + Add input variable + + ) : null} +
+
+ +
+
+
+

Edit source

+ {!showSource ?

Click here to edit the function's source code

: null} +
+ + {!showSource ? ( + setShowSource(true)} + disabledReason={ + !hasAddon + ? 'Editing the source code requires the Data Pipelines addon' + : undefined + } + > + Edit source code + ) : ( -
- setShowSource(true)} - disabledReason={ - !hasAddon - ? 'Editing the source code requires the Data Pipelines addon' - : undefined - } - > - Show function source code - -
+ setShowSource(false)} + > + Hide source code + )}
+ + {showSource ? ( + + {({ value, onChange }) => ( + <> + + This is the underlying Hog code that will run whenever the filters + match. See the docs{' '} + for more info + + onChange(v ?? '')} + globals={globalsWithInputs} + options={{ + minimap: { + enabled: false, + }, + wordWrap: 'on', + scrollBeyondLastLine: false, + automaticLayout: true, + fixedOverflowWidgets: true, + suggest: { + showInlineDetails: true, + }, + quickSuggestionsDelay: 300, + }} + /> + + )} + + ) : null}
{id ? : } diff --git a/latest_migrations.manifest b/latest_migrations.manifest index e6381aceefe20..66fcab590d608 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0491_team_session_recording_url_trigger_config +posthog: 0491_alertconfiguration_snoozed_until_and_more sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/posthog/api/alert.py b/posthog/api/alert.py index 19611889c6662..707db62140c4a 100644 --- a/posthog/api/alert.py +++ b/posthog/api/alert.py @@ -16,6 +16,9 @@ from posthog.schema import AlertState from posthog.api.insight import InsightBasicSerializer +from posthog.utils import relative_date_parse +from zoneinfo import ZoneInfo + class ThresholdSerializer(serializers.ModelSerializer): class Meta: @@ -73,6 +76,11 @@ def validate(self, data): return data +class RelativeDateTimeField(serializers.DateTimeField): + def to_internal_value(self, data): + return data + + class AlertSerializer(serializers.ModelSerializer): created_by = UserBasicSerializer(read_only=True) checks = AlertCheckSerializer(many=True, read_only=True) @@ -84,6 +92,7 @@ class AlertSerializer(serializers.ModelSerializer): write_only=True, allow_empty=False, ) + snoozed_until = RelativeDateTimeField(allow_null=True, required=False) class Meta: model = AlertConfiguration @@ -104,6 +113,7 @@ class Meta: "checks", "config", "calculation_interval", + "snoozed_until", ] read_only_fields = [ "id", @@ -149,6 +159,28 @@ def create(self, validated_data: dict) -> AlertConfiguration: return instance def update(self, instance, validated_data): + if "snoozed_until" in validated_data: + snoozed_until_param = validated_data.pop("snoozed_until") + + if snoozed_until_param is None: + instance.state = AlertState.NOT_FIRING + instance.snoozed_until = None + else: + # always store snoozed_until as UTC time + # as we look at current UTC time to check when to run alerts + snoozed_until = relative_date_parse(snoozed_until_param, ZoneInfo("UTC"), increase=True) + instance.state = AlertState.SNOOZED + instance.snoozed_until = snoozed_until + + AlertCheck.objects.create( + alert_configuration=instance, + calculated_value=None, + condition=instance.condition, + targets_notified={}, + state=instance.state, + error=None, + ) + conditions_or_threshold_changed = False threshold_data = validated_data.pop("threshold", None) @@ -183,6 +215,12 @@ def update(self, instance, validated_data): return super().update(instance, validated_data) + def validate_snoozed_until(self, value): + if value is not None and not isinstance(value, str): + raise ValidationError("snoozed_until has to be passed in string format") + + return value + def validate_insight(self, value): if value and not are_alerts_supported_for_insight(value): raise ValidationError("Alerts are not supported for this insight.") diff --git a/posthog/api/test/test_alert.py b/posthog/api/test/test_alert.py index e1a1fcaccd836..4c56520f15027 100644 --- a/posthog/api/test/test_alert.py +++ b/posthog/api/test/test_alert.py @@ -6,6 +6,10 @@ from posthog.test.base import APIBaseTest, QueryMatchingTest from posthog.models.team import Team +from posthog.schema import InsightThresholdType, AlertState +from posthog.models import AlertConfiguration +from posthog.models.alert import AlertCheck +from datetime import datetime class TestAlert(APIBaseTest, QueryMatchingTest): @@ -33,7 +37,7 @@ def test_create_and_delete_alert(self) -> None: ], "config": {"type": "TrendsAlertConfig", "series_index": 0}, "name": "alert name", - "threshold": {"configuration": {}}, + "threshold": {"configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}}, "calculation_interval": "daily", } response = self.client.post(f"/api/projects/{self.team.id}/alerts", creation_request) @@ -52,13 +56,14 @@ def test_create_and_delete_alert(self) -> None: "state": "Not firing", "config": {"type": "TrendsAlertConfig", "series_index": 0}, "threshold": { - "configuration": {}, + "configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}, "created_at": mock.ANY, "id": mock.ANY, "name": "", }, "last_checked_at": None, "next_check_at": None, + "snoozed_until": None, } assert response.status_code == status.HTTP_201_CREATED, response.content assert response.json() == expected_alert_json @@ -107,7 +112,7 @@ def test_create_and_list_alert(self) -> None: "subscribed_users": [ self.user.id, ], - "threshold": {"configuration": {}}, + "threshold": {"configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}}, "name": "alert name", } alert = self.client.post(f"/api/projects/{self.team.id}/alerts", creation_request).json() @@ -133,7 +138,7 @@ def test_alert_limit(self) -> None: "subscribed_users": [ self.user.id, ], - "threshold": {"configuration": {}}, + "threshold": {"configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}}, "name": "alert name", } self.client.post(f"/api/projects/{self.team.id}/alerts", creation_request) @@ -151,7 +156,7 @@ def test_alert_is_deleted_on_insight_update(self) -> None: "subscribed_users": [ self.user.id, ], - "threshold": {"configuration": {}}, + "threshold": {"configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}}, "name": "alert name", } alert = self.client.post(f"/api/projects/{self.team.id}/alerts", creation_request).json() @@ -176,3 +181,33 @@ def test_alert_is_deleted_on_insight_update(self) -> None: response = self.client.get(f"/api/projects/{self.team.id}/alerts/{alert['id']}") assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_snooze_alert(self) -> None: + creation_request = { + "insight": self.insight["id"], + "subscribed_users": [ + self.user.id, + ], + "threshold": {"configuration": {"type": InsightThresholdType.ABSOLUTE, "bounds": {}}}, + "name": "alert name", + "state": AlertState.FIRING, + } + + alert = self.client.post(f"/api/projects/{self.team.id}/alerts", creation_request).json() + assert alert["state"] == AlertState.NOT_FIRING + + alert = AlertConfiguration.objects.get(pk=alert["id"]) + alert.state = AlertState.FIRING + alert.save() + + firing_alert = AlertConfiguration.objects.get(pk=alert.id) + assert firing_alert.state == AlertState.FIRING + + resolved_alert = self.client.patch( + f"/api/projects/{self.team.id}/alerts/{firing_alert.id}", {"snoozed_until": datetime.now()} + ).json() + assert resolved_alert["state"] == AlertState.SNOOZED + + # should also create a new alert check with resolution + check = AlertCheck.objects.filter(alert_configuration=firing_alert.id).latest("created_at") + assert check.state == AlertState.SNOOZED diff --git a/posthog/migrations/0491_alertconfiguration_snoozed_until_and_more.py b/posthog/migrations/0491_alertconfiguration_snoozed_until_and_more.py new file mode 100644 index 0000000000000..d8fa097c43b32 --- /dev/null +++ b/posthog/migrations/0491_alertconfiguration_snoozed_until_and_more.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.15 on 2024-10-17 09:21 + +from django.db import migrations, models +import posthog.schema + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0490_dashboard_variables"), + ] + + operations = [ + migrations.AddField( + model_name="alertconfiguration", + name="snoozed_until", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AlterField( + model_name="alertcheck", + name="state", + field=models.CharField( + choices=[ + (posthog.schema.AlertState["FIRING"], posthog.schema.AlertState["FIRING"]), + (posthog.schema.AlertState["NOT_FIRING"], posthog.schema.AlertState["NOT_FIRING"]), + (posthog.schema.AlertState["ERRORED"], posthog.schema.AlertState["ERRORED"]), + (posthog.schema.AlertState["SNOOZED"], posthog.schema.AlertState["SNOOZED"]), + ], + default=posthog.schema.AlertState["NOT_FIRING"], + max_length=10, + ), + ), + migrations.AlterField( + model_name="alertconfiguration", + name="state", + field=models.CharField( + choices=[ + (posthog.schema.AlertState["FIRING"], posthog.schema.AlertState["FIRING"]), + (posthog.schema.AlertState["NOT_FIRING"], posthog.schema.AlertState["NOT_FIRING"]), + (posthog.schema.AlertState["ERRORED"], posthog.schema.AlertState["ERRORED"]), + (posthog.schema.AlertState["SNOOZED"], posthog.schema.AlertState["SNOOZED"]), + ], + default=posthog.schema.AlertState["NOT_FIRING"], + max_length=10, + ), + ), + ] diff --git a/posthog/models/alert.py b/posthog/models/alert.py index 8db059a992232..d00425327fd48 100644 --- a/posthog/models/alert.py +++ b/posthog/models/alert.py @@ -1,38 +1,24 @@ from datetime import datetime, UTC, timedelta -from typing import Any, Optional, cast -from dateutil.relativedelta import relativedelta from django.db import models from django.core.exceptions import ValidationError +import pydantic from posthog.hogql_queries.legacy_compatibility.flagged_conversion_manager import conversion_to_query_based from posthog.models.insight import Insight from posthog.models.utils import UUIDModel, CreatedMetaFields -from posthog.schema import AlertCondition, InsightThreshold, AlertState, AlertCalculationInterval +from posthog.schema import InsightThreshold, AlertState, AlertCalculationInterval ALERT_STATE_CHOICES = [ (AlertState.FIRING, AlertState.FIRING), (AlertState.NOT_FIRING, AlertState.NOT_FIRING), (AlertState.ERRORED, AlertState.ERRORED), + (AlertState.SNOOZED, AlertState.SNOOZED), ] -def alert_calculation_interval_to_relativedelta(alert_calculation_interval: AlertCalculationInterval) -> relativedelta: - match alert_calculation_interval: - case AlertCalculationInterval.HOURLY: - return relativedelta(hours=1) - case AlertCalculationInterval.DAILY: - return relativedelta(days=1) - case AlertCalculationInterval.WEEKLY: - return relativedelta(weeks=1) - case AlertCalculationInterval.MONTHLY: - return relativedelta(months=1) - case _: - raise ValueError(f"Invalid alert calculation interval: {alert_calculation_interval}") - - def are_alerts_supported_for_insight(insight: Insight) -> bool: with conversion_to_query_based(insight): query = insight.query @@ -43,32 +29,6 @@ def are_alerts_supported_for_insight(insight: Insight) -> bool: return True -class ConditionValidator: - def __init__(self, threshold: Optional[InsightThreshold], condition: AlertCondition): - self.threshold = threshold - self.condition = condition - - def validate(self, calculated_value: float) -> list[str]: - validators: Any = [ - self.validate_absolute_threshold, - ] - breaches = [] - for validator in validators: - breaches += validator(calculated_value) - return breaches - - def validate_absolute_threshold(self, calculated_value: float) -> list[str]: - if not self.threshold or not self.threshold.absoluteThreshold: - return [] - - absolute_threshold = self.threshold.absoluteThreshold - if absolute_threshold.lower is not None and calculated_value < absolute_threshold.lower: - return [f"The trend value ({calculated_value}) is below the lower threshold ({absolute_threshold.lower})"] - if absolute_threshold.upper is not None and calculated_value > absolute_threshold.upper: - return [f"The trend value ({calculated_value}) is above the upper threshold ({absolute_threshold.upper})"] - return [] - - class Alert(models.Model): """ @deprecated("AlertConfiguration should be used instead.") @@ -95,11 +55,15 @@ class Threshold(CreatedMetaFields, UUIDModel): configuration = models.JSONField(default=dict) def clean(self): - config = InsightThreshold.model_validate(self.configuration) - if not config or not config.absoluteThreshold: + try: + config = InsightThreshold.model_validate(self.configuration) + except pydantic.ValidationError as e: + raise ValidationError(f"Invalid threshold configuration: {e}") + + if not config or not config.bounds: return - if config.absoluteThreshold.lower is not None and config.absoluteThreshold.upper is not None: - if config.absoluteThreshold.lower > config.absoluteThreshold.upper: + if config.bounds.lower is not None and config.bounds.upper is not None: + if config.bounds.lower > config.bounds.upper: raise ValidationError("Lower threshold must be less than upper threshold") @@ -145,7 +109,10 @@ class AlertConfiguration(CreatedMetaFields, UUIDModel): last_notified_at = models.DateTimeField(null=True, blank=True) last_checked_at = models.DateTimeField(null=True, blank=True) + # UTC time for when next alert check is due next_check_at = models.DateTimeField(null=True, blank=True) + # UTC time until when we shouldn't check alert/notify user + snoozed_until = models.DateTimeField(null=True, blank=True) def __str__(self): return f"{self.name} (Team: {self.team})" @@ -159,75 +126,6 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) - def evaluate_condition(self, calculated_value) -> list[str]: - threshold = InsightThreshold.model_validate(self.threshold.configuration) if self.threshold else None - condition = AlertCondition.model_validate(self.condition) - validator = ConditionValidator(threshold=threshold, condition=condition) - return validator.validate(calculated_value) - - def add_check( - self, *, aggregated_value: Optional[float], error: Optional[dict] = None - ) -> tuple["AlertCheck", list[str], Optional[dict], bool]: - """ - Add a new AlertCheck, managing state transitions and cool down. - - Args: - aggregated_value: result of insight calculation compressed to one number to compare against threshold - error: any error raised while calculating insight value, if present then set state as errored - """ - - targets_notified: dict[str, list[str]] = {} - breaches = [] - notify = False - - if not error: - try: - breaches = self.evaluate_condition(aggregated_value) if aggregated_value is not None else [] - except Exception as err: - # error checking the condition - error = { - "message": f"Error checking alert condition {str(err)}", - } - - if error: - # If the alert is not already errored, notify user - if self.state != AlertState.ERRORED: - self.state = AlertState.ERRORED - notify = True - elif breaches: - # If the alert is not already firing, notify user - if self.state != AlertState.FIRING: - self.state = AlertState.FIRING - notify = True - else: - self.state = AlertState.NOT_FIRING # Set the Alert to not firing if the threshold is no longer met - # TODO: Optionally send a resolved notification when alert goes from firing to not_firing? - - now = datetime.now(UTC) - self.last_checked_at = datetime.now(UTC) - - # IMPORTANT: update next_check_at according to interval - # ensure we don't recheck alert until the next interval is due - self.next_check_at = (self.next_check_at or now) + alert_calculation_interval_to_relativedelta( - cast(AlertCalculationInterval, self.calculation_interval) - ) - - if notify: - self.last_notified_at = now - targets_notified = {"users": list(self.subscribed_users.all().values_list("email", flat=True))} - - alert_check = AlertCheck.objects.create( - alert_configuration=self, - calculated_value=aggregated_value, - condition=self.condition, - targets_notified=targets_notified, - state=self.state, - error=error, - ) - - self.save() - return alert_check, breaches, error, notify - class AlertSubscription(CreatedMetaFields, UUIDModel): user = models.ForeignKey( diff --git a/posthog/schema.py b/posthog/schema.py index f3256462172fa..afd2ca10dde3e 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -50,17 +50,17 @@ class AlertCalculationInterval(StrEnum): MONTHLY = "monthly" -class AlertCondition(BaseModel): - pass - model_config = ConfigDict( - extra="forbid", - ) +class AlertConditionType(StrEnum): + ABSOLUTE_VALUE = "absolute_value" + RELATIVE_INCREASE = "relative_increase" + RELATIVE_DECREASE = "relative_decrease" class AlertState(StrEnum): FIRING = "Firing" NOT_FIRING = "Not firing" ERRORED = "Errored" + SNOOZED = "Snoozed" class Kind(StrEnum): @@ -797,7 +797,12 @@ class InsightNodeKind(StrEnum): LIFECYCLE_QUERY = "LifecycleQuery" -class InsightsThresholdAbsolute(BaseModel): +class InsightThresholdType(StrEnum): + ABSOLUTE = "absolute" + PERCENTAGE = "percentage" + + +class InsightsThresholdBounds(BaseModel): model_config = ConfigDict( extra="forbid", ) @@ -1704,6 +1709,13 @@ class ActorsQueryResponse(BaseModel): types: list[str] +class AlertCondition(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: AlertConditionType + + class Breakdown(BaseModel): model_config = ConfigDict( extra="forbid", @@ -3078,7 +3090,8 @@ class InsightThreshold(BaseModel): model_config = ConfigDict( extra="forbid", ) - absoluteThreshold: Optional[InsightsThresholdAbsolute] = None + bounds: Optional[InsightsThresholdBounds] = None + type: InsightThresholdType class LifecycleFilter(BaseModel): diff --git a/posthog/settings/temporal.py b/posthog/settings/temporal.py index b73a7a0b6af83..dcab7bfb9a58a 100644 --- a/posthog/settings/temporal.py +++ b/posthog/settings/temporal.py @@ -17,6 +17,7 @@ BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 100 # 100MB BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 50 # 50MB BATCH_EXPORT_HTTP_BATCH_SIZE: int = 5000 +BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES: int = 1024 * 1024 * 300 # 300MB UNCONSTRAINED_TIMESTAMP_TEAM_IDS: list[str] = get_list(os.getenv("UNCONSTRAINED_TIMESTAMP_TEAM_IDS", "")) ASYNC_ARROW_STREAMING_TEAM_IDS: list[str] = get_list(os.getenv("ASYNC_ARROW_STREAMING_TEAM_IDS", "")) diff --git a/posthog/tasks/alerts/checks.py b/posthog/tasks/alerts/checks.py index 7c66c1158c12b..4986899047faa 100644 --- a/posthog/tasks/alerts/checks.py +++ b/posthog/tasks/alerts/checks.py @@ -1,17 +1,14 @@ from datetime import datetime, timedelta, UTC -from typing import Optional, cast +from typing import cast from dateutil.relativedelta import relativedelta +import traceback from celery import shared_task from celery.canvas import chain from django.db import transaction -from django.utils import timezone import structlog from sentry_sdk import capture_exception -from posthog.api.services.query import ExecutionMode -from posthog.caching.calculate_results import calculate_for_query_based_insight -from posthog.email import EmailMessage from posthog.errors import CHQueryErrorTooManySimultaneousQueries from posthog.hogql_queries.legacy_compatibility.flagged_conversion_manager import ( conversion_to_query_based, @@ -21,40 +18,28 @@ from posthog.tasks.utils import CeleryQueue from posthog.schema import ( TrendsQuery, - IntervalType, - ChartDisplayType, - NodeKind, AlertCalculationInterval, AlertState, - TrendsAlertConfig, ) from posthog.utils import get_from_dict_or_attr -from posthog.caching.fetch_from_cache import InsightResult -from posthog.clickhouse.client.limit import limit_concurrency from prometheus_client import Counter, Gauge from django.db.models import Q, F -from typing import TypedDict, NotRequired from collections import defaultdict +from posthog.tasks.alerts.utils import ( + AlertEvaluationResult, + calculation_interval_to_order, + send_notifications_for_errors, + send_notifications_for_breaches, + WRAPPER_NODE_KINDS, + alert_calculation_interval_to_relativedelta, +) +from posthog.tasks.alerts.trends import check_trends_alert + + +logger = structlog.get_logger(__name__) -# TODO: move the TrendResult UI type to schema.ts and use that instead -class TrendResult(TypedDict): - action: dict - actions: list[dict] - count: int - data: list[float] - days: list[str] - dates: list[str] - label: str - labels: list[str] - breakdown_value: str | int | list[str] - aggregated_value: NotRequired[float] - status: str | None - compare_label: str | None - compare: bool - persons_urls: list[dict] - persons: dict - filter: dict +class AlertCheckException(Exception): ... HOURLY_ALERTS_BACKLOG_GAUGE = Gauge( @@ -78,28 +63,9 @@ class TrendResult(TypedDict): ) -logger = structlog.get_logger(__name__) - - -WRAPPER_NODE_KINDS = [NodeKind.DATA_TABLE_NODE, NodeKind.DATA_VISUALIZATION_NODE, NodeKind.INSIGHT_VIZ_NODE] - -NON_TIME_SERIES_DISPLAY_TYPES = { - ChartDisplayType.BOLD_NUMBER, - ChartDisplayType.ACTIONS_PIE, - ChartDisplayType.ACTIONS_BAR_VALUE, - ChartDisplayType.ACTIONS_TABLE, - ChartDisplayType.WORLD_MAP, -} - - -def calculation_interval_to_order(interval: AlertCalculationInterval | None) -> int: - match interval: - case AlertCalculationInterval.HOURLY: - return 0 - case AlertCalculationInterval.DAILY: - return 1 - case _: - return 2 +@shared_task(ignore_result=True) +def checks_cleanup_task() -> None: + AlertCheck.clean_up_old_checks() @shared_task( @@ -145,48 +111,18 @@ def check_alerts_task() -> None: """ This runs every 2min to check for alerts that are due to recalculate """ - check_alerts() - - -@shared_task( - ignore_result=True, - queue=CeleryQueue.ALERTS.value, - autoretry_for=(CHQueryErrorTooManySimultaneousQueries,), - retry_backoff=1, - retry_backoff_max=10, - max_retries=3, - expires=60 * 60, -) -@limit_concurrency(5) # Max 5 concurrent alert checks -def check_alert_task(alert_id: str) -> None: - try: - check_alert(alert_id) - except Exception as err: - ALERT_CHECK_ERROR_COUNTER.inc() - capture_exception(Exception(f"Error checking alert, user wasn't notified: {err}")) - raise - - -@shared_task(ignore_result=True) -def checks_cleanup_task() -> None: - AlertCheck.clean_up_old_checks() - - -def check_alerts() -> None: now = datetime.now(UTC) # Use a fixed expiration time since tasks in the chain are executed sequentially expire_after = now + timedelta(minutes=30) - # find all alerts with the provided interval that are due to be calculated (next_check_at is null or less than now) + # find all alerts with the provided interval that are due to be calculated + # (next_check_at is null or less than now) and it's not snoozed alerts = ( AlertConfiguration.objects.filter( Q(enabled=True, is_calculating=False, next_check_at__lte=now) - | Q( - enabled=True, - is_calculating=False, - next_check_at__isnull=True, - ) + | Q(enabled=True, is_calculating=False, next_check_at__isnull=True) ) + .filter(Q(snoozed_until__isnull=True) | Q(snoozed_until__lt=now)) .order_by(F("next_check_at").asc(nulls_first=True)) .only("id", "team", "calculation_interval") ) @@ -207,6 +143,20 @@ def check_alerts() -> None: chain(*(check_alert_task.si(str(alert_id)).set(expires=expire_after) for alert_id in alert_ids))() +@shared_task( + ignore_result=True, + queue=CeleryQueue.ALERTS.value, + autoretry_for=(CHQueryErrorTooManySimultaneousQueries,), + retry_backoff=1, + retry_backoff_max=10, + max_retries=3, + expires=60 * 60, +) +# @limit_concurrency(5) Concurrency controlled by CeleryQueue.ALERTS for now +def check_alert_task(alert_id: str) -> None: + check_alert(alert_id) + + def check_alert(alert_id: str) -> None: try: alert = AlertConfiguration.objects.get(id=alert_id, enabled=True) @@ -230,12 +180,28 @@ def check_alert(alert_id: str) -> None: ) return + if alert.snoozed_until: + if alert.snoozed_until > now: + logger.warning( + "Alert has been snoozed so skipping checking it now", + alert=alert, + ) + return + else: + # not snoozed (anymore) so clear snoozed_until + alert.snoozed_until = None + alert.state = AlertState.NOT_FIRING + alert.is_calculating = True alert.save() try: - check_alert_atomically(alert) - except Exception: + check_alert_and_notify_atomically(alert) + except Exception as err: + ALERT_CHECK_ERROR_COUNTER.inc() + logger.exception(AlertCheckException(err)) + capture_exception(AlertCheckException(err)) + # raise again so alert check is retried depending on error type raise finally: # Get all updates with alert checks @@ -245,185 +211,122 @@ def check_alert(alert_id: str) -> None: @transaction.atomic -def check_alert_atomically(alert: AlertConfiguration) -> None: +def check_alert_and_notify_atomically(alert: AlertConfiguration) -> None: """ - Alert check only gets updated when we successfully - 1. Compute the aggregated value for the insight for the interval - 2. Compare the aggregated value with the threshold - 3. Send notifications if breaches are found + Computes insight results, checks alert for breaches and notifies user. + Only commits updates to alert state if all of the above complete successfully. + TODO: Later separate notification mechanism from alert checking mechanism (when we move to CDP) + so we can retry notification without re-computing insight. """ ALERT_COMPUTED_COUNTER.inc() + value = breaches = error = None - insight = alert.insight - aggregated_value: Optional[float] = None - error: Optional[dict] = None - + # 1. Evaluate insight and get alert value try: - with conversion_to_query_based(insight): - query = insight.query - kind = get_from_dict_or_attr(query, "kind") - - if kind in WRAPPER_NODE_KINDS: - query = get_from_dict_or_attr(query, "source") - kind = get_from_dict_or_attr(query, "kind") - - if kind == "TrendsQuery": - query = TrendsQuery.model_validate(query) - - filters_override = _calculate_date_range_override_for_alert(query) - - calculation_result = calculate_for_query_based_insight( - insight, - team=alert.team, - execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE, - user=None, - filters_override=filters_override, - ) - else: - raise NotImplementedError(f"Alerts for {query.kind} are not supported yet") - - if not calculation_result.result: - raise RuntimeError(f"No results for alert {alert.id}") - - aggregated_value = _aggregate_insight_result_value(alert, query, calculation_result) + alert_evaluation_result = check_alert_for_insight(alert) + value = alert_evaluation_result.value + breaches = alert_evaluation_result.breaches except CHQueryErrorTooManySimultaneousQueries: - # error on our side, need to make sure to retry the alert check + # error on our side so we raise + # as celery task can be retried according to config raise except Exception as err: - # error possibly on user's config side - # notify user that alert check errored - error_message = f"AlertCheckError: error computing aggregate value for insight, alert_id = {alert.id}" - logger.exception(error_message) + capture_exception(AlertCheckException(err)) + # error can be on user side (incorrectly configured insight/alert) + # we won't retry and set alert to errored state + error = {"message": str(err), "traceback": traceback.format_exc()} - event_id = capture_exception( - Exception(error_message), - {"alert_id": alert.id, "query": str(query), "message": str(err)}, - ) - - error = { - "sentry_event_id": event_id, - "message": f"{error_message}: {str(err)}", - } + # 2. Check alert value against threshold + alert_check = add_alert_check(alert, value, breaches, error) - try: - # Lock alert to prevent concurrent state changes - alert = AlertConfiguration.objects.select_for_update().get(id=alert.id, enabled=True) - check, breaches, error, notify = alert.add_check(aggregated_value=aggregated_value, error=error) - except Exception as err: - error_message = f"AlertCheckError: error comparing insight value with threshold for alert_id = {alert.id}" - logger.exception(error_message) - - event_id = capture_exception( - Exception(error_message), - {"alert_id": alert.id, "query": str(query), "message": str(err)}, - ) - raise - - if not notify: - # no need to notify users + # 3. Notify users if needed + if not alert_check.targets_notified: return try: - match check.state: + match alert_check.state: case AlertState.NOT_FIRING: - logger.info("Check state is %s", check.state, alert_id=alert.id) + logger.info("Check state is %s", alert_check.state, alert_id=alert.id) case AlertState.ERRORED: - if error: - _send_notifications_for_errors(alert, error) + send_notifications_for_errors(alert, alert_check.error) case AlertState.FIRING: - _send_notifications_for_breaches(alert, breaches) + assert breaches is not None + send_notifications_for_breaches(alert, breaches) except Exception as err: error_message = f"AlertCheckError: error sending notifications for alert_id = {alert.id}" logger.exception(error_message) - event_id = capture_exception( + capture_exception( Exception(error_message), - {"alert_id": alert.id, "query": str(query), "message": str(err)}, + {"alert_id": alert.id, "message": str(err)}, ) + + # don't want alert state to be updated (so that it's retried as next_check_at won't be updated) + # so we raise again as @transaction.atomic decorator won't commit db updates + # TODO: later should have a way just to retry notification mechanism raise -def _calculate_date_range_override_for_alert(query: TrendsQuery) -> Optional[dict]: - if query.trendsFilter and query.trendsFilter.display in NON_TIME_SERIES_DISPLAY_TYPES: - # for single value insights, need to recompute with full time range - return None - - match query.interval: - case IntervalType.DAY: - date_from = "-1d" - case IntervalType.WEEK: - date_from = "-1w" - case IntervalType.MONTH: - date_from = "-1m" - case _: - date_from = "-1h" - - return {"date_from": date_from} - - -def _aggregate_insight_result_value(alert: AlertConfiguration, query: TrendsQuery, results: InsightResult) -> float: - if "type" in alert.config and alert.config["type"] == "TrendsAlertConfig": - alert_config = TrendsAlertConfig.model_validate(alert.config) - series_index = alert_config.series_index - result = cast(list[TrendResult], results.result)[series_index] - - if query.trendsFilter and query.trendsFilter.display in NON_TIME_SERIES_DISPLAY_TYPES: - return result["aggregated_value"] - - return result["data"][-1] - - raise ValueError(f"Unsupported alert config type: {alert_config.type}") - - -def _send_notifications_for_breaches(alert: AlertConfiguration, breaches: list[str]) -> None: - subject = f"PostHog alert {alert.name} is firing" - campaign_key = f"alert-firing-notification-{alert.id}-{timezone.now().timestamp()}" - insight_url = f"/project/{alert.team.pk}/insights/{alert.insight.short_id}?alert_id={alert.id}" - alert_url = f"{insight_url}/alerts/{alert.id}" - message = EmailMessage( - campaign_key=campaign_key, - subject=subject, - template_name="alert_check_firing", - template_context={ - "match_descriptions": breaches, - "insight_url": insight_url, - "insight_name": alert.insight.name, - "alert_url": alert_url, - "alert_name": alert.name, - }, +def check_alert_for_insight(alert: AlertConfiguration) -> AlertEvaluationResult: + """ + Matches insight type with alert checking logic + """ + insight = alert.insight + + with conversion_to_query_based(insight): + query = insight.query + kind = get_from_dict_or_attr(query, "kind") + + if kind in WRAPPER_NODE_KINDS: + query = get_from_dict_or_attr(query, "source") + kind = get_from_dict_or_attr(query, "kind") + + match kind: + case "TrendsQuery": + query = TrendsQuery.model_validate(query) + return check_trends_alert(alert, insight, query) + case _: + raise NotImplementedError(f"AlertCheckError: Alerts for {query.kind} are not supported yet") + + +def add_alert_check( + alert: AlertConfiguration, value: float | None, breaches: list[str] | None, error: dict | None +) -> AlertCheck: + notify = False + targets_notified = {} + + if error: + alert.state = AlertState.ERRORED + notify = True + elif breaches: + alert.state = AlertState.FIRING + notify = True + else: + alert.state = AlertState.NOT_FIRING # Set the Alert to not firing if the threshold is no longer met + # TODO: Optionally send a resolved notification when alert goes from firing to not_firing? + + now = datetime.now(UTC) + alert.last_checked_at = datetime.now(UTC) + + # IMPORTANT: update next_check_at according to interval + # ensure we don't recheck alert until the next interval is due + alert.next_check_at = (alert.next_check_at or now) + alert_calculation_interval_to_relativedelta( + cast(AlertCalculationInterval, alert.calculation_interval) ) - targets = alert.subscribed_users.all().values_list("email", flat=True) - if not targets: - raise RuntimeError(f"no targets configured for the alert {alert.id}") - for target in targets: - message.add_recipient(email=target) - - logger.info(f"Send notifications about {len(breaches)} anomalies", alert_id=alert.id) - message.send() - - -def _send_notifications_for_errors(alert: AlertConfiguration, error: dict) -> None: - subject = f"PostHog alert {alert.name} check failed to evaluate" - campaign_key = f"alert-firing-notification-{alert.id}-{timezone.now().timestamp()}" - insight_url = f"/project/{alert.team.pk}/insights/{alert.insight.short_id}?alert_id={alert.id}" - alert_url = f"{insight_url}/alerts/{alert.id}" - message = EmailMessage( - campaign_key=campaign_key, - subject=subject, - template_name="alert_check_firing", - template_context={ - "match_descriptions": error, - "insight_url": insight_url, - "insight_name": alert.insight.name, - "alert_url": alert_url, - "alert_name": alert.name, - }, + + if notify: + alert.last_notified_at = now + targets_notified = {"users": list(alert.subscribed_users.all().values_list("email", flat=True))} + + alert_check = AlertCheck.objects.create( + alert_configuration=alert, + calculated_value=value, + condition=alert.condition, + targets_notified=targets_notified, + state=alert.state, + error=error, ) - targets = alert.subscribed_users.all().values_list("email", flat=True) - if not targets: - raise RuntimeError(f"no targets configured for the alert {alert.id}") - for target in targets: - message.add_recipient(email=target) - - logger.info(f"Send notifications about alert checking error", alert_id=alert.id) - message.send() + + alert.save() + + return alert_check diff --git a/posthog/tasks/alerts/test/test_alert_checks.py b/posthog/tasks/alerts/test/test_alert_checks.py index e14c48359aac3..79fe6227180a0 100644 --- a/posthog/tasks/alerts/test/test_alert_checks.py +++ b/posthog/tasks/alerts/test/test_alert_checks.py @@ -5,7 +5,8 @@ from posthog.models.alert import AlertCheck from posthog.models.instance_setting import set_instance_setting -from posthog.tasks.alerts.checks import _send_notifications_for_breaches, check_alert +from posthog.tasks.alerts.utils import send_notifications_for_breaches +from posthog.tasks.alerts.checks import check_alert from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events, ClickhouseDestroyTablesMixin from posthog.api.test.dashboards import DashboardAPI from posthog.schema import ChartDisplayType, EventsNode, TrendsQuery, TrendsFilter, AlertState @@ -14,8 +15,8 @@ @freeze_time("2024-06-02T08:55:00.000Z") -@patch("posthog.tasks.alerts.checks._send_notifications_for_errors") -@patch("posthog.tasks.alerts.checks._send_notifications_for_breaches") +@patch("posthog.tasks.alerts.checks.send_notifications_for_errors") +@patch("posthog.tasks.alerts.checks.send_notifications_for_breaches") class TestAlertChecks(APIBaseTest, ClickhouseDestroyTablesMixin): def setUp(self) -> None: super().setUp() @@ -52,14 +53,15 @@ def setUp(self) -> None: "type": "TrendsAlertConfig", "series_index": 0, }, - "threshold": {"configuration": {"absoluteThreshold": {}}}, + "condition": {"type": "absolute_value"}, + "threshold": {"configuration": {"type": "absolute", "bounds": {}}}, }, ).json() def set_thresholds(self, lower: Optional[int] = None, upper: Optional[int] = None) -> None: self.client.patch( f"/api/projects/{self.team.id}/alerts/{self.alert['id']}", - data={"threshold": {"configuration": {"absoluteThreshold": {"lower": lower, "upper": upper}}}}, + data={"threshold": {"configuration": {"type": "absolute", "bounds": {"lower": lower, "upper": upper}}}}, ) def get_breach_description(self, mock_send_notifications_for_breaches: MagicMock, call_index: int) -> list[str]: @@ -225,7 +227,7 @@ def test_send_error_while_calculating( self, _mock_send_notifications_for_breaches: MagicMock, mock_send_notifications_for_errors: MagicMock ) -> None: with patch( - "posthog.tasks.alerts.checks.calculate_for_query_based_insight" + "posthog.tasks.alerts.trends.calculate_for_query_based_insight" ) as mock_calculate_for_query_based_insight: mock_calculate_for_query_based_insight.side_effect = Exception("Some error") @@ -238,7 +240,6 @@ def test_send_error_while_calculating( ) error_message = latest_alert_check.error["message"] - assert "AlertCheckError: error computing aggregate value for insight" in error_message assert "Some error" in error_message def test_error_while_calculating_on_alert_in_firing_state( @@ -254,7 +255,7 @@ def test_error_while_calculating_on_alert_in_firing_state( assert latest_alert_check.error is None with patch( - "posthog.tasks.alerts.checks.calculate_for_query_based_insight" + "posthog.tasks.alerts.trends.calculate_for_query_based_insight" ) as mock_calculate_for_query_based_insight: mock_calculate_for_query_based_insight.side_effect = Exception("Some error") @@ -269,7 +270,6 @@ def test_error_while_calculating_on_alert_in_firing_state( assert latest_alert_check.state == AlertState.ERRORED error_message = latest_alert_check.error["message"] - assert "AlertCheckError: error computing aggregate value for insight" in error_message assert "Some error" in error_message def test_error_while_calculating_on_alert_in_not_firing_state( @@ -285,7 +285,7 @@ def test_error_while_calculating_on_alert_in_not_firing_state( assert latest_alert_check.error is None with patch( - "posthog.tasks.alerts.checks.calculate_for_query_based_insight" + "posthog.tasks.alerts.trends.calculate_for_query_based_insight" ) as mock_calculate_for_query_based_insight: mock_calculate_for_query_based_insight.side_effect = Exception("Some error") @@ -299,7 +299,6 @@ def test_error_while_calculating_on_alert_in_not_firing_state( ) error_message = latest_alert_check.error["message"] - assert "AlertCheckError: error computing aggregate value for insight" in error_message assert "Some error" in error_message def test_alert_with_insight_with_filter( @@ -318,13 +317,13 @@ def test_alert_with_insight_with_filter( anomalies = self.get_breach_description(mock_send_notifications_for_breaches, call_index=0) assert "The trend value (0) is below the lower threshold (1.0)" in anomalies - @patch("posthog.tasks.alerts.checks.EmailMessage") + @patch("posthog.tasks.alerts.utils.EmailMessage") def test_send_emails( self, MockEmailMessage: MagicMock, mock_send_notifications_for_breaches: MagicMock, mock_send_errors: MagicMock ) -> None: mocked_email_messages = mock_email_messages(MockEmailMessage) alert = AlertConfiguration.objects.get(pk=self.alert["id"]) - _send_notifications_for_breaches(alert, ["first anomaly description", "second anomaly description"]) + send_notifications_for_breaches(alert, ["first anomaly description", "second anomaly description"]) assert len(mocked_email_messages) == 1 email = mocked_email_messages[0] diff --git a/posthog/tasks/alerts/test/test_trend_alerts.py b/posthog/tasks/alerts/test/test_trends_absolute_alerts.py similarity index 95% rename from posthog/tasks/alerts/test/test_trend_alerts.py rename to posthog/tasks/alerts/test/test_trends_absolute_alerts.py index a5ff389d59f98..9402117e79fe0 100644 --- a/posthog/tasks/alerts/test/test_trend_alerts.py +++ b/posthog/tasks/alerts/test/test_trends_absolute_alerts.py @@ -30,9 +30,9 @@ @freeze_time("2024-06-02T08:55:00.000Z") -@patch("posthog.tasks.alerts.checks._send_notifications_for_errors") -@patch("posthog.tasks.alerts.checks._send_notifications_for_breaches") -class TestTimeSeriesTrendsAlerts(APIBaseTest, ClickhouseDestroyTablesMixin): +@patch("posthog.tasks.alerts.checks.send_notifications_for_errors") +@patch("posthog.tasks.alerts.checks.send_notifications_for_breaches") +class TestTimeSeriesTrendsAbsoluteAlerts(APIBaseTest, ClickhouseDestroyTablesMixin): def setUp(self) -> None: super().setUp() @@ -54,8 +54,9 @@ def create_alert( "type": "TrendsAlertConfig", "series_index": series_index, }, + "condition": {"type": "absolute_value"}, "calculation_interval": AlertCalculationInterval.DAILY, - "threshold": {"configuration": {"absoluteThreshold": {"lower": lower, "upper": upper}}}, + "threshold": {"configuration": {"type": "absolute", "bounds": {"lower": lower, "upper": upper}}}, }, ).json() diff --git a/posthog/tasks/alerts/test/test_trends_relative_alerts.py b/posthog/tasks/alerts/test/test_trends_relative_alerts.py new file mode 100644 index 0000000000000..6e5b17b633894 --- /dev/null +++ b/posthog/tasks/alerts/test/test_trends_relative_alerts.py @@ -0,0 +1,775 @@ +from typing import Optional, Any +from unittest.mock import MagicMock, patch +import dateutil + + +import dateutil.relativedelta +from freezegun import freeze_time + +from posthog.models.alert import AlertCheck +from posthog.models.instance_setting import set_instance_setting +from posthog.tasks.alerts.checks import check_alert +from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events, ClickhouseDestroyTablesMixin +from posthog.api.test.dashboards import DashboardAPI +from posthog.schema import ( + ChartDisplayType, + EventsNode, + TrendsQuery, + TrendsFilter, + IntervalType, + InsightDateRange, + EventPropertyFilter, + PropertyOperator, + BaseMathType, + AlertState, + AlertCalculationInterval, + AlertConditionType, + InsightThresholdType, + BreakdownFilter, +) +from posthog.models import AlertConfiguration + +# Tuesday +FROZEN_TIME = dateutil.parser.parse("2024-06-04T08:55:00.000Z") + + +@freeze_time(FROZEN_TIME) +@patch("posthog.tasks.alerts.checks.send_notifications_for_errors") +@patch("posthog.tasks.alerts.checks.send_notifications_for_breaches") +class TestTimeSeriesTrendsRelativeAlerts(APIBaseTest, ClickhouseDestroyTablesMixin): + def setUp(self) -> None: + super().setUp() + + set_instance_setting("EMAIL_HOST", "fake_host") + set_instance_setting("EMAIL_ENABLED", True) + + self.dashboard_api = DashboardAPI(self.client, self.team, self.assertEqual) + + def create_alert( + self, + insight: dict, + series_index: int, + condition_type: AlertConditionType, + threshold_type: InsightThresholdType, + lower: Optional[float] = None, + upper: Optional[float] = None, + ) -> dict: + alert = self.client.post( + f"/api/projects/{self.team.id}/alerts", + data={ + "name": "alert name", + "insight": insight["id"], + "subscribed_users": [self.user.id], + "config": { + "type": "TrendsAlertConfig", + "series_index": series_index, + }, + "condition": {"type": condition_type}, + "calculation_interval": AlertCalculationInterval.DAILY, + "threshold": {"configuration": {"type": threshold_type, "bounds": {"lower": lower, "upper": upper}}}, + }, + ).json() + + return alert + + def create_time_series_trend_insight( + self, interval: IntervalType, breakdown: Optional[BreakdownFilter] = None + ) -> dict[str, Any]: + query_dict = TrendsQuery( + series=[ + EventsNode( + event="signed_up", + math=BaseMathType.TOTAL, + properties=[ + EventPropertyFilter( + key="$browser", + operator=PropertyOperator.EXACT, + value=["Chrome"], + ) + ], + ), + EventsNode( + event="$pageview", + name="Pageview", + math=BaseMathType.TOTAL, + ), + ], + breakdownFilter=breakdown, + trendsFilter=TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH), + interval=interval, + dateRange=InsightDateRange(date_from="-8w"), + ).model_dump() + + insight = self.dashboard_api.create_insight( + data={ + "name": "insight", + "query": query_dict, + } + )[1] + + return insight + + def test_alert_properties(self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + # alert if sign ups increase by less than 1 + alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + lower=1, + ) + + assert alert["state"] == AlertState.NOT_FIRING + assert alert["last_checked_at"] is None + assert alert["last_notified_at"] is None + assert alert["next_check_at"] is None + + check_alert(alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.last_checked_at == FROZEN_TIME + assert updated_alert.last_notified_at == FROZEN_TIME + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=alert["id"]).latest("created_at") + assert alert_check.calculated_value == 0 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_increase_absolute_upper_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by more than 1 + alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=1, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + # Previous to previous interval (last to last week) has 0 events + # add events for previous interval (last week on Sat) + last_sat = FROZEN_TIME - dateutil.relativedelta.relativedelta(days=3) + with freeze_time(last_sat): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + check_alert(alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_increase_upper_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by more than 1 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=1, + ) + + # alert if sign ups increase by more than 20% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=0.2, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 1 event + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 2 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="4", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert should fire as we had *increase* in events of (2 or 200%) week over week + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_increase_lower_threshold_breached_1( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by less than 2 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + lower=2, + ) + + # alert if sign ups increase by less than 20 + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + lower=0.5, # 50% + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 2 events + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 1 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert should fire as overall we had *decrease* in events (-1 or -50%) week over week + # check absolute alert + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == -1 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + # check percentage alert + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == -0.5 # 50% decrease + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_increase_lower_threshold_breached_2( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by less than 2 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + lower=2, + ) + + # alert if sign ups increase by less than 110% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + lower=1.1, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 1 event + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 2 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert should fire as overall we had *increase* in events of just (1 or 100%) week over week + # alert required at least 2 + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 1 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 1 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_decrease_upper_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups decrease by more than 1 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=1, + ) + + # alert if sign ups decrease by more than 20% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + upper=0.2, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 3 event + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 1 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="4", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert should fire as we had decrease in events of (2 or 200%) week over week + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == (2 / 3) + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_decrease_lower_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups decrease by less than 2 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + lower=2, + ) + + # alert if sign ups decrease by less than 80% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + lower=0.8, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 2 event + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 1 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="4", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert should fire as we had decrease in events of (1 or 50%) week over week + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 1 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + + assert alert_check.calculated_value == 0.5 + assert alert_check.state == AlertState.FIRING + assert alert_check.error is None + + def test_relative_increase_no_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by more than 4 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=4, + ) + + # alert if sign ups increase by more than 400% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_INCREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + upper=4, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 1 event + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 3 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="4", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert shouldn't fire as increase was only of 2 or 200% + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.NOT_FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.NOT_FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.NOT_FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.NOT_FIRING + assert alert_check.error is None + + def test_relative_decrease_no_threshold_breached( + self, mock_send_breaches: MagicMock, mock_send_errors: MagicMock + ) -> None: + insight = self.create_time_series_trend_insight(interval=IntervalType.WEEK) + + # alert if sign ups increase by more than 4 + absolute_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.ABSOLUTE, + upper=4, + ) + + # alert if sign ups decrease by more than 80% + percentage_alert = self.create_alert( + insight, + series_index=0, + condition_type=AlertConditionType.RELATIVE_DECREASE, + threshold_type=InsightThresholdType.PERCENTAGE, + upper=0.8, + ) + + # FROZEN_TIME is on Tue, insight has weekly interval + # we aggregate our weekly insight numbers to display for Sun (19th May, 26th May, 2nd June) + + # set previous to previous interval (last to last week) to have 3 events + last_to_last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=2) + + with freeze_time(last_to_last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="1", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="4", + properties={"$browser": "Chrome"}, + ) + _create_event( + team=self.team, + event="signed_up", + distinct_id="2", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # set previous interval to have 1 event + # add events for last week (last Tue) + last_tue = FROZEN_TIME - dateutil.relativedelta.relativedelta(weeks=1) + with freeze_time(last_tue): + _create_event( + team=self.team, + event="signed_up", + distinct_id="3", + properties={"$browser": "Chrome"}, + ) + flush_persons_and_events() + + # alert shouldn't fire as increase was only of 2 or 200% + check_alert(absolute_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=absolute_alert["id"]) + assert updated_alert.state == AlertState.NOT_FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=absolute_alert["id"]).latest("created_at") + assert alert_check.calculated_value == 2 + assert alert_check.state == AlertState.NOT_FIRING + assert alert_check.error is None + + check_alert(percentage_alert["id"]) + + updated_alert = AlertConfiguration.objects.get(pk=percentage_alert["id"]) + assert updated_alert.state == AlertState.NOT_FIRING + assert updated_alert.next_check_at == FROZEN_TIME + dateutil.relativedelta.relativedelta(days=1) + + alert_check = AlertCheck.objects.filter(alert_configuration=percentage_alert["id"]).latest("created_at") + assert alert_check.calculated_value == (2 / 3) + assert alert_check.state == AlertState.NOT_FIRING + assert alert_check.error is None diff --git a/posthog/tasks/alerts/trends.py b/posthog/tasks/alerts/trends.py new file mode 100644 index 0000000000000..3f7fdebae2644 --- /dev/null +++ b/posthog/tasks/alerts/trends.py @@ -0,0 +1,219 @@ +from typing import Optional, cast + +from posthog.api.services.query import ExecutionMode +from posthog.caching.calculate_results import calculate_for_query_based_insight + +from posthog.models import AlertConfiguration, Insight +from posthog.schema import ( + TrendsQuery, + IntervalType, + TrendsAlertConfig, + InsightThreshold, + AlertCondition, + AlertConditionType, + InsightsThresholdBounds, + InsightThresholdType, +) +from posthog.caching.fetch_from_cache import InsightResult +from typing import TypedDict, NotRequired +from posthog.tasks.alerts.utils import ( + AlertEvaluationResult, + NON_TIME_SERIES_DISPLAY_TYPES, +) + + +# TODO: move the TrendResult UI type to schema.ts and use that instead +class TrendResult(TypedDict): + action: dict + actions: list[dict] + count: int + data: list[float] + days: list[str] + dates: list[str] + label: str + labels: list[str] + breakdown_value: str | int | list[str] + aggregated_value: NotRequired[float] + status: str | None + compare_label: str | None + compare: bool + persons_urls: list[dict] + persons: dict + filter: dict + + +def check_trends_alert(alert: AlertConfiguration, insight: Insight, query: TrendsQuery) -> AlertEvaluationResult: + if "type" in alert.config and alert.config["type"] == "TrendsAlertConfig": + config = TrendsAlertConfig.model_validate(alert.config) + else: + ValueError(f"Unsupported alert config type: {alert.config}") + + condition = AlertCondition.model_validate(alert.condition) + threshold = InsightThreshold.model_validate(alert.threshold.configuration) if alert.threshold else None + + if not threshold: + return AlertEvaluationResult(value=0, breaches=[]) + + match condition.type: + case AlertConditionType.ABSOLUTE_VALUE: + if threshold.type != InsightThresholdType.ABSOLUTE: + raise ValueError(f"Absolute threshold not configured for alert condition ABSOLUTE_VALUE") + + # want value for current interval (last hour, last day, last week, last month) + # depending on the alert calculation interval + if _is_non_time_series_trend(query): + filters_override = _date_range_override_for_intervals(query) + else: + # for non time series, it's an aggregated value for full interval + # so we need to compute full insight + filters_override = None + + calculation_result = calculate_for_query_based_insight( + insight, + team=alert.team, + execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE, + user=None, + filters_override=filters_override, + ) + + if not calculation_result.result: + raise RuntimeError(f"No results found for insight with alert id = {alert.id}") + + current_interval_value = _pick_interval_value_from_trend_result(config, query, calculation_result) + breaches = _validate_bounds(threshold.bounds, current_interval_value) + + return AlertEvaluationResult(value=current_interval_value, breaches=breaches) + + case AlertConditionType.RELATIVE_INCREASE: + if _is_non_time_series_trend(query): + raise ValueError(f"Relative alerts not supported for non time series trends") + + # to measure relative increase, we can't alert until current interval has completed + # as to check increase less than X, we need interval to complete + # so we need to compute the trend values for last 3 intervals + # and then compare the previous interval with value for the interval before previous + filters_overrides = _date_range_override_for_intervals(query, last_x_intervals=3) + + calculation_result = calculate_for_query_based_insight( + insight, + team=alert.team, + execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE, + user=None, + filters_override=filters_overrides, + ) + + prev_interval_value = _pick_interval_value_from_trend_result(config, query, calculation_result, -1) + prev_prev_interval_value = _pick_interval_value_from_trend_result(config, query, calculation_result, -2) + + if threshold.type == InsightThresholdType.ABSOLUTE: + increase = prev_interval_value - prev_prev_interval_value + breaches = _validate_bounds(threshold.bounds, increase) + elif threshold.type == InsightThresholdType.PERCENTAGE: + increase = (prev_interval_value - prev_prev_interval_value) / prev_prev_interval_value + breaches = _validate_bounds(threshold.bounds, increase, is_percentage=True) + else: + raise ValueError( + f"Neither relative nor absolute threshold configured for alert condition RELATIVE_INCREASE" + ) + + return AlertEvaluationResult(value=increase, breaches=breaches) + + case AlertConditionType.RELATIVE_DECREASE: + if _is_non_time_series_trend(query): + raise ValueError(f"Relative alerts not supported for non time series trends") + + # to measure relative decrease, we can't alert until current interval has completed + # as to check decrease more than X, we need interval to complete + # so we need to compute the trend values for last 3 intervals + # and then compare the previous interval with value for the interval before previous + filters_overrides = _date_range_override_for_intervals(query, last_x_intervals=3) + + calculation_result = calculate_for_query_based_insight( + insight, + team=alert.team, + execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE, + user=None, + filters_override=filters_overrides, + ) + + prev_interval_value = _pick_interval_value_from_trend_result(config, query, calculation_result, -1) + prev_prev_interval_value = _pick_interval_value_from_trend_result(config, query, calculation_result, -2) + + if threshold.type == InsightThresholdType.ABSOLUTE: + decrease = prev_prev_interval_value - prev_interval_value + breaches = _validate_bounds(threshold.bounds, decrease) + elif threshold.type == InsightThresholdType.PERCENTAGE: + decrease = (prev_prev_interval_value - prev_interval_value) / prev_prev_interval_value + breaches = _validate_bounds(threshold.bounds, decrease, is_percentage=True) + else: + raise ValueError( + f"Neither relative nor absolute threshold configured for alert condition RELATIVE_INCREASE" + ) + + return AlertEvaluationResult(value=decrease, breaches=breaches) + + case _: + raise NotImplementedError(f"Unsupported alert condition type: {condition.type}") + + +def _is_non_time_series_trend(query: TrendsQuery) -> bool: + return bool(query.trendsFilter and query.trendsFilter.display in NON_TIME_SERIES_DISPLAY_TYPES) + + +def _date_range_override_for_intervals(query: TrendsQuery, last_x_intervals: int = 1) -> Optional[dict]: + """ + Resulting filter overrides don't set 'date_to' so we always get value for current interval. + last_x_intervals controls how many intervals to look back to + """ + assert last_x_intervals > 0 + + match query.interval: + case IntervalType.DAY: + date_from = f"-{last_x_intervals}d" + case IntervalType.WEEK: + date_from = f"-{last_x_intervals}w" + case IntervalType.MONTH: + date_from = f"-{last_x_intervals}m" + case _: + date_from = f"-{last_x_intervals}h" + + return {"date_from": date_from} + + +def _pick_interval_value_from_trend_result( + config: TrendsAlertConfig, query: TrendsQuery, results: InsightResult, interval_to_pick: int = 0 +) -> float: + """ + interval_to_pick to controls whether to pick value for current (0), last (-1), one before last (-2)... + """ + assert interval_to_pick <= 0 + + series_index = config.series_index + result = cast(list[TrendResult], results.result)[series_index] + + if _is_non_time_series_trend(query): + # only one value in result + return result["aggregated_value"] + + data = result["data"] + # data is pre sorted in ascending order of timestamps + index_from_back = len(data) - 1 + interval_to_pick + return data[index_from_back] + + +def _validate_bounds( + bounds: InsightsThresholdBounds | None, calculated_value: float, is_percentage: bool = False +) -> list[str]: + if not bounds: + return [] + + formatted_value = f"{calculated_value:.2%}" if is_percentage else calculated_value + + if bounds.lower is not None and calculated_value < bounds.lower: + lower_value = f"{bounds.lower:.2%}" if is_percentage else bounds.lower + return [f"The trend value ({formatted_value}) is below the lower threshold ({lower_value})"] + if bounds.upper is not None and calculated_value > bounds.upper: + upper_value = f"{bounds.upper:.2%}" if is_percentage else bounds.upper + return [f"The trend value ({formatted_value}) is above the upper threshold ({upper_value})"] + + return [] diff --git a/posthog/tasks/alerts/utils.py b/posthog/tasks/alerts/utils.py new file mode 100644 index 0000000000000..06b94cc938089 --- /dev/null +++ b/posthog/tasks/alerts/utils.py @@ -0,0 +1,110 @@ +from dateutil.relativedelta import relativedelta + +from django.utils import timezone +import structlog + +from posthog.email import EmailMessage +from posthog.models import AlertConfiguration +from posthog.schema import ( + ChartDisplayType, + NodeKind, + AlertCalculationInterval, +) +from dataclasses import dataclass + +logger = structlog.get_logger(__name__) + + +@dataclass +class AlertEvaluationResult: + value: float | None + breaches: list[str] | None + + +WRAPPER_NODE_KINDS = [NodeKind.DATA_TABLE_NODE, NodeKind.DATA_VISUALIZATION_NODE, NodeKind.INSIGHT_VIZ_NODE] + +NON_TIME_SERIES_DISPLAY_TYPES = { + ChartDisplayType.BOLD_NUMBER, + ChartDisplayType.ACTIONS_PIE, + ChartDisplayType.ACTIONS_BAR_VALUE, + ChartDisplayType.ACTIONS_TABLE, + ChartDisplayType.WORLD_MAP, +} + + +def calculation_interval_to_order(interval: AlertCalculationInterval | None) -> int: + match interval: + case AlertCalculationInterval.HOURLY: + return 0 + case AlertCalculationInterval.DAILY: + return 1 + case _: + return 2 + + +def alert_calculation_interval_to_relativedelta(alert_calculation_interval: AlertCalculationInterval) -> relativedelta: + match alert_calculation_interval: + case AlertCalculationInterval.HOURLY: + return relativedelta(hours=1) + case AlertCalculationInterval.DAILY: + return relativedelta(days=1) + case AlertCalculationInterval.WEEKLY: + return relativedelta(weeks=1) + case AlertCalculationInterval.MONTHLY: + return relativedelta(months=1) + case _: + raise ValueError(f"Invalid alert calculation interval: {alert_calculation_interval}") + + +def send_notifications_for_breaches(alert: AlertConfiguration, breaches: list[str]) -> None: + subject = f"PostHog alert {alert.name} is firing" + campaign_key = f"alert-firing-notification-{alert.id}-{timezone.now().timestamp()}" + insight_url = f"/project/{alert.team.pk}/insights/{alert.insight.short_id}?alert_id={alert.id}" + alert_url = f"{insight_url}/alerts/{alert.id}" + message = EmailMessage( + campaign_key=campaign_key, + subject=subject, + template_name="alert_check_firing", + template_context={ + "match_descriptions": breaches, + "insight_url": insight_url, + "insight_name": alert.insight.name, + "alert_url": alert_url, + "alert_name": alert.name, + }, + ) + targets = alert.subscribed_users.all().values_list("email", flat=True) + if not targets: + raise RuntimeError(f"no targets configured for the alert {alert.id}") + for target in targets: + message.add_recipient(email=target) + + logger.info(f"Send notifications about {len(breaches)} anomalies", alert_id=alert.id) + message.send() + + +def send_notifications_for_errors(alert: AlertConfiguration, error: dict) -> None: + subject = f"PostHog alert {alert.name} check failed to evaluate" + campaign_key = f"alert-firing-notification-{alert.id}-{timezone.now().timestamp()}" + insight_url = f"/project/{alert.team.pk}/insights/{alert.insight.short_id}?alert_id={alert.id}" + alert_url = f"{insight_url}/alerts/{alert.id}" + message = EmailMessage( + campaign_key=campaign_key, + subject=subject, + template_name="alert_check_firing", + template_context={ + "match_descriptions": error, + "insight_url": insight_url, + "insight_name": alert.insight.name, + "alert_url": alert_url, + "alert_name": alert.name, + }, + ) + targets = alert.subscribed_users.all().values_list("email", flat=True) + if not targets: + raise RuntimeError(f"no targets configured for the alert {alert.id}") + for target in targets: + message.add_recipient(email=target) + + logger.info(f"Send notifications about alert checking error", alert_id=alert.id) + message.send() diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index 4c1114cfb2cdf..16d4ccdacf0d0 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -1,3 +1,5 @@ +import asyncio +import collections import collections.abc import dataclasses import datetime as dt @@ -251,6 +253,135 @@ async def iter_records_from_model_view( yield record_batch +class RecordBatchQueue(asyncio.Queue): + """A queue of pyarrow RecordBatch instances limited by bytes.""" + + def __init__(self, max_size_bytes=0): + super().__init__(maxsize=max_size_bytes) + self._bytes_size = 0 + self._schema_set = asyncio.Event() + self.record_batch_schema = None + # This is set by `asyncio.Queue.__init__` calling `_init` + self._queue: collections.deque + + def _get(self) -> pa.RecordBatch: + """Override parent `_get` to keep track of bytes.""" + item = self._queue.popleft() + self._bytes_size -= item.get_total_buffer_size() + return item + + def _put(self, item: pa.RecordBatch) -> None: + """Override parent `_put` to keep track of bytes.""" + self._bytes_size += item.get_total_buffer_size() + + if not self._schema_set.is_set(): + self.set_schema(item) + + self._queue.append(item) + + def set_schema(self, record_batch: pa.RecordBatch) -> None: + """Used to keep track of schema of events in queue.""" + self.record_batch_schema = record_batch.schema + self._schema_set.set() + + async def get_schema(self) -> pa.Schema: + """Return the schema of events in queue. + + Currently, this is not enforced. It's purely for reporting to users of + the queue what do the record batches look like. It's up to the producer + to ensure all record batches have the same schema. + """ + await self._schema_set.wait() + return self.record_batch_schema + + def qsize(self) -> int: + """Size in bytes of record batches in the queue. + + This is used to determine when the queue is full, so it returns the + number of bytes. + """ + return self._bytes_size + + +def start_produce_batch_export_record_batches( + client: ClickHouseClient, + model_name: str, + is_backfill: bool, + team_id: int, + interval_start: str, + interval_end: str, + fields: list[BatchExportField] | None = None, + destination_default_fields: list[BatchExportField] | None = None, + **parameters, +): + """Start producing batch export record batches from a model query. + + Depending on the model, we issue a query to ClickHouse and initialize a + producer to stream record batches to a queue. Callers can then consume from + this queue as the record batches arrive. The producer runs asynchronously as + a background task, which is returned. + + Returns: + A tuple containing the record batch queue, an event used by the producer + to indicate there is nothing more to produce, and a reference to the + producer task + """ + if fields is None: + if destination_default_fields is None: + fields = default_fields() + else: + fields = destination_default_fields + + if model_name == "persons": + view = SELECT_FROM_PERSONS_VIEW + + else: + if parameters.get("exclude_events", None): + parameters["exclude_events"] = list(parameters["exclude_events"]) + else: + parameters["exclude_events"] = [] + + if parameters.get("include_events", None): + parameters["include_events"] = list(parameters["include_events"]) + else: + parameters["include_events"] = [] + + if str(team_id) in settings.UNCONSTRAINED_TIMESTAMP_TEAM_IDS: + query_template = SELECT_FROM_EVENTS_VIEW_UNBOUNDED + elif is_backfill: + query_template = SELECT_FROM_EVENTS_VIEW_BACKFILL + else: + query_template = SELECT_FROM_EVENTS_VIEW + lookback_days = settings.OVERRIDE_TIMESTAMP_TEAM_IDS.get(team_id, settings.DEFAULT_TIMESTAMP_LOOKBACK_DAYS) + parameters["lookback_days"] = lookback_days + + if "_inserted_at" not in [field["alias"] for field in fields]: + control_fields = [BatchExportField(expression="_inserted_at", alias="_inserted_at")] + else: + control_fields = [] + + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields) + + view = query_template.substitute(fields=query_fields) + + parameters["team_id"] = team_id + parameters["interval_start"] = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S") + parameters["interval_end"] = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S") + extra_query_parameters = parameters.pop("extra_query_parameters", {}) or {} + parameters = {**parameters, **extra_query_parameters} + + queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES) + query_id = uuid.uuid4() + done_event = asyncio.Event() + produce_task = asyncio.create_task( + client.aproduce_query_as_arrow_record_batches( + view, queue=queue, done_event=done_event, query_parameters=parameters, query_id=str(query_id) + ) + ) + + return queue, done_event, produce_task + + def iter_records( client: ClickHouseClient, team_id: int, diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 9da8c89e56e53..521c6b1d92f85 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -3,9 +3,12 @@ import contextlib import dataclasses import datetime as dt +import functools import json +import operator import pyarrow as pa +import structlog from django.conf import settings from google.cloud import bigquery from google.oauth2 import service_account @@ -27,8 +30,8 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_model_records, start_batch_export_run, + start_produce_batch_export_record_batches, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -42,18 +45,19 @@ ) from posthog.temporal.batch_exports.utils import ( JsonType, - apeek_first_and_rewind, cast_record_batch_json_columns, set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater -from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.common.logger import configure_temporal_worker_logger from posthog.temporal.common.utils import ( BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat, ) +logger = structlog.get_logger() + def get_bigquery_fields_from_record_schema( record_schema: pa.Schema, known_json_columns: list[str] @@ -72,6 +76,9 @@ def get_bigquery_fields_from_record_schema( bq_schema: list[bigquery.SchemaField] = [] for name in record_schema.names: + if name == "_inserted_at": + continue + pa_field = record_schema.field(name) if pa.types.is_string(pa_field.type) or isinstance(pa_field.type, JsonType): @@ -264,8 +271,13 @@ async def load_parquet_file(self, parquet_file, table, table_schema): schema=table_schema, ) - load_job = self.load_table_from_file(parquet_file, table, job_config=job_config, rewind=True) - return await asyncio.to_thread(load_job.result) + await logger.adebug("Creating BigQuery load job for Parquet file '%s'", parquet_file) + load_job = await asyncio.to_thread( + self.load_table_from_file, parquet_file, table, job_config=job_config, rewind=True + ) + await logger.adebug("Waiting for BigQuery load job for Parquet file '%s'", parquet_file) + result = await asyncio.to_thread(load_job.result) + return result async def load_jsonl_file(self, jsonl_file, table, table_schema): """Execute a COPY FROM query with given connection to copy contents of jsonl_file.""" @@ -274,8 +286,14 @@ async def load_jsonl_file(self, jsonl_file, table, table_schema): schema=table_schema, ) - load_job = self.load_table_from_file(jsonl_file, table, job_config=job_config, rewind=True) - return await asyncio.to_thread(load_job.result) + await logger.adebug("Creating BigQuery load job for JSONL file '%s'", jsonl_file) + load_job = await asyncio.to_thread( + self.load_table_from_file, jsonl_file, table, job_config=job_config, rewind=True + ) + + await logger.adebug("Waiting for BigQuery load job for JSONL file '%s'", jsonl_file) + result = await asyncio.to_thread(load_job.result) + return result @contextlib.contextmanager @@ -327,7 +345,9 @@ def bigquery_default_fields() -> list[BatchExportField]: @activity.defn async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to BigQuery.""" - logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="BigQuery") + logger = await configure_temporal_worker_logger( + logger=structlog.get_logger(), team_id=inputs.team_id, destination="BigQuery" + ) await logger.ainfo( "Batch exporting range %s - %s to BigQuery: %s.%s.%s", inputs.data_interval_start, @@ -357,24 +377,52 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records field.name for field in dataclasses.fields(inputs) }: model = inputs.batch_export_model + if model is not None: + model_name = model.name + extra_query_parameters = model.schema["values"] if model.schema is not None else None + fields = model.schema["fields"] if model.schema is not None else None + else: + model_name = "events" + extra_query_parameters = None + fields = None else: model = inputs.batch_export_schema + model_name = "custom" + extra_query_parameters = model["values"] if model is not None else {} + fields = model["fields"] if model is not None else None - records_iterator = iter_model_records( + queue, done_event, produce_task = start_produce_batch_export_record_batches( client=client, - model=model, + model_name=model_name, + is_backfill=inputs.is_backfill, team_id=inputs.team_id, interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, + fields=fields, destination_default_fields=bigquery_default_fields(), - is_backfill=inputs.is_backfill, + extra_query_parameters=extra_query_parameters, ) - first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator) - if first_record_batch is None: + get_schema_task = asyncio.create_task(queue.get_schema()) + wait_for_producer_done_task = asyncio.create_task(done_event.wait()) + + await asyncio.wait([get_schema_task, wait_for_producer_done_task], return_when=asyncio.FIRST_COMPLETED) + + # Finishing producing happens sequentially after putting to queue and setting the schema. + # So, either we finished both tasks, or we finished without putting anything in the queue. + if get_schema_task.done(): + # In the first case, we'll land here. + # The schema is available, and the queue is not empty, so we can start the batch export. + record_batch_schema = get_schema_task.result() + elif wait_for_producer_done_task.done(): + # In the second case, we'll land here. + # The schema is not available as the queue is empty. + # Since we finished producing with an empty queue, there is nothing to batch export. return 0 + else: + raise Exception("Unreachable") if inputs.use_json_type is True: json_type = "JSON" @@ -383,8 +431,6 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records json_type = "STRING" json_columns = [] - first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns) - if model is None or (isinstance(model, BatchExportModel) and model.name == "events"): schema = [ bigquery.SchemaField("uuid", "STRING"), @@ -401,9 +447,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"), ] else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema - schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns) + schema = get_bigquery_fields_from_record_schema(record_batch_schema, known_json_columns=json_columns) rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() @@ -446,41 +490,47 @@ async def flush_to_bigquery( last: bool, error: Exception | None, ): + table = bigquery_stage_table if requires_merge else bigquery_table await logger.adebug( - "Loading %s records of size %s bytes", + "Loading %s records of size %s bytes to BigQuery table '%s'", records_since_last_flush, bytes_since_last_flush, + table, ) - table = bigquery_stage_table if requires_merge else bigquery_table await bq_client.load_jsonl_file(local_results_file, table, schema) + await logger.adebug("Loading to BigQuery table '%s' finished", table) rows_exported.add(records_since_last_flush) bytes_exported.add(bytes_since_last_flush) heartbeater.details = (str(last_inserted_at),) - record_schema = pa.schema( - # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other - # record batches have them as nullable. - # Until we figure it out, we set all fields to nullable. There are some fields we know - # are not nullable, but I'm opting for the more flexible option until we out why schemas differ - # between batches. - [ - field.with_nullable(True) - for field in first_record_batch.select([field.name for field in schema]).schema - ] - ) - writer = JSONLBatchExportWriter( - max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, - flush_callable=flush_to_bigquery, - ) + flush_tasks = [] + while not queue.empty() or not done_event.is_set(): + await logger.adebug("Starting record batch writer") + flush_start_event = asyncio.Event() + task = asyncio.create_task( + consume_batch_export_record_batches( + queue, + done_event, + flush_start_event, + flush_to_bigquery, + json_columns, + settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, + ) + ) + + await flush_start_event.wait() - async with writer.open_temporary_file(): - async for record_batch in records_iterator: - record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) + flush_tasks.append(task) + + await logger.adebug( + "Finished producing and consuming all record batches, now waiting on any pending flush tasks" + ) + await asyncio.wait(flush_tasks) - await writer.write_record_batch(record_batch) + records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks)) if requires_merge: merge_key = ( @@ -494,7 +544,74 @@ async def flush_to_bigquery( update_fields=schema, ) - return writer.records_total + return records_total + + +async def consume_batch_export_record_batches( + queue: asyncio.Queue, + done_event: asyncio.Event, + flush_start_event: asyncio.Event, + flush_to_bigquery: FlushCallable, + json_columns: list[str], + max_bytes: int, +): + """Consume batch export record batches from queue into a writing loop. + + Each record will be written to a temporary file, and flushed after + configured `max_bytes`. Flush is done on context manager exit by + `JSONLBatchExportWriter`. + + This coroutine reports when flushing will start by setting the + `flush_start_event`. This is used by the main thread to start a new writer + task as flushing is about to begin, since that can be too slow to do + sequentially. + + If there are not enough events to fill up `max_bytes`, the writing + loop will detect that there are no more events produced and shut itself off + by using the `done_event`, which should be set by the queue producer. + + Arguments: + queue: The queue we will be listening on for record batches. + done_event: Event set by producer when done. + flush_to_start_event: Event set by us when flushing is to about to + start. + json_columns: Used to cast columns of the record batch to JSON. + max_bytes: Max bytes to write before flushing. + + Returns: + Number of total records written and flushed in this task. + """ + writer = JSONLBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_to_bigquery, + ) + + async with writer.open_temporary_file(): + await logger.adebug("Starting record batch writing loop") + while True: + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if done_event.is_set(): + await logger.adebug("Empty queue with no more events being produced, closing writer loop") + flush_start_event.set() + # Exit context manager to trigger flush + break + else: + await asyncio.sleep(0.1) + continue + + record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) + await writer.write_record_batch(record_batch, flush=False) + + if writer.should_flush(): + await logger.adebug("Writer finished, ready to flush events") + flush_start_event.set() + # Exit context manager to trigger flush + break + + await logger.adebug("Completed %s records", writer.records_total) + return writer.records_total def get_batch_export_writer( diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 4d7dc45df5496..97d20bc785e09 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -96,6 +96,9 @@ def __exit__(self, exc, value, tb): def __iter__(self): yield from self._file + def __str__(self) -> str: + return self._file.name + @property def brotli_compressor(self): if self._brotli_compressor is None: @@ -387,7 +390,7 @@ def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> No self.bytes_total = batch_export_file.bytes_total self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset - async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: + async def write_record_batch(self, record_batch: pa.RecordBatch, flush: bool = True) -> None: """Issue a record batch write tracking progress and flushing if required.""" record_batch = record_batch.sort_by("_inserted_at") last_inserted_at = record_batch.column("_inserted_at")[-1].as_py() @@ -401,9 +404,12 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: self.track_records_written(record_batch) self.track_bytes_written(self.batch_export_file) - if self.bytes_since_last_flush >= self.max_bytes: + if flush and self.should_flush(): await self.flush(last_inserted_at) + def should_flush(self) -> bool: + return self.bytes_since_last_flush >= self.max_bytes + async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: """Call the provided `flush_callable` and reset underlying file. diff --git a/posthog/temporal/common/asyncpa.py b/posthog/temporal/common/asyncpa.py index 31eab18d02928..d76dffb5ecb9c 100644 --- a/posthog/temporal/common/asyncpa.py +++ b/posthog/temporal/common/asyncpa.py @@ -1,6 +1,10 @@ +import asyncio import typing import pyarrow as pa +import structlog + +logger = structlog.get_logger() CONTINUATION_BYTES = b"\xff\xff\xff\xff" @@ -128,3 +132,20 @@ async def read_schema(self) -> pa.Schema: raise TypeError(f"Expected message of type 'schema' got '{message.type}'") return pa.ipc.read_schema(message) + + +class AsyncRecordBatchProducer(AsyncRecordBatchReader): + def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None: + super().__init__(bytes_iter) + + async def produce(self, queue: asyncio.Queue, done_event: asyncio.Event): + await logger.adebug("Starting record batch produce loop") + while True: + try: + record_batch = await self.read_next_record_batch() + except StopAsyncIteration: + await logger.adebug("No more record batches to produce, closing loop") + done_event.set() + return + + await queue.put(record_batch) diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index 485eb68901e21..570cfe8d5bb5e 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -1,3 +1,4 @@ +import asyncio import collections.abc import contextlib import datetime as dt @@ -11,7 +12,7 @@ import requests from django.conf import settings -from posthog.temporal.common.asyncpa import AsyncRecordBatchReader +import posthog.temporal.common.asyncpa as asyncpa def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: @@ -383,13 +384,31 @@ async def astream_query_as_arrow( """Execute the given query in ClickHouse and stream back the response as Arrow record batches. This method makes sense when running with FORMAT ArrowStream, although we currently do not enforce this. - As pyarrow doesn't support async/await buffers, this method is sync and utilizes requests instead of aiohttp. """ async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: - reader = AsyncRecordBatchReader(response.content.iter_chunks()) + reader = asyncpa.AsyncRecordBatchReader(response.content.iter_chunks()) async for batch in reader: yield batch + async def aproduce_query_as_arrow_record_batches( + self, + query, + *data, + queue: asyncio.Queue, + done_event: asyncio.Event, + query_parameters=None, + query_id: str | None = None, + ) -> None: + """Execute the given query in ClickHouse and produce Arrow record batches to given buffer queue. + + This method makes sense when running with FORMAT ArrowStream, although we currently do not enforce this. + This method is intended to be ran as a background task, producing record batches continuously, while other + downstream consumer tasks process them from the queue. + """ + async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: + reader = asyncpa.AsyncRecordBatchProducer(response.content.iter_chunks()) + await reader.produce(queue=queue, done_event=done_event) + async def __aenter__(self): """Enter method part of the AsyncContextManager protocol.""" self.connector = aiohttp.TCPConnector(ssl=self.ssl) diff --git a/posthog/temporal/common/logger.py b/posthog/temporal/common/logger.py index c769116921f6c..2b1107d8124cc 100644 --- a/posthog/temporal/common/logger.py +++ b/posthog/temporal/common/logger.py @@ -1,8 +1,8 @@ import asyncio import json import logging -import uuid import ssl +import uuid import aiokafka import structlog @@ -14,7 +14,6 @@ from posthog.kafka_client.topics import KAFKA_LOG_ENTRIES - BACKGROUND_LOGGER_TASKS = set() @@ -29,6 +28,18 @@ async def bind_temporal_worker_logger(team_id: int, destination: str | None = No return logger.new(team_id=team_id, destination=destination, **temporal_context) +async def configure_temporal_worker_logger( + logger, team_id: int, destination: str | None = None +) -> FilteringBoundLogger: + """Return a bound logger for Temporal Workers.""" + if not structlog.is_configured(): + configure_logger() + + temporal_context = get_temporal_context() + + return logger.new(team_id=team_id, destination=destination, **temporal_context) + + async def bind_temporal_org_worker_logger( organization_id: uuid.UUID, destination: str | None = None ) -> FilteringBoundLogger: diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index dda307dda004a..8c3fb186b82cd 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -2,15 +2,19 @@ import json import operator from random import randint +import asyncio import pytest from django.test import override_settings +import pyarrow as pa from posthog.batch_exports.service import BatchExportModel from posthog.temporal.batch_exports.batch_exports import ( get_data_interval, iter_model_records, iter_records, + start_produce_batch_export_record_batches, + RecordBatchQueue, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -404,3 +408,427 @@ def test_get_data_interval(interval, data_interval_end, expected): """Test get_data_interval returns the expected data interval tuple.""" result = get_data_interval(interval, data_interval_end) assert result == expected + + +async def get_record_batch_from_queue(queue, done_event): + while not queue.empty() or not done_event.is_set(): + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if done_event.is_set(): + break + else: + await asyncio.sleep(0.1) + continue + + return record_batch + return None + + +async def test_start_produce_batch_export_record_batches_uses_extra_query_parameters(clickhouse_client): + """Test start_produce_batch_export_record_batches uses a HogQL value.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X", "custom": 3}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[ + {"expression": "JSONExtractInt(properties, %(hogql_val_0)s)", "alias": "custom_prop"}, + ], + extra_query_parameters={"hogql_val_0": "custom"}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + for expected, record in zip(events, records): + if expected["properties"] is None: + raise ValueError("Empty properties") + + assert record["custom_prop"] == expected["properties"]["custom"] + + +async def test_start_produce_batch_export_record_batches_can_flatten_properties(clickhouse_client): + """Test start_produce_batch_export_record_batches can flatten properties.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X", "custom-property": 3}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[ + {"expression": "event", "alias": "event"}, + {"expression": "JSONExtractString(properties, '$browser')", "alias": "browser"}, + {"expression": "JSONExtractString(properties, '$os')", "alias": "os"}, + {"expression": "JSONExtractInt(properties, 'custom-property')", "alias": "custom_prop"}, + ], + extra_query_parameters={"hogql_val_0": "custom"}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + all_expected = sorted(events, key=operator.itemgetter("event")) + all_record = sorted(records, key=operator.itemgetter("event")) + + for expected, record in zip(all_expected, all_record): + if expected["properties"] is None: + raise ValueError("Empty properties") + + assert record["browser"] == expected["properties"]["$browser"] + assert record["os"] == expected["properties"]["$os"] + assert record["custom_prop"] == expected["properties"]["custom-property"] + + +@pytest.mark.parametrize( + "field", + [ + {"expression": "event", "alias": "event_name"}, + {"expression": "team_id", "alias": "team"}, + {"expression": "timestamp", "alias": "time_the_stamp"}, + {"expression": "created_at", "alias": "creation_time"}, + ], +) +async def test_start_produce_batch_export_record_batches_with_single_field_and_alias(clickhouse_client, field): + """Test start_produce_batch_export_record_batches can return a single aliased field.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[field], + extra_query_parameters={}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + all_expected = sorted(events, key=operator.itemgetter(field["expression"])) + all_record = sorted(records, key=operator.itemgetter(field["alias"])) + + for expected, record in zip(all_expected, all_record): + assert len(record) == 2 + # Always set for progress tracking + assert record.get("_inserted_at", None) is not None + + result = record[field["alias"]] + expected_value = expected[field["expression"]] # type: ignore + + if isinstance(result, dt.datetime): + # Event generation function returns datetimes as strings. + expected_value = dt.datetime.fromisoformat(expected_value).replace(tzinfo=dt.UTC) + + assert result == expected_value + + +async def test_start_produce_batch_export_record_batches_ignores_timestamp_predicates(clickhouse_client): + """Test the rows returned ignore timestamp predicates when configured.""" + team_id = randint(1, 1000000) + + inserted_at = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + data_interval_end = inserted_at + dt.timedelta(hours=1) + + # Insert some data with timestamps a couple of years before inserted_at + timestamp_start = inserted_at - dt.timedelta(hours=24 * 365 * 2) + timestamp_end = inserted_at - dt.timedelta(hours=24 * 365) + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=timestamp_start, + end_time=timestamp_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + inserted_at=inserted_at, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=inserted_at.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert len(records) == 0 + + with override_settings(UNCONSTRAINED_TIMESTAMP_TEAM_IDS=[str(team_id)]): + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=inserted_at.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events) + + +async def test_start_produce_batch_export_record_batches_can_include_events(clickhouse_client): + """Test the rows returned can include events.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10000, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + # Include the latter half of events. + include_events = (event["event"] for event in events[5000:]) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + include_events=include_events, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events[5000:]) + + +async def test_start_produce_batch_export_record_batches_can_exclude_events(clickhouse_client): + """Test the rows returned can include events.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10000, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + # Exclude the latter half of events. + exclude_events = (event["event"] for event in events[5000:]) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events[:5000]) + + +async def test_start_produce_batch_export_record_batches_handles_duplicates(clickhouse_client): + """Test the rows returned are de-duplicated.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=100, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events) + + +async def test_record_batch_queue_tracks_bytes(): + """Test `RecordBatchQueue` tracks bytes from `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_raises_queue_full(): + """Test `QueueFull` is raised when we put too many bytes.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + record_batch_size = record_batch.get_total_buffer_size() + + queue = RecordBatchQueue(max_size_bytes=record_batch_size) + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + with pytest.raises(asyncio.QueueFull): + queue.put_nowait(record_batch) + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_sets_schema(): + """Test `RecordBatchQueue` sets a schema from first `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + + assert queue._schema_set.is_set() + + schema = await queue.get_schema() + assert schema == record_batch.schema diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index 0f184b79356a1..00228adcb8cff 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -105,7 +105,12 @@ async def assert_clickhouse_records_in_bigquery( inserted_bq_ingested_timestamp.append(v) continue - inserted_record[k] = json.loads(v) if k in json_columns and v is not None else v + if k in json_columns: + assert ( + isinstance(v, dict) or v is None + ), f"Expected '{k}' to be JSON, but it was not deserialized to dict" + + inserted_record[k] = v inserted_records.append(inserted_record) diff --git a/posthog/utils.py b/posthog/utils.py index 5fc94a7722dbf..7535df0700638 100644 --- a/posthog/utils.py +++ b/posthog/utils.py @@ -175,8 +175,14 @@ def relative_date_parse_with_delta_mapping( *, always_truncate: bool = False, now: Optional[datetime.datetime] = None, + increase: bool = False, ) -> tuple[datetime.datetime, Optional[dict[str, int]], str | None]: - """Returns the parsed datetime, along with the period mapping - if the input was a relative datetime string.""" + """ + Returns the parsed datetime, along with the period mapping - if the input was a relative datetime string. + + :increase controls whether to add relative delta to the current time or subtract + Should later control this using +/- infront of the input regex + """ try: try: # This supports a few formats, but we primarily care about: @@ -245,9 +251,13 @@ def relative_date_parse_with_delta_mapping( delta_mapping["month"] = 1 delta_mapping["day"] = 1 elif match.group("position") == "End": - delta_mapping["month"] = 12 delta_mapping["day"] = 31 - parsed_dt -= relativedelta(**delta_mapping) # type: ignore + + if increase: + parsed_dt += relativedelta(**delta_mapping) # type: ignore + else: + parsed_dt -= relativedelta(**delta_mapping) # type: ignore + if always_truncate: # Truncate to the start of the hour for hour-precision datetimes, to the start of the day for larger intervals # TODO: Remove this from this function, this should not be the responsibility of it @@ -264,8 +274,11 @@ def relative_date_parse( *, always_truncate: bool = False, now: Optional[datetime.datetime] = None, + increase: bool = False, ) -> datetime.datetime: - return relative_date_parse_with_delta_mapping(input, timezone_info, always_truncate=always_truncate, now=now)[0] + return relative_date_parse_with_delta_mapping( + input, timezone_info, always_truncate=always_truncate, now=now, increase=increase + )[0] def get_js_url(request: HttpRequest) -> str: diff --git a/requirements.in b/requirements.in index 4fef89511c686..45151b4d5d38c 100644 --- a/requirements.in +++ b/requirements.in @@ -39,7 +39,7 @@ drf-exceptions-hog==0.4.0 drf-extensions==0.7.0 drf-spectacular==0.27.2 geoip2==4.6.0 -google-cloud-bigquery==3.11.4 +google-cloud-bigquery==3.26 gunicorn==20.1.0 infi-clickhouse-orm@ git+https://github.com/PostHog/infi.clickhouse_orm@9578c79f29635ee2c1d01b7979e89adab8383de2 kafka-python==2.0.2 diff --git a/requirements.txt b/requirements.txt index d7ed441fe2c61..b9fdf3b435d36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -244,16 +244,17 @@ google-api-core==2.11.1 google-auth==2.22.0 # via # google-api-core + # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core # sqlalchemy-bigquery -google-cloud-bigquery==3.11.4 +google-cloud-bigquery==3.26.0 # via # -r requirements.in # sqlalchemy-bigquery google-cloud-bigquery-storage==2.26.0 # via sqlalchemy-bigquery -google-cloud-core==2.3.3 +google-cloud-core==2.4.1 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media @@ -263,10 +264,11 @@ googleapis-common-protos==1.60.0 # via # google-api-core # grpcio-status +greenlet==3.1.1 + # via sqlalchemy grpcio==1.57.0 # via # google-api-core - # google-cloud-bigquery # grpcio-status # sqlalchemy-bigquery grpcio-status==1.57.0 @@ -445,13 +447,10 @@ prometheus-client==0.14.1 prompt-toolkit==3.0.39 # via click-repl proto-plus==1.22.3 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage + # via google-cloud-bigquery-storage protobuf==4.22.1 # via # google-api-core - # google-cloud-bigquery # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status