- {appearance.thankYouMessageDescription || ''}
-
- powered by {posthogLogoSVG} PostHog
+
+
+
+ setShowThankYou(false)}>
+ {cancel}
+
+
{appearance?.thankYouMessageHeader || 'Thank you!'}
+
{appearance?.thankYouMessageDescription || ''}
+
setShowThankYou(false)}>
+ Close
+
+ {!appearance.whiteLabel && (
+
+ Survey by {posthogLogoSVG}
+
+ )}
)
diff --git a/frontend/src/scenes/surveys/SurveyAppearanceUtils.tsx b/frontend/src/scenes/surveys/SurveyAppearanceUtils.tsx
index c0876cb33bf2b..57ecfd93fde3b 100644
--- a/frontend/src/scenes/surveys/SurveyAppearanceUtils.tsx
+++ b/frontend/src/scenes/surveys/SurveyAppearanceUtils.tsx
@@ -1,27 +1,3 @@
-export const posthogLogoSVG = (
-
-
-
-
-
-
-
-)
export const satisfiedEmoji = (
@@ -47,3 +23,84 @@ export const verySatisfiedEmoji = (
)
+export const cancel = (
+
+
+
+)
+export const check = (
+
+
+
+)
+export const posthogLogoSVG = (
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+)
+export function getTextColor(el: never): string {
+ const backgroundColor = window.getComputedStyle(el).backgroundColor
+ if (backgroundColor === 'rgba(0, 0, 0, 0)') {
+ return 'black'
+ }
+ const colorMatch = backgroundColor.match(/^rgba?\((\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+(?:\.\d+)?))?\)$/)
+ if (!colorMatch) {
+ return 'black'
+ }
+ const r = parseInt(colorMatch[1]),
+ g = parseInt(colorMatch[2]),
+ b = parseInt(colorMatch[3]),
+ hsp = Math.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b))
+ return hsp > 127.5 ? 'black' : 'white'
+}
diff --git a/frontend/src/scenes/surveys/SurveyView.tsx b/frontend/src/scenes/surveys/SurveyView.tsx
index 20aae5ae7af40..7f4fb0e241eab 100644
--- a/frontend/src/scenes/surveys/SurveyView.tsx
+++ b/frontend/src/scenes/surveys/SurveyView.tsx
@@ -10,7 +10,7 @@ import { capitalizeFirstLetter } from 'lib/utils'
import { useState, useEffect } from 'react'
import { pluginsLogic } from 'scenes/plugins/pluginsLogic'
import { Query } from '~/queries/Query/Query'
-import { defaultSurveyAppearance, surveyEventName, surveyLogic } from './surveyLogic'
+import { surveyLogic } from './surveyLogic'
import { surveysLogic } from './surveysLogic'
import { PageHeader } from 'lib/components/PageHeader'
import { SurveyReleaseSummary } from './Survey'
@@ -21,6 +21,7 @@ import { LemonBanner } from 'lib/lemon-ui/LemonBanner'
import { IconOpenInNew } from 'lib/lemon-ui/icons'
import { NodeKind } from '~/queries/schema'
import { dayjs } from 'lib/dayjs'
+import { defaultSurveyAppearance, SURVEY_EVENT_NAME } from './constants'
import { FEATURE_FLAGS } from 'lib/constants'
import { featureFlagLogic } from 'lib/logic/featureFlagLogic'
@@ -326,7 +327,7 @@ function SurveyNPSResults({ survey }: { survey: Survey }): JSX.Element {
},
series: [
{
- event: surveyEventName,
+ event: SURVEY_EVENT_NAME,
kind: NodeKind.EventsNode,
custom_name: 'Promoters',
properties: [
@@ -339,7 +340,7 @@ function SurveyNPSResults({ survey }: { survey: Survey }): JSX.Element {
],
},
{
- event: surveyEventName,
+ event: SURVEY_EVENT_NAME,
kind: NodeKind.EventsNode,
custom_name: 'Passives',
properties: [
@@ -352,7 +353,7 @@ function SurveyNPSResults({ survey }: { survey: Survey }): JSX.Element {
],
},
{
- event: surveyEventName,
+ event: SURVEY_EVENT_NAME,
kind: NodeKind.EventsNode,
custom_name: 'Detractors',
properties: [
diff --git a/frontend/src/scenes/surveys/Surveys.stories.tsx b/frontend/src/scenes/surveys/Surveys.stories.tsx
index 61a9edb6bb187..7bd9755f5b841 100644
--- a/frontend/src/scenes/surveys/Surveys.stories.tsx
+++ b/frontend/src/scenes/surveys/Surveys.stories.tsx
@@ -27,7 +27,7 @@ const MOCK_BASIC_SURVEY: Survey = {
linked_flag_id: null,
targeting_flag: null,
targeting_flag_filters: undefined,
- appearance: { backgroundColor: 'white', textColor: 'black', submitButtonColor: '#2C2C2C' },
+ appearance: { backgroundColor: 'white', submitButtonColor: '#2C2C2C' },
start_date: null,
end_date: null,
archived: false,
@@ -47,7 +47,7 @@ const MOCK_SURVEY_WITH_RELEASE_CONS: Survey = {
email: 'test2@posthog.com',
},
questions: [{ question: 'question 2?', type: SurveyQuestionType.Open }],
- appearance: { backgroundColor: 'white', textColor: 'black', submitButtonColor: '#2C2C2C' },
+ appearance: { backgroundColor: 'white', submitButtonColor: '#2C2C2C' },
conditions: { url: 'posthog', selector: '' },
linked_flag: {
id: 7,
diff --git a/frontend/src/scenes/surveys/Surveys.tsx b/frontend/src/scenes/surveys/Surveys.tsx
index de69d4c030811..e43e6c90e3cba 100644
--- a/frontend/src/scenes/surveys/Surveys.tsx
+++ b/frontend/src/scenes/surveys/Surveys.tsx
@@ -1,4 +1,14 @@
-import { LemonButton, LemonTable, LemonDivider, Link, LemonTag, LemonTagType, Spinner } from '@posthog/lemon-ui'
+import {
+ LemonButton,
+ LemonDivider,
+ LemonInput,
+ LemonSelect,
+ LemonTable,
+ Link,
+ LemonTag,
+ LemonTagType,
+ Spinner,
+} from '@posthog/lemon-ui'
import { PageHeader } from 'lib/components/PageHeader'
import { More } from 'lib/lemon-ui/LemonButton/More'
import stringWithWBR from 'lib/utils/stringWithWBR'
@@ -14,13 +24,13 @@ import { LemonTabs } from 'lib/lemon-ui/LemonTabs'
import { useState } from 'react'
import { ProductIntroduction } from 'lib/components/ProductIntroduction/ProductIntroduction'
import { userLogic } from 'scenes/userLogic'
-import { LemonSkeleton } from 'lib/lemon-ui/LemonSkeleton'
import { dayjs } from 'lib/dayjs'
import { VersionCheckerBanner } from 'lib/components/VersionChecker/VersionCheckerBanner'
import { teamLogic } from 'scenes/teamLogic'
import { LemonBanner } from 'lib/lemon-ui/LemonBanner'
import { IconSettings } from 'lib/lemon-ui/icons'
import { openSurveysSettingsDialog } from './SurveySettings'
+import { SurveyQuestionLabel } from './constants'
import { FEATURE_FLAGS } from 'lib/constants'
import { featureFlagLogic } from 'lib/logic/featureFlagLogic'
@@ -30,34 +40,37 @@ export const scene: SceneExport = {
}
export enum SurveysTabs {
- All = 'all',
+ Active = 'active',
Yours = 'yours',
Archived = 'archived',
}
export function Surveys(): JSX.Element {
const {
- nonArchivedSurveys,
- archivedSurveys,
surveys,
+ searchedSurveys,
surveysLoading,
surveysResponsesCount,
surveysResponsesCountLoading,
usingSurveysSiteApp,
+ searchTerm,
+ filters,
+ uniqueCreators,
} = useValues(surveysLogic)
- const { deleteSurvey, updateSurvey } = useActions(surveysLogic)
+ const { deleteSurvey, updateSurvey, setSearchTerm, setSurveysFilters } = useActions(surveysLogic)
+
const { user } = useValues(userLogic)
const { featureFlags } = useValues(featureFlagLogic)
const { currentTeam } = useValues(teamLogic)
const surveysPopupDisabled = currentTeam && !currentTeam?.surveys_opt_in
- const [tab, setSurveyTab] = useState(SurveysTabs.All)
+ const [tab, setSurveyTab] = useState(SurveysTabs.Active)
const shouldShowEmptyState = !surveysLoading && surveys.length === 0
return (
-
+
@@ -84,10 +97,13 @@ export function Surveys(): JSX.Element {
/>
setSurveyTab(newTab)}
+ onChange={(newTab) => {
+ setSurveyTab(newTab)
+ setSurveysFilters({ ...filters, archived: newTab === SurveysTabs.Archived })
+ }}
tabs={[
- { key: SurveysTabs.All, label: 'All surveys' },
- { key: SurveysTabs.Archived, label: 'Archived surveys' },
+ { key: SurveysTabs.Active, label: 'Active' },
+ { key: SurveysTabs.Archived, label: 'Archived' },
]}
/>
{featureFlags[FEATURE_FLAGS.SURVEYS_SITE_APP_DEPRECATION] && (
@@ -111,25 +127,72 @@ export function Surveys(): JSX.Element {
) : null}
)}
- {surveysLoading ? (
-
- ) : (
- <>
- {(shouldShowEmptyState || !user?.has_seen_product_intro_for?.[ProductKey.SURVEYS]) && (
-
router.actions.push(urls.survey('new'))}
- isEmpty={surveys.length === 0}
- productKey={ProductKey.SURVEYS}
- />
- )}
- {!shouldShowEmptyState && (
+
+ <>
+ {(shouldShowEmptyState || !user?.has_seen_product_intro_for?.[ProductKey.SURVEYS]) && (
+ router.actions.push(urls.survey('new'))}
+ isEmpty={surveys.length === 0}
+ productKey={ProductKey.SURVEYS}
+ />
+ )}
+ {!shouldShowEmptyState && (
+ <>
+
+
+
+
+
+ Status
+
+ {
+ setSurveysFilters({ status })
+ }}
+ options={[
+ { label: 'Any', value: 'any' },
+ { label: 'Draft', value: 'draft' },
+ { label: 'Running', value: 'running' },
+ { label: 'Complete', value: 'complete' },
+ ]}
+ value={filters.status}
+ />
+
+ Created by
+
+ {
+ setSurveysFilters({ created_by: user })
+ }}
+ options={uniqueCreators}
+ value={filters.created_by}
+ />
+
+
+
() as LemonTableColumn,
createdAtColumn() as LemonTableColumn,
@@ -280,18 +351,10 @@ export function Surveys(): JSX.Element {
},
},
]}
- dataSource={tab === SurveysTabs.Archived ? archivedSurveys : nonArchivedSurveys}
- defaultSorting={{
- columnKey: 'created_at',
- order: -1,
- }}
- nouns={['survey', 'surveys']}
- data-attr="surveys-table"
- emptyState="No surveys. Create a new survey?"
/>
- )}
- >
- )}
+ >
+ )}
+ >
)
}
diff --git a/frontend/src/scenes/surveys/constants.ts b/frontend/src/scenes/surveys/constants.ts
new file mode 100644
index 0000000000000..fa14c9310288c
--- /dev/null
+++ b/frontend/src/scenes/surveys/constants.ts
@@ -0,0 +1,137 @@
+import { FeatureFlagFilters, Survey, SurveyQuestionType, SurveyType } from '~/types'
+
+export const SURVEY_EVENT_NAME = 'survey sent'
+export const SURVEY_RESPONSE_PROPERTY = '$survey_response'
+
+export const SurveyQuestionLabel = {
+ [SurveyQuestionType.Open]: 'Open text',
+ [SurveyQuestionType.Rating]: 'Rating',
+ [SurveyQuestionType.Link]: 'Link',
+ [SurveyQuestionType.SingleChoice]: 'Single choice select',
+ [SurveyQuestionType.MultipleChoice]: 'Multiple choice select',
+}
+
+export const defaultSurveyAppearance = {
+ backgroundColor: 'white',
+ textColor: 'black',
+ submitButtonText: 'Submit',
+ submitButtonColor: '#2c2c2c',
+ ratingButtonColor: '#e0e2e8',
+ descriptionTextColor: '#4b4b52',
+ whiteLabel: false,
+ displayThankYouMessage: true,
+ placeholder: '',
+ position: 'right',
+ thankYouMessageHeader: 'Thank you for your feedback!',
+}
+
+export const defaultSurveyFieldValues = {
+ [SurveyQuestionType.Open]: {
+ questions: [
+ {
+ question: 'Give us feedback on our product!',
+ description: '',
+ },
+ ],
+ appearance: {
+ submitButtonText: 'Submit',
+ thankYouMessageHeader: 'Thank you for your feedback!',
+ },
+ },
+ [SurveyQuestionType.Link]: {
+ questions: [
+ {
+ question: 'Do you want to join our upcoming webinar?',
+ description: '',
+ },
+ ],
+ appearance: {
+ submitButtonText: 'Register',
+ thankYouMessageHeader: 'Redirecting ...',
+ },
+ },
+ [SurveyQuestionType.Rating]: {
+ questions: [
+ {
+ question: 'How likely are you to recommend us to a friend?',
+ description: '',
+ display: 'number',
+ scale: 10,
+ lowerBoundLabel: 'Unlikely',
+ upperBoundLabel: 'Very likely',
+ },
+ ],
+ appearance: {
+ thankYouMessageHeader: 'Thank you for your feedback!',
+ },
+ },
+ [SurveyQuestionType.SingleChoice]: {
+ questions: [
+ {
+ question: 'Have you found this tutorial useful?',
+ description: '',
+ choices: ['Yes', 'No'],
+ },
+ ],
+ appearance: {
+ submitButtonText: 'Submit',
+ thankYouMessageHeader: 'Thank you for your feedback!',
+ },
+ },
+ [SurveyQuestionType.MultipleChoice]: {
+ questions: [
+ {
+ question: 'Which types of content would you like to see more of?',
+ description: '',
+ choices: ['Tutorials', 'Customer case studies', 'Product announcements'],
+ },
+ ],
+ appearance: {
+ submitButtonText: 'Submit',
+ thankYouMessageHeader: 'Thank you for your feedback!',
+ },
+ },
+}
+
+export interface NewSurvey
+ extends Pick<
+ Survey,
+ | 'name'
+ | 'description'
+ | 'type'
+ | 'conditions'
+ | 'questions'
+ | 'start_date'
+ | 'end_date'
+ | 'linked_flag'
+ | 'targeting_flag'
+ | 'archived'
+ | 'appearance'
+ > {
+ id: 'new'
+ linked_flag_id: number | undefined
+ targeting_flag_filters: Pick
| undefined
+}
+
+export const NEW_SURVEY: NewSurvey = {
+ id: 'new',
+ name: '',
+ description: '',
+ questions: [
+ {
+ type: SurveyQuestionType.Open,
+ question: defaultSurveyFieldValues[SurveyQuestionType.Open].questions[0].question,
+ description: defaultSurveyFieldValues[SurveyQuestionType.Open].questions[0].description,
+ },
+ ],
+ type: SurveyType.Popover,
+ linked_flag_id: undefined,
+ targeting_flag_filters: undefined,
+ linked_flag: null,
+ targeting_flag: null,
+ start_date: null,
+ end_date: null,
+ conditions: null,
+ archived: false,
+ appearance: defaultSurveyAppearance,
+}
diff --git a/frontend/src/scenes/surveys/surveyLogic.tsx b/frontend/src/scenes/surveys/surveyLogic.tsx
index 2d03f7a707a8f..65e93c8c31956 100644
--- a/frontend/src/scenes/surveys/surveyLogic.tsx
+++ b/frontend/src/scenes/surveys/surveyLogic.tsx
@@ -8,7 +8,6 @@ import { urls } from 'scenes/urls'
import {
Breadcrumb,
ChartDisplayType,
- FeatureFlagFilters,
PluginType,
PropertyFilterType,
PropertyOperator,
@@ -26,133 +25,13 @@ import { eventUsageLogic } from 'lib/utils/eventUsageLogic'
import { featureFlagLogic } from 'scenes/feature-flags/featureFlagLogic'
import { featureFlagLogic as enabledFlagLogic } from 'lib/logic/featureFlagLogic'
import { FEATURE_FLAGS } from 'lib/constants'
-
-export interface NewSurvey
- extends Pick<
- Survey,
- | 'name'
- | 'description'
- | 'type'
- | 'conditions'
- | 'questions'
- | 'start_date'
- | 'end_date'
- | 'linked_flag'
- | 'targeting_flag'
- | 'archived'
- | 'appearance'
- > {
- id: 'new'
- linked_flag_id: number | undefined
- targeting_flag_filters: Pick | undefined
-}
-
-export const defaultSurveyAppearance = {
- backgroundColor: 'white',
- textColor: 'black',
- submitButtonText: 'Submit',
- submitButtonColor: '#2c2c2c',
- ratingButtonColor: '#e0e2e8',
- descriptionTextColor: '#4b4b52',
- whiteLabel: false,
- displayThankYouMessage: true,
- thankYouMessageHeader: 'Thank you for your feedback!',
-}
-
-export const defaultSurveyFieldValues = {
- [SurveyQuestionType.Open]: {
- questions: [
- {
- question: 'Give us feedback on our product!',
- description: '',
- },
- ],
- appearance: {
- submitButtonText: 'Submit',
- thankYouMessageHeader: 'Thank you for your feedback!',
- },
- },
- [SurveyQuestionType.Link]: {
- questions: [
- {
- question: 'Do you want to join our upcoming webinar?',
- description: '',
- },
- ],
- appearance: {
- submitButtonText: 'Register',
- thankYouMessageHeader: 'Redirecting ...',
- },
- },
- [SurveyQuestionType.Rating]: {
- questions: [
- {
- question: 'How likely are you to recommend us to a friend?',
- description: '',
- display: 'number',
- scale: 10,
- lowerBoundLabel: 'Unlikely',
- upperBoundLabel: 'Very likely',
- },
- ],
- appearance: {
- thankYouMessageHeader: 'Thank you for your feedback!',
- },
- },
- [SurveyQuestionType.SingleChoice]: {
- questions: [
- {
- question: 'Have you found this tutorial useful?',
- description: '',
- choices: ['Yes', 'No'],
- },
- ],
- appearance: {
- submitButtonText: 'Submit',
- thankYouMessageHeader: 'Thank you for your feedback!',
- },
- },
- [SurveyQuestionType.MultipleChoice]: {
- questions: [
- {
- question: 'Which types of content would you like to see more of?',
- description: '',
- choices: ['Tutorials', 'Customer case studies', 'Product announcements'],
- },
- ],
- appearance: {
- submitButtonText: 'Submit',
- thankYouMessageHeader: 'Thank you for your feedback!',
- },
- },
-}
-
-export const NEW_SURVEY: NewSurvey = {
- id: 'new',
- name: '',
- description: '',
- questions: [
- {
- type: SurveyQuestionType.Open,
- question: defaultSurveyFieldValues[SurveyQuestionType.Open].questions[0].question,
- description: defaultSurveyFieldValues[SurveyQuestionType.Open].questions[0].description,
- },
- ],
- type: SurveyType.Popover,
- linked_flag_id: undefined,
- targeting_flag_filters: undefined,
- linked_flag: null,
- targeting_flag: null,
- start_date: null,
- end_date: null,
- conditions: null,
- archived: false,
- appearance: defaultSurveyAppearance,
-}
-
-export const surveyEventName = 'survey sent'
-
-const SURVEY_RESPONSE_PROPERTY = '$survey_response'
+import {
+ defaultSurveyFieldValues,
+ SURVEY_EVENT_NAME,
+ SURVEY_RESPONSE_PROPERTY,
+ NEW_SURVEY,
+ NewSurvey,
+} from './constants'
export interface SurveyLogicProps {
id: string | 'new'
@@ -360,7 +239,7 @@ export const surveyLogic = kea([
kind: NodeKind.EventsQuery,
select: ['*', `properties.${SURVEY_RESPONSE_PROPERTY}`, 'timestamp', 'person'],
orderBy: ['timestamp DESC'],
- where: [`event == 'survey sent' or event == '${survey.name} survey sent'`],
+ where: [`event == 'survey sent'`],
after: createdAt,
properties: [
{
@@ -441,7 +320,7 @@ export const surveyLogic = kea([
value: survey.id,
},
],
- series: [{ event: surveyEventName, kind: NodeKind.EventsNode }],
+ series: [{ event: SURVEY_EVENT_NAME, kind: NodeKind.EventsNode }],
trendsFilter: { display: ChartDisplayType.ActionsBarValue },
breakdown: { breakdown: '$survey_response', breakdown_type: 'event' },
},
diff --git a/frontend/src/scenes/surveys/surveysLogic.tsx b/frontend/src/scenes/surveys/surveysLogic.tsx
index e58379813e1f8..fa25749343280 100644
--- a/frontend/src/scenes/surveys/surveysLogic.tsx
+++ b/frontend/src/scenes/surveys/surveysLogic.tsx
@@ -1,6 +1,7 @@
-import { afterMount, connect, kea, listeners, path, selectors } from 'kea'
+import { afterMount, connect, kea, listeners, path, selectors, actions, reducers } from 'kea'
import { loaders } from 'kea-loaders'
import api from 'lib/api'
+import Fuse from 'fuse.js'
import { AvailableFeature, Breadcrumb, ProgressStatus, Survey } from '~/types'
import { urls } from 'scenes/urls'
@@ -9,6 +10,7 @@ import { lemonToast } from '@posthog/lemon-ui'
import { userLogic } from 'scenes/userLogic'
import { router } from 'kea-router'
import { pluginsLogic } from 'scenes/plugins/pluginsLogic'
+import { LemonSelectOption } from 'lib/lemon-ui/LemonSelect'
export function getSurveyStatus(survey: Survey): ProgressStatus {
if (!survey.start_date) {
@@ -19,17 +21,25 @@ export function getSurveyStatus(survey: Survey): ProgressStatus {
return ProgressStatus.Complete
}
+export interface SurveysFilters {
+ status: string
+ created_by: string
+ archived: boolean
+}
+
+interface SurveysCreators {
+ [id: string]: string
+}
+
export const surveysLogic = kea([
path(['scenes', 'surveys', 'surveysLogic']),
connect(() => ({
- values: [
- pluginsLogic,
- ['installedPlugins', 'loading as pluginsLoading', 'enabledPlugins'],
- // ['enabledPlugins'],
- userLogic,
- ['user'],
- ],
+ values: [pluginsLogic, ['loading as pluginsLoading', 'enabledPlugins'], userLogic, ['user']],
})),
+ actions({
+ setSearchTerm: (searchTerm: string) => ({ searchTerm }),
+ setSurveysFilters: (filters: Partial, replace?: boolean) => ({ filters, replace }),
+ }),
loaders(({ values }) => ({
surveys: {
__default: [] as Survey[],
@@ -54,7 +64,24 @@ export const surveysLogic = kea([
},
},
})),
- listeners(() => ({
+ reducers({
+ searchTerm: {
+ setSearchTerm: (_, { searchTerm }) => searchTerm,
+ },
+ filters: [
+ {
+ archived: false,
+ status: 'any',
+ created_by: 'any',
+ } as Partial,
+ {
+ setSurveysFilters: (state, { filters }) => {
+ return { ...state, ...filters }
+ },
+ },
+ ],
+ }),
+ listeners(({ actions }) => ({
deleteSurveySuccess: () => {
lemonToast.success('Survey deleted')
router.actions.push(urls.surveys())
@@ -62,8 +89,49 @@ export const surveysLogic = kea([
updateSurveySuccess: () => {
lemonToast.success('Survey updated')
},
+ setSurveysFilters: () => {
+ actions.loadSurveys()
+ actions.loadResponsesCount()
+ },
})),
selectors({
+ searchedSurveys: [
+ (selectors) => [selectors.surveys, selectors.searchTerm, selectors.filters],
+ (surveys, searchTerm, filters) => {
+ let searchedSurveys = surveys
+
+ if (!searchTerm && Object.keys(filters).length === 0) {
+ return searchedSurveys
+ }
+
+ if (searchTerm) {
+ searchedSurveys = new Fuse(searchedSurveys, {
+ keys: ['key', 'name'],
+ threshold: 0.3,
+ })
+ .search(searchTerm)
+ .map((result) => result.item)
+ }
+
+ const { status, created_by, archived } = filters
+ if (status !== 'any') {
+ searchedSurveys = searchedSurveys.filter((survey) => getSurveyStatus(survey) === status)
+ }
+ if (created_by !== 'any') {
+ searchedSurveys = searchedSurveys.filter(
+ (survey) => survey.created_by?.id === (created_by ? parseInt(created_by) : '')
+ )
+ }
+
+ if (archived) {
+ searchedSurveys = searchedSurveys.filter((survey) => survey.archived)
+ } else {
+ searchedSurveys = searchedSurveys.filter((survey) => !survey.archived)
+ }
+
+ return searchedSurveys
+ },
+ ],
breadcrumbs: [
() => [],
(): Breadcrumb[] => [
@@ -73,13 +141,23 @@ export const surveysLogic = kea([
},
],
],
- nonArchivedSurveys: [
- (s) => [s.surveys],
- (surveys: Survey[]): Survey[] => surveys.filter((survey) => !survey.archived),
- ],
- archivedSurveys: [
- (s) => [s.surveys],
- (surveys: Survey[]): Survey[] => surveys.filter((survey) => survey.archived),
+ uniqueCreators: [
+ (selectors) => [selectors.surveys],
+ (surveys) => {
+ const creators: SurveysCreators = {}
+ for (const survey of surveys) {
+ if (survey.created_by) {
+ if (!creators[survey.created_by.id]) {
+ creators[survey.created_by.id] = survey.created_by.first_name
+ }
+ }
+ }
+ const response: LemonSelectOption[] = [
+ { label: 'Any user', value: 'any' },
+ ...Object.entries(creators).map(([id, first_name]) => ({ label: first_name, value: id })),
+ ]
+ return response
+ },
],
whitelabelAvailable: [
(s) => [s.user],
@@ -92,8 +170,8 @@ export const surveysLogic = kea([
},
],
}),
- afterMount(async ({ actions }) => {
- await actions.loadSurveys()
- await actions.loadResponsesCount()
+ afterMount(({ actions }) => {
+ actions.loadSurveys()
+ actions.loadResponsesCount()
}),
])
diff --git a/frontend/src/scenes/teamLogic.tsx b/frontend/src/scenes/teamLogic.tsx
index e59d8b12fa224..a30bab40298f0 100644
--- a/frontend/src/scenes/teamLogic.tsx
+++ b/frontend/src/scenes/teamLogic.tsx
@@ -86,7 +86,10 @@ export const teamLogic = kea([
payload.slack_incoming_webhook
)}`
: 'Webhook integration disabled'
- } else if (updatedAttribute === 'completed_snippet_onboarding') {
+ } else if (
+ updatedAttribute === 'completed_snippet_onboarding' ||
+ updatedAttribute === 'has_completed_onboarding_for'
+ ) {
message = "Congrats! You're now ready to use PostHog."
} else {
message = `${parseUpdatedAttributeName(updatedAttribute)} updated successfully!`
diff --git a/frontend/src/scenes/web-analytics/WebAnalyticsScene.tsx b/frontend/src/scenes/web-analytics/WebAnalyticsScene.tsx
index 61a5b16ae49f7..4fb41fe7261cc 100644
--- a/frontend/src/scenes/web-analytics/WebAnalyticsScene.tsx
+++ b/frontend/src/scenes/web-analytics/WebAnalyticsScene.tsx
@@ -6,38 +6,38 @@ import { NodeKind } from '~/queries/schema'
export function WebAnalyticsScene(): JSX.Element {
return (
- Top pages
+ Top sources
- Top sources
+ Top clicks
- Top clicks
+ Top pages
= now() - INTERVAL 7 DAY
-AND events.properties.$event_type = 'click'
-AND el_text IS NOT NULL
-GROUP BY
- el_text
-ORDER BY total_clicks DESC
- `
-
-const TOP_PAGES_SQL = `
-WITH
-
-scroll_depth_cte AS (
-SELECT
- events.properties.\`$prev_pageview_pathname\` AS pathname,
- countIf(events.event == '$pageview') as total_pageviews,
- COUNT(DISTINCT events.properties.distinct_id) as unique_visitors, -- might want to use person id? have seen a small number of pages where unique > total
- avg(CASE
- WHEN events.properties.\`$prev_pageview_max_content_percentage\` IS NULL THEN NULL
- WHEN events.properties.\`$prev_pageview_max_content_percentage\` > 0.8 THEN 100
- ELSE 0
- END) AS scroll_gt80_percentage,
- avg(events.properties.$prev_pageview_max_scroll_percentage) * 100 as average_scroll_percentage
-FROM
- events
-WHERE
- (event = '$pageview' OR event = '$pageleave') AND events.properties.\`$prev_pageview_pathname\` IS NOT NULL
- AND events.timestamp >= now() - INTERVAL 7 DAY
-GROUP BY pathname
-)
-
-,
-
-session_cte AS (
-SELECT
- events.properties.\`$session_id\` AS session_id,
- min(events.timestamp) AS min_timestamp,
- max(events.timestamp) AS max_timestamp,
- dateDiff('second', min_timestamp, max_timestamp) AS duration_s,
-
- -- create a tuple so that these are grouped in the same order, see https://github.com/ClickHouse/ClickHouse/discussions/42338
- groupArray((events.timestamp, events.properties.\`$referrer\`, events.properties.\`$pathname\`, events.properties.utm_source)) AS tuple_array,
- arrayFirstIndex(x -> tupleElement(x, 1) == min_timestamp, tuple_array) as index_of_earliest,
- arrayFirstIndex(x -> tupleElement(x, 1) == max_timestamp, tuple_array) as index_of_latest,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 2) AS earliest_referrer,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 3) AS earliest_pathname,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 4) AS earliest_utm_source,
-
- if(domain(earliest_referrer) = '', earliest_referrer, domain(earliest_referrer)) AS referrer_domain,
- multiIf(
- earliest_utm_source IS NOT NULL, earliest_utm_source,
- -- This will need to be an approach that scales better
- referrer_domain == 'app.posthog.com', 'posthog',
- referrer_domain == 'eu.posthog.com', 'posthog',
- referrer_domain == 'posthog.com', 'posthog',
- referrer_domain == 'www.google.com', 'google',
- referrer_domain == 'www.google.co.uk', 'google',
- referrer_domain == 'www.google.com.hk', 'google',
- referrer_domain == 'www.google.de', 'google',
- referrer_domain == 't.co', 'twitter',
- referrer_domain == 'github.com', 'github',
- referrer_domain == 'duckduckgo.com', 'duckduckgo',
- referrer_domain == 'www.bing.com', 'bing',
- referrer_domain == 'bing.com', 'bing',
- referrer_domain == 'yandex.ru', 'yandex',
- referrer_domain == 'quora.com', 'quora',
- referrer_domain == 'www.quora.com', 'quora',
- referrer_domain == 'linkedin.com', 'linkedin',
- referrer_domain == 'www.linkedin.com', 'linkedin',
- startsWith(referrer_domain, 'http://localhost:'), 'localhost',
- referrer_domain
- ) AS blended_source,
-
- countIf(events.event == '$pageview') AS num_pageviews,
- countIf(events.event == '$autocapture') AS num_autocaptures,
- -- in v1 we'd also want to count whether there were any conversion events
-
- any(events.person_id) as person_id,
- -- definition of a GA4 bounce from here https://support.google.com/analytics/answer/12195621?hl=en
- (num_autocaptures == 0 AND num_pageviews <= 1 AND duration_s < 10) AS is_bounce
-FROM
- events
-WHERE
- session_id IS NOT NULL
-AND
- events.timestamp >= now() - INTERVAL 8 DAY
-GROUP BY
- events.properties.\`$session_id\`
-HAVING
- min_timestamp >= now() - INTERVAL 7 DAY
-)
-
-,
-
-bounce_rate_cte AS (
-SELECT session_cte.earliest_pathname,
- avg(session_cte.is_bounce) as bounce_rate
-FROM session_cte
-GROUP BY earliest_pathname
-)
-
-
-
-SELECT scroll_depth_cte.pathname as pathname,
-scroll_depth_cte.total_pageviews as total_pageviews,
-scroll_depth_cte.unique_visitors as unique_visitors,
-scroll_depth_cte.scroll_gt80_percentage as scroll_gt80_percentage,
-scroll_depth_cte.average_scroll_percentage as average_scroll_percentage,
-bounce_rate_cte.bounce_rate as bounce_rate
-FROM
- scroll_depth_cte LEFT OUTER JOIN bounce_rate_cte
-ON scroll_depth_cte.pathname = bounce_rate_cte.earliest_pathname
-ORDER BY total_pageviews DESC
-`
-
-const TOP_SOURCES = `
-WITH
-
-session_cte AS (
-SELECT
- events.properties.\`$session_id\` AS session_id,
- min(events.timestamp) AS min_timestamp,
- max(events.timestamp) AS max_timestamp,
- dateDiff('second', min_timestamp, max_timestamp) AS duration_s,
-
- -- create a tuple so that these are grouped in the same order, see https://github.com/ClickHouse/ClickHouse/discussions/42338
- groupArray((events.timestamp, events.properties.\`$referrer\`, events.properties.\`$pathname\`, events.properties.utm_source)) AS tuple_array,
- arrayFirstIndex(x -> tupleElement(x, 1) == min_timestamp, tuple_array) as index_of_earliest,
- arrayFirstIndex(x -> tupleElement(x, 1) == max_timestamp, tuple_array) as index_of_latest,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 2) AS earliest_referrer,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 3) AS earliest_pathname,
- tupleElement(arrayElement(
- tuple_array,
- index_of_earliest
- ), 4) AS earliest_utm_source,
-
- if(domain(earliest_referrer) = '', earliest_referrer, domain(earliest_referrer)) AS referrer_domain,
- multiIf(
- earliest_utm_source IS NOT NULL, earliest_utm_source,
- -- This will need to be an approach that scales better
- referrer_domain == 'app.posthog.com', 'posthog',
- referrer_domain == 'eu.posthog.com', 'posthog',
- referrer_domain == 'posthog.com', 'posthog',
- referrer_domain == 'www.google.com', 'google',
- referrer_domain == 'www.google.co.uk', 'google',
- referrer_domain == 'www.google.com.hk', 'google',
- referrer_domain == 'www.google.de', 'google',
- referrer_domain == 't.co', 'twitter',
- referrer_domain == 'github.com', 'github',
- referrer_domain == 'duckduckgo.com', 'duckduckgo',
- referrer_domain == 'www.bing.com', 'bing',
- referrer_domain == 'bing.com', 'bing',
- referrer_domain == 'yandex.ru', 'yandex',
- referrer_domain == 'quora.com', 'quora',
- referrer_domain == 'www.quora.com', 'quora',
- referrer_domain == 'linkedin.com', 'linkedin',
- referrer_domain == 'www.linkedin.com', 'linkedin',
- startsWith(referrer_domain, 'http://localhost:'), 'localhost',
- referrer_domain
- ) AS blended_source,
-
- countIf(events.event == '$pageview') AS num_pageviews,
- countIf(events.event == '$autocapture') AS num_autocaptures,
- -- in v1 we'd also want to count whether there were any conversion events
-
- any(events.person_id) as person_id,
- -- definition of a GA4 bounce from here https://support.google.com/analytics/answer/12195621?hl=en
- (num_autocaptures == 0 AND num_pageviews <= 1 AND duration_s < 10) AS is_bounce
-FROM
- events
-WHERE
- session_id IS NOT NULL
-AND
- events.timestamp >= now() - INTERVAL 8 DAY
-GROUP BY
- events.properties.\`$session_id\`
-HAVING
- min_timestamp >= now() - INTERVAL 7 DAY
-)
-
-
-
-SELECT
- blended_source,
- count(num_pageviews) as total_pageviews,
- count(DISTINCT person_id) as unique_visitors,
- avg(is_bounce) AS bounce_rate
-FROM
- session_cte
-WHERE
- blended_source IS NOT NULL
-GROUP BY blended_source
-
-ORDER BY total_pageviews DESC
-LIMIT 100
-
-
-`
diff --git a/frontend/src/toolbar/actions/EditAction.tsx b/frontend/src/toolbar/actions/EditAction.tsx
index 6898002fc6349..409f439565704 100644
--- a/frontend/src/toolbar/actions/EditAction.tsx
+++ b/frontend/src/toolbar/actions/EditAction.tsx
@@ -64,7 +64,11 @@ export function EditAction(): JSX.Element {
What did your user do?
-
+
diff --git a/frontend/src/toolbar/actions/StepField.tsx b/frontend/src/toolbar/actions/StepField.tsx
index 5bb2d5802900c..b066bae78d239 100644
--- a/frontend/src/toolbar/actions/StepField.tsx
+++ b/frontend/src/toolbar/actions/StepField.tsx
@@ -63,12 +63,14 @@ export function StepField({ step, item, label, caption }: StepFieldProps): JSX.E
className={clsx(!selected && 'opacity-50')}
onChange={onChange}
value={value ?? ''}
+ stopPropagation={true}
/>
) : (
)
}}
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index 8a038927dc90e..14d971b8624e8 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -2111,15 +2111,16 @@ export enum SurveyType {
export interface SurveyAppearance {
backgroundColor?: string
submitButtonColor?: string
- textColor?: string
submitButtonText?: string
- descriptionTextColor?: string
ratingButtonColor?: string
- ratingButtonHoverColor?: string
+ ratingButtonActiveColor?: string
+ borderColor?: string
+ placeholder?: string
whiteLabel?: boolean
displayThankYouMessage?: boolean
thankYouMessageHeader?: string
thankYouMessageDescription?: string
+ position?: string
}
export interface SurveyQuestionBase {
diff --git a/package.json b/package.json
index c1a3772e89de9..d766ed4da53bb 100644
--- a/package.json
+++ b/package.json
@@ -141,12 +141,11 @@
"react-draggable": "^4.2.0",
"react-grid-layout": "^1.3.0",
"react-intersection-observer": "^9.4.3",
- "react-json-view": "^1.21.3",
+ "@microlink/react-json-view": "^1.21.3",
"react-markdown": "^5.0.3",
"react-modal": "^3.15.1",
"react-resizable": "^3.0.5",
"react-shadow": "^18.4.2",
- "react-sortable-hoc": "^1.11.0",
"react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.3.3",
"react-textfit": "^1.1.1",
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 74ec71a92715e..d0c40db53220a 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -1,4 +1,4 @@
-lockfileVersion: '6.1'
+lockfileVersion: '6.0'
settings:
autoInstallPeers: true
@@ -29,6 +29,9 @@ dependencies:
'@medv/finder':
specifier: ^2.1.0
version: 2.1.0
+ '@microlink/react-json-view':
+ specifier: ^1.21.3
+ version: 1.22.2(@types/react@16.14.34)(react-dom@16.14.0)(react@16.14.0)
'@monaco-editor/react':
specifier: 4.4.6
version: 4.4.6(monaco-editor@0.39.0)(react-dom@16.14.0)(react@16.14.0)
@@ -242,9 +245,6 @@ dependencies:
react-intersection-observer:
specifier: ^9.4.3
version: 9.4.3(react@16.14.0)
- react-json-view:
- specifier: ^1.21.3
- version: 1.21.3(@types/react@16.14.34)(react-dom@16.14.0)(react@16.14.0)
react-markdown:
specifier: ^5.0.3
version: 5.0.3(@types/react@16.14.34)(react@16.14.0)
@@ -257,9 +257,6 @@ dependencies:
react-shadow:
specifier: ^18.4.2
version: 18.6.2(prop-types@15.8.1)(react-dom@16.14.0)(react@16.14.0)
- react-sortable-hoc:
- specifier: ^1.11.0
- version: 1.11.0(prop-types@15.8.1)(react-dom@16.14.0)(react@16.14.0)
react-syntax-highlighter:
specifier: ^15.5.0
version: 15.5.0(react@16.14.0)
@@ -3167,6 +3164,23 @@ packages:
resolution: {integrity: sha512-Egrg5XO4kLol24b1Kv50HDfi5hW0yQ6aWSsO0Hea1eJ4rogKElIN0M86FdVnGF4XIGYyA7QWx0MgbOzVPA0qkA==}
dev: false
+ /@microlink/react-json-view@1.22.2(@types/react@16.14.34)(react-dom@16.14.0)(react@16.14.0):
+ resolution: {integrity: sha512-liJzdlbspT5GbEuPffw4jzZfXOypKLK1Er9br03T31bAaIi/WptZqpcJaXPi7OmwC7v/YYczCkmAS7WaEfItPQ==}
+ peerDependencies:
+ react: '>= 15'
+ react-dom: '>= 15'
+ dependencies:
+ flux: 4.0.3(react@16.14.0)
+ react: 16.14.0
+ react-base16-styling: 0.6.0
+ react-dom: 16.14.0(react@16.14.0)
+ react-lifecycles-compat: 3.0.4
+ react-textarea-autosize: 8.3.4(@types/react@16.14.34)(react@16.14.0)
+ transitivePeerDependencies:
+ - '@types/react'
+ - encoding
+ dev: false
+
/@monaco-editor/loader@1.3.3(monaco-editor@0.39.0):
resolution: {integrity: sha512-6KKF4CTzcJiS8BJwtxtfyYt9shBiEv32ateQ9T4UVogwn4HM/uPo9iJd2Dmbkpz8CM6Y0PDUpjnZzCwC+eYo2Q==}
peerDependencies:
@@ -11505,6 +11519,7 @@ packages:
resolution: {integrity: sha512-phJfQVBuaJM5raOpJjSfkiD6BpbCE4Ns//LaXl6wGYtUBY83nWS6Rf9tXm2e8VaK60JEjYldbPif/A2B1C2gNA==}
dependencies:
loose-envify: 1.4.0
+ dev: true
/ip6addr@0.2.5:
resolution: {integrity: sha512-9RGGSB6Zc9Ox5DpDGFnJdIeF0AsqXzdH+FspCfPPaU/L/4tI6P+5lIoFUFm9JXs9IrJv1boqAaNCQmoDADTSKQ==}
@@ -13097,7 +13112,7 @@ packages:
dependencies:
universalify: 2.0.0
optionalDependencies:
- graceful-fs: 4.2.10
+ graceful-fs: 4.2.11
/jsprim@2.0.2:
resolution: {integrity: sha512-gqXddjPqQ6G40VdnI6T6yObEC+pDNvyP95wdQhkWkg7crHH3km5qP1FsOXEkzEQwnz6gz5qGTn1c2Y52wP3OyQ==}
@@ -16082,23 +16097,6 @@ packages:
resolution: {integrity: sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==}
dev: true
- /react-json-view@1.21.3(@types/react@16.14.34)(react-dom@16.14.0)(react@16.14.0):
- resolution: {integrity: sha512-13p8IREj9/x/Ye4WI/JpjhoIwuzEgUAtgJZNBJckfzJt1qyh24BdTm6UQNGnyTq9dapQdrqvquZTo3dz1X6Cjw==}
- peerDependencies:
- react: ^17.0.0 || ^16.3.0 || ^15.5.4
- react-dom: ^17.0.0 || ^16.3.0 || ^15.5.4
- dependencies:
- flux: 4.0.3(react@16.14.0)
- react: 16.14.0
- react-base16-styling: 0.6.0
- react-dom: 16.14.0(react@16.14.0)
- react-lifecycles-compat: 3.0.4
- react-textarea-autosize: 8.3.4(@types/react@16.14.34)(react@16.14.0)
- transitivePeerDependencies:
- - '@types/react'
- - encoding
- dev: false
-
/react-lifecycles-compat@3.0.4:
resolution: {integrity: sha512-fBASbA6LnOU9dOU2eW7aQ8xmYBSXUIWr+UmF9b1efZBazGNO+rcXT/icdKnYm2pTwcRylVUYwW7H1PHfLekVzA==}
dev: false
@@ -16217,20 +16215,6 @@ packages:
react-use: 15.3.8(react-dom@16.14.0)(react@16.14.0)
dev: false
- /react-sortable-hoc@1.11.0(prop-types@15.8.1)(react-dom@16.14.0)(react@16.14.0):
- resolution: {integrity: sha512-v1CDCvdfoR3zLGNp6qsBa4J1BWMEVH25+UKxF/RvQRh+mrB+emqtVHMgZ+WreUiKJoEaiwYoScaueIKhMVBHUg==}
- peerDependencies:
- prop-types: ^15.5.7
- react: ^0.14.0 || ^15.0.0 || ^16.0.0
- react-dom: ^0.14.0 || ^15.0.0 || ^16.0.0
- dependencies:
- '@babel/runtime': 7.22.10
- invariant: 2.2.4
- prop-types: 15.8.1
- react: 16.14.0
- react-dom: 16.14.0(react@16.14.0)
- dev: false
-
/react-style-singleton@2.2.1(@types/react@16.14.34)(react@16.14.0):
resolution: {integrity: sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==}
engines: {node: '>=10'}
diff --git a/posthog/api/decide.py b/posthog/api/decide.py
index 66364ba617ae9..5bd187b819cb7 100644
--- a/posthog/api/decide.py
+++ b/posthog/api/decide.py
@@ -2,6 +2,7 @@
import re
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
+from posthog.api.survey import SURVEY_TARGETING_FLAG_PREFIX
from posthog.database_healthcheck import DATABASE_FOR_FLAG_MATCHING
from posthog.metrics import LABEL_TEAM_ID
from posthog.models.feature_flag.flag_analytics import increment_request_count
@@ -244,11 +245,13 @@ def get_decide(request: HttpRequest):
if feature_flags:
# Billing analytics for decide requests with feature flags
+ # Don't count if all requests are for survey targeting flags only.
+ if not all(flag.startswith(SURVEY_TARGETING_FLAG_PREFIX) for flag in feature_flags.keys()):
- # Sample no. of decide requests with feature flags
- if settings.DECIDE_BILLING_SAMPLING_RATE and random() < settings.DECIDE_BILLING_SAMPLING_RATE:
- count = int(1 / settings.DECIDE_BILLING_SAMPLING_RATE)
- increment_request_count(team.pk, count)
+ # Sample no. of decide requests with feature flags
+ if settings.DECIDE_BILLING_SAMPLING_RATE and random() < settings.DECIDE_BILLING_SAMPLING_RATE:
+ count = int(1 / settings.DECIDE_BILLING_SAMPLING_RATE)
+ increment_request_count(team.pk, count)
else:
# no auth provided
diff --git a/posthog/api/query.py b/posthog/api/query.py
index 1c18b4e74c163..0917e5bf1446b 100644
--- a/posthog/api/query.py
+++ b/posthog/api/query.py
@@ -206,7 +206,15 @@ def process_query(
tag_queries(query=query_json)
- if query_kind == "LifecycleQuery" or query_kind == "TrendsQuery":
+ refreshable_queries = [
+ "LifecycleQuery",
+ "TrendsQuery",
+ "WebTopSourcesQuery",
+ "WebTopClicksQuery",
+ "WebTopPagesQuery",
+ ]
+
+ if query_kind in refreshable_queries:
refresh_requested = refresh_requested_by_client(request) if request else False
query_runner = get_query_runner(query_json, team)
return _unwrap_pydantic_dict(query_runner.run(refresh_requested=refresh_requested))
diff --git a/posthog/api/survey.py b/posthog/api/survey.py
index 081cee0f25a10..485a0bef58c5e 100644
--- a/posthog/api/survey.py
+++ b/posthog/api/survey.py
@@ -26,6 +26,8 @@
from posthog.utils_cors import cors_response
+SURVEY_TARGETING_FLAG_PREFIX = "survey-targeting-"
+
class SurveySerializer(serializers.ModelSerializer):
linked_flag_id = serializers.IntegerField(required=False, allow_null=True, source="linked_flag.id")
@@ -174,7 +176,7 @@ def update(self, instance: Survey, validated_data):
return super().update(instance, validated_data)
def _create_new_targeting_flag(self, name, filters):
- feature_flag_key = slugify(f"survey-targeting-{name}")
+ feature_flag_key = slugify(f"{SURVEY_TARGETING_FLAG_PREFIX}{name}")
feature_flag_serializer = FeatureFlagSerializer(
data={
"key": feature_flag_key,
diff --git a/posthog/api/test/test_capture.py b/posthog/api/test/test_capture.py
index 14b5238ac7cb9..8fff90642f112 100644
--- a/posthog/api/test/test_capture.py
+++ b/posthog/api/test/test_capture.py
@@ -1143,10 +1143,7 @@ def test_handle_invalid_snapshot(self):
for headers in [
(
"sentry",
- [
- "traceparent",
- "request-id",
- ],
+ ["traceparent", "request-id", "Sentry-Trace", "Baggage"],
),
(
"aws",
@@ -1171,8 +1168,8 @@ def test_cors_allows_tracing_headers(self, _: str, path: str, headers: List[str]
HTTP_ACCESS_CONTROL_REQUEST_HEADERS=presented_headers,
HTTP_ACCESS_CONTROL_REQUEST_METHOD="POST",
)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.headers["Access-Control-Allow-Headers"], expected_headers)
+ assert response.status_code == 200
+ assert response.headers["Access-Control-Allow-Headers"] == expected_headers
@patch("posthog.kafka_client.client._KafkaProducer.produce")
def test_legacy_recording_ingestion_data_sent_to_kafka(self, kafka_produce) -> None:
diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py
index 548c32b09103e..4eadf7238f244 100644
--- a/posthog/api/test/test_decide.py
+++ b/posthog/api/test/test_decide.py
@@ -2089,6 +2089,99 @@ def test_decide_analytics_samples_dont_break_with_zero_sampling(self, *args):
# check that no increments made it to redis
self.assertEqual(client.hgetall(f"posthog:decide_requests:{self.team.pk}"), {})
+ @patch("posthog.models.feature_flag.flag_analytics.CACHE_BUCKET_SIZE", 10)
+ def test_decide_analytics_only_fires_with_non_survey_targeting_flags(self, *args):
+ ff = FeatureFlag.objects.create(
+ team=self.team, rollout_percentage=50, name="Beta feature", key="beta-feature", created_by=self.user
+ )
+ # use a non-csrf client to make requests
+ req_client = Client()
+ req_client.force_login(self.user)
+ response = req_client.post(
+ f"/api/projects/{self.team.id}/surveys/",
+ data={
+ "name": "Notebooks power users survey",
+ "type": "popover",
+ "questions": [{"type": "open", "question": "What would you want to improve from notebooks?"}],
+ "linked_flag_id": ff.id,
+ "targeting_flag_filters": {
+ "groups": [
+ {
+ "variant": None,
+ "rollout_percentage": None,
+ "properties": [
+ {"key": "billing_plan", "value": ["cloud"], "operator": "exact", "type": "person"}
+ ],
+ }
+ ]
+ },
+ "conditions": {"url": "https://app.posthog.com/notebooks"},
+ },
+ format="json",
+ content_type="application/json",
+ )
+
+ response_data = response.json()
+ assert response.status_code == status.HTTP_201_CREATED, response_data
+ req_client.logout()
+ self.client.logout()
+
+ with self.settings(DECIDE_BILLING_SAMPLING_RATE=1), freeze_time("2022-05-07 12:23:07"):
+ response = self._post_decide(api_version=3)
+ self.assertEqual(response.status_code, 200)
+
+ client = redis.get_client()
+ # check that single increment made it to redis
+ self.assertEqual(client.hgetall(f"posthog:decide_requests:{self.team.pk}"), {b"165192618": b"1"})
+
+ @patch("posthog.models.feature_flag.flag_analytics.CACHE_BUCKET_SIZE", 10)
+ def test_decide_analytics_does_not_fire_for_survey_targeting_flags(self, *args):
+ FeatureFlag.objects.create(
+ team=self.team,
+ rollout_percentage=50,
+ name="Beta feature",
+ key="survey-targeting-random",
+ created_by=self.user,
+ )
+ # use a non-csrf client to make requests
+ req_client = Client()
+ req_client.force_login(self.user)
+ response = req_client.post(
+ f"/api/projects/{self.team.id}/surveys/",
+ data={
+ "name": "Notebooks power users survey",
+ "type": "popover",
+ "questions": [{"type": "open", "question": "What would you want to improve from notebooks?"}],
+ "targeting_flag_filters": {
+ "groups": [
+ {
+ "variant": None,
+ "rollout_percentage": None,
+ "properties": [
+ {"key": "billing_plan", "value": ["cloud"], "operator": "exact", "type": "person"}
+ ],
+ }
+ ]
+ },
+ "conditions": {"url": "https://app.posthog.com/notebooks"},
+ },
+ format="json",
+ content_type="application/json",
+ )
+
+ response_data = response.json()
+ assert response.status_code == status.HTTP_201_CREATED, response_data
+ req_client.logout()
+ self.client.logout()
+
+ with self.settings(DECIDE_BILLING_SAMPLING_RATE=1), freeze_time("2022-05-07 12:23:07"):
+ response = self._post_decide(api_version=3)
+ self.assertEqual(response.status_code, 200)
+
+ client = redis.get_client()
+ # check that single increment made it to redis
+ self.assertEqual(client.hgetall(f"posthog:decide_requests:{self.team.pk}"), {})
+
class TestDatabaseCheckForDecide(BaseTest, QueryMatchingTest):
"""
diff --git a/posthog/api/test/test_element.py b/posthog/api/test/test_element.py
index 401c9af4f68ab..81b28c3198be7 100644
--- a/posthog/api/test/test_element.py
+++ b/posthog/api/test/test_element.py
@@ -7,7 +7,6 @@
from rest_framework import status
from posthog.models import Element, ElementGroup, Organization
-from posthog.settings import CORS_ALLOW_HEADERS
from posthog.test.base import (
APIBaseTest,
ClickhouseTestMixin,
@@ -274,21 +273,6 @@ def test_element_stats_obeys_limit_parameter(self) -> None:
limit_to_one_results = response_json["results"]
assert limit_to_one_results == [expected_all_data_response_results[1]]
- def test_element_stats_cors_headers(self) -> None:
- # Azure App Insights sends the same tracing headers as Sentry
- # _and_ a request-context header
- # this is added by the cors headers package so should apply to any endpoint
-
- response = self.client.generic(
- "OPTIONS",
- "/api/element/stats/",
- HTTP_ORIGIN="https://localhost",
- HTTP_ACCESS_CONTROL_REQUEST_HEADERS="traceparent,request-id,someotherrandomheader,request-context",
- HTTP_ACCESS_CONTROL_REQUEST_METHOD="POST",
- )
-
- assert response.headers["Access-Control-Allow-Headers"] == ", ".join(CORS_ALLOW_HEADERS)
-
def test_element_stats_does_not_allow_non_numeric_limit(self) -> None:
response = self.client.get(f"/api/element/stats/?limit=not-a-number")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py
index cfae7802f1e1f..bd68320c37bbd 100644
--- a/posthog/batch_exports/service.py
+++ b/posthog/batch_exports/service.py
@@ -258,6 +258,11 @@ def backfill_export(
async def backfill_schedule(temporal: Client, schedule_id: str, schedule_backfill: ScheduleBackfill):
"""Async call the Temporal client to execute a backfill on the given schedule."""
handle = temporal.get_schedule_handle(schedule_id)
+ description = await handle.describe()
+
+ if description.schedule.spec.jitter is not None:
+ schedule_backfill.end_at += description.schedule.spec.jitter
+
await handle.backfill(schedule_backfill)
diff --git a/posthog/celery.py b/posthog/celery.py
index 7f41920f1aebc..3189ccf11ae2b 100644
--- a/posthog/celery.py
+++ b/posthog/celery.py
@@ -5,6 +5,7 @@
from uuid import UUID
from celery import Celery
+from celery.canvas import Signature
from celery.schedules import crontab
from celery.signals import (
setup_logging,
@@ -101,15 +102,36 @@ def on_worker_start(**kwargs) -> None:
start_http_server(8001)
+def add_periodic_task_with_expiry(
+ sender: Celery, schedule_seconds: int, task_signature: Signature, name: str | None = None
+):
+ """
+ If the workers get delayed in processing tasks, then tasks that fire every X seconds get queued multiple times
+ And so, are processed multiple times. But they often only need to be processed once.
+ This schedules them with an expiry so that they aren't processed multiple times.
+ The expiry is larger than the schedule so that if the worker is only slightly delayed, it still gets processed.
+ """
+ sender.add_periodic_task(
+ schedule_seconds,
+ task_signature,
+ name=name,
+ # we don't want to run multiple of these if the workers build up a backlog
+ expires=schedule_seconds * 1.5,
+ )
+
+
@app.on_after_configure.connect
def setup_periodic_tasks(sender: Celery, **kwargs):
# Monitoring tasks
- sender.add_periodic_task(60.0, monitoring_check_clickhouse_schema_drift.s(), name="Monitor ClickHouse schema drift")
+ add_periodic_task_with_expiry(
+ sender, 60, monitoring_check_clickhouse_schema_drift.s(), "check clickhouse schema drift"
+ )
if not settings.DEBUG:
- sender.add_periodic_task(10.0, redis_celery_queue_depth.s(), name="10 sec queue probe", priority=0)
+ add_periodic_task_with_expiry(sender, 10, redis_celery_queue_depth.s(), "10 sec queue probe")
+
# Heartbeat every 10sec to make sure the worker is alive
- sender.add_periodic_task(10.0, redis_heartbeat.s(), name="10 sec heartbeat", priority=0)
+ add_periodic_task_with_expiry(sender, 10, redis_heartbeat.s(), "10 sec heartbeat")
# Update events table partitions twice a week
sender.add_periodic_task(
@@ -146,30 +168,78 @@ def setup_periodic_tasks(sender: Celery, **kwargs):
sync_insight_cache_states_schedule, sync_insight_cache_states_task.s(), name="sync insight cache states"
)
- sender.add_periodic_task(
+ add_periodic_task_with_expiry(
+ sender,
settings.UPDATE_CACHED_DASHBOARD_ITEMS_INTERVAL_SECONDS,
schedule_cache_updates_task.s(),
- name="check dashboard items",
+ "check dashboard items",
)
sender.add_periodic_task(crontab(minute="*/15"), check_async_migration_health.s())
if settings.INGESTION_LAG_METRIC_TEAM_IDS:
sender.add_periodic_task(60, ingestion_lag.s(), name="ingestion lag")
- sender.add_periodic_task(120, clickhouse_lag.s(), name="clickhouse table lag")
- sender.add_periodic_task(120, clickhouse_row_count.s(), name="clickhouse events table row count")
- sender.add_periodic_task(120, clickhouse_part_count.s(), name="clickhouse table parts count")
- sender.add_periodic_task(120, clickhouse_mutation_count.s(), name="clickhouse table mutations count")
- sender.add_periodic_task(120, clickhouse_errors_count.s(), name="clickhouse instance errors count")
-
- sender.add_periodic_task(120, pg_row_count.s(), name="PG tables row counts")
- sender.add_periodic_task(120, pg_table_cache_hit_rate.s(), name="PG table cache hit rate")
+
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ clickhouse_lag.s(),
+ name="clickhouse table lag",
+ )
+
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ clickhouse_row_count.s(),
+ name="clickhouse events table row count",
+ )
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ clickhouse_part_count.s(),
+ name="clickhouse table parts count",
+ )
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ clickhouse_mutation_count.s(),
+ name="clickhouse table mutations count",
+ )
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ clickhouse_errors_count.s(),
+ name="clickhouse instance errors count",
+ )
+
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ pg_row_count.s(),
+ name="PG tables row counts",
+ )
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ pg_table_cache_hit_rate.s(),
+ name="PG table cache hit rate",
+ )
sender.add_periodic_task(
crontab(minute="0", hour="*"), pg_plugin_server_query_timing.s(), name="PG plugin server query timing"
)
- sender.add_periodic_task(60, graphile_worker_queue_size.s(), name="Graphile Worker queue size")
+ add_periodic_task_with_expiry(
+ sender,
+ 60,
+ graphile_worker_queue_size.s(),
+ name="Graphile Worker queue size",
+ )
- sender.add_periodic_task(120, calculate_cohort.s(), name="recalculate cohorts")
+ add_periodic_task_with_expiry(
+ sender,
+ 120,
+ calculate_cohort.s(),
+ name="recalculate cohorts",
+ )
if clear_clickhouse_crontab := get_crontab(settings.CLEAR_CLICKHOUSE_REMOVED_DATA_SCHEDULE_CRON):
sender.add_periodic_task(
@@ -207,12 +277,6 @@ def setup_periodic_tasks(sender: Celery, **kwargs):
sender.add_periodic_task(crontab(hour="*", minute="55"), schedule_all_subscriptions.s())
sender.add_periodic_task(crontab(hour="2", minute=str(randrange(0, 40))), ee_persist_finished_recordings.s())
- sender.add_periodic_task(
- settings.COUNT_TILES_WITH_NO_FILTERS_HASH_INTERVAL_SECONDS,
- count_tiles_with_no_hash.s(),
- name="count tiles with no filters_hash",
- )
-
sender.add_periodic_task(
crontab(minute="0", hour="*"),
check_flags_to_rollback.s(),
@@ -277,15 +341,6 @@ def delete_expired_exported_assets() -> None:
ExportedAsset.delete_expired_assets()
-@app.task(ignore_result=True)
-def count_tiles_with_no_hash() -> None:
- from statshog.defaults.django import statsd
-
- from posthog.models.dashboard_tile import DashboardTile
-
- statsd.gauge("dashboard_tiles.with_no_filters_hash", DashboardTile.objects.filter(filters_hash=None).count())
-
-
@app.task(ignore_result=True)
def redis_heartbeat():
get_client().set("POSTHOG_HEARTBEAT", int(time.time()))
@@ -446,7 +501,7 @@ def ingestion_lag():
from posthog.client import sync_execute
# Requires https://github.com/PostHog/posthog-heartbeat-plugin to be enabled on team 2
- # Note that it runs every minute and we compare it with now(), so there's up to 60s delay
+ # Note that it runs every minute, and we compare it with now(), so there's up to 60s delay
query = """
SELECT event, date_diff('second', max(timestamp), now())
FROM events
@@ -680,21 +735,16 @@ def clear_clickhouse_deleted_person():
@app.task(ignore_result=True)
def redis_celery_queue_depth():
- from statshog.defaults.django import statsd
-
try:
- llen = get_client().llen("celery")
- with pushed_metrics_registry("celery_redis_queue_depth") as registry:
- depth_gauge = Gauge(
- "posthog_celery_queue_depth",
- "Number of tasks in the Celery Redis queue.",
- registry=registry,
+ with pushed_metrics_registry("redis_celery_queue_depth_registry") as registry:
+ celery_task_queue_depth_gauge = Gauge(
+ "posthog_celery_queue_depth", "We use this to monitor the depth of the celery queue.", registry=registry
)
- depth_gauge.set(llen)
- statsd.gauge(f"posthog_celery_queue_depth", llen)
+
+ llen = get_client().llen("celery")
+ celery_task_queue_depth_gauge.set(llen)
except:
- # if we can't connect to statsd don't complain about it.
- # not every installation will have statsd available
+ # if we can't generate the metric don't complain about it.
return
diff --git a/posthog/hogql_queries/insights/__init__.py b/posthog/hogql_queries/insights/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/posthog/hogql_queries/lifecycle_query_runner.py b/posthog/hogql_queries/insights/lifecycle_query_runner.py
similarity index 100%
rename from posthog/hogql_queries/lifecycle_query_runner.py
rename to posthog/hogql_queries/insights/lifecycle_query_runner.py
diff --git a/posthog/hogql_queries/insights/test/__init__.py b/posthog/hogql_queries/insights/test/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/posthog/hogql_queries/test/test_lifecycle_query_runner.py b/posthog/hogql_queries/insights/test/test_lifecycle_query_runner.py
similarity index 99%
rename from posthog/hogql_queries/test/test_lifecycle_query_runner.py
rename to posthog/hogql_queries/insights/test/test_lifecycle_query_runner.py
index 0f5788ecfb705..3dd3602b9f88c 100644
--- a/posthog/hogql_queries/test/test_lifecycle_query_runner.py
+++ b/posthog/hogql_queries/insights/test/test_lifecycle_query_runner.py
@@ -3,7 +3,7 @@
from freezegun import freeze_time
from posthog.hogql.query import execute_hogql_query
-from posthog.hogql_queries.lifecycle_query_runner import LifecycleQueryRunner
+from posthog.hogql_queries.insights.lifecycle_query_runner import LifecycleQueryRunner
from posthog.models.utils import UUIDT
from posthog.schema import DateRange, IntervalType, LifecycleQuery, EventsNode
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person, flush_persons_and_events
diff --git a/posthog/hogql_queries/trends_query_runner.py b/posthog/hogql_queries/insights/trends_query_runner.py
similarity index 100%
rename from posthog/hogql_queries/trends_query_runner.py
rename to posthog/hogql_queries/insights/trends_query_runner.py
diff --git a/posthog/hogql_queries/query_runner.py b/posthog/hogql_queries/query_runner.py
index 98a0139c9353b..53fd9456ac549 100644
--- a/posthog/hogql_queries/query_runner.py
+++ b/posthog/hogql_queries/query_runner.py
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import Any, Generic, List, Optional, Type, Dict, TypeVar
+from typing import Any, Generic, List, Optional, Type, Dict, TypeVar, Union, Tuple
-from prometheus_client import Counter
-from django.core.cache import cache
from django.conf import settings
+from django.core.cache import cache
+from prometheus_client import Counter
from pydantic import BaseModel, ConfigDict
from posthog.clickhouse.query_tagging import tag_queries
@@ -14,8 +14,18 @@
from posthog.hogql.timings import HogQLTimings
from posthog.metrics import LABEL_TEAM_ID
from posthog.models import Team
-from posthog.schema import QueryTiming
-from posthog.types import InsightQueryNode
+from posthog.schema import (
+ QueryTiming,
+ TrendsQuery,
+ FunnelsQuery,
+ RetentionQuery,
+ PathsQuery,
+ StickinessQuery,
+ LifecycleQuery,
+ WebTopSourcesQuery,
+ WebTopClicksQuery,
+ WebTopPagesQuery,
+)
from posthog.utils import generate_cache_key, get_safe_cache
QUERY_CACHE_WRITE_COUNTER = Counter(
@@ -43,7 +53,7 @@ def get_query_runner(
kind = query.kind
if kind == "LifecycleQuery":
- from .lifecycle_query_runner import LifecycleQueryRunner
+ from .insights.lifecycle_query_runner import LifecycleQueryRunner
return LifecycleQueryRunner(query=query, team=team, timings=timings)
if kind == "PersonsQuery":
@@ -51,9 +61,23 @@ def get_query_runner(
return PersonsQueryRunner(query=query, team=team, timings=timings)
if kind == "TrendsQuery":
- from .trends_query_runner import TrendsQueryRunner
+ from .insights.trends_query_runner import TrendsQueryRunner
return TrendsQueryRunner(query=query, team=team, timings=timings)
+
+ if kind == "WebTopSourcesQuery":
+ from .web_analytics.top_sources import WebTopSourcesQueryRunner
+
+ return WebTopSourcesQueryRunner(query=query, team=team, timings=timings)
+ if kind == "WebTopClicksQuery":
+ from .web_analytics.top_clicks import WebTopClicksQueryRunner
+
+ return WebTopClicksQueryRunner(query=query, team=team, timings=timings)
+ if kind == "WebTopPagesQuery":
+ from .web_analytics.top_pages import WebTopPagesQueryRunner
+
+ return WebTopPagesQueryRunner(query=query, team=team, timings=timings)
+
raise ValueError(f"Can't get a runner for an unknown query kind: {kind}")
@@ -63,6 +87,8 @@ class QueryResponse(BaseModel, Generic[DataT]):
)
result: DataT
timings: Optional[List[QueryTiming]] = None
+ types: Optional[List[Tuple[str, str]]] = None
+ columns: Optional[List[str]] = None
hogql: Optional[str] = None
@@ -75,13 +101,26 @@ class CachedQueryResponse(QueryResponse):
next_allowed_client_refresh: str
+RunnableQueryNode = Union[
+ TrendsQuery,
+ FunnelsQuery,
+ RetentionQuery,
+ PathsQuery,
+ StickinessQuery,
+ LifecycleQuery,
+ WebTopSourcesQuery,
+ WebTopClicksQuery,
+ WebTopPagesQuery,
+]
+
+
class QueryRunner(ABC):
- query: InsightQueryNode
- query_type: Type[InsightQueryNode]
+ query: RunnableQueryNode
+ query_type: Type[RunnableQueryNode]
team: Team
timings: HogQLTimings
- def __init__(self, query: InsightQueryNode | Dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None):
+ def __init__(self, query: RunnableQueryNode | Dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None):
self.team = team
self.timings = timings or HogQLTimings()
if isinstance(query, self.query_type):
@@ -124,7 +163,6 @@ def run(self, refresh_requested: bool) -> CachedQueryResponse:
def to_query(self) -> ast.SelectQuery:
raise NotImplementedError()
- @abstractmethod
def to_persons_query(self) -> ast.SelectQuery:
# TODO: add support for selecting and filtering by breakdowns
raise NotImplementedError()
@@ -141,7 +179,9 @@ def toJSON(self) -> str:
return self.query.model_dump_json(exclude_defaults=True, exclude_none=True)
def _cache_key(self) -> str:
- return generate_cache_key(f"query_{self.toJSON()}_{self.team.pk}_{self.team.timezone}")
+ return generate_cache_key(
+ f"query_{self.toJSON()}_{self.__class__.__name__}_{self.team.pk}_{self.team.timezone}"
+ )
@abstractmethod
def _is_stale(self, cached_result_package):
diff --git a/posthog/hogql_queries/test/test_query_runner.py b/posthog/hogql_queries/test/test_query_runner.py
index d9af90a1e4ff9..85258fefd380a 100644
--- a/posthog/hogql_queries/test/test_query_runner.py
+++ b/posthog/hogql_queries/test/test_query_runner.py
@@ -1,13 +1,14 @@
from datetime import datetime, timedelta
-from dateutil.parser import isoparse
-from zoneinfo import ZoneInfo
from typing import Any, List, Literal, Optional, Type
+from zoneinfo import ZoneInfo
+
+from dateutil.parser import isoparse
from freezegun import freeze_time
from pydantic import BaseModel
-from posthog.hogql_queries.query_runner import QueryResponse, QueryRunner
+
+from posthog.hogql_queries.query_runner import QueryResponse, QueryRunner, RunnableQueryNode
from posthog.models.team.team import Team
from posthog.test.base import BaseTest
-from posthog.types import InsightQueryNode
class TestQuery(BaseModel):
@@ -17,7 +18,7 @@ class TestQuery(BaseModel):
class QueryRunnerTest(BaseTest):
- def setup_test_query_runner_class(self, query_class: Type[InsightQueryNode] = TestQuery): # type: ignore
+ def setup_test_query_runner_class(self, query_class: Type[RunnableQueryNode] = TestQuery): # type: ignore
"""Setup required methods and attributes of the abstract base class."""
class TestQueryRunner(QueryRunner):
@@ -81,12 +82,27 @@ def test_serializes_to_json_ignores_empty_dict(self):
def test_cache_key(self):
TestQueryRunner = self.setup_test_query_runner_class()
+ # set the pk directly as it affects the hash in the _cache_key call
team = Team.objects.create(pk=42, organization=self.organization)
runner = TestQueryRunner(query={"some_attr": "bla"}, team=team) # type: ignore
cache_key = runner._cache_key()
- self.assertEqual(cache_key, "cache_f0f2ce8b1f3d107b9671a178b25be2aa")
+ self.assertEqual(cache_key, "cache_33c9ea3098895d5a363a75feefafef06")
+
+ def test_cache_key_runner_subclass(self):
+ TestQueryRunner = self.setup_test_query_runner_class()
+
+ class TestSubclassQueryRunner(TestQueryRunner): # type: ignore
+ pass
+
+ # set the pk directly as it affects the hash in the _cache_key call
+ team = Team.objects.create(pk=42, organization=self.organization)
+
+ runner = TestSubclassQueryRunner(query={"some_attr": "bla"}, team=team) # type: ignore
+
+ cache_key = runner._cache_key()
+ self.assertEqual(cache_key, "cache_d626615de8ad0df73c1d8610ca586597")
def test_cache_key_different_timezone(self):
TestQueryRunner = self.setup_test_query_runner_class()
@@ -97,7 +113,7 @@ def test_cache_key_different_timezone(self):
runner = TestQueryRunner(query={"some_attr": "bla"}, team=team) # type: ignore
cache_key = runner._cache_key()
- self.assertEqual(cache_key, "cache_0fa2172980705adb41741351f40189b7")
+ self.assertEqual(cache_key, "cache_aeb23ec9e8de56dd8499f99f2e976d5a")
def test_cache_response(self):
TestQueryRunner = self.setup_test_query_runner_class()
diff --git a/posthog/hogql_queries/web_analytics/__init__.py b/posthog/hogql_queries/web_analytics/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/posthog/hogql_queries/web_analytics/top_clicks.py b/posthog/hogql_queries/web_analytics/top_clicks.py
new file mode 100644
index 0000000000000..c1d7c9ee6ab29
--- /dev/null
+++ b/posthog/hogql_queries/web_analytics/top_clicks.py
@@ -0,0 +1,53 @@
+from django.utils.timezone import datetime
+
+from posthog.hogql import ast
+from posthog.hogql.parser import parse_select
+from posthog.hogql.query import execute_hogql_query
+from posthog.hogql_queries.web_analytics.web_analytics_query_runner import WebAnalyticsQueryRunner
+from posthog.hogql_queries.utils.query_date_range import QueryDateRange
+from posthog.models.filters.mixins.utils import cached_property
+from posthog.schema import WebTopClicksQuery, WebTopClicksQueryResponse
+
+
+class WebTopClicksQueryRunner(WebAnalyticsQueryRunner):
+ query: WebTopClicksQuery
+ query_type = WebTopClicksQuery
+
+ def to_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
+ with self.timings.measure("top_clicks_query"):
+ top_sources_query = parse_select(
+ """
+SELECT
+ properties.$el_text as el_text,
+ count() as total_clicks,
+ COUNT(DISTINCT events.person_id) as unique_visitors
+FROM
+ events
+WHERE
+ event == '$autocapture'
+AND events.timestamp >= now() - INTERVAL 7 DAY
+AND events.properties.$event_type = 'click'
+AND el_text IS NOT NULL
+GROUP BY
+ el_text
+ORDER BY total_clicks DESC
+ """,
+ timings=self.timings,
+ )
+ return top_sources_query
+
+ def calculate(self):
+ response = execute_hogql_query(
+ query_type="top_sources_query",
+ query=self.to_query(),
+ team=self.team,
+ timings=self.timings,
+ )
+
+ return WebTopClicksQueryResponse(
+ columns=response.columns, result=response.results, timings=response.timings, types=response.types
+ )
+
+ @cached_property
+ def query_date_range(self):
+ return QueryDateRange(date_range=self.query.dateRange, team=self.team, interval=None, now=datetime.now())
diff --git a/posthog/hogql_queries/web_analytics/top_pages.py b/posthog/hogql_queries/web_analytics/top_pages.py
new file mode 100644
index 0000000000000..7ded183b80d1b
--- /dev/null
+++ b/posthog/hogql_queries/web_analytics/top_pages.py
@@ -0,0 +1,151 @@
+from django.utils.timezone import datetime
+
+from posthog.hogql import ast
+from posthog.hogql.parser import parse_select
+from posthog.hogql.query import execute_hogql_query
+from posthog.hogql_queries.web_analytics.web_analytics_query_runner import WebAnalyticsQueryRunner
+from posthog.hogql_queries.utils.query_date_range import QueryDateRange
+from posthog.models.filters.mixins.utils import cached_property
+from posthog.schema import WebTopPagesQuery, WebTopPagesQueryResponse
+
+
+class WebTopPagesQueryRunner(WebAnalyticsQueryRunner):
+ query: WebTopPagesQuery
+ query_type = WebTopPagesQuery
+
+ def to_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
+ with self.timings.measure("top_pages_query"):
+ top_sources_query = parse_select(
+ """
+WITH
+
+scroll_depth_cte AS (
+SELECT
+ events.properties.`$prev_pageview_pathname` AS pathname,
+ countIf(events.event == '$pageview') as total_pageviews,
+ COUNT(DISTINCT events.properties.distinct_id) as unique_visitors, -- might want to use person id? have seen a small number of pages where unique > total
+ avg(CASE
+ WHEN events.properties.`$prev_pageview_max_content_percentage` IS NULL THEN NULL
+ WHEN events.properties.`$prev_pageview_max_content_percentage` > 0.8 THEN 100
+ ELSE 0
+ END) AS scroll_gt80_percentage,
+ avg(events.properties.$prev_pageview_max_scroll_percentage) * 100 as average_scroll_percentage
+FROM
+ events
+WHERE
+ (event = '$pageview' OR event = '$pageleave') AND events.properties.`$prev_pageview_pathname` IS NOT NULL
+ AND events.timestamp >= now() - INTERVAL 7 DAY
+GROUP BY pathname
+)
+
+,
+
+session_cte AS (
+SELECT
+ events.properties.`$session_id` AS session_id,
+ min(events.timestamp) AS min_timestamp,
+ max(events.timestamp) AS max_timestamp,
+ dateDiff('second', min_timestamp, max_timestamp) AS duration_s,
+
+ -- create a tuple so that these are grouped in the same order, see https://github.com/ClickHouse/ClickHouse/discussions/42338
+ groupArray((events.timestamp, events.properties.`$referrer`, events.properties.`$pathname`, events.properties.utm_source)) AS tuple_array,
+ arrayFirstIndex(x -> tupleElement(x, 1) == min_timestamp, tuple_array) as index_of_earliest,
+ arrayFirstIndex(x -> tupleElement(x, 1) == max_timestamp, tuple_array) as index_of_latest,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 2) AS earliest_referrer,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 3) AS earliest_pathname,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 4) AS earliest_utm_source,
+
+ if(domain(earliest_referrer) = '', earliest_referrer, domain(earliest_referrer)) AS referrer_domain,
+ multiIf(
+ earliest_utm_source IS NOT NULL, earliest_utm_source,
+ -- This will need to be an approach that scales better
+ referrer_domain == 'app.posthog.com', 'posthog',
+ referrer_domain == 'eu.posthog.com', 'posthog',
+ referrer_domain == 'posthog.com', 'posthog',
+ referrer_domain == 'www.google.com', 'google',
+ referrer_domain == 'www.google.co.uk', 'google',
+ referrer_domain == 'www.google.com.hk', 'google',
+ referrer_domain == 'www.google.de', 'google',
+ referrer_domain == 't.co', 'twitter',
+ referrer_domain == 'github.com', 'github',
+ referrer_domain == 'duckduckgo.com', 'duckduckgo',
+ referrer_domain == 'www.bing.com', 'bing',
+ referrer_domain == 'bing.com', 'bing',
+ referrer_domain == 'yandex.ru', 'yandex',
+ referrer_domain == 'quora.com', 'quora',
+ referrer_domain == 'www.quora.com', 'quora',
+ referrer_domain == 'linkedin.com', 'linkedin',
+ referrer_domain == 'www.linkedin.com', 'linkedin',
+ startsWith(referrer_domain, 'http://localhost:'), 'localhost',
+ referrer_domain
+ ) AS blended_source,
+
+ countIf(events.event == '$pageview') AS num_pageviews,
+ countIf(events.event == '$autocapture') AS num_autocaptures,
+ -- in v1 we'd also want to count whether there were any conversion events
+
+ any(events.person_id) as person_id,
+ -- definition of a GA4 bounce from here https://support.google.com/analytics/answer/12195621?hl=en
+ (num_autocaptures == 0 AND num_pageviews <= 1 AND duration_s < 10) AS is_bounce
+FROM
+ events
+WHERE
+ session_id IS NOT NULL
+AND
+ events.timestamp >= now() - INTERVAL 8 DAY
+GROUP BY
+ events.properties.`$session_id`
+HAVING
+ min_timestamp >= now() - INTERVAL 7 DAY
+)
+
+,
+
+bounce_rate_cte AS (
+SELECT session_cte.earliest_pathname,
+ avg(session_cte.is_bounce) as bounce_rate
+FROM session_cte
+GROUP BY earliest_pathname
+)
+
+
+
+SELECT scroll_depth_cte.pathname as pathname,
+scroll_depth_cte.total_pageviews as total_pageviews,
+scroll_depth_cte.unique_visitors as unique_visitors,
+scroll_depth_cte.scroll_gt80_percentage as scroll_gt80_percentage,
+scroll_depth_cte.average_scroll_percentage as average_scroll_percentage,
+bounce_rate_cte.bounce_rate as bounce_rate
+FROM
+ scroll_depth_cte LEFT OUTER JOIN bounce_rate_cte
+ON scroll_depth_cte.pathname = bounce_rate_cte.earliest_pathname
+ORDER BY total_pageviews DESC
+ """,
+ timings=self.timings,
+ )
+ return top_sources_query
+
+ def calculate(self):
+ response = execute_hogql_query(
+ query_type="top_sources_query",
+ query=self.to_query(),
+ team=self.team,
+ timings=self.timings,
+ )
+
+ return WebTopPagesQueryResponse(
+ columns=response.columns, result=response.results, timings=response.timings, types=response.types
+ )
+
+ @cached_property
+ def query_date_range(self):
+ return QueryDateRange(date_range=self.query.dateRange, team=self.team, interval=None, now=datetime.now())
diff --git a/posthog/hogql_queries/web_analytics/top_sources.py b/posthog/hogql_queries/web_analytics/top_sources.py
new file mode 100644
index 0000000000000..2762627c6002d
--- /dev/null
+++ b/posthog/hogql_queries/web_analytics/top_sources.py
@@ -0,0 +1,124 @@
+from django.utils.timezone import datetime
+
+from posthog.hogql import ast
+from posthog.hogql.parser import parse_select
+from posthog.hogql.query import execute_hogql_query
+from posthog.hogql_queries.web_analytics.web_analytics_query_runner import WebAnalyticsQueryRunner
+from posthog.hogql_queries.utils.query_date_range import QueryDateRange
+from posthog.models.filters.mixins.utils import cached_property
+from posthog.schema import WebTopSourcesQuery, WebTopSourcesQueryResponse
+
+
+class WebTopSourcesQueryRunner(WebAnalyticsQueryRunner):
+ query: WebTopSourcesQuery
+ query_type = WebTopSourcesQuery
+
+ def to_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
+ with self.timings.measure("top_sources_query"):
+ top_sources_query = parse_select(
+ """
+WITH
+
+session_cte AS (
+SELECT
+ events.properties.`$session_id` AS session_id,
+ min(events.timestamp) AS min_timestamp,
+ max(events.timestamp) AS max_timestamp,
+ dateDiff('second', min_timestamp, max_timestamp) AS duration_s,
+
+ -- create a tuple so that these are grouped in the same order, see https://github.com/ClickHouse/ClickHouse/discussions/42338
+ groupArray((events.timestamp, events.properties.`$referrer`, events.properties.`$pathname`, events.properties.utm_source)) AS tuple_array,
+ arrayFirstIndex(x -> tupleElement(x, 1) == min_timestamp, tuple_array) as index_of_earliest,
+ arrayFirstIndex(x -> tupleElement(x, 1) == max_timestamp, tuple_array) as index_of_latest,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 2) AS earliest_referrer,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 3) AS earliest_pathname,
+ tupleElement(arrayElement(
+ tuple_array,
+ index_of_earliest
+ ), 4) AS earliest_utm_source,
+
+ if(domain(earliest_referrer) = '', earliest_referrer, domain(earliest_referrer)) AS referrer_domain,
+ multiIf(
+ earliest_utm_source IS NOT NULL, earliest_utm_source,
+ -- This will need to be an approach that scales better
+ referrer_domain == 'app.posthog.com', 'posthog',
+ referrer_domain == 'eu.posthog.com', 'posthog',
+ referrer_domain == 'posthog.com', 'posthog',
+ referrer_domain == 'www.google.com', 'google',
+ referrer_domain == 'www.google.co.uk', 'google',
+ referrer_domain == 'www.google.com.hk', 'google',
+ referrer_domain == 'www.google.de', 'google',
+ referrer_domain == 't.co', 'twitter',
+ referrer_domain == 'github.com', 'github',
+ referrer_domain == 'duckduckgo.com', 'duckduckgo',
+ referrer_domain == 'www.bing.com', 'bing',
+ referrer_domain == 'bing.com', 'bing',
+ referrer_domain == 'yandex.ru', 'yandex',
+ referrer_domain == 'quora.com', 'quora',
+ referrer_domain == 'www.quora.com', 'quora',
+ referrer_domain == 'linkedin.com', 'linkedin',
+ referrer_domain == 'www.linkedin.com', 'linkedin',
+ startsWith(referrer_domain, 'http://localhost:'), 'localhost',
+ referrer_domain
+ ) AS blended_source,
+
+ countIf(events.event == '$pageview') AS num_pageviews,
+ countIf(events.event == '$autocapture') AS num_autocaptures,
+ -- in v1 we'd also want to count whether there were any conversion events
+
+ any(events.person_id) as person_id,
+ -- definition of a GA4 bounce from here https://support.google.com/analytics/answer/12195621?hl=en
+ (num_autocaptures == 0 AND num_pageviews <= 1 AND duration_s < 10) AS is_bounce
+FROM
+ events
+WHERE
+ session_id IS NOT NULL
+AND
+ events.timestamp >= now() - INTERVAL 8 DAY
+GROUP BY
+ events.properties.`$session_id`
+HAVING
+ min_timestamp >= now() - INTERVAL 7 DAY
+)
+
+
+
+SELECT
+ blended_source,
+ count(num_pageviews) as total_pageviews,
+ count(DISTINCT person_id) as unique_visitors,
+ avg(is_bounce) AS bounce_rate
+FROM
+ session_cte
+WHERE
+ blended_source IS NOT NULL
+GROUP BY blended_source
+
+ORDER BY total_pageviews DESC
+LIMIT 100
+ """,
+ timings=self.timings,
+ )
+ return top_sources_query
+
+ def calculate(self):
+ response = execute_hogql_query(
+ query_type="top_sources_query",
+ query=self.to_query(),
+ team=self.team,
+ timings=self.timings,
+ )
+
+ return WebTopSourcesQueryResponse(
+ columns=response.columns, result=response.results, timings=response.timings, types=response.types
+ )
+
+ @cached_property
+ def query_date_range(self):
+ return QueryDateRange(date_range=self.query.dateRange, team=self.team, interval=None, now=datetime.now())
diff --git a/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py
new file mode 100644
index 0000000000000..e023ad32954b3
--- /dev/null
+++ b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py
@@ -0,0 +1,12 @@
+from abc import ABC
+
+from posthog.caching.insights_api import BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL
+from posthog.hogql_queries.query_runner import QueryRunner
+
+
+class WebAnalyticsQueryRunner(QueryRunner, ABC):
+ def _is_stale(self, cached_result_package):
+ return True
+
+ def _refresh_frequency(self):
+ return BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL
diff --git a/posthog/schema.py b/posthog/schema.py
index 99088a5f718b0..03e60279ac7fc 100644
--- a/posthog/schema.py
+++ b/posthog/schema.py
@@ -365,6 +365,7 @@ class SavedInsightNode(BaseModel):
propertiesViaUrl: Optional[bool] = Field(default=None, description="Link properties via the URL (default: false)")
shortId: str
showActions: Optional[bool] = Field(default=None, description="Show the kebab menu at the end of the row")
+ showBackButton: Optional[bool] = Field(default=None, description="Show a button to go back to the source query")
showColumnConfigurator: Optional[bool] = Field(
default=None, description="Show a button to configure the table's columns if possible"
)
@@ -468,6 +469,45 @@ class TrendsQueryResponse(BaseModel):
timings: Optional[List[QueryTiming]] = None
+class WebTopClicksQueryResponse(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ columns: Optional[List] = None
+ is_cached: Optional[bool] = None
+ last_refresh: Optional[str] = None
+ next_allowed_client_refresh: Optional[str] = None
+ result: List
+ timings: Optional[List[QueryTiming]] = None
+ types: Optional[List] = None
+
+
+class WebTopPagesQueryResponse(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ columns: Optional[List] = None
+ is_cached: Optional[bool] = None
+ last_refresh: Optional[str] = None
+ next_allowed_client_refresh: Optional[str] = None
+ result: List
+ timings: Optional[List[QueryTiming]] = None
+ types: Optional[List] = None
+
+
+class WebTopSourcesQueryResponse(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ columns: Optional[List] = None
+ is_cached: Optional[bool] = None
+ last_refresh: Optional[str] = None
+ next_allowed_client_refresh: Optional[str] = None
+ result: List
+ timings: Optional[List[QueryTiming]] = None
+ types: Optional[List] = None
+
+
class Breakdown(BaseModel):
model_config = ConfigDict(
extra="forbid",
@@ -669,6 +709,36 @@ class TimeToSeeDataSessionsQuery(BaseModel):
teamId: Optional[float] = Field(default=None, description="Project to filter on. Defaults to current project")
+class WebTopClicksQuery(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ dateRange: Optional[DateRange] = None
+ filters: Any
+ kind: Literal["WebTopClicksQuery"] = "WebTopClicksQuery"
+ response: Optional[WebTopClicksQueryResponse] = None
+
+
+class WebTopPagesQuery(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ dateRange: Optional[DateRange] = None
+ filters: Any
+ kind: Literal["WebTopPagesQuery"] = "WebTopPagesQuery"
+ response: Optional[WebTopPagesQueryResponse] = None
+
+
+class WebTopSourcesQuery(BaseModel):
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ dateRange: Optional[DateRange] = None
+ filters: Any
+ kind: Literal["WebTopSourcesQuery"] = "WebTopSourcesQuery"
+ response: Optional[WebTopSourcesQueryResponse] = None
+
+
class DatabaseSchemaQuery(BaseModel):
model_config = ConfigDict(
extra="forbid",
@@ -1249,6 +1319,7 @@ class DataTableNode(BaseModel):
kind: Literal["DataTableNode"] = "DataTableNode"
propertiesViaUrl: Optional[bool] = Field(default=None, description="Link properties via the URL (default: false)")
showActions: Optional[bool] = Field(default=None, description="Show the kebab menu at the end of the row")
+ showBackButton: Optional[bool] = Field(default=None, description="Show a button to go back to the source query")
showColumnConfigurator: Optional[bool] = Field(
default=None, description="Show a button to configure the table's columns if possible"
)
@@ -1271,9 +1342,17 @@ class DataTableNode(BaseModel):
showSavedQueries: Optional[bool] = Field(default=None, description="Shows a list of saved queries")
showSearch: Optional[bool] = Field(default=None, description="Include a free text search field (PersonsNode only)")
showTimings: Optional[bool] = Field(default=None, description="Show a detailed query timing breakdown")
- source: Union[EventsNode, EventsQuery, PersonsNode, PersonsQuery, HogQLQuery, TimeToSeeDataSessionsQuery] = Field(
- ..., description="Source of the events"
- )
+ source: Union[
+ EventsNode,
+ EventsQuery,
+ PersonsNode,
+ PersonsQuery,
+ HogQLQuery,
+ TimeToSeeDataSessionsQuery,
+ WebTopSourcesQuery,
+ WebTopClicksQuery,
+ WebTopPagesQuery,
+ ] = Field(..., description="Source of the events")
class Model(RootModel):
@@ -1297,6 +1376,9 @@ class Model(RootModel):
PersonsQuery,
HogQLQuery,
HogQLMetadata,
+ WebTopSourcesQuery,
+ WebTopClicksQuery,
+ WebTopPagesQuery,
],
]
diff --git a/posthog/settings/sentry.py b/posthog/settings/sentry.py
index e1afc3c8ccde9..8aa2fe37f1620 100644
--- a/posthog/settings/sentry.py
+++ b/posthog/settings/sentry.py
@@ -10,6 +10,8 @@
from posthog.settings import get_from_env
from posthog.settings.base_variables import TEST
+from datetime import datetime, timezone
+
def traces_sampler(sampling_context: dict) -> float:
#
@@ -36,6 +38,12 @@ def traces_sampler(sampling_context: dict) -> float:
return 0.0000001 # 0.00001%
# Get more traces for /decide than other high volume endpoints
elif path.startswith("/decide"):
+ # Get the current time in GMT
+ current_time_gmt = datetime.now(timezone.utc)
+ # Check if the time is between 5 and 6:59 am GMT, where we get spikes of latency
+ # so we can get more traces to debug
+ if 5 <= current_time_gmt.hour < 7:
+ return 0.001 # 0.1%
return 0.00001 # 0.001%
# Probes/monitoring endpoints
elif path.startswith(("/_health", "/_readyz", "/_livez")):
diff --git a/posthog/tasks/exporter.py b/posthog/tasks/exporter.py
index c557abbb67fbc..87456edb02165 100644
--- a/posthog/tasks/exporter.py
+++ b/posthog/tasks/exporter.py
@@ -1,24 +1,55 @@
from typing import Optional
+from prometheus_client import Counter, Histogram
+
from posthog.celery import app
from posthog.models import ExportedAsset
+EXPORT_QUEUED_COUNTER = Counter(
+ "exporter_task_queued",
+ "An export task was queued",
+ labelnames=["type"],
+)
+EXPORT_SUCCEEDED_COUNTER = Counter(
+ "exporter_task_succeeded",
+ "An export task succeeded",
+ labelnames=["type"],
+)
+EXPORT_ASSET_UNKNOWN_COUNTER = Counter(
+ "exporter_task_unknown_asset",
+ "An export task was for an unknown asset",
+ labelnames=["type"],
+)
+EXPORT_FAILED_COUNTER = Counter(
+ "exporter_task_failed",
+ "An export task failed",
+ labelnames=["type"],
+)
+EXPORT_TIMER = Histogram(
+ "exporter_task_duration_seconds",
+ "Time spent exporting an asset",
+ labelnames=["type"],
+ buckets=(1, 5, 10, 30, 60, 120, 240, 300, 360, 420, 480, 540, 600, float("inf")),
+)
-@app.task(autoretry_for=(Exception,), max_retries=5, retry_backoff=True, acks_late=True)
-def export_asset(exported_asset_id: int, limit: Optional[int] = None) -> None:
- from statshog.defaults.django import statsd
+# export_asset is used in chords/groups and so must not ignore its results
+@app.task(autoretry_for=(Exception,), max_retries=5, retry_backoff=True, acks_late=True, ignore_result=False)
+def export_asset(exported_asset_id: int, limit: Optional[int] = None) -> None:
from posthog.tasks.exports import csv_exporter, image_exporter
- exported_asset: ExportedAsset = ExportedAsset.objects.select_related("insight", "dashboard").get(
- pk=exported_asset_id
- )
+ # if Celery is lagging then you can end up with an exported asset that has had a TTL added
+ # and that TTL has passed, in the exporter we don't care about that.
+ # the TTL is for later cleanup.
+ exported_asset: ExportedAsset = ExportedAsset.objects_including_ttl_deleted.select_related(
+ "insight", "dashboard"
+ ).get(pk=exported_asset_id)
is_csv_export = exported_asset.export_format == ExportedAsset.ExportFormat.CSV
if is_csv_export:
max_limit = exported_asset.export_context.get("max_limit", 10000)
csv_exporter.export_csv(exported_asset, limit=limit, max_limit=max_limit)
- statsd.incr("csv_exporter.queued", tags={"team_id": str(exported_asset.team_id)})
+ EXPORT_QUEUED_COUNTER.labels(type="csv").inc()
else:
image_exporter.export_image(exported_asset)
- statsd.incr("image_exporter.queued", tags={"team_id": str(exported_asset.team_id)})
+ EXPORT_QUEUED_COUNTER.labels(type="image").inc()
diff --git a/posthog/tasks/exports/csv_exporter.py b/posthog/tasks/exports/csv_exporter.py
index 69ad972c37393..64cb23cbac0ac 100644
--- a/posthog/tasks/exports/csv_exporter.py
+++ b/posthog/tasks/exports/csv_exporter.py
@@ -1,20 +1,18 @@
import datetime
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
-from django.http import QueryDict
import requests
import structlog
+from django.http import QueryDict
from sentry_sdk import capture_exception, push_scope
-from statshog.defaults.django import statsd
-from posthog.jwt import PosthogJwtAudience, encode_jwt
from posthog.api.query import process_query
-from posthog.logging.timing import timed
+from posthog.jwt import PosthogJwtAudience, encode_jwt
from posthog.models.exported_asset import ExportedAsset, save_content
from posthog.utils import absolute_uri
-
from .ordered_csv_renderer import OrderedCsvRenderer
+from ..exporter import EXPORT_FAILED_COUNTER, EXPORT_ASSET_UNKNOWN_COUNTER, EXPORT_SUCCEEDED_COUNTER, EXPORT_TIMER
logger = structlog.get_logger(__name__)
@@ -271,17 +269,17 @@ def make_api_call(
raise ex
-@timed("csv_exporter")
def export_csv(exported_asset: ExportedAsset, limit: Optional[int] = None, max_limit: int = 3_500) -> None:
if not limit:
limit = 1000
try:
if exported_asset.export_format == "text/csv":
- _export_to_csv(exported_asset, limit, max_limit)
- statsd.incr("csv_exporter.succeeded", tags={"team_id": exported_asset.team.id})
+ with EXPORT_TIMER.labels(type="csv").time():
+ _export_to_csv(exported_asset, limit, max_limit)
+ EXPORT_SUCCEEDED_COUNTER.labels(type="csv").inc()
else:
- statsd.incr("csv_exporter.unknown_asset", tags={"team_id": exported_asset.team.id})
+ EXPORT_ASSET_UNKNOWN_COUNTER.labels(type="csv").inc()
raise NotImplementedError(f"Export to format {exported_asset.export_format} is not supported")
except Exception as e:
if exported_asset:
@@ -291,8 +289,9 @@ def export_csv(exported_asset: ExportedAsset, limit: Optional[int] = None, max_l
with push_scope() as scope:
scope.set_tag("celery_task", "csv_export")
+ scope.set_tag("team_id", team_id)
capture_exception(e)
logger.error("csv_exporter.failed", exception=e, exc_info=True)
- statsd.incr("csv_exporter.failed", tags={"team_id": team_id})
+ EXPORT_FAILED_COUNTER.labels(type="csv").inc()
raise e
diff --git a/posthog/tasks/exports/image_exporter.py b/posthog/tasks/exports/image_exporter.py
index 79180670e9b4d..057239a929f50 100644
--- a/posthog/tasks/exports/image_exporter.py
+++ b/posthog/tasks/exports/image_exporter.py
@@ -6,51 +6,30 @@
import structlog
from django.conf import settings
-from prometheus_client import Counter, Summary
from selenium import webdriver
+from selenium.common.exceptions import TimeoutException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
-from selenium.common.exceptions import TimeoutException
from selenium.webdriver.support.wait import WebDriverWait
from sentry_sdk import capture_exception, configure_scope, push_scope
-
from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.core.os_manager import ChromeType
from posthog.caching.fetch_from_cache import synchronously_update_cache
-from posthog.logging.timing import timed
-from posthog.metrics import LABEL_TEAM_ID
from posthog.models.exported_asset import ExportedAsset, get_public_access_token, save_content
+from posthog.tasks.exporter import EXPORT_SUCCEEDED_COUNTER, EXPORT_FAILED_COUNTER, EXPORT_TIMER
from posthog.tasks.exports.exporter_utils import log_error_if_site_url_not_reachable
from posthog.utils import absolute_uri
logger = structlog.get_logger(__name__)
-IMAGE_EXPORT_SUCCEEDED_COUNTER = Counter(
- "image_exporter_task_succeeded",
- "An image export task succeeded",
- labelnames=[LABEL_TEAM_ID],
-)
-
-IMAGE_EXPORT_FAILED_COUNTER = Counter(
- "image_exporter_task_failure",
- "An image export task failed",
- labelnames=[LABEL_TEAM_ID],
-)
-
-IMAGE_EXPORT_TIMER = Summary(
- "image_exporter_task_success_time",
- "Number of seconds it took to export an image",
- labelnames=[LABEL_TEAM_ID],
-)
-
TMP_DIR = "/tmp" # NOTE: Externalise this to ENV var
ScreenWidth = Literal[800, 1920]
CSSSelector = Literal[".InsightCard", ".ExportedInsight"]
-# NOTE: We purporsefully DONT re-use the driver. It would be slightly faster but would keep an in-memory browser
+# NOTE: We purposefully DON'T re-use the driver. It would be slightly faster but would keep an in-memory browser
# window permanently around which is unnecessary
def get_driver() -> webdriver.Chrome:
options = Options()
@@ -144,15 +123,29 @@ def _screenshot_asset(
try:
WebDriverWait(driver, 20).until_not(lambda x: x.find_element_by_class_name("Spinner"))
except TimeoutException:
- capture_exception()
+ logger.error(
+ "image_exporter.timeout",
+ url_to_render=url_to_render,
+ wait_for_css_selector=wait_for_css_selector,
+ image_path=image_path,
+ )
+ with push_scope() as scope:
+ scope.set_extra("url_to_render", url_to_render)
+ try:
+ driver.save_screenshot(image_path)
+ scope.add_attachment(None, None, image_path)
+ except Exception:
+ pass
+ capture_exception()
height = driver.execute_script("return document.body.scrollHeight")
driver.set_window_size(screenshot_width, height)
driver.save_screenshot(image_path)
except Exception as e:
- if driver:
- # To help with debugging, add a screenshot and any chrome logs
- with configure_scope() as scope:
- # If we encounter issues getting extra info we should silenty fail rather than creating a new exception
+ # To help with debugging, add a screenshot and any chrome logs
+ with configure_scope() as scope:
+ scope.set_extra("url_to_render", url_to_render)
+ if driver:
+ # If we encounter issues getting extra info we should silently fail rather than creating a new exception
try:
all_logs = [x for x in driver.get_log("browser")]
scope.add_attachment(json.dumps(all_logs).encode("utf-8"), "logs.txt")
@@ -163,7 +156,7 @@ def _screenshot_asset(
scope.add_attachment(None, None, image_path)
except Exception:
pass
- capture_exception(e)
+ capture_exception(e)
raise e
finally:
@@ -171,10 +164,9 @@ def _screenshot_asset(
driver.quit()
-@timed("image_exporter")
def export_image(exported_asset: ExportedAsset) -> None:
with push_scope() as scope:
- scope.set_tag("team_id", exported_asset.team if exported_asset else "unknown")
+ scope.set_tag("team_id", exported_asset.team.pk if exported_asset else "unknown")
scope.set_tag("asset_id", exported_asset.id if exported_asset else "unknown")
try:
@@ -184,9 +176,9 @@ def export_image(exported_asset: ExportedAsset) -> None:
synchronously_update_cache(exported_asset.insight, exported_asset.dashboard)
if exported_asset.export_format == "image/png":
- with IMAGE_EXPORT_TIMER.labels(team_id=exported_asset.team.id).time():
+ with EXPORT_TIMER.labels(type="image").time():
_export_to_png(exported_asset)
- IMAGE_EXPORT_SUCCEEDED_COUNTER.labels(team_id=exported_asset.team.id).inc()
+ EXPORT_SUCCEEDED_COUNTER.labels(type="image").inc()
else:
raise NotImplementedError(
f"Export to format {exported_asset.export_format} is not supported for insights"
@@ -197,8 +189,11 @@ def export_image(exported_asset: ExportedAsset) -> None:
else:
team_id = "unknown"
- capture_exception(e)
+ with push_scope() as scope:
+ scope.set_tag("celery_task", "image_export")
+ scope.set_tag("team_id", team_id)
+ capture_exception(e)
logger.error("image_exporter.failed", exception=e, exc_info=True)
- IMAGE_EXPORT_FAILED_COUNTER.labels(team_id=team_id).inc()
+ EXPORT_FAILED_COUNTER.labels(type="image").inc()
raise e
diff --git a/posthog/tasks/exports/test/test_csv_exporter.py b/posthog/tasks/exports/test/test_csv_exporter.py
index f7c866597e27b..0df5c1cf2759b 100644
--- a/posthog/tasks/exports/test/test_csv_exporter.py
+++ b/posthog/tasks/exports/test/test_csv_exporter.py
@@ -242,8 +242,7 @@ def test_csv_exporter_limits_breakdown_insights_correctly(
)
@patch("posthog.tasks.exports.csv_exporter.logger")
- @patch("posthog.tasks.exports.csv_exporter.statsd")
- def test_failing_export_api_is_reported(self, mock_statsd, mock_logger) -> None:
+ def test_failing_export_api_is_reported(self, _mock_logger: MagicMock) -> None:
with patch("posthog.tasks.exports.csv_exporter.requests.request") as patched_request:
exported_asset = self._create_asset()
mock_response = MagicMock()
diff --git a/posthog/tasks/test/__snapshots__/test_usage_report.ambr b/posthog/tasks/test/__snapshots__/test_usage_report.ambr
index ac94eb38d9841..74f71be82a5cc 100644
--- a/posthog/tasks/test/__snapshots__/test_usage_report.ambr
+++ b/posthog/tasks/test/__snapshots__/test_usage_report.ambr
@@ -233,6 +233,28 @@
GROUP BY team_id
'
---
+# name: TestFeatureFlagsUsageReport.test_usage_report_decide_requests.22
+ '
+
+ SELECT team_id,
+ COUNT() as count
+ FROM events
+ WHERE event = 'survey sent'
+ AND timestamp between '2022-01-10 00:00:00' AND '2022-01-10 23:59:59'
+ GROUP BY team_id
+ '
+---
+# name: TestFeatureFlagsUsageReport.test_usage_report_decide_requests.23
+ '
+
+ SELECT team_id,
+ COUNT() as count
+ FROM events
+ WHERE event = 'survey sent'
+ AND timestamp between '2022-01-01 00:00:00' AND '2022-01-10 23:59:59'
+ GROUP BY team_id
+ '
+---
# name: TestFeatureFlagsUsageReport.test_usage_report_decide_requests.3
'
diff --git a/posthog/tasks/test/test_usage_report.py b/posthog/tasks/test/test_usage_report.py
index 86ba59e07a8d8..491d50c0bb57a 100644
--- a/posthog/tasks/test/test_usage_report.py
+++ b/posthog/tasks/test/test_usage_report.py
@@ -93,6 +93,13 @@ def _create_sample_usage_data(self) -> None:
distinct_id = str(uuid4())
_create_person(distinct_ids=[distinct_id], team=self.org_1_team_1)
+ _create_event(
+ distinct_id=distinct_id,
+ event="survey sent",
+ timestamp=now() - relativedelta(hours=12),
+ team=self.org_1_team_1,
+ )
+
Dashboard.objects.create(team=self.org_1_team_1, name="Dash one", created_by=self.user)
dashboard = Dashboard.objects.create(
@@ -338,9 +345,9 @@ def _test_usage_report(self) -> List[dict]:
"plugins_installed": {"Installed and enabled": 1, "Installed but not enabled": 1},
"plugins_enabled": {"Installed and enabled": 1},
"instance_tag": "none",
- "event_count_lifetime": 54,
- "event_count_in_period": 22,
- "event_count_in_month": 42,
+ "event_count_lifetime": 55,
+ "event_count_in_period": 23,
+ "event_count_in_month": 43,
"event_count_with_groups_in_period": 2,
"recording_count_in_period": 5,
"recording_count_total": 16,
@@ -357,6 +364,8 @@ def _test_usage_report(self) -> List[dict]:
"local_evaluation_requests_count_in_period": 0,
"billable_feature_flag_requests_count_in_month": 0,
"billable_feature_flag_requests_count_in_period": 0,
+ "survey_responses_count_in_period": 1,
+ "survey_responses_count_in_month": 1,
"hogql_app_bytes_read": 0,
"hogql_app_rows_read": 0,
"hogql_app_duration_ms": 0,
@@ -377,9 +386,9 @@ def _test_usage_report(self) -> List[dict]:
"team_count": 2,
"teams": {
str(self.org_1_team_1.id): {
- "event_count_lifetime": 43,
- "event_count_in_period": 12,
- "event_count_in_month": 32,
+ "event_count_lifetime": 44,
+ "event_count_in_period": 13,
+ "event_count_in_month": 33,
"event_count_with_groups_in_period": 2,
"recording_count_in_period": 0,
"recording_count_total": 0,
@@ -396,6 +405,8 @@ def _test_usage_report(self) -> List[dict]:
"local_evaluation_requests_count_in_period": 0,
"billable_feature_flag_requests_count_in_month": 0,
"billable_feature_flag_requests_count_in_period": 0,
+ "survey_responses_count_in_period": 1,
+ "survey_responses_count_in_month": 1,
"hogql_app_bytes_read": 0,
"hogql_app_rows_read": 0,
"hogql_app_duration_ms": 0,
@@ -429,6 +440,8 @@ def _test_usage_report(self) -> List[dict]:
"local_evaluation_requests_count_in_period": 0,
"billable_feature_flag_requests_count_in_month": 0,
"billable_feature_flag_requests_count_in_period": 0,
+ "survey_responses_count_in_period": 0,
+ "survey_responses_count_in_month": 0,
"hogql_app_bytes_read": 0,
"hogql_app_rows_read": 0,
"hogql_app_duration_ms": 0,
@@ -482,6 +495,8 @@ def _test_usage_report(self) -> List[dict]:
"local_evaluation_requests_count_in_period": 0,
"billable_feature_flag_requests_count_in_month": 0,
"billable_feature_flag_requests_count_in_period": 0,
+ "survey_responses_count_in_period": 0,
+ "survey_responses_count_in_month": 0,
"hogql_app_bytes_read": 0,
"hogql_app_rows_read": 0,
"hogql_app_duration_ms": 0,
@@ -521,6 +536,8 @@ def _test_usage_report(self) -> List[dict]:
"local_evaluation_requests_count_in_period": 0,
"billable_feature_flag_requests_count_in_month": 0,
"billable_feature_flag_requests_count_in_period": 0,
+ "survey_responses_count_in_period": 0,
+ "survey_responses_count_in_month": 0,
"hogql_app_bytes_read": 0,
"hogql_app_rows_read": 0,
"hogql_app_duration_ms": 0,
@@ -834,6 +851,102 @@ def test_usage_report_local_evaluation_requests(
assert org_2_report["teams"]["5"]["billable_feature_flag_requests_count_in_month"] == 0
+@freeze_time("2022-01-10T00:01:00Z")
+class TestSurveysUsageReport(ClickhouseDestroyTablesMixin, TestCase, ClickhouseTestMixin):
+ def setUp(self) -> None:
+ Team.objects.all().delete()
+ return super().setUp()
+
+ def _setup_teams(self) -> None:
+ self.analytics_org = Organization.objects.create(name="PostHog")
+ self.org_1 = Organization.objects.create(name="Org 1")
+ self.org_2 = Organization.objects.create(name="Org 2")
+
+ self.analytics_team = Team.objects.create(pk=2, organization=self.analytics_org, name="Analytics")
+
+ self.org_1_team_1 = Team.objects.create(pk=3, organization=self.org_1, name="Team 1 org 1")
+ self.org_1_team_2 = Team.objects.create(pk=4, organization=self.org_1, name="Team 2 org 1")
+ self.org_2_team_3 = Team.objects.create(pk=5, organization=self.org_2, name="Team 3 org 2")
+
+ @patch("posthog.tasks.usage_report.Client")
+ @patch("posthog.tasks.usage_report.send_report_to_billing_service")
+ def test_usage_report_survey_responses(self, billing_task_mock: MagicMock, posthog_capture_mock: MagicMock) -> None:
+ self._setup_teams()
+ for i in range(10):
+ _create_event(
+ distinct_id="3",
+ event="survey sent",
+ properties={"$survey_id": "seeeep-o12-as124", "$survey_response": "correct"},
+ timestamp=now() - relativedelta(hours=i),
+ team=self.analytics_team,
+ )
+
+ for i in range(5):
+ _create_event(
+ distinct_id="4",
+ event="survey sent",
+ properties={"$survey_id": "see22eep-o12-as124", "$survey_response": "correct"},
+ timestamp=now() - relativedelta(hours=i),
+ team=self.org_1_team_1,
+ )
+ _create_event(
+ distinct_id="4",
+ event="survey sent",
+ properties={"count": 100, "token": "wrong"},
+ timestamp=now() - relativedelta(hours=i),
+ team=self.org_1_team_2,
+ )
+
+ for i in range(7):
+ _create_event(
+ distinct_id="5",
+ event="survey sent",
+ properties={"count": 100},
+ timestamp=now() - relativedelta(hours=i),
+ team=self.org_2_team_3,
+ )
+
+ # some out of range events
+ _create_event(
+ distinct_id="3",
+ event="survey sent",
+ properties={"count": 20000, "token": "correct"},
+ timestamp=now() - relativedelta(days=20),
+ team=self.analytics_team,
+ )
+ flush_persons_and_events()
+
+ period = get_previous_day(at=now() + relativedelta(days=1))
+ period_start, period_end = period
+ all_reports = _get_all_org_reports(period_start, period_end)
+
+ assert len(all_reports) == 3
+
+ org_1_report = _get_full_org_usage_report_as_dict(
+ _get_full_org_usage_report(all_reports[str(self.org_1.id)], get_instance_metadata(period))
+ )
+ assert org_1_report["organization_name"] == "Org 1"
+ org_2_report = _get_full_org_usage_report_as_dict(
+ _get_full_org_usage_report(all_reports[str(self.org_2.id)], get_instance_metadata(period))
+ )
+
+ assert org_1_report["organization_name"] == "Org 1"
+ assert org_1_report["survey_responses_count_in_period"] == 2
+ assert org_1_report["survey_responses_count_in_month"] == 10
+ assert org_1_report["teams"]["3"]["survey_responses_count_in_period"] == 1
+ assert org_1_report["teams"]["3"]["survey_responses_count_in_month"] == 5
+ assert org_1_report["teams"]["4"]["survey_responses_count_in_period"] == 1
+ assert org_1_report["teams"]["4"]["survey_responses_count_in_month"] == 5
+
+ assert org_2_report["organization_name"] == "Org 2"
+ assert org_2_report["decide_requests_count_in_period"] == 0
+ assert org_2_report["decide_requests_count_in_month"] == 0
+ assert org_2_report["survey_responses_count_in_period"] == 1
+ assert org_2_report["survey_responses_count_in_month"] == 7
+ assert org_2_report["teams"]["5"]["survey_responses_count_in_period"] == 1
+ assert org_2_report["teams"]["5"]["survey_responses_count_in_month"] == 7
+
+
class SendUsageTest(LicensedTestMixin, ClickhouseDestroyTablesMixin, APIBaseTest):
def setUp(self) -> None:
super().setUp()
diff --git a/posthog/tasks/usage_report.py b/posthog/tasks/usage_report.py
index 4627a95af6ab1..b9164dd6cf690 100644
--- a/posthog/tasks/usage_report.py
+++ b/posthog/tasks/usage_report.py
@@ -101,6 +101,9 @@ class UsageReportCounters:
event_explorer_api_bytes_read: int
event_explorer_api_rows_read: int
event_explorer_api_duration_ms: int
+ # Surveys
+ survey_responses_count_in_period: int
+ survey_responses_count_in_month: int
# Instance metadata to be included in oveall report
@@ -533,6 +536,28 @@ def get_teams_with_feature_flag_requests_count_in_period(
return result
+@timed_log()
+@retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF)
+def get_teams_with_survey_responses_count_in_period(
+ begin: datetime,
+ end: datetime,
+) -> List[Tuple[int, int]]:
+
+ results = sync_execute(
+ """
+ SELECT team_id, COUNT() as count
+ FROM events
+ WHERE event = 'survey sent' AND timestamp between %(begin)s AND %(end)s
+ GROUP BY team_id
+ """,
+ {"begin": begin, "end": end},
+ workload=Workload.OFFLINE,
+ settings=CH_BILLING_SETTINGS,
+ )
+
+ return results
+
+
@app.task(ignore_result=True, max_retries=0)
def capture_report(
capture_event_name: str, org_id: str, full_report_dict: Dict[str, Any], at_date: Optional[datetime] = None
@@ -556,6 +581,7 @@ def has_non_zero_usage(report: FullUsageReport) -> bool:
or report.recording_count_in_period > 0
or report.decide_requests_count_in_period > 0
or report.local_evaluation_requests_count_in_period > 0
+ or report.survey_responses_count_in_period > 0
)
@@ -716,6 +742,12 @@ def _get_all_usage_data(period_start: datetime, period_end: datetime) -> Dict[st
query_types=["EventsQuery"],
access_method="personal_api_key",
),
+ teams_with_survey_responses_count_in_period=get_teams_with_survey_responses_count_in_period(
+ period_start, period_end
+ ),
+ teams_with_survey_responses_count_in_month=get_teams_with_survey_responses_count_in_period(
+ period_start.replace(day=1), period_end
+ ),
)
@@ -784,6 +816,8 @@ def _get_team_report(all_data: Dict[str, Any], team: Team) -> UsageReportCounter
event_explorer_api_bytes_read=all_data["teams_with_event_explorer_api_bytes_read"].get(team.id, 0),
event_explorer_api_rows_read=all_data["teams_with_event_explorer_api_rows_read"].get(team.id, 0),
event_explorer_api_duration_ms=all_data["teams_with_event_explorer_api_duration_ms"].get(team.id, 0),
+ survey_responses_count_in_period=all_data["teams_with_survey_responses_count_in_period"].get(team.id, 0),
+ survey_responses_count_in_month=all_data["teams_with_survey_responses_count_in_month"].get(team.id, 0),
)
diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
index bb9e3db7bca61..e361ef436dfb1 100644
--- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
+++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py
@@ -1,4 +1,5 @@
import asyncio
+import contextlib
import datetime as dt
import functools
import gzip
@@ -9,7 +10,7 @@
from unittest import mock
from uuid import uuid4
-import boto3
+import aioboto3
import botocore.exceptions
import brotli
import pytest
@@ -55,18 +56,20 @@
TEST_ROOT_BUCKET = "test-batch-exports"
-def check_valid_credentials() -> bool:
+async def check_valid_credentials() -> bool:
"""Check if there are valid AWS credentials in the environment."""
- sts = boto3.client("sts")
+ session = aioboto3.Session()
+ sts = await session.client("sts")
try:
- sts.get_caller_identity()
+ await sts.get_caller_identity()
except botocore.exceptions.ClientError:
return False
else:
return True
-create_test_client = functools.partial(boto3.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)
+SESSION = aioboto3.Session()
+create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT)
@pytest.fixture
@@ -75,39 +78,38 @@ def bucket_name() -> str:
return f"{TEST_ROOT_BUCKET}-{str(uuid4())}"
-@pytest.fixture
-def s3_client(bucket_name):
+@pytest_asyncio.fixture
+async def s3_client(bucket_name):
"""Manage a testing S3 client to interact with a testing S3 bucket.
Yields the test S3 client after creating a testing S3 bucket. Upon resuming, we delete
the contents and the bucket itself.
"""
- s3_client = create_test_client(
+ async with create_test_client(
"s3",
aws_access_key_id="object_storage_root_user",
aws_secret_access_key="object_storage_root_password",
- )
+ ) as s3_client:
+ await s3_client.create_bucket(Bucket=bucket_name)
- s3_client.create_bucket(Bucket=bucket_name)
+ yield s3_client
- yield s3_client
+ response = await s3_client.list_objects_v2(Bucket=bucket_name)
- response = s3_client.list_objects_v2(Bucket=bucket_name)
+ if "Contents" in response:
+ for obj in response["Contents"]:
+ if "Key" in obj:
+ await s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])
- if "Contents" in response:
- for obj in response["Contents"]:
- if "Key" in obj:
- s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"])
+ await s3_client.delete_bucket(Bucket=bucket_name)
- s3_client.delete_bucket(Bucket=bucket_name)
-
-def assert_events_in_s3(
+async def assert_events_in_s3(
s3_client, bucket_name, key_prefix, events, compression: str | None = None, exclude_events: list[str] | None = None
):
"""Assert provided events written to JSON in key_prefix in S3 bucket_name."""
# List the objects in the bucket with the prefix.
- objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
+ objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix)
# Check that there is only one object.
assert len(objects.get("Contents", [])) == 1
@@ -115,8 +117,8 @@ def assert_events_in_s3(
# Get the object.
key = objects["Contents"][0].get("Key")
assert key
- object = s3_client.get_object(Bucket=bucket_name, Key=key)
- data = object["Body"].read()
+ s3_object = await s3_client.get_object(Bucket=bucket_name, Key=key)
+ data = await s3_object["Body"].read()
# Check that the data is correct.
match compression:
@@ -306,10 +308,12 @@ async def test_insert_into_s3_activity_puts_data_into_s3(
with override_settings(
BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2
): # 5MB, the minimum for Multipart uploads
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.run(insert_into_s3_activity, insert_inputs)
- assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
@pytest.mark.django_db
@@ -436,7 +440,9 @@ async def test_s3_export_workflow_with_minio_bucket(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
@@ -452,7 +458,7 @@ async def test_s3_export_workflow_with_minio_bucket(
run = runs[0]
assert run.status == "Completed"
- assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
@pytest.mark.skipif(
@@ -581,45 +587,46 @@ async def test_s3_export_workflow_with_s3_bucket(interval, compression, encrypti
**batch_export.destination.config,
)
- s3_client = boto3.client("s3")
+ async with aioboto3.Session().client("s3") as s3_client:
- def create_s3_client(*args, **kwargs):
- """Mock function to return an already initialized S3 client."""
- return s3_client
+ @contextlib.asynccontextmanager
+ async def create_s3_client(*args, **kwargs):
+ """Mock function to return an already initialized S3 client."""
+ yield s3_client
- async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
- async with Worker(
- activity_environment.client,
- task_queue=settings.TEMPORAL_TASK_QUEUE,
- workflows=[S3BatchExportWorkflow],
- activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
- workflow_runner=UnsandboxedWorkflowRunner(),
- ):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_s3_client):
- await activity_environment.client.execute_workflow(
- S3BatchExportWorkflow.run,
- inputs,
- id=workflow_id,
- task_queue=settings.TEMPORAL_TASK_QUEUE,
- retry_policy=RetryPolicy(maximum_attempts=1),
- execution_timeout=dt.timedelta(seconds=10),
- )
+ async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
+ async with Worker(
+ activity_environment.client,
+ task_queue=settings.TEMPORAL_TASK_QUEUE,
+ workflows=[S3BatchExportWorkflow],
+ activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
+ workflow_runner=UnsandboxedWorkflowRunner(),
+ ):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_s3_client
+ ):
+ await activity_environment.client.execute_workflow(
+ S3BatchExportWorkflow.run,
+ inputs,
+ id=workflow_id,
+ task_queue=settings.TEMPORAL_TASK_QUEUE,
+ retry_policy=RetryPolicy(maximum_attempts=1),
+ execution_timeout=dt.timedelta(seconds=10),
+ )
- runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
- assert len(runs) == 1
+ runs = await afetch_batch_export_runs(batch_export_id=batch_export.id)
+ assert len(runs) == 1
- run = runs[0]
- assert run.status == "Completed"
+ run = runs[0]
+ assert run.status == "Completed"
- assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression, exclude_events)
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip"])
-async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
- client: HttpClient, s3_client, bucket_name, compression
-):
+async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
The workflow should update the batch export run status to completed and produce the expected
@@ -700,7 +707,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
@@ -716,15 +725,15 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data(
run = runs[0]
assert run.status == "Completed"
- assert_events_in_s3(s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression)
+ await assert_events_in_s3(
+ s3_client, bucket_name, prefix.format(year=2023, month="04", day="25"), events, compression
+ )
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
-async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
- client: HttpClient, s3_client, bucket_name, compression
-):
+async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(s3_client, bucket_name, compression):
"""Test the full S3 workflow targetting a MinIO bucket.
In this scenario we assert that when inserted_at is NULL, we default to _timestamp.
@@ -818,7 +827,9 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
@@ -834,15 +845,13 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at(
run = runs[0]
assert run.status == "Completed"
- assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
-async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
- client: HttpClient, s3_client, bucket_name, compression
-):
+async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(s3_client, bucket_name, compression):
"""Test the S3BatchExport Workflow utilizing a custom key prefix.
We will be asserting that exported events land in the appropiate S3 key according to the prefix.
@@ -921,7 +930,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
@@ -940,20 +951,18 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
expected_key_prefix = prefix.format(
table="events", year="2023", month="04", day="25", hour="14", minute="30", second="00"
)
- objects = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
+ objects = await s3_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix)
key = objects["Contents"][0].get("Key")
assert len(objects.get("Contents", [])) == 1
assert key.startswith(expected_key_prefix)
- assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
+ await assert_events_in_s3(s3_client, bucket_name, expected_key_prefix, events, compression)
@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize("compression", [None, "gzip", "brotli"])
-async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
- client: HttpClient, s3_client, bucket_name, compression
-):
+async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(s3_client, bucket_name, compression):
"""Test that S3 Export Workflow end-to-end by using a local MinIO bucket instead of S3.
In this particular instance of the test, we assert no duplicates are exported to S3.
@@ -1065,7 +1074,9 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
activities=[create_export_run, insert_into_s3_activity, update_export_run_status],
workflow_runner=UnsandboxedWorkflowRunner(),
):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.client.execute_workflow(
S3BatchExportWorkflow.run,
inputs,
@@ -1080,7 +1091,7 @@ async def test_s3_export_workflow_with_minio_bucket_produces_no_duplicates(
run = runs[0]
assert run.status == "Completed"
- assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, compression)
@pytest_asyncio.fixture
@@ -1537,9 +1548,11 @@ def assert_heartbeat_details(*details):
)
with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
- with mock.patch("posthog.temporal.workflows.s3_batch_export.boto3.client", side_effect=create_test_client):
+ with mock.patch(
+ "posthog.temporal.workflows.s3_batch_export.aioboto3.Session.client", side_effect=create_test_client
+ ):
await activity_environment.run(insert_into_s3_activity, insert_inputs)
# This checks that the assert_heartbeat_details function was actually called
assert current_part_number > 1
- assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
+ await assert_events_in_s3(s3_client, bucket_name, prefix, events, None, None)
diff --git a/posthog/temporal/workflows/batch_exports.py b/posthog/temporal/workflows/batch_exports.py
index 7b9abbe100808..c451a8a036e87 100644
--- a/posthog/temporal/workflows/batch_exports.py
+++ b/posthog/temporal/workflows/batch_exports.py
@@ -295,6 +295,9 @@ def __exit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)
+ def __iter__(self):
+ yield from self._file
+
@property
def brotli_compressor(self):
if self._brotli_compressor is None:
diff --git a/posthog/temporal/workflows/s3_batch_export.py b/posthog/temporal/workflows/s3_batch_export.py
index 4252614a0263f..14568c44541f2 100644
--- a/posthog/temporal/workflows/s3_batch_export.py
+++ b/posthog/temporal/workflows/s3_batch_export.py
@@ -1,11 +1,13 @@
import asyncio
+import contextlib
import datetime as dt
+import io
import json
import posixpath
import typing
from dataclasses import dataclass
-import boto3
+import aioboto3
from django.conf import settings
from temporalio import activity, exceptions, workflow
from temporalio.common import RetryPolicy
@@ -90,8 +92,20 @@ class S3MultiPartUploadState(typing.NamedTuple):
class S3MultiPartUpload:
"""An S3 multi-part upload."""
- def __init__(self, s3_client, bucket_name: str, key: str, encryption: str | None, kms_key_id: str | None):
- self.s3_client = s3_client
+ def __init__(
+ self,
+ region_name: str,
+ bucket_name: str,
+ key: str,
+ encryption: str | None,
+ kms_key_id: str | None,
+ aws_access_key_id: str | None = None,
+ aws_secret_access_key: str | None = None,
+ ):
+ self._session = aioboto3.Session()
+ self.region_name = region_name
+ self.aws_access_key_id = aws_access_key_id
+ self.aws_secret_access_key = aws_secret_access_key
self.bucket_name = bucket_name
self.key = key
self.encryption = encryption
@@ -118,7 +132,17 @@ def is_upload_in_progress(self) -> bool:
return False
return True
- def start(self) -> str:
+ @contextlib.asynccontextmanager
+ async def s3_client(self):
+ async with self._session.client(
+ "s3",
+ region_name=self.region_name,
+ aws_access_key_id=self.aws_access_key_id,
+ aws_secret_access_key=self.aws_secret_access_key,
+ ) as client:
+ yield client
+
+ async def start(self) -> str:
"""Start this S3MultiPartUpload."""
if self.is_upload_in_progress() is True:
raise UploadAlreadyInProgressError(self.upload_id)
@@ -129,11 +153,13 @@ def start(self) -> str:
if self.kms_key_id:
optional_kwargs["SSEKMSKeyId"] = self.kms_key_id
- multipart_response = self.s3_client.create_multipart_upload(
- Bucket=self.bucket_name,
- Key=self.key,
- **optional_kwargs,
- )
+ async with self.s3_client() as s3_client:
+ multipart_response = await s3_client.create_multipart_upload(
+ Bucket=self.bucket_name,
+ Key=self.key,
+ **optional_kwargs,
+ )
+
upload_id: str = multipart_response["UploadId"]
self.upload_id = upload_id
@@ -146,66 +172,72 @@ def continue_from_state(self, state: S3MultiPartUploadState):
return self.upload_id
- def complete(self) -> str:
+ async def complete(self) -> str:
if self.is_upload_in_progress() is False:
raise NoUploadInProgressError()
- response = self.s3_client.complete_multipart_upload(
- Bucket=self.bucket_name,
- Key=self.key,
- UploadId=self.upload_id,
- MultipartUpload={"Parts": self.parts},
- )
+ async with self.s3_client() as s3_client:
+ response = await s3_client.complete_multipart_upload(
+ Bucket=self.bucket_name, Key=self.key, UploadId=self.upload_id, MultipartUpload={"Parts": self.parts}
+ )
self.upload_id = None
self.parts = []
return response["Location"]
- def abort(self):
+ async def abort(self):
if self.is_upload_in_progress() is False:
raise NoUploadInProgressError()
- self.s3_client.abort_multipart_upload(
- Bucket=self.bucket_name,
- Key=self.key,
- UploadId=self.upload_id,
- )
+ async with self.s3_client() as s3_client:
+ await s3_client.abort_multipart_upload(
+ Bucket=self.bucket_name,
+ Key=self.key,
+ UploadId=self.upload_id,
+ )
self.upload_id = None
self.parts = []
- def upload_part(self, body: BatchExportTemporaryFile, rewind: bool = True):
+ async def upload_part(self, body: BatchExportTemporaryFile, rewind: bool = True):
next_part_number = self.part_number + 1
if rewind is True:
body.rewind()
- response = self.s3_client.upload_part(
- Bucket=self.bucket_name,
- Key=self.key,
- PartNumber=next_part_number,
- UploadId=self.upload_id,
- Body=body,
- )
+ # aiohttp is not duck-type friendly and requires a io.IOBase
+ # We comply with the file-like interface of io.IOBase.
+ # So we tell mypy to be nice with us.
+ reader = io.BufferedReader(body) # type: ignore
+
+ async with self.s3_client() as s3_client:
+ response = await s3_client.upload_part(
+ Bucket=self.bucket_name,
+ Key=self.key,
+ PartNumber=next_part_number,
+ UploadId=self.upload_id,
+ Body=reader,
+ )
+ reader.detach() # BufferedReader closes the file otherwise.
self.parts.append({"PartNumber": next_part_number, "ETag": response["ETag"]})
- def __enter__(self):
+ async def __aenter__(self):
if not self.is_upload_in_progress():
- self.start()
+ await self.start()
return self
- def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ async def __aexit__(self, exc_type, exc_value, traceback) -> bool:
if exc_value is None:
# Succesfully completed the upload
- self.complete()
+ await self.complete()
return True
if exc_type == asyncio.CancelledError:
# Ensure we clean-up the cancelled upload.
- self.abort()
+ await self.abort()
return False
@@ -249,17 +281,20 @@ class S3InsertInputs:
kms_key_id: str | None = None
-def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]:
+async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]:
"""Initialize a S3MultiPartUpload and resume it from a hearbeat state if available."""
logger = get_batch_exports_logger(inputs=inputs)
key = get_s3_key(inputs)
- s3_client = boto3.client(
- "s3",
+
+ s3_upload = S3MultiPartUpload(
+ bucket_name=inputs.bucket_name,
+ key=key,
+ encryption=inputs.encryption,
+ kms_key_id=inputs.kms_key_id,
region_name=inputs.region,
aws_access_key_id=inputs.aws_access_key_id,
aws_secret_access_key=inputs.aws_secret_access_key,
)
- s3_upload = S3MultiPartUpload(s3_client, inputs.bucket_name, key, inputs.encryption, inputs.kms_key_id)
details = activity.info().heartbeat_details
@@ -291,7 +326,7 @@ def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3Mu
logger.info(
f"Export will start from the beginning as we are using brotli compression: {interval_start}",
)
- s3_upload.abort()
+ await s3_upload.abort()
return s3_upload, interval_start
@@ -335,7 +370,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs):
logger.info("BatchExporting %s rows to S3", count)
- s3_upload, interval_start = initialize_and_resume_multipart_upload(inputs)
+ s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs)
# Iterate through chunks of results from ClickHouse and push them to S3
# as a multipart upload. The intention here is to keep memory usage low,
@@ -364,7 +399,7 @@ async def worker_shutdown_handler():
asyncio.create_task(worker_shutdown_handler())
- with s3_upload as s3_upload:
+ async with s3_upload as s3_upload:
with BatchExportTemporaryFile(compression=inputs.compression) as local_results_file:
for result in results_iterator:
record = {
@@ -390,7 +425,7 @@ async def worker_shutdown_handler():
local_results_file.bytes_since_last_reset,
)
- s3_upload.upload_part(local_results_file)
+ await s3_upload.upload_part(local_results_file)
last_uploaded_part_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
@@ -405,7 +440,7 @@ async def worker_shutdown_handler():
local_results_file.bytes_since_last_reset,
)
- s3_upload.upload_part(local_results_file)
+ await s3_upload.upload_part(local_results_file)
last_uploaded_part_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state())
diff --git a/posthog/types.py b/posthog/types.py
index b9fd4bdfe3d55..f768cf8833a85 100644
--- a/posthog/types.py
+++ b/posthog/types.py
@@ -4,7 +4,7 @@
from posthog.models.filters.path_filter import PathFilter
from posthog.models.filters.retention_filter import RetentionFilter
from posthog.models.filters.stickiness_filter import StickinessFilter
-from posthog.schema import FunnelsQuery, LifecycleQuery, PathsQuery, RetentionQuery, StickinessQuery, TrendsQuery
+from posthog.schema import TrendsQuery, FunnelsQuery, RetentionQuery, PathsQuery, StickinessQuery, LifecycleQuery
FilterType = Union[Filter, PathFilter, RetentionFilter, StickinessFilter]
diff --git a/posthog/utils.py b/posthog/utils.py
index d9a9cda7d3439..4de5563bc0d0b 100644
--- a/posthog/utils.py
+++ b/posthog/utils.py
@@ -1,3 +1,4 @@
+import asyncio
import base64
import dataclasses
import datetime
@@ -34,6 +35,8 @@
import posthoganalytics
import pytz
import structlog
+from asgiref.sync import async_to_sync
+from celery.result import AsyncResult
from celery.schedules import crontab
from dateutil import parser
from dateutil.relativedelta import relativedelta
@@ -1182,7 +1185,24 @@ def get_week_start_for_country_code(country_code: str) -> int:
return 1 # Monday
-def wait_for_parallel_celery_group(task: Any, max_timeout: Optional[datetime.timedelta] = None) -> Any:
+def sleep_time_generator() -> Generator[float, None, None]:
+ # a generator that yield an exponential back off between 0.1 and 3 seconds
+ for _ in range(10):
+ yield 0.1 # 1 second in total
+ for _ in range(5):
+ yield 0.2 # 1 second in total
+ for _ in range(5):
+ yield 0.4 # 2 seconds in total
+ for _ in range(5):
+ yield 0.8 # 4 seconds in total
+ for _ in range(10):
+ yield 1.5 # 15 seconds in total
+ while True:
+ yield 3.0
+
+
+@async_to_sync
+async def wait_for_parallel_celery_group(task: Any, max_timeout: Optional[datetime.timedelta] = None) -> Any:
"""
Wait for a group of celery tasks to finish, but don't wait longer than max_timeout.
For parallel tasks, this is the only way to await the entire group.
@@ -1192,10 +1212,36 @@ def wait_for_parallel_celery_group(task: Any, max_timeout: Optional[datetime.tim
start_time = timezone.now()
+ sleep_generator = sleep_time_generator()
+
while not task.ready():
if timezone.now() - start_time > max_timeout:
+ child_states = []
+ child: AsyncResult
+ for child in task.children:
+ child_states.append(child.state)
+ # this child should not be retried...
+ if child.state in ["PENDING", "STARTED"]:
+ # terminating here terminates the process not the task
+ # but if the task is in PENDING or STARTED after 10 minutes
+ # we have to assume the celery process isn't processing another task
+ # see: https://docs.celeryq.dev/en/stable/userguide/workers.html#revoke-revoking-tasks
+ # and: https://docs.celeryq.dev/en/latest/reference/celery.result.html
+ # we terminate the process to avoid leaking an instance of Chrome
+ child.revoke(terminate=True)
+
+ logger.error(
+ "Timed out waiting for celery task to finish",
+ ready=task.ready(),
+ successful=task.successful(),
+ failed=task.failed(),
+ child_states=child_states,
+ timeout=max_timeout,
+ start_time=start_time,
+ )
raise TimeoutError("Timed out waiting for celery task to finish")
- time.sleep(0.1)
+
+ await asyncio.sleep(next(sleep_generator))
return task
diff --git a/posthog/utils_cors.py b/posthog/utils_cors.py
index 88db42c9df77b..0c4f6cb52765a 100644
--- a/posthog/utils_cors.py
+++ b/posthog/utils_cors.py
@@ -6,6 +6,8 @@
"request-context",
"x-amzn-trace-id",
"x-cloud-trace-context",
+ "Sentry-Trace",
+ "Baggage",
)
diff --git a/requirements.in b/requirements.in
index 62f9b7a8a6f1c..673f7e045a34e 100644
--- a/requirements.in
+++ b/requirements.in
@@ -5,9 +5,10 @@
# - `pip-compile --rebuild requirements-dev.in`
#
aiohttp>=3.8.4
+aioboto3==11.1
antlr4-python3-runtime==4.13.0
amqp==5.1.1
-boto3==1.26.66
+boto3==1.26.76
boto3-stubs[s3]
brotli==1.1.0
celery==5.3.4
diff --git a/requirements.txt b/requirements.txt
index 403e3f597cf12..beb9261497bb4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,11 +4,18 @@
#
# pip-compile requirements.in
#
+aioboto3==11.1
+ # via -r requirements.in
+aiobotocore[boto3]==2.5.0
+ # via aioboto3
aiohttp==3.8.5
# via
# -r requirements.in
+ # aiobotocore
# geoip2
# openai
+aioitertools==0.11.0
+ # via aiobotocore
aiosignal==1.2.0
# via aiohttp
amqp==5.1.1
@@ -43,12 +50,15 @@ backoff==2.2.1
# via posthoganalytics
billiard==4.1.0
# via celery
-boto3==1.26.66
- # via -r requirements.in
+boto3==1.26.76
+ # via
+ # -r requirements.in
+ # aiobotocore
boto3-stubs[s3]==1.26.138
# via -r requirements.in
-botocore==1.29.66
+botocore==1.29.76
# via
+ # aiobotocore
# boto3
# s3transfer
botocore-stubs==1.29.130
@@ -539,6 +549,8 @@ webdriver-manager==4.0.1
# via -r requirements.in
whitenoise==6.5.0
# via -r requirements.in
+wrapt==1.15.0
+ # via aiobotocore
wsproto==1.1.0
# via trio-websocket
xmlsec==1.3.13