diff --git a/.github/workflows/ci-plugin-server.yml b/.github/workflows/ci-plugin-server.yml index a62bd4a66851a..55b071a49b02a 100644 --- a/.github/workflows/ci-plugin-server.yml +++ b/.github/workflows/ci-plugin-server.yml @@ -44,7 +44,7 @@ jobs: - 'plugin-server/**' - 'posthog/clickhouse/migrations/**' - 'ee/migrations/**' - - 'ee/management/commands/setup_test_environment.py' + - 'posthog/management/commands/setup_test_environment.py' - 'posthog/migrations/**' - 'posthog/plugins/**' - 'docker*.yml' diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a6ab35daa8d3b..3624e6c028c0f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,6 +34,9 @@ jobs: - '.github/workflows/rust.yml' - '.github/workflows/rust-docker-build.yml' - '.github/workflows/rust-hook-migrator-docker.yml' + - 'posthog/management/commands/setup_test_environment.py' + - 'posthog/migrations/**' + - 'ee/migrations/**' build: name: Build rust services @@ -73,6 +76,11 @@ jobs: test: name: Test rust services + strategy: + matrix: + package: + - feature-flags + - others needs: changes runs-on: depot-ubuntu-22.04-4 timeout-minutes: 10 @@ -86,11 +94,15 @@ jobs: # Use sparse checkout to only select files in rust directory # Turning off cone mode ensures that files in the project root are not included during checkout - uses: actions/checkout@v3 - if: needs.changes.outputs.rust == 'true' + if: needs.changes.outputs.rust == 'true' && matrix.package == 'others' with: sparse-checkout: 'rust/' sparse-checkout-cone-mode: false + # For flags checkout entire repository + - uses: actions/checkout@v3 + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + - name: Login to DockerHub if: needs.changes.outputs.rust == 'true' uses: docker/login-action@v2 @@ -99,8 +111,15 @@ jobs: username: posthog password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Setup main repo dependencies for flags + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + docker compose -f ../docker-compose.dev.yml down + docker compose -f ../docker-compose.dev.yml up -d + echo "127.0.0.1 kafka" | sudo tee -a /etc/hosts + - name: Setup dependencies - if: needs.changes.outputs.rust == 'true' + if: needs.changes.outputs.rust == 'true' && matrix.package == 'others' run: | docker compose up kafka redis db echo_server -d --wait docker compose up setup_test_db @@ -119,9 +138,46 @@ jobs: rust/target key: ${ runner.os }-cargo-debug-${{ hashFiles('**/Cargo.lock') }} + - name: Set up Python + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + uses: actions/setup-python@v5 + with: + python-version: 3.11.9 + cache: 'pip' + cache-dependency-path: '**/requirements*.txt' + token: ${{ secrets.POSTHOG_BOT_GITHUB_TOKEN }} + + # uv is a fast pip alternative: https://github.com/astral-sh/uv/ + - run: pip install uv + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + + - name: Install SAML (python3-saml) dependencies + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + sudo apt-get update + sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl + + - name: Install python dependencies + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + uv pip install --system -r ../requirements-dev.txt + uv pip install --system -r ../requirements.txt + + - name: Set up databases + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + env: + DEBUG: 'true' + TEST: 'true' + SECRET_KEY: 'abcdef' # unsafe - for testing only + DATABASE_URL: 'postgres://posthog:posthog@localhost:5432/posthog' + run: cd ../ && python manage.py setup_test_environment --only-postgres + - name: Run cargo test if: needs.changes.outputs.rust == 'true' - run: cargo test --all-features + run: | + echo "Starting cargo test" + cargo test --all-features ${{ matrix.package == 'feature-flags' && '--package feature-flags' || '--workspace --exclude feature-flags' }} + echo "Cargo test completed" linting: name: Lint rust services diff --git a/.storybook/test-runner.ts b/.storybook/test-runner.ts index 24cb830f76ddd..026a3f786ada3 100644 --- a/.storybook/test-runner.ts +++ b/.storybook/test-runner.ts @@ -161,9 +161,8 @@ async function expectStoryToMatchSnapshot( }) // Wait for all images to load - await page.waitForFunction(() => Array.from(document.images).every((i: HTMLImageElement) => i.complete)) await waitForPageReady(page) - await page.waitForLoadState('networkidle') + await page.waitForFunction(() => Array.from(document.images).every((i: HTMLImageElement) => !!i.naturalWidth)) await page.waitForTimeout(2000) await check(page, context, browser, 'light', storyContext.parameters?.testOptions?.snapshotTargetSelector) @@ -174,9 +173,8 @@ async function expectStoryToMatchSnapshot( }) // Wait for all images to load - await page.waitForFunction(() => Array.from(document.images).every((i: HTMLImageElement) => i.complete)) await waitForPageReady(page) - await page.waitForLoadState('networkidle') + await page.waitForFunction(() => Array.from(document.images).every((i: HTMLImageElement) => !!i.naturalWidth)) await page.waitForTimeout(100) await check(page, context, browser, 'dark', storyContext.parameters?.testOptions?.snapshotTargetSelector) diff --git a/frontend/__snapshots__/components-cards-insight-details--trends-world-map--dark.png b/frontend/__snapshots__/components-cards-insight-details--trends-world-map--dark.png index 15b60cc25575c..6b08cf6fd3d08 100644 Binary files a/frontend/__snapshots__/components-cards-insight-details--trends-world-map--dark.png and b/frontend/__snapshots__/components-cards-insight-details--trends-world-map--dark.png differ diff --git a/frontend/__snapshots__/components-cards-insight-details--trends-world-map--light.png b/frontend/__snapshots__/components-cards-insight-details--trends-world-map--light.png index 06e1bfb343f04..975a66aad5e55 100644 Binary files a/frontend/__snapshots__/components-cards-insight-details--trends-world-map--light.png and b/frontend/__snapshots__/components-cards-insight-details--trends-world-map--light.png differ diff --git a/frontend/__snapshots__/scenes-app-dashboards--insight-legend--dark.png b/frontend/__snapshots__/scenes-app-dashboards--insight-legend--dark.png index d7d9398cd00cc..9d811913930ca 100644 Binary files a/frontend/__snapshots__/scenes-app-dashboards--insight-legend--dark.png and b/frontend/__snapshots__/scenes-app-dashboards--insight-legend--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-errortracking--group-page--dark.png b/frontend/__snapshots__/scenes-app-errortracking--group-page--dark.png index 4a2a4bfe03265..6ced54136d89e 100644 Binary files a/frontend/__snapshots__/scenes-app-errortracking--group-page--dark.png and b/frontend/__snapshots__/scenes-app-errortracking--group-page--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-errortracking--group-page--light.png b/frontend/__snapshots__/scenes-app-errortracking--group-page--light.png index e771dbfc37400..f84c4b3332880 100644 Binary files a/frontend/__snapshots__/scenes-app-errortracking--group-page--light.png and b/frontend/__snapshots__/scenes-app-errortracking--group-page--light.png differ diff --git a/frontend/__snapshots__/scenes-app-errortracking--list-page--dark.png b/frontend/__snapshots__/scenes-app-errortracking--list-page--dark.png index 3108b395d78fd..fda158a90c73f 100644 Binary files a/frontend/__snapshots__/scenes-app-errortracking--list-page--dark.png and b/frontend/__snapshots__/scenes-app-errortracking--list-page--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-errortracking--list-page--light.png b/frontend/__snapshots__/scenes-app-errortracking--list-page--light.png index 3a44c91f3b344..1070b31f74596 100644 Binary files a/frontend/__snapshots__/scenes-app-errortracking--list-page--light.png and b/frontend/__snapshots__/scenes-app-errortracking--list-page--light.png differ diff --git a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png index 1912f55d0ffa8..d83ed27fe8dca 100644 Binary files a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png and b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--light.png b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--light.png index 4ca21748c10e4..01df2b71a3a80 100644 Binary files a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--light.png and b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--light.png differ diff --git a/frontend/public/services/mysql.png b/frontend/public/services/mysql.png new file mode 100644 index 0000000000000..923f456fd3fbd Binary files /dev/null and b/frontend/public/services/mysql.png differ diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 648ea4aba4478..6bb3b9159b040 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -8,7 +8,13 @@ import posthog from 'posthog-js' import { SavedSessionRecordingPlaylistsResult } from 'scenes/session-recordings/saved-playlists/savedSessionRecordingPlaylistsLogic' import { getCurrentExporterData } from '~/exporter/exporterViewLogic' -import { DatabaseSerializedFieldType, QuerySchema, QueryStatusResponse, RefreshType } from '~/queries/schema' +import { + DatabaseSerializedFieldType, + ErrorTrackingGroup, + QuerySchema, + QueryStatusResponse, + RefreshType, +} from '~/queries/schema' import { ActionType, ActivityScope, @@ -658,6 +664,15 @@ class ApiRequest { return this.surveys(teamId).addPathComponent('activity') } + // Error tracking + public errorTracking(teamId?: TeamType['id']): ApiRequest { + return this.projectsDetail(teamId).addPathComponent('error_tracking') + } + + public errorTrackingGroup(fingerprint: ErrorTrackingGroup['fingerprint'], teamId?: TeamType['id']): ApiRequest { + return this.errorTracking(teamId).addPathComponent(fingerprint) + } + // # Warehouse public dataWarehouseTables(teamId?: TeamType['id']): ApiRequest { return this.projectsDetail(teamId).addPathComponent('warehouse_tables') @@ -1698,6 +1713,15 @@ const api = { }, }, + errorTracking: { + async update( + fingerprint: ErrorTrackingGroup['fingerprint'], + data: Partial> + ): Promise { + return await new ApiRequest().errorTrackingGroup(fingerprint).update({ data }) + }, + }, + recordings: { async list(params: Record): Promise { return await new ApiRequest().recordings().withQueryString(toParams(params)).get() diff --git a/frontend/src/lib/components/ActivityLog/activityLogLogic.insight.test.tsx b/frontend/src/lib/components/ActivityLog/activityLogLogic.insight.test.tsx index fc98560ef7f20..c86e548e0b8ba 100644 --- a/frontend/src/lib/components/ActivityLog/activityLogLogic.insight.test.tsx +++ b/frontend/src/lib/components/ActivityLog/activityLogLogic.insight.test.tsx @@ -4,6 +4,7 @@ import { render } from '@testing-library/react' import { MOCK_TEAM_ID } from 'lib/api.mock' import { makeTestSetup } from 'lib/components/ActivityLog/activityLogLogic.test.setup' +import { BreakdownFilter } from '~/queries/schema' import { ActivityScope } from '~/types' jest.mock('lib/colors') @@ -84,81 +85,104 @@ describe('the activity log logic', () => { }) it('can handle change of insight query', async () => { - const logic = await insightTestSetup('test insight', 'updated', [ - { - type: ActivityScope.INSIGHT, - action: 'changed', - field: 'query', - after: { - kind: 'TrendsQuery', - properties: { - type: 'AND', - values: [ - { - type: 'OR', - values: [ - { - type: 'event', - key: '$current_url', - operator: 'exact', - value: ['https://hedgebox.net/files/'], - }, - { - type: 'event', - key: '$geoip_country_code', - operator: 'exact', - value: ['US', 'AU'], - }, - ], - }, - ], - }, - filterTestAccounts: false, - interval: 'day', - dateRange: { - date_from: '-7d', - }, - series: [ + const insightMock = { + type: ActivityScope.INSIGHT, + action: 'changed', + field: 'query', + after: { + kind: 'TrendsQuery', + properties: { + type: 'AND', + values: [ { - kind: 'EventsNode', - name: '$pageview', - custom_name: 'Views', - event: '$pageview', - properties: [ + type: 'OR', + values: [ { type: 'event', - key: '$browser', + key: '$current_url', operator: 'exact', - value: 'Chrome', + value: ['https://hedgebox.net/files/'], }, { - type: 'cohort', - key: 'id', - value: 2, + type: 'event', + key: '$geoip_country_code', + operator: 'exact', + value: ['US', 'AU'], }, ], - limit: 100, }, ], - trendsFilter: { - display: 'ActionsAreaGraph', - }, - breakdownFilter: { - breakdown: '$geoip_country_code', - breakdown_type: 'event', + }, + filterTestAccounts: false, + interval: 'day', + dateRange: { + date_from: '-7d', + }, + series: [ + { + kind: 'EventsNode', + name: '$pageview', + custom_name: 'Views', + event: '$pageview', + properties: [ + { + type: 'event', + key: '$browser', + operator: 'exact', + value: 'Chrome', + }, + { + type: 'cohort', + key: 'id', + value: 2, + }, + ], + limit: 100, }, + ], + trendsFilter: { + display: 'ActionsAreaGraph', + }, + breakdownFilter: { + breakdown: '$geoip_country_code', + breakdown_type: 'event', }, }, - ]) - const actual = logic.values.humanizedActivity + } - const renderedDescription = render(<>{actual[0].description}).container + let logic = await insightTestSetup('test insight', 'updated', [insightMock as any]) + let actual = logic.values.humanizedActivity + + let renderedDescription = render(<>{actual[0].description}).container expect(renderedDescription).toHaveTextContent('peter changed query definition on test insight') - const renderedExtendedDescription = render(<>{actual[0].extendedDescription}).container + let renderedExtendedDescription = render(<>{actual[0].extendedDescription}).container expect(renderedExtendedDescription).toHaveTextContent( "Query summaryAShowing \"Views\"Pageviewcounted by total countwhere event'sBrowser= equals Chromeand person belongs to cohortID 2FiltersEvent'sCurrent URL= equals https://hedgebox.net/files/or event'sCountry Code= equals US or AUBreakdown byCountry Code" ) + ;(insightMock.after.breakdownFilter as BreakdownFilter) = { + breakdowns: [ + { + property: '$geoip_country_code', + type: 'event', + }, + { + property: '$session_duration', + type: 'session', + }, + ], + } + + logic = await insightTestSetup('test insight', 'updated', [insightMock as any]) + actual = logic.values.humanizedActivity + + renderedDescription = render(<>{actual[0].description}).container + expect(renderedDescription).toHaveTextContent('peter changed query definition on test insight') + + renderedExtendedDescription = render(<>{actual[0].extendedDescription}).container + expect(renderedExtendedDescription).toHaveTextContent( + "Query summaryAShowing \"Views\"Pageviewcounted by total countwhere event'sBrowser= equals Chromeand person belongs to cohortID 2FiltersEvent'sCurrent URL= equals https://hedgebox.net/files/or event'sCountry Code= equals US or AUBreakdown byCountry CodeSession duration" + ) }) it('can handle change of filters on a retention graph', async () => { diff --git a/frontend/src/lib/components/ActivityLog/complex.sql b/frontend/src/lib/components/ActivityLog/complex.sql new file mode 100644 index 0000000000000..377778eff333d --- /dev/null +++ b/frontend/src/lib/components/ActivityLog/complex.sql @@ -0,0 +1,27 @@ +SELECT + count() AS total, + toStartOfDay(min_timestamp) AS day_start, + breakdown_value AS breakdown_value +FROM + (SELECT + min(timestamp) AS min_timestamp, + argMin(breakdown_value, timestamp) AS breakdown_value + FROM + (SELECT + person_id, + timestamp, + ifNull(nullIf(toString(properties.$browser), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value + FROM + events AS e SAMPLE 1 + WHERE + and(equals(event, '$pageview'), lessOrEquals(timestamp, assumeNotNull(toDateTime('2025-01-20 23:59:59')))) + ) + GROUP BY + person_id + ) +WHERE + greaterOrEquals(min_timestamp, toStartOfDay(assumeNotNull(toDateTime('2020-01-09 00:00:00')))) +GROUP BY + day_start, + breakdown_value +LIMIT 50000 diff --git a/frontend/src/lib/components/ActivityLog/full_query.sql b/frontend/src/lib/components/ActivityLog/full_query.sql new file mode 100644 index 0000000000000..b7129f83c64f4 --- /dev/null +++ b/frontend/src/lib/components/ActivityLog/full_query.sql @@ -0,0 +1,41 @@ +SELECT + arrayMap(number -> plus(toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00'))), toStartOfDay(assumeNotNull(toDateTime('2024-07-23 23:59:59'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(count), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> equals(x, _match_date), _days_for_count), _index), 1))), date) AS total +FROM + (SELECT + sum(total) AS count, + day_start + FROM (SELECT + count() AS total, + day_start, + breakdown_value + FROM ( + SELECT + min(timestamp) as day_start, + argMin(breakdown_value, timestamp) AS breakdown_value, + FROM + ( + SELECT + person_id, + toStartOfDay(timestamp) AS timestamp, + ifNull(nullIf(toString(person.properties.email), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value + FROM + events AS e SAMPLE 1 + WHERE + and(lessOrEquals(timestamp, assumeNotNull(toDateTime('2024-07-23 23:59:59'))), equals(properties.$browser, 'Safari')) + ) + WHERE + greaterOrEquals(timestamp, toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00')))) + GROUP BY + person_id + ) + GROUP BY + day_start, + breakdown_value) + GROUP BY + day_start + ORDER BY + day_start ASC) +ORDER BY + arraySum(total) DESC +LIMIT 50000 diff --git a/frontend/src/lib/components/ActivityLog/weekly.sql b/frontend/src/lib/components/ActivityLog/weekly.sql new file mode 100644 index 0000000000000..579249842508c --- /dev/null +++ b/frontend/src/lib/components/ActivityLog/weekly.sql @@ -0,0 +1,45 @@ +SELECT + arrayMap(number -> plus(toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00'))), toStartOfDay(assumeNotNull(toDateTime('2024-07-23 23:59:59'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(count), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> equals(x, _match_date), _days_for_count), _index), 1))), date) AS total +FROM + (SELECT + sum(total) AS count, + day_start + FROM + (SELECT + counts AS total, + toStartOfDay(timestamp) AS day_start + FROM + (SELECT + d.timestamp, + count(DISTINCT actor_id) AS counts + FROM + (SELECT + minus(toStartOfDay(assumeNotNull(toDateTime('2024-07-23 23:59:59'))), toIntervalDay(number)) AS timestamp + FROM + numbers(dateDiff('day', minus(toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00'))), toIntervalDay(7)), assumeNotNull(toDateTime('2024-07-23 23:59:59')))) AS numbers) AS d + CROSS JOIN (SELECT + timestamp AS timestamp, + e.person_id AS actor_id + FROM + events AS e SAMPLE 1 + WHERE + and(equals(event, '$pageview'), greaterOrEquals(timestamp, minus(assumeNotNull(toDateTime('2024-07-16 00:00:00')), toIntervalDay(7))), lessOrEquals(timestamp, assumeNotNull(toDateTime('2024-07-23 23:59:59')))) + GROUP BY + timestamp, + actor_id) AS e + WHERE + and(lessOrEquals(e.timestamp, plus(d.timestamp, toIntervalDay(1))), greater(e.timestamp, minus(d.timestamp, toIntervalDay(6)))) + GROUP BY + d.timestamp + ORDER BY + d.timestamp ASC) + WHERE + and(greaterOrEquals(timestamp, toStartOfDay(assumeNotNull(toDateTime('2024-07-16 00:00:00')))), lessOrEquals(timestamp, assumeNotNull(toDateTime('2024-07-23 23:59:59'))))) + GROUP BY + day_start + ORDER BY + day_start ASC) +ORDER BY + arraySum(total) DESC +LIMIT 50000 diff --git a/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx b/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx index 3c5d5636e29d1..9a0038556dea6 100644 --- a/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx +++ b/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx @@ -39,6 +39,7 @@ import { isLifecycleQuery, isPathsQuery, isTrendsQuery, + isValidBreakdown, } from '~/queries/utils' import { AnyPropertyFilter, @@ -155,11 +156,7 @@ function SeriesDisplay({ const { mathDefinitions } = useValues(mathsLogic) const filter = query.series[seriesIndex] - const hasBreakdown = - isInsightQueryWithBreakdown(query) && - query.breakdownFilter != null && - query.breakdownFilter.breakdown_type != null && - query.breakdownFilter.breakdown != null + const hasBreakdown = isInsightQueryWithBreakdown(query) && isValidBreakdown(query.breakdownFilter) const mathDefinition = mathDefinitions[ isLifecycleQuery(query) @@ -338,25 +335,26 @@ export function LEGACY_FilterBasedBreakdownSummary({ filters }: { filters: Parti } export function BreakdownSummary({ query }: { query: InsightQueryNode }): JSX.Element | null { - if ( - !isInsightQueryWithBreakdown(query) || - !query.breakdownFilter || - query.breakdownFilter.breakdown_type == null || - query.breakdownFilter.breakdown == null - ) { + if (!isInsightQueryWithBreakdown(query) || !isValidBreakdown(query.breakdownFilter)) { return null } - const { breakdown_type, breakdown } = query.breakdownFilter - const breakdownArray = Array.isArray(breakdown) ? breakdown : [breakdown] + const { breakdown_type, breakdown, breakdowns } = query.breakdownFilter return ( <>
Breakdown by
- {breakdownArray.map((b) => ( - - ))} + {Array.isArray(breakdowns) + ? breakdowns.map((b) => ( + + )) + : breakdown && + (Array.isArray(breakdown) + ? breakdown + : [breakdown].map((b) => ( + + )))}
) diff --git a/frontend/src/lib/components/DatabaseTableTree/TreeRow.tsx b/frontend/src/lib/components/DatabaseTableTree/TreeRow.tsx index 099b77ff9d89f..d50c84c42875d 100644 --- a/frontend/src/lib/components/DatabaseTableTree/TreeRow.tsx +++ b/frontend/src/lib/components/DatabaseTableTree/TreeRow.tsx @@ -1,5 +1,6 @@ import { IconChevronDown, IconEllipsis } from '@posthog/icons' import { LemonButton, Spinner } from '@posthog/lemon-ui' +import { copyToClipboard } from 'lib/utils/copyToClipboard' import { useCallback, useState } from 'react' import { DatabaseSchemaTable } from '~/queries/schema' @@ -13,10 +14,17 @@ export interface TreeRowProps { selected?: boolean } -export function TreeRow({ item, selected }: TreeRowProps): JSX.Element { +export function TreeRow({ item }: TreeRowProps): JSX.Element { return (
  • - {item.icon} : null}> + { + void copyToClipboard(item.name, item.name) + }} + size="xsmall" + fullWidth + icon={item.icon ? <>{item.icon} : null} + > {item.name} {item.type} diff --git a/frontend/src/lib/components/MemberSelect.tsx b/frontend/src/lib/components/MemberSelect.tsx index c1a919dbf6b68..853aa4ae11ba1 100644 --- a/frontend/src/lib/components/MemberSelect.tsx +++ b/frontend/src/lib/components/MemberSelect.tsx @@ -1,4 +1,4 @@ -import { LemonButton, LemonDropdown, LemonInput, ProfilePicture } from '@posthog/lemon-ui' +import { LemonButton, LemonDropdown, LemonDropdownProps, LemonInput, ProfilePicture } from '@posthog/lemon-ui' import { useActions, useValues } from 'kea' import { fullName } from 'lib/utils' import { useEffect, useMemo, useState } from 'react' @@ -9,11 +9,12 @@ import { UserBasicType } from '~/types' export type MemberSelectProps = { defaultLabel?: string // NOTE: Trying to cover a lot of different cases - if string we assume uuid, if number we assume id - value: UserBasicType | string | number | null + value: string | number | null onChange: (value: UserBasicType | null) => void + children?: (selectedUser: UserBasicType | null) => LemonDropdownProps['children'] } -export function MemberSelect({ defaultLabel = 'Any user', value, onChange }: MemberSelectProps): JSX.Element { +export function MemberSelect({ defaultLabel = 'Any user', value, onChange, children }: MemberSelectProps): JSX.Element { const { meFirstMembers, filteredMembers, search, membersLoading } = useValues(membersLogic) const { ensureAllMembersLoaded, setSearch } = useActions(membersLogic) const [showPopover, setShowPopover] = useState(false) @@ -22,11 +23,8 @@ export function MemberSelect({ defaultLabel = 'Any user', value, onChange }: Mem if (!value) { return null } - if (typeof value === 'string' || typeof value === 'number') { - const propToCompare = typeof value === 'string' ? 'uuid' : 'id' - return meFirstMembers.find((member) => member.user[propToCompare] === value)?.user ?? `${value}` - } - return value + const propToCompare = typeof value === 'string' ? 'uuid' : 'id' + return meFirstMembers.find((member) => member.user[propToCompare] === value)?.user ?? null }, [value, meFirstMembers]) const _onChange = (value: UserBasicType | null): void => { @@ -94,18 +92,20 @@ export function MemberSelect({ defaultLabel = 'Any user', value, onChange }: Mem } > - - {typeof selectedMemberAsUser === 'string' ? ( - selectedMemberAsUser - ) : selectedMemberAsUser ? ( - - {fullName(selectedMemberAsUser)} - {meFirstMembers[0].user.uuid === selectedMemberAsUser.uuid ? ` (you)` : ''} - - ) : ( - defaultLabel - )} - + {children ? ( + children(selectedMemberAsUser) + ) : ( + + {selectedMemberAsUser ? ( + + {fullName(selectedMemberAsUser)} + {meFirstMembers[0].user.uuid === selectedMemberAsUser.uuid ? ` (you)` : ''} + + ) : ( + defaultLabel + )} + + )} ) } diff --git a/frontend/src/lib/lemon-ui/LemonDialog/LemonDialog.tsx b/frontend/src/lib/lemon-ui/LemonDialog/LemonDialog.tsx index ffdf4a7c7fe3a..4efcb5666664c 100644 --- a/frontend/src/lib/lemon-ui/LemonDialog/LemonDialog.tsx +++ b/frontend/src/lib/lemon-ui/LemonDialog/LemonDialog.tsx @@ -131,7 +131,15 @@ export const LemonFormDialog = ({ }, []) return ( -
    + ): void => { + if (e.key === 'Enter' && primaryButton?.htmlType === 'submit' && isFormValid) { + void onSubmit(form) + } + }} + > ) diff --git a/frontend/src/lib/lemon-ui/LemonTable/LemonTableLink.tsx b/frontend/src/lib/lemon-ui/LemonTable/LemonTableLink.tsx index c5c1b4915af45..868754e9f19fd 100644 --- a/frontend/src/lib/lemon-ui/LemonTable/LemonTableLink.tsx +++ b/frontend/src/lib/lemon-ui/LemonTable/LemonTableLink.tsx @@ -5,7 +5,7 @@ export function LemonTableLink({ title, description, ...props -}: Pick & { +}: Pick & { title: JSX.Element | string description?: JSX.Element | string }): JSX.Element { diff --git a/frontend/src/queries/nodes/DataTable/DataTable.tsx b/frontend/src/queries/nodes/DataTable/DataTable.tsx index 308d53bc17f78..2eef1f4d5b282 100644 --- a/frontend/src/queries/nodes/DataTable/DataTable.tsx +++ b/frontend/src/queries/nodes/DataTable/DataTable.tsx @@ -168,7 +168,7 @@ export function DataTable({ uniqueKey, query, setQuery, context, cachedResults } ...columnsInLemonTable.map((key, index) => ({ dataIndex: key as any, ...renderColumnMeta(key, query, context), - render: function RenderDataTableColumn(_: any, { result, label }: DataTableRow) { + render: function RenderDataTableColumn(_: any, { result, label }: DataTableRow, recordIndex: number) { if (label) { if (index === (expandable ? 1 : 0)) { return { @@ -179,9 +179,9 @@ export function DataTable({ uniqueKey, query, setQuery, context, cachedResults } return { props: { colSpan: 0 } } } else if (result) { if (sourceFeatures.has(QueryFeature.resultIsArrayOfArrays)) { - return renderColumn(key, result[index], result, query, setQuery, context) + return renderColumn(key, result[index], result, recordIndex, query, setQuery, context) } - return renderColumn(key, result[key], result, query, setQuery, context) + return renderColumn(key, result[key], result, recordIndex, query, setQuery, context) } }, sorter: undefined, // using custom sorting code diff --git a/frontend/src/queries/nodes/DataTable/renderColumn.tsx b/frontend/src/queries/nodes/DataTable/renderColumn.tsx index 5eddb88cba390..d8c5da94857bc 100644 --- a/frontend/src/queries/nodes/DataTable/renderColumn.tsx +++ b/frontend/src/queries/nodes/DataTable/renderColumn.tsx @@ -26,6 +26,7 @@ export function renderColumn( key: string, value: any, record: Record | any[], + recordIndex: number, query: DataTableNode, setQuery?: (query: DataTableNode) => void, context?: QueryContext @@ -37,6 +38,26 @@ export function renderColumn( return } else if (value === errorColumn) { return Error + } else if (queryContextColumnName && queryContextColumn?.render) { + const Component = queryContextColumn?.render + return ( + + ) + } else if (context?.columns?.[key] && context?.columns?.[key].render) { + const Component = context?.columns?.[key]?.render + return Component ? ( + + ) : ( + String(value) + ) + } else if (typeof value === 'object' && Array.isArray(value) && value[0] === '__hx_tag') { + return renderHogQLX(value) } else if (value === null) { return ( @@ -45,11 +66,6 @@ export function renderColumn(
    ) - } else if (queryContextColumnName && queryContextColumn?.render) { - const Component = queryContextColumn?.render - return - } else if (typeof value === 'object' && Array.isArray(value) && value[0] === '__hx_tag') { - return renderHogQLX(value) } else if (isHogQLQuery(query.source)) { if (typeof value === 'string') { try { @@ -230,13 +246,10 @@ export function renderColumn( const columnName = trimQuotes(key.substring(16)) // 16 = "context.columns.".length const Component = context?.columns?.[columnName]?.render return Component ? ( - + ) : ( String(value) ) - } else if (context?.columns?.[key]) { - const Component = context?.columns?.[key]?.render - return Component ? : String(value) } else if (key === 'id' && (isPersonsNode(query.source) || isActorsQuery(query.source))) { return ( diff --git a/frontend/src/queries/utils.ts b/frontend/src/queries/utils.ts index 05e42c680b0e2..3062354c7df91 100644 --- a/frontend/src/queries/utils.ts +++ b/frontend/src/queries/utils.ts @@ -447,3 +447,14 @@ export function hogql(strings: TemplateStringsArray, ...values: any[]): string { return strings.reduce((acc, str, i) => acc + str + (i < strings.length - 1 ? formatHogQlValue(values[i]) : ''), '') } hogql.identifier = hogQlIdentifier + +/** + * Wether we have a valid `breakdownFilter` or not. + */ +export function isValidBreakdown(breakdownFilter?: BreakdownFilter | null): breakdownFilter is BreakdownFilter { + return !!( + breakdownFilter && + ((breakdownFilter.breakdown && breakdownFilter.breakdown_type) || + (breakdownFilter.breakdowns && breakdownFilter.breakdowns.length > 0)) + ) +} diff --git a/frontend/src/scenes/dashboard/DashboardInsightCardLegend.stories.tsx b/frontend/src/scenes/dashboard/DashboardInsightCardLegend.stories.tsx index c2cfb06a3ff07..5c9caef93e5d2 100644 --- a/frontend/src/scenes/dashboard/DashboardInsightCardLegend.stories.tsx +++ b/frontend/src/scenes/dashboard/DashboardInsightCardLegend.stories.tsx @@ -22,6 +22,7 @@ const meta: Meta = { mockDate: '2023-02-01', waitForSelector: '.InsightCard', }, + tags: ['test-skip'], // Flakey } export default meta diff --git a/frontend/src/scenes/data-warehouse/external/DataWarehouseExternalScene.tsx b/frontend/src/scenes/data-warehouse/external/DataWarehouseExternalScene.tsx index 20e0cee417cee..dcbadd261a832 100644 --- a/frontend/src/scenes/data-warehouse/external/DataWarehouseExternalScene.tsx +++ b/frontend/src/scenes/data-warehouse/external/DataWarehouseExternalScene.tsx @@ -32,9 +32,9 @@ export const humanFriendlyDataWarehouseTabName = (tab: DataWarehouseTab): string case DataWarehouseTab.Explore: return 'Explore' case DataWarehouseTab.ManagedSources: - return 'Managed Sources' + return 'Managed sources' case DataWarehouseTab.SelfManagedSources: - return 'Self-Managed Sources' + return 'Self-Managed sources' } } diff --git a/frontend/src/scenes/data-warehouse/external/DataWarehouseTables.tsx b/frontend/src/scenes/data-warehouse/external/DataWarehouseTables.tsx index d8c1e6114a913..83a1c9925bf3f 100644 --- a/frontend/src/scenes/data-warehouse/external/DataWarehouseTables.tsx +++ b/frontend/src/scenes/data-warehouse/external/DataWarehouseTables.tsx @@ -4,6 +4,7 @@ import { clsx } from 'clsx' import { BindLogic, useActions, useValues } from 'kea' import { router } from 'kea-router' import { DatabaseTableTree, TreeItem } from 'lib/components/DatabaseTableTree/DatabaseTableTree' +import { copyToClipboard } from 'lib/utils/copyToClipboard' import { useState } from 'react' import { insightDataLogic } from 'scenes/insights/insightDataLogic' import { insightLogic } from 'scenes/insights/insightLogic' @@ -97,6 +98,15 @@ export const DatabaseTableTreeWithItems = ({ inline }: DatabaseTableTreeProps): const dropdownOverlay = (table: DatabaseSchemaTable): JSX.Element => ( <> + { + void copyToClipboard(table.name, table.name) + }} + fullWidth + data-attr="schema-list-item-copy" + > + Copy table name + { selectRow(table) diff --git a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx index 8aad2919fc516..b0216397540e9 100644 --- a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx +++ b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx @@ -198,6 +198,137 @@ export const SOURCE_DETAILS: Record = { }, ], }, + MySQL: { + name: 'MySQL', + caption: ( + <> + Enter your MySQL/MariaDB credentials to automatically pull your MySQL data into the PostHog Data + warehouse. + + ), + fields: [ + { + name: 'host', + label: 'Host', + type: 'text', + required: true, + placeholder: 'localhost', + }, + { + name: 'port', + label: 'Port', + type: 'number', + required: true, + placeholder: '3306', + }, + { + name: 'dbname', + label: 'Database', + type: 'text', + required: true, + placeholder: 'mysql', + }, + { + name: 'user', + label: 'User', + type: 'text', + required: true, + placeholder: 'mysql', + }, + { + name: 'password', + label: 'Password', + type: 'password', + required: true, + placeholder: '', + }, + { + name: 'schema', + label: 'Schema', + type: 'text', + required: true, + placeholder: 'public', + }, + { + name: 'ssh-tunnel', + label: 'Use SSH tunnel?', + type: 'switch-group', + default: false, + fields: [ + { + name: 'host', + label: 'Tunnel host', + type: 'text', + required: true, + placeholder: 'localhost', + }, + { + name: 'port', + label: 'Tunnel port', + type: 'number', + required: true, + placeholder: '22', + }, + { + type: 'select', + name: 'auth_type', + label: 'Authentication type', + required: true, + defaultValue: 'password', + options: [ + { + label: 'Password', + value: 'password', + fields: [ + { + name: 'username', + label: 'Tunnel username', + type: 'text', + required: true, + placeholder: 'User1', + }, + { + name: 'password', + label: 'Tunnel password', + type: 'password', + required: true, + placeholder: '', + }, + ], + }, + { + label: 'Key pair', + value: 'keypair', + fields: [ + { + name: 'username', + label: 'Tunnel username', + type: 'text', + required: false, + placeholder: 'User1', + }, + { + name: 'private_key', + label: 'Tunnel private key', + type: 'textarea', + required: true, + placeholder: '', + }, + { + name: 'passphrase', + label: 'Tunnel passphrase', + type: 'password', + required: false, + placeholder: '', + }, + ], + }, + ], + }, + ], + }, + ], + }, Snowflake: { name: 'Snowflake', caption: ( diff --git a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx index 51153885c7bbb..0c97d1bd088f8 100644 --- a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx +++ b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx @@ -10,6 +10,7 @@ import Iconazure from 'public/services/azure.png' import IconCloudflare from 'public/services/cloudflare.png' import IconGoogleCloudStorage from 'public/services/google-cloud-storage.png' import IconHubspot from 'public/services/hubspot.png' +import IconMySQL from 'public/services/mysql.png' import IconPostgres from 'public/services/postgres.png' import IconSnowflake from 'public/services/snowflake.png' import IconStripe from 'public/services/stripe.png' @@ -187,6 +188,7 @@ export function RenderDataWarehouseSourceIcon({ Hubspot: IconHubspot, Zendesk: IconZendesk, Postgres: IconPostgres, + MySQL: IconMySQL, Snowflake: IconSnowflake, aws: IconAwsS3, 'google-cloud': IconGoogleCloudStorage, diff --git a/frontend/src/scenes/data-warehouse/settings/source/Schemas.tsx b/frontend/src/scenes/data-warehouse/settings/source/Schemas.tsx index 93f469198b53c..916bc13c99f41 100644 --- a/frontend/src/scenes/data-warehouse/settings/source/Schemas.tsx +++ b/frontend/src/scenes/data-warehouse/settings/source/Schemas.tsx @@ -70,14 +70,19 @@ export const SchemaTable = ({ schemas, isLoading }: SchemaTableProps): JSX.Eleme return ( updateSchema({ ...schema, sync_frequency: value as DataWarehouseSyncInterval }) } options={[ - { value: 'day' as DataWarehouseSyncInterval, label: 'Daily' }, - { value: 'week' as DataWarehouseSyncInterval, label: 'Weekly' }, - { value: 'month' as DataWarehouseSyncInterval, label: 'Monthly' }, + { value: '5min' as DataWarehouseSyncInterval, label: '5 mins' }, + { value: '30min' as DataWarehouseSyncInterval, label: '30 mins' }, + { value: '1hour' as DataWarehouseSyncInterval, label: '1 hour' }, + { value: '6hour' as DataWarehouseSyncInterval, label: '6 hours' }, + { value: '12hour' as DataWarehouseSyncInterval, label: '12 hours' }, + { value: '24hour' as DataWarehouseSyncInterval, label: 'Daily' }, + { value: '7day' as DataWarehouseSyncInterval, label: 'Weekly' }, + { value: '30day' as DataWarehouseSyncInterval, label: 'Monthly' }, ]} /> ) diff --git a/frontend/src/scenes/error-tracking/ErrorTrackingActions.tsx b/frontend/src/scenes/error-tracking/ErrorTrackingActions.tsx new file mode 100644 index 0000000000000..d5bdbf4889d54 --- /dev/null +++ b/frontend/src/scenes/error-tracking/ErrorTrackingActions.tsx @@ -0,0 +1,62 @@ +import { LemonSelect } from '@posthog/lemon-ui' +import { useActions, useValues } from 'kea' +import { DateFilter } from 'lib/components/DateFilter/DateFilter' + +import { errorTrackingLogic } from './errorTrackingLogic' +import { errorTrackingSceneLogic } from './errorTrackingSceneLogic' + +export const ErrorTrackingActions = ({ showOrder = true }: { showOrder?: boolean }): JSX.Element => { + const { dateRange } = useValues(errorTrackingLogic) + const { setDateRange } = useActions(errorTrackingLogic) + const { order } = useValues(errorTrackingSceneLogic) + const { setOrder } = useActions(errorTrackingSceneLogic) + + return ( +
    +
    + Date range: + { + setDateRange({ date_from: changedDateFrom, date_to: changedDateTo }) + }} + size="small" + /> +
    + {showOrder && ( +
    + Sort by: + +
    + )} +
    + ) +} diff --git a/frontend/src/scenes/error-tracking/ErrorTrackingFilters.tsx b/frontend/src/scenes/error-tracking/ErrorTrackingFilters.tsx index 4f10f2459020d..37d5007133f28 100644 --- a/frontend/src/scenes/error-tracking/ErrorTrackingFilters.tsx +++ b/frontend/src/scenes/error-tracking/ErrorTrackingFilters.tsx @@ -1,6 +1,4 @@ -import { LemonSelect } from '@posthog/lemon-ui' import { useActions, useValues } from 'kea' -import { DateFilter } from 'lib/components/DateFilter/DateFilter' import { TaxonomicFilterGroupType } from 'lib/components/TaxonomicFilter/types' import UniversalFilters from 'lib/components/UniversalFilters/UniversalFilters' import { universalFiltersLogic } from 'lib/components/UniversalFilters/universalFiltersLogic' @@ -9,76 +7,29 @@ import { useEffect, useState } from 'react' import { TestAccountFilter } from 'scenes/insights/filters/TestAccountFilter' import { errorTrackingLogic } from './errorTrackingLogic' -import { errorTrackingSceneLogic } from './errorTrackingSceneLogic' -export const ErrorTrackingFilters = ({ showOrder = true }: { showOrder?: boolean }): JSX.Element => { - const { dateRange, filterGroup, filterTestAccounts } = useValues(errorTrackingLogic) - const { setDateRange, setFilterGroup, setFilterTestAccounts } = useActions(errorTrackingLogic) - const { order } = useValues(errorTrackingSceneLogic) - const { setOrder } = useActions(errorTrackingSceneLogic) +export const ErrorTrackingFilters = (): JSX.Element => { + const { filterGroup, filterTestAccounts } = useValues(errorTrackingLogic) + const { setFilterGroup, setFilterTestAccounts } = useActions(errorTrackingLogic) return ( -
    -
    -
    - { - setDateRange({ date_from: changedDateFrom, date_to: changedDateTo }) - }} - size="small" - /> - {showOrder && ( - - )} -
    -
    - { - setFilterTestAccounts(filter_test_accounts || false) - }} - /> -
    -
    -
    - - - +
    + + + +
    + { + setFilterTestAccounts(filter_test_accounts || false) + }} + />
    ) diff --git a/frontend/src/scenes/error-tracking/ErrorTrackingGroupScene.tsx b/frontend/src/scenes/error-tracking/ErrorTrackingGroupScene.tsx index 668565b97495c..6009b8f03961e 100644 --- a/frontend/src/scenes/error-tracking/ErrorTrackingGroupScene.tsx +++ b/frontend/src/scenes/error-tracking/ErrorTrackingGroupScene.tsx @@ -1,9 +1,10 @@ import './ErrorTracking.scss' -import { LemonTabs } from '@posthog/lemon-ui' +import { LemonDivider, LemonTabs } from '@posthog/lemon-ui' import { useActions, useValues } from 'kea' import { SceneExport } from 'scenes/sceneTypes' +import { ErrorTrackingActions } from './ErrorTrackingActions' import { ErrorTrackingFilters } from './ErrorTrackingFilters' import { ErrorGroupTab, errorTrackingGroupSceneLogic } from './errorTrackingGroupSceneLogic' import { BreakdownsTab } from './groups/BreakdownsTab' @@ -12,8 +13,8 @@ import { OverviewTab } from './groups/OverviewTab' export const scene: SceneExport = { component: ErrorTrackingGroupScene, logic: errorTrackingGroupSceneLogic, - paramsToProps: ({ params: { id } }): (typeof errorTrackingGroupSceneLogic)['props'] => ({ - id, + paramsToProps: ({ params: { id: fingerprint } }): (typeof errorTrackingGroupSceneLogic)['props'] => ({ + fingerprint, }), } @@ -23,7 +24,9 @@ export function ErrorTrackingGroupScene(): JSX.Element { return (
    - + + + - errorTrackingQuery({ - order, - dateRange, - filterTestAccounts, - filterGroup, - sparklineSelectedPeriod, - }), - [order, dateRange, filterTestAccounts, filterGroup, sparklineSelectedPeriod] - ) + const insightProps: InsightLogicProps = { + dashboardItemId: 'new-ErrorTrackingQuery', + } const context: QueryContext = { columns: { @@ -42,16 +37,24 @@ export function ErrorTrackingScene(): JSX.Element { width: '50%', render: CustomGroupTitleColumn, }, + occurrences: { align: 'center' }, volume: { renderTitle: CustomVolumeColumnHeader }, + assignee: { render: AssigneeColumn, align: 'center' }, }, showOpenEditorButton: false, + insightProps: insightProps, + alwaysRefresh: true, } return ( -
    - - -
    + +
    + + + + +
    +
    ) } @@ -80,19 +83,52 @@ const CustomGroupTitleColumn: QueryContextColumnComponent = (props) => { const record = props.record as ErrorTrackingGroup return ( - -
    {record.description}
    -
    - - | - +
    + +
    {record.description}
    +
    + + | + +
    -
    - } - to={urls.errorTrackingGroup(record.fingerprint)} - /> + } + className="flex-1" + to={urls.errorTrackingGroup(record.fingerprint)} + /> +
    + ) +} + +const AssigneeColumn: QueryContextColumnComponent = (props) => { + const { assignGroup } = useActions(errorTrackingDataLogic) + + const record = props.record as ErrorTrackingGroup + + return ( + { + const assigneeId = user?.id || null + assignGroup(props.recordIndex, assigneeId) + }} + > + {(user) => ( + + ) : ( + + ) + } + /> + )} + ) } diff --git a/frontend/src/scenes/error-tracking/errorTrackingDataLogic.tsx b/frontend/src/scenes/error-tracking/errorTrackingDataLogic.tsx new file mode 100644 index 0000000000000..5f7ad4ee30345 --- /dev/null +++ b/frontend/src/scenes/error-tracking/errorTrackingDataLogic.tsx @@ -0,0 +1,44 @@ +import { actions, connect, kea, listeners, path, props } from 'kea' +import api from 'lib/api' + +import { dataNodeLogic, DataNodeLogicProps } from '~/queries/nodes/DataNode/dataNodeLogic' +import { ErrorTrackingGroup } from '~/queries/schema' + +import type { errorTrackingDataLogicType } from './errorTrackingDataLogicType' + +export interface ErrorTrackingDataLogicProps { + query: DataNodeLogicProps['query'] + key: DataNodeLogicProps['key'] +} + +export const errorTrackingDataLogic = kea([ + path(['scenes', 'error-tracking', 'errorTrackingDataLogic']), + props({} as ErrorTrackingDataLogicProps), + + connect(({ key, query }: ErrorTrackingDataLogicProps) => ({ + values: [dataNodeLogic({ key, query }), ['response']], + actions: [dataNodeLogic({ key, query }), ['setResponse']], + })), + + actions({ + assignGroup: (recordIndex: number, assigneeId: number | null) => ({ + recordIndex, + assigneeId, + }), + }), + + listeners(({ values, actions }) => ({ + assignGroup: async ({ recordIndex, assigneeId }) => { + const response = values.response + if (response) { + const params = { assignee: assigneeId } + const results = values.response?.results as ErrorTrackingGroup[] + const group = { ...results[recordIndex], ...params } + results.splice(recordIndex, 1, group) + // optimistically update local results + actions.setResponse({ ...response, results: results }) + await api.errorTracking.update(group.fingerprint, params) + } + }, + })), +]) diff --git a/frontend/src/scenes/error-tracking/errorTrackingGroupSceneLogic.ts b/frontend/src/scenes/error-tracking/errorTrackingGroupSceneLogic.ts index e5a07b73b03fd..8289707d37d1e 100644 --- a/frontend/src/scenes/error-tracking/errorTrackingGroupSceneLogic.ts +++ b/frontend/src/scenes/error-tracking/errorTrackingGroupSceneLogic.ts @@ -13,7 +13,7 @@ import { errorTrackingLogic } from './errorTrackingLogic' import { errorTrackingGroupQuery } from './queries' export interface ErrorTrackingGroupSceneLogicProps { - id: string + fingerprint: string } export enum ErrorGroupTab { @@ -61,7 +61,7 @@ export const errorTrackingGroupSceneLogic = kea { const response = await api.query( errorTrackingGroupQuery({ - fingerprint: props.id, + fingerprint: props.fingerprint, dateRange: values.dateRange, filterTestAccounts: values.filterTestAccounts, filterGroup: values.filterGroup, @@ -78,8 +78,8 @@ export const errorTrackingGroupSceneLogic = kea [p.id], - (id): Breadcrumb[] => { + (_, p) => [p.fingerprint], + (fingerprint): Breadcrumb[] => { return [ { key: Scene.ErrorTracking, @@ -87,8 +87,8 @@ export const errorTrackingGroupSceneLogic = kea ({ setErrorGroupTab: () => { - const searchParams = {} + const searchParams = router.values.searchParams if (values.errorGroupTab != ErrorGroupTab.Overview) { searchParams['tab'] = values.errorGroupTab diff --git a/frontend/src/scenes/error-tracking/errorTrackingSceneLogic.ts b/frontend/src/scenes/error-tracking/errorTrackingSceneLogic.ts index 5bc91d06a7206..145c047d59648 100644 --- a/frontend/src/scenes/error-tracking/errorTrackingSceneLogic.ts +++ b/frontend/src/scenes/error-tracking/errorTrackingSceneLogic.ts @@ -1,12 +1,18 @@ -import { actions, kea, path, reducers } from 'kea' +import { actions, connect, kea, path, reducers, selectors } from 'kea' -import { ErrorTrackingQuery } from '~/queries/schema' +import { DataTableNode, ErrorTrackingQuery } from '~/queries/schema' +import { errorTrackingLogic } from './errorTrackingLogic' import type { errorTrackingSceneLogicType } from './errorTrackingSceneLogicType' +import { errorTrackingQuery } from './queries' export const errorTrackingSceneLogic = kea([ path(['scenes', 'error-tracking', 'errorTrackingSceneLogic']), + connect({ + values: [errorTrackingLogic, ['dateRange', 'filterTestAccounts', 'filterGroup', 'sparklineSelectedPeriod']], + }), + actions({ setOrder: (order: ErrorTrackingQuery['order']) => ({ order }), }), @@ -19,4 +25,18 @@ export const errorTrackingSceneLogic = kea([ }, ], }), + + selectors({ + query: [ + (s) => [s.order, s.dateRange, s.filterTestAccounts, s.filterGroup, s.sparklineSelectedPeriod], + (order, dateRange, filterTestAccounts, filterGroup, sparklineSelectedPeriod): DataTableNode => + errorTrackingQuery({ + order, + dateRange, + filterTestAccounts, + filterGroup, + sparklineSelectedPeriod, + }), + ], + }), ]) diff --git a/frontend/src/scenes/error-tracking/queries.ts b/frontend/src/scenes/error-tracking/queries.ts index 9124197d14f08..f6f2a49998aa3 100644 --- a/frontend/src/scenes/error-tracking/queries.ts +++ b/frontend/src/scenes/error-tracking/queries.ts @@ -47,7 +47,7 @@ export const errorTrackingQuery = ({ }): DataTableNode => { const select: string[] = [] - const columns = ['error', 'occurrences', 'sessions', 'users'] + const columns = ['error', 'occurrences', 'sessions', 'users', 'assignee'] if (sparklineSelectedPeriod) { const { value, displayAs, offsetHours } = parseSparklineSelection(sparklineSelectedPeriod) diff --git a/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.test.ts b/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.test.ts index 1064db47dd981..a011a4a4e1eb6 100644 --- a/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.test.ts +++ b/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.test.ts @@ -2,7 +2,7 @@ import { expectLogic } from 'kea-test-utils' import { TaxonomicFilterGroup, TaxonomicFilterGroupType } from 'lib/components/TaxonomicFilter/types' import { initKeaTests } from '~/test/init' -import { InsightLogicProps } from '~/types' +import { ChartDisplayType, InsightLogicProps } from '~/types' import * as breakdownLogic from './taxonomicBreakdownFilterLogic' @@ -138,6 +138,32 @@ describe('taxonomicBreakdownFilterLogic', () => { }) }) + it('resets the map view when adding a next breakdown', async () => { + logic = taxonomicBreakdownFilterLogic({ + insightProps, + breakdownFilter: { + breakdown: '$geoip_country_code', + breakdown_type: 'person', + }, + isTrends: true, + display: ChartDisplayType.WorldMap, + updateBreakdownFilter, + updateDisplay, + }) + logic.mount() + const changedBreakdown = 'c' + const group: TaxonomicFilterGroup = taxonomicGroupFor(TaxonomicFilterGroupType.EventProperties, undefined) + + await expectLogic(logic, () => { + logic.actions.addBreakdown(changedBreakdown, group) + }).toFinishListeners() + + expect(updateBreakdownFilter).toHaveBeenCalledWith({ + breakdown_type: 'event', + breakdown: 'c', + }) + }) + it('sets a limit', async () => { logic = taxonomicBreakdownFilterLogic({ insightProps, @@ -700,6 +726,35 @@ describe('taxonomicBreakdownFilterLogic', () => { expect(updateBreakdownFilter.mock.calls[0][0]).toHaveProperty('breakdowns', undefined) }) + + it('resets the map view when adding a next breakdown', async () => { + const logic = taxonomicBreakdownFilterLogic({ + insightProps, + breakdownFilter: { + breakdowns: [{ property: '$geoip_country_code', type: 'person' }], + }, + isTrends: true, + display: ChartDisplayType.WorldMap, + updateBreakdownFilter, + updateDisplay, + }) + mockFeatureFlag(logic) + logic.mount() + const changedBreakdown = 'c' + const group: TaxonomicFilterGroup = taxonomicGroupFor(TaxonomicFilterGroupType.EventProperties, undefined) + + await expectLogic(logic, () => { + logic.actions.addBreakdown(changedBreakdown, group) + }).toFinishListeners() + + expect(updateBreakdownFilter).toHaveBeenCalledWith({ + breakdowns: [ + { property: '$geoip_country_code', type: 'person' }, + { property: 'c', type: 'event' }, + ], + }) + expect(updateDisplay).toHaveBeenCalledWith(undefined) + }) }) describe('single breakdown to multiple breakdowns', () => { diff --git a/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.ts b/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.ts index 291968d794bf9..d9ea96530c955 100644 --- a/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.ts +++ b/frontend/src/scenes/insights/filters/BreakdownFilter/taxonomicBreakdownFilterLogic.ts @@ -277,6 +277,15 @@ export const taxonomicBreakdownFilterLogic = kea { }, ], }, + '/api/projects/:team/dashboards/34/': { + id: 33, + filters: {}, + tiles: [ + { + layouts: {}, + color: null, + insight: { + id: 42, + short_id: Insight43, + result: 'result!', + filters: { insight: InsightType.TRENDS, interval: 'month' }, + tags: ['bla'], + }, + }, + ], + }, }, post: { '/api/projects/:team/insights/funnel/': { result: ['result from api'] }, @@ -513,14 +533,19 @@ describe('insightLogic', () => { }) test('saveInsight updates dashboards', async () => { + const dashLogic = dashboardLogic({ id: MOCK_DASHBOARD_ID }) + dashLogic.mount() + await expectLogic(dashLogic).toDispatchActions(['loadDashboard']) + savedInsightsLogic.mount() + logic = insightLogic({ dashboardItemId: Insight43, }) logic.mount() - logic.actions.saveInsight() - await expectLogic(dashboardsModel).toDispatchActions(['updateDashboardInsight']) + + await expectLogic(dashLogic).toDispatchActions(['loadDashboard']) }) test('updateInsight updates dashboards', async () => { diff --git a/frontend/src/scenes/insights/insightLogic.ts b/frontend/src/scenes/insights/insightLogic.ts index 88480775ad8b2..036b501f42f91 100644 --- a/frontend/src/scenes/insights/insightLogic.ts +++ b/frontend/src/scenes/insights/insightLogic.ts @@ -8,6 +8,7 @@ import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast' import { featureFlagLogic } from 'lib/logic/featureFlagLogic' import { objectsEqual } from 'lib/utils' import { eventUsageLogic, InsightEventSource } from 'lib/utils/eventUsageLogic' +import { dashboardLogic } from 'scenes/dashboard/dashboardLogic' import { insightSceneLogic } from 'scenes/insights/insightSceneLogic' import { keyForInsightLogicProps } from 'scenes/insights/sharedUtils' import { summarizeInsight } from 'scenes/insights/summarizeInsight' @@ -425,6 +426,16 @@ export const insightLogic = kea([ dashboardsModel.actions.updateDashboardInsight(savedInsight) + // reload dashboards with updated insight + // since filters on dashboard might be different from filters on insight + // we need to trigger dashboard reload to pick up results for updated insight + savedInsight.dashboard_tiles?.forEach(({ dashboard_id }) => + dashboardLogic.findMounted({ id: dashboard_id })?.actions.loadDashboard({ + action: 'update', + refresh: 'lazy_async', + }) + ) + const mountedInsightSceneLogic = insightSceneLogic.findMounted() if (redirectToViewMode) { if (!insightNumericId && dashboards?.length === 1) { diff --git a/frontend/src/scenes/insights/insightVizDataLogic.ts b/frontend/src/scenes/insights/insightVizDataLogic.ts index b8e2f852f9fbb..b3038bbe2f25a 100644 --- a/frontend/src/scenes/insights/insightVizDataLogic.ts +++ b/frontend/src/scenes/insights/insightVizDataLogic.ts @@ -179,7 +179,9 @@ export const insightVizDataLogic = kea([ isUsingSessionAnalysis: [ (s) => [s.series, s.breakdownFilter, s.properties], (series, breakdownFilter, properties) => { - const using_session_breakdown = breakdownFilter?.breakdown_type === 'session' + const using_session_breakdown = + breakdownFilter?.breakdown_type === 'session' || + breakdownFilter?.breakdowns?.find((breakdown) => breakdown.type === 'session') const using_session_math = series?.some((entity) => entity.math === 'unique_session') const using_session_property_math = series?.some((entity) => { // Should be made more generic is we ever add more session properties @@ -598,6 +600,11 @@ const handleQuerySourceUpdateSideEffects = ( mergedUpdate['properties'] = [] } + // Remove breakdown filter if display type is BoldNumber because it is not supported + if (kind === NodeKind.TrendsQuery && maybeChangedDisplay === ChartDisplayType.BoldNumber) { + mergedUpdate['breakdownFilter'] = null + } + // Don't allow minutes on anything other than Trends if ( currentState.kind == NodeKind.TrendsQuery && diff --git a/frontend/src/scenes/pipeline/AppMetricSparkLine.tsx b/frontend/src/scenes/pipeline/AppMetricSparkLine.tsx index ed14e30ec0fdd..ee4ff1dbfa5f8 100644 --- a/frontend/src/scenes/pipeline/AppMetricSparkLine.tsx +++ b/frontend/src/scenes/pipeline/AppMetricSparkLine.tsx @@ -4,7 +4,7 @@ import { useEffect } from 'react' import { pipelineNodeMetricsLogic } from './pipelineNodeMetricsLogic' import { pipelineNodeMetricsV2Logic } from './pipelineNodeMetricsV2Logic' -import { PipelineBackend, PipelineNode } from './types' +import { PipelineNode } from './types' export function AppMetricSparkLine({ pipelineNode }: { pipelineNode: PipelineNode }): JSX.Element { const logic = pipelineNodeMetricsLogic({ id: pipelineNode.id }) @@ -19,22 +19,28 @@ export function AppMetricSparkLine({ pipelineNode }: { pipelineNode: PipelineNod const displayData: SparklineTimeSeries[] = [ { color: 'success', - name: pipelineNode.backend == 'batch_export' ? 'Runs succeeded' : 'Events sent', + name: 'Success', values: successes, }, ] + if (appMetricsResponse?.metrics.failures.some((failure) => failure > 0)) { displayData.push({ color: 'danger', - name: pipelineNode.backend == 'batch_export' ? 'Runs failed' : 'Events dropped', + name: 'Failure', values: failures, }) } - if (pipelineNode.backend == PipelineBackend.HogFunction) { - return Coming soon - } - return + return ( + + ) } export function AppMetricSparkLineV2({ pipelineNode }: { pipelineNode: PipelineNode }): JSX.Element { @@ -59,5 +65,13 @@ export function AppMetricSparkLineV2({ pipelineNode }: { pipelineNode: PipelineN }, ] - return + return ( + + ) } diff --git a/frontend/src/scenes/pipeline/destinations/newDestinationsLogic.tsx b/frontend/src/scenes/pipeline/destinations/newDestinationsLogic.tsx index a0aecfca38bab..5b702e7718654 100644 --- a/frontend/src/scenes/pipeline/destinations/newDestinationsLogic.tsx +++ b/frontend/src/scenes/pipeline/destinations/newDestinationsLogic.tsx @@ -4,7 +4,9 @@ import { actions, afterMount, connect, kea, listeners, path, reducers, selectors import { loaders } from 'kea-loaders' import { actionToUrl, combineUrl, router, urlToAction } from 'kea-router' import api from 'lib/api' +import { FEATURE_FLAGS } from 'lib/constants' import { LemonField } from 'lib/lemon-ui/LemonField' +import { featureFlagLogic } from 'lib/logic/featureFlagLogic' import { objectsEqual } from 'lib/utils' import posthog from 'posthog-js' import { urls } from 'scenes/urls' @@ -42,7 +44,7 @@ export interface Fuse extends FuseClass {} export const newDestinationsLogic = kea([ connect({ - values: [userLogic, ['user']], + values: [userLogic, ['user'], featureFlagLogic, ['featureFlags']], }), path(() => ['scenes', 'pipeline', 'destinations', 'newDestinationsLogic']), actions({ @@ -104,10 +106,26 @@ export const newDestinationsLogic = kea([ }, ], destinations: [ - (s) => [s.plugins, s.hogFunctionTemplates, s.batchExportServiceNames, router.selectors.hashParams], - (plugins, hogFunctionTemplates, batchExportServiceNames, hashParams): NewDestinationItemType[] => { + (s) => [ + s.plugins, + s.hogFunctionTemplates, + s.batchExportServiceNames, + s.featureFlags, + router.selectors.hashParams, + ], + ( + plugins, + hogFunctionTemplates, + batchExportServiceNames, + featureFlags, + hashParams + ): NewDestinationItemType[] => { + const hogTemplates = featureFlags[FEATURE_FLAGS.HOG_FUNCTIONS] + ? Object.values(hogFunctionTemplates) + : [] + return [ - ...Object.values(hogFunctionTemplates).map((hogFunction) => ({ + ...hogTemplates.map((hogFunction) => ({ icon: , name: hogFunction.name, description: hogFunction.description, diff --git a/frontend/src/scenes/pipeline/hogfunctions/hogFunctionConfigurationLogic.tsx b/frontend/src/scenes/pipeline/hogfunctions/hogFunctionConfigurationLogic.tsx index 5b66a92a708b1..70d39a02e3e01 100644 --- a/frontend/src/scenes/pipeline/hogfunctions/hogFunctionConfigurationLogic.tsx +++ b/frontend/src/scenes/pipeline/hogfunctions/hogFunctionConfigurationLogic.tsx @@ -8,6 +8,7 @@ import api from 'lib/api' import { dayjs } from 'lib/dayjs' import { uuid } from 'lib/utils' import { deleteWithUndo } from 'lib/utils/deleteWithUndo' +import posthog from 'posthog-js' import { teamLogic } from 'scenes/teamLogic' import { urls } from 'scenes/urls' import { userLogic } from 'scenes/userLogic' @@ -184,6 +185,12 @@ export const hogFunctionConfigurationLogic = kea [s.configuration, s.currentTeam, s.groupTypes], (configuration, currentTeam, groupTypes): HogFunctionInvocationGlobals => { + const currentUrl = window.location.href.split('#')[0] const globals: HogFunctionInvocationGlobals = { event: { uuid: uuid(), @@ -312,7 +320,7 @@ export const hogFunctionConfigurationLogic = kea { case PipelineTab.Destinations: return 'Destinations' case PipelineTab.DataImport: - return 'Data Import' + return 'Data import' case PipelineTab.SiteApps: - return 'Site Apps' + return 'Site apps' case PipelineTab.ImportApps: return 'Legacy sources' case PipelineTab.AppsManagement: diff --git a/frontend/src/scenes/saved-insights/activityDescriptions.tsx b/frontend/src/scenes/saved-insights/activityDescriptions.tsx index 94403b56c65e2..cd7668905e5ad 100644 --- a/frontend/src/scenes/saved-insights/activityDescriptions.tsx +++ b/frontend/src/scenes/saved-insights/activityDescriptions.tsx @@ -21,7 +21,7 @@ import { urls } from 'scenes/urls' import { filtersToQueryNode } from '~/queries/nodes/InsightQuery/utils/filtersToQueryNode' import { queryNodeToFilter } from '~/queries/nodes/InsightQuery/utils/queryNodeToFilter' import { InsightQueryNode, QuerySchema, TrendsQuery } from '~/queries/schema' -import { isInsightQueryNode } from '~/queries/utils' +import { isInsightQueryNode, isValidBreakdown } from '~/queries/utils' import { FilterType, InsightModel, InsightShortId } from '~/types' const nameOrLinkToInsight = (short_id?: InsightShortId | null, name?: string | null): string | JSX.Element => { @@ -235,6 +235,7 @@ const insightActionsMapping: Record< function summarizeChanges(filtersAfter: Partial): ChangeMapping | null { const query = filtersToQueryNode(filtersAfter) + const trendsQuery = query as TrendsQuery return { description: ['changed query definition'], @@ -242,7 +243,7 @@ function summarizeChanges(filtersAfter: Partial): ChangeMapping | nu
    - {(query as TrendsQuery)?.breakdownFilter?.breakdown_type && } + {isValidBreakdown(trendsQuery?.breakdownFilter) && }
    ), } diff --git a/frontend/src/scenes/surveys/SurveyView.tsx b/frontend/src/scenes/surveys/SurveyView.tsx index c95426b72e236..8672ea7070b68 100644 --- a/frontend/src/scenes/surveys/SurveyView.tsx +++ b/frontend/src/scenes/surveys/SurveyView.tsx @@ -27,7 +27,7 @@ import { SurveyType, } from '~/types' -import { SURVEY_EVENT_NAME } from './constants' +import { SURVEY_EVENT_NAME, SurveyQuestionLabel } from './constants' import { SurveyDisplaySummary } from './Survey' import { SurveyAPIEditor } from './SurveyAPIEditor' import { SurveyFormAppearance } from './SurveyFormAppearance' @@ -293,11 +293,7 @@ export function SurveyView({ id }: { id: string }): JSX.Element { {survey.questions[0].question && ( <> Type - - {survey.questions.length > 1 - ? 'Multiple questions' - : capitalizeFirstLetter(survey.questions[0].type)} - + {SurveyQuestionLabel[survey.questions[0].type]} {pluralize( survey.questions.length, diff --git a/frontend/src/scenes/surveys/surveyActivityDescriber.test.tsx b/frontend/src/scenes/surveys/surveyActivityDescriber.test.tsx index cc9dca12c7a6c..5d56642390a89 100644 --- a/frontend/src/scenes/surveys/surveyActivityDescriber.test.tsx +++ b/frontend/src/scenes/surveys/surveyActivityDescriber.test.tsx @@ -255,7 +255,9 @@ describe('describeQuestionChanges', () => { ) expect(getTextContent(changes[1])).toBe('made question optional') expect(getTextContent(changes[2])).toBe('changed button text from "Next" to "Continue"') - expect(getTextContent(changes[3])).toBe('changed question type from single_choice to multiple_choice') + expect(getTextContent(changes[3])).toBe( + 'changed question type from Single choice select to Multiple choice select' + ) expect(getTextContent(changes[4])).toBe('added choices: Maybe') expect(getTextContent(changes[5])).toBe('updated branching logic') }) diff --git a/frontend/src/scenes/surveys/surveyActivityDescriber.tsx b/frontend/src/scenes/surveys/surveyActivityDescriber.tsx index 4a355ca7b2eb7..8b538b9ae35cb 100644 --- a/frontend/src/scenes/surveys/surveyActivityDescriber.tsx +++ b/frontend/src/scenes/surveys/surveyActivityDescriber.tsx @@ -25,6 +25,8 @@ import { SurveyQuestionType, } from '~/types' +import { SurveyQuestionLabel } from './constants' + const isEmptyOrUndefined = (value: any): boolean => value === undefined || value === null || value === '' const nameOrLinkToSurvey = ( @@ -440,7 +442,8 @@ export function describeQuestionChanges(before: SurveyQuestion, after: SurveyQue before.type !== after.type ? [ <> - changed question type from {before.type} to {after.type} + changed question type from {SurveyQuestionLabel[before.type]} to{' '} + {SurveyQuestionLabel[after.type]} , ] : [] diff --git a/frontend/src/types.ts b/frontend/src/types.ts index ceede7bd84355..31a685da19cbe 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -3828,7 +3828,7 @@ export enum DataWarehouseSettingsTab { SelfManaged = 'self-managed', } -export const externalDataSources = ['Stripe', 'Hubspot', 'Postgres', 'Zendesk', 'Snowflake'] as const +export const externalDataSources = ['Stripe', 'Hubspot', 'Postgres', 'MySQL', 'Zendesk', 'Snowflake'] as const export type ExternalDataSourceType = (typeof externalDataSources)[number] @@ -4012,7 +4012,7 @@ export type BatchExportService = export type PipelineInterval = 'hour' | 'day' | 'every 5 minutes' -export type DataWarehouseSyncInterval = 'day' | 'week' | 'month' +export type DataWarehouseSyncInterval = '5min' | '30min' | '1hour' | '6hour' | '12hour' | '24hour' | '7day' | '30day' export type BatchExportConfiguration = { // User provided data for the export. This is the data that the user diff --git a/latest_migrations.manifest b/latest_migrations.manifest index d9d360ae6dc94..3079c62e846e3 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0447_alter_integration_kind +posthog: 0451_datawarehousetable_updated_at_and_more sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/package.json b/package.json index 6fe9d5d70e712..df58f8681ecad 100644 --- a/package.json +++ b/package.json @@ -147,7 +147,7 @@ "pmtiles": "^2.11.0", "postcss": "^8.4.31", "postcss-preset-env": "^9.3.0", - "posthog-js": "1.149.0", + "posthog-js": "1.149.1", "posthog-js-lite": "3.0.0", "prettier": "^2.8.8", "prop-types": "^15.7.2", diff --git a/plugin-server/functional_tests/plugins.test.ts b/plugin-server/functional_tests/plugins.test.ts index db56c3f2cefdb..e9129c0ae5188 100644 --- a/plugin-server/functional_tests/plugins.test.ts +++ b/plugin-server/functional_tests/plugins.test.ts @@ -583,7 +583,7 @@ test.concurrent('plugins can use attachements', async () => { key: 'testAttachment', contents: 'test', }) - await enablePluginConfig(teamId, plugin.id) + await enablePluginConfig(teamId, pluginConfig.id) await reloadPlugins() diff --git a/plugin-server/src/capabilities.ts b/plugin-server/src/capabilities.ts index 9bfe5a642155e..11158a284b951 100644 --- a/plugin-server/src/capabilities.ts +++ b/plugin-server/src/capabilities.ts @@ -26,6 +26,7 @@ export function getPluginServerCapabilities(config: PluginsServerConfig): Plugin cdpProcessedEvents: true, cdpFunctionCallbacks: true, cdpFunctionOverflow: true, + syncInlinePlugins: true, ...sharedCapabilities, } case PluginServerMode.ingestion: @@ -89,6 +90,7 @@ export function getPluginServerCapabilities(config: PluginsServerConfig): Plugin return { pluginScheduledTasks: true, appManagementSingleton: true, + syncInlinePlugins: true, ...sharedCapabilities, } case PluginServerMode.cdp_processed_events: @@ -121,6 +123,7 @@ export function getPluginServerCapabilities(config: PluginsServerConfig): Plugin sessionRecordingBlobIngestion: true, appManagementSingleton: true, preflightSchedules: true, + syncInlinePlugins: true, ...sharedCapabilities, } } diff --git a/plugin-server/src/cdp/cdp-consumers.ts b/plugin-server/src/cdp/cdp-consumers.ts index 8463946bbf672..bbce8d9743506 100644 --- a/plugin-server/src/cdp/cdp-consumers.ts +++ b/plugin-server/src/cdp/cdp-consumers.ts @@ -18,7 +18,6 @@ import { AppMetric2Type, GroupTypeToColumnIndex, Hub, RawClickHouseEvent, TeamId import { KafkaProducerWrapper } from '../utils/db/kafka-producer-wrapper' import { status } from '../utils/status' import { castTimestampOrNow } from '../utils/utils' -import { AppMetrics } from '../worker/ingestion/app-metrics' import { RustyHook } from '../worker/rusty-hook' import { AsyncFunctionExecutor } from './async-function-executor' import { HogExecutor } from './hog-executor' @@ -77,7 +76,6 @@ abstract class CdpConsumerBase { asyncFunctionExecutor: AsyncFunctionExecutor hogExecutor: HogExecutor hogWatcher: HogWatcher - appMetrics?: AppMetrics isStopping = false messagesToProduce: HogFunctionMessageToProduce[] = [] @@ -399,13 +397,6 @@ abstract class CdpConsumerBase { await createKafkaProducer(globalConnectionConfig, globalProducerConfig) ) - this.appMetrics = - this.hub?.appMetrics ?? - new AppMetrics( - this.kafkaProducer, - this.hub.APP_METRICS_FLUSH_FREQUENCY_MS, - this.hub.APP_METRICS_FLUSH_MAX_QUEUE_SIZE - ) this.kafkaProducer.producer.connect() this.batchConsumer = await startBatchConsumer({ diff --git a/plugin-server/src/main/pluginsServer.ts b/plugin-server/src/main/pluginsServer.ts index d8d619be7e7b3..d12a2f4362fe1 100644 --- a/plugin-server/src/main/pluginsServer.ts +++ b/plugin-server/src/main/pluginsServer.ts @@ -28,6 +28,7 @@ import { OrganizationManager } from '../worker/ingestion/organization-manager' import { TeamManager } from '../worker/ingestion/team-manager' import Piscina, { makePiscina as defaultMakePiscina } from '../worker/piscina' import { RustyHook } from '../worker/rusty-hook' +import { syncInlinePlugins } from '../worker/vm/inline/inline' import { GraphileWorker } from './graphile-worker/graphile-worker' import { loadPluginSchedule } from './graphile-worker/schedule' import { startGraphileWorker } from './graphile-worker/worker-setup' @@ -439,6 +440,13 @@ export async function startPluginsServer( healthChecks['webhooks-ingestion'] = isWebhooksIngestionHealthy } + if (capabilities.syncInlinePlugins) { + ;[hub, closeHub] = hub ? [hub, closeHub] : await createHub(serverConfig, capabilities) + serverInstance = serverInstance ? serverInstance : { hub } + + await syncInlinePlugins(hub) + } + if (hub && serverInstance) { pubSub = new PubSub(hub, { [hub.PLUGINS_RELOAD_PUBSUB_CHANNEL]: async () => { diff --git a/plugin-server/src/types.ts b/plugin-server/src/types.ts index c4df28fa9e798..92ec13670deed 100644 --- a/plugin-server/src/types.ts +++ b/plugin-server/src/types.ts @@ -33,7 +33,7 @@ import { TeamManager } from './worker/ingestion/team-manager' import { RustyHook } from './worker/rusty-hook' import { PluginsApiKeyManager } from './worker/vm/extensions/helpers/api-key-manager' import { RootAccessManager } from './worker/vm/extensions/helpers/root-acess-manager' -import { LazyPluginVM } from './worker/vm/lazy' +import { PluginInstance } from './worker/vm/lazy' export { Element } from '@posthog/plugin-scaffold' // Re-export Element from scaffolding, for backwards compat. @@ -314,7 +314,7 @@ export interface Hub extends PluginsServerConfig { // diagnostics lastActivity: number lastActivityType: string - statelessVms: StatelessVmMap + statelessVms: StatelessInstanceMap conversionBufferEnabledTeams: Set // functions enqueuePluginJob: (job: EnqueuedPluginJob) => Promise @@ -344,6 +344,7 @@ export interface PluginServerCapabilities { preflightSchedules?: boolean // Used for instance health checks on hobby deploy, not useful on cloud http?: boolean mmdb?: boolean + syncInlinePlugins?: boolean } export type EnqueuedJob = EnqueuedPluginJob | GraphileWorkerCronScheduleJob @@ -394,9 +395,9 @@ export interface JobSpec { export interface Plugin { id: number - organization_id: string + organization_id?: string name: string - plugin_type: 'local' | 'respository' | 'custom' | 'source' + plugin_type: 'local' | 'respository' | 'custom' | 'source' | 'inline' description?: string is_global: boolean is_preinstalled?: boolean @@ -443,7 +444,7 @@ export interface PluginConfig { order: number config: Record attachments?: Record - vm?: LazyPluginVM | null + instance?: PluginInstance | null created_at: string updated_at?: string // We're migrating to a new functions that take PostHogEvent instead of PluginEvent @@ -528,7 +529,7 @@ export interface PluginTask { __ignoreForAppMetrics?: boolean } -export type VMMethods = { +export type PluginMethods = { setupPlugin?: () => Promise teardownPlugin?: () => Promise getSettings?: () => PluginSettings @@ -538,7 +539,7 @@ export type VMMethods = { } // Helper when ensuring that a required method is implemented -export type VMMethodsConcrete = Required +export type PluginMethodsConcrete = Required export enum AlertLevel { P0 = 0, @@ -565,7 +566,7 @@ export interface Alert { } export interface PluginConfigVMResponse { vm: VM - methods: VMMethods + methods: PluginMethods tasks: Record> vmResponseVariable: string usedImports: Set @@ -1150,7 +1151,7 @@ export enum PropertyUpdateOperation { SetOnce = 'set_once', } -export type StatelessVmMap = Record +export type StatelessInstanceMap = Record export enum OrganizationPluginsAccessLevel { NONE = 0, diff --git a/plugin-server/src/utils/db/sql.ts b/plugin-server/src/utils/db/sql.ts index 6aab87a5f9ceb..37f2bfeff4384 100644 --- a/plugin-server/src/utils/db/sql.ts +++ b/plugin-server/src/utils/db/sql.ts @@ -1,4 +1,5 @@ import { Hub, Plugin, PluginAttachmentDB, PluginCapabilities, PluginConfig, PluginConfigId } from '../../types' +import { InlinePluginDescription } from '../../worker/vm/inline/inline' import { PostgresUse } from './postgres' function pluginConfigsInForceQuery(specificField?: keyof PluginConfig): string { @@ -58,6 +59,49 @@ const PLUGIN_SELECT = `SELECT LEFT JOIN posthog_pluginsourcefile psf__site_ts ON (psf__site_ts.plugin_id = posthog_plugin.id AND psf__site_ts.filename = 'site.ts')` +const PLUGIN_UPSERT_RETURNING = `INSERT INTO posthog_plugin + ( + name, + url, + tag, + from_json, + from_web, + error, + plugin_type, + organization_id, + is_global, + capabilities, + public_jobs, + is_stateless, + log_level, + description, + is_preinstalled, + config_schema, + updated_at, + created_at + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, NOW(), NOW()) + ON CONFLICT (url) + DO UPDATE SET + name = $1, + tag = $3, + from_json = $4, + from_web = $5, + error = $6, + plugin_type = $7, + organization_id = $8, + is_global = $9, + capabilities = $10, + public_jobs = $11, + is_stateless = $12, + log_level = $13, + description = $14, + is_preinstalled = $15, + config_schema = $16, + updated_at = NOW() + RETURNING * +` + export async function getPlugin(hub: Hub, pluginId: number): Promise { const result = await hub.db.postgres.query( PostgresUse.COMMON_READ, @@ -68,14 +112,14 @@ export async function getPlugin(hub: Hub, pluginId: number): Promise { +export async function getActivePluginRows(hub: Hub): Promise { const { rows }: { rows: Plugin[] } = await hub.db.postgres.query( PostgresUse.COMMON_READ, `${PLUGIN_SELECT} WHERE posthog_plugin.id IN (${pluginConfigsInForceQuery('plugin_id')} GROUP BY posthog_pluginconfig.plugin_id)`, undefined, - 'getPluginRows' + 'getActivePluginRows' ) return rows @@ -124,3 +168,53 @@ export async function disablePlugin(hub: Hub, pluginConfigId: PluginConfigId): P ) await hub.db.redisPublish(hub.PLUGINS_RELOAD_PUBSUB_CHANNEL, 'reload!') } + +// Given an inline plugin description, upsert it into the known plugins table, returning the full +// Plugin object. Matching is done based on plugin url, not id, since that varies by region. +export async function upsertInlinePlugin(hub: Hub, inline: InlinePluginDescription): Promise { + const fullPlugin: Plugin = { + id: 0, + name: inline.name, + url: inline.url, + tag: inline.tag, + from_json: false, + from_web: false, + error: undefined, + plugin_type: 'inline', + organization_id: undefined, + is_global: inline.is_global, + capabilities: inline.capabilities, + public_jobs: undefined, + is_stateless: inline.is_stateless, + log_level: inline.log_level, + description: inline.description, + is_preinstalled: inline.is_preinstalled, + config_schema: inline.config_schema, + } + + const { rows }: { rows: Plugin[] } = await hub.db.postgres.query( + PostgresUse.COMMON_WRITE, + `${PLUGIN_UPSERT_RETURNING}`, + [ + fullPlugin.name, + fullPlugin.url, + fullPlugin.tag, + fullPlugin.from_json, + fullPlugin.from_web, + fullPlugin.error, + fullPlugin.plugin_type, + fullPlugin.organization_id, + fullPlugin.is_global, + fullPlugin.capabilities, + fullPlugin.public_jobs, + fullPlugin.is_stateless, + fullPlugin.log_level, + fullPlugin.description, + fullPlugin.is_preinstalled, + JSON.stringify(fullPlugin.config_schema), + ], + 'upsertInlinePlugin' + ) + + return rows[0] +} diff --git a/plugin-server/src/worker/plugins/loadPlugin.ts b/plugin-server/src/worker/plugins/loadPlugin.ts index 26a7d45f97e62..a264961ad8c6c 100644 --- a/plugin-server/src/worker/plugins/loadPlugin.ts +++ b/plugin-server/src/worker/plugins/loadPlugin.ts @@ -18,10 +18,16 @@ export async function loadPlugin(hub: Hub, pluginConfig: PluginConfig): Promise< const isLocalPlugin = plugin?.plugin_type === 'local' if (!plugin) { - pluginConfig.vm?.failInitialization!() + pluginConfig.instance?.failInitialization!() return false } + // Inline plugins don't need "loading", and have no source files. + if (plugin.plugin_type === 'inline') { + await pluginConfig.instance?.initialize!('', pluginDigest(plugin)) + return true + } + try { // load config json const configJson = isLocalPlugin @@ -32,7 +38,7 @@ export async function loadPlugin(hub: Hub, pluginConfig: PluginConfig): Promise< try { config = JSON.parse(configJson) } catch (e) { - pluginConfig.vm?.failInitialization!() + pluginConfig.instance?.failInitialization!() await processError(hub, pluginConfig, `Could not load "plugin.json" for ${pluginDigest(plugin)}`) return false } @@ -46,11 +52,11 @@ export async function loadPlugin(hub: Hub, pluginConfig: PluginConfig): Promise< readFileIfExists(hub.BASE_DIR, plugin, 'index.ts') : plugin.source__index_ts if (pluginSource) { - void pluginConfig.vm?.initialize!(pluginSource, pluginDigest(plugin)) + void pluginConfig.instance?.initialize!(pluginSource, pluginDigest(plugin)) return true } else { // always call this if no backend app present, will signal that the VM is done - pluginConfig.vm?.failInitialization!() + pluginConfig.instance?.failInitialization!() // if there is a frontend or site app, don't save an error if no backend app const hasFrontend = isLocalPlugin @@ -72,7 +78,7 @@ export async function loadPlugin(hub: Hub, pluginConfig: PluginConfig): Promise< } } } catch (error) { - pluginConfig.vm?.failInitialization!() + pluginConfig.instance?.failInitialization!() await processError(hub, pluginConfig, error) } return false diff --git a/plugin-server/src/worker/plugins/loadPluginsFromDB.ts b/plugin-server/src/worker/plugins/loadPluginsFromDB.ts index 81282e0646794..b36fb0e251141 100644 --- a/plugin-server/src/worker/plugins/loadPluginsFromDB.ts +++ b/plugin-server/src/worker/plugins/loadPluginsFromDB.ts @@ -2,7 +2,7 @@ import { PluginAttachment } from '@posthog/plugin-scaffold' import { Summary } from 'prom-client' import { Hub, Plugin, PluginConfig, PluginConfigId, PluginId, PluginMethod, TeamId } from '../../types' -import { getPluginAttachmentRows, getPluginConfigRows, getPluginRows } from '../../utils/db/sql' +import { getActivePluginRows, getPluginAttachmentRows, getPluginConfigRows } from '../../utils/db/sql' const loadPluginsMsSummary = new Summary({ name: 'load_plugins_ms', @@ -29,7 +29,7 @@ export async function loadPluginsFromDB( hub: Hub ): Promise> { const startTimer = new Date() - const pluginRows = await getPluginRows(hub) + const pluginRows = await getActivePluginRows(hub) const plugins = new Map() for (const row of pluginRows) { @@ -78,7 +78,7 @@ export async function loadPluginsFromDB( ...row, plugin: plugin, attachments: attachmentsPerConfig.get(row.id) || {}, - vm: null, + instance: null, method, } pluginConfigs.set(row.id, pluginConfig) diff --git a/plugin-server/src/worker/plugins/loadSchedule.ts b/plugin-server/src/worker/plugins/loadSchedule.ts index 6c5c4684d7390..ff54dae570aa1 100644 --- a/plugin-server/src/worker/plugins/loadSchedule.ts +++ b/plugin-server/src/worker/plugins/loadSchedule.ts @@ -18,7 +18,7 @@ export async function loadSchedule(server: Hub): Promise { let count = 0 for (const [id, pluginConfig] of server.pluginConfigs) { - const tasks = (await pluginConfig.vm?.getScheduledTasks()) ?? {} + const tasks = (await pluginConfig.instance?.getScheduledTasks()) ?? {} for (const [taskName, task] of Object.entries(tasks)) { if (task && taskName in pluginSchedule) { pluginSchedule[taskName].push(id) diff --git a/plugin-server/src/worker/plugins/run.ts b/plugin-server/src/worker/plugins/run.ts index 7b24bc10a4a0e..4fb0635994aaf 100644 --- a/plugin-server/src/worker/plugins/run.ts +++ b/plugin-server/src/worker/plugins/run.ts @@ -1,6 +1,6 @@ import { PluginEvent, Webhook } from '@posthog/plugin-scaffold' -import { Hub, PluginConfig, PluginTaskType, PostIngestionEvent, VMMethodsConcrete } from '../../types' +import { Hub, PluginConfig, PluginMethodsConcrete, PluginTaskType, PostIngestionEvent } from '../../types' import { processError } from '../../utils/db/error' import { convertToOnEventPayload, @@ -19,7 +19,7 @@ async function runSingleTeamPluginOnEvent( hub: Hub, event: PostIngestionEvent, pluginConfig: PluginConfig, - onEvent: VMMethodsConcrete['onEvent'] + onEvent: PluginMethodsConcrete['onEvent'] ): Promise { const timeout = setTimeout(() => { status.warn('⌛', `Still running single onEvent plugin for team ${event.teamId} for plugin ${pluginConfig.id}`) @@ -85,7 +85,7 @@ async function runSingleTeamPluginComposeWebhook( hub: Hub, postIngestionEvent: PostIngestionEvent, pluginConfig: PluginConfig, - composeWebhook: VMMethodsConcrete['composeWebhook'] + composeWebhook: PluginMethodsConcrete['composeWebhook'] ): Promise { // 1. Calls `composeWebhook` for the plugin, send `composeWebhook` appmetric success/fail if applicable. // 2. Send via Rusty-Hook if enabled. @@ -329,7 +329,7 @@ export async function runPluginTask( let shouldQueueAppMetric = false try { - const task = await pluginConfig?.vm?.getTask(taskName, taskType) + const task = await pluginConfig?.instance?.getTask(taskName, taskType) if (!task) { throw new Error( `Task "${taskName}" not found for plugin "${pluginConfig?.plugin?.name}" with config id ${pluginConfigId}` @@ -381,23 +381,23 @@ export async function runPluginTask( return response } -async function getPluginMethodsForTeam( +async function getPluginMethodsForTeam( hub: Hub, teamId: number, method: M -): Promise<[PluginConfig, VMMethodsConcrete[M]][]> { +): Promise<[PluginConfig, PluginMethodsConcrete[M]][]> { const pluginConfigs = hub.pluginConfigsPerTeam.get(teamId) || [] if (pluginConfigs.length === 0) { return [] } const methodsObtained = await Promise.all( - pluginConfigs.map(async (pluginConfig) => [pluginConfig, await pluginConfig?.vm?.getVmMethod(method)]) + pluginConfigs.map(async (pluginConfig) => [pluginConfig, await pluginConfig?.instance?.getPluginMethod(method)]) ) const methodsObtainedFiltered = methodsObtained.filter(([_, method]) => !!method) as [ PluginConfig, - VMMethodsConcrete[M] + PluginMethodsConcrete[M] ][] return methodsObtainedFiltered diff --git a/plugin-server/src/worker/plugins/setup.ts b/plugin-server/src/worker/plugins/setup.ts index b2e4e0bdd0f0c..161309f76877a 100644 --- a/plugin-server/src/worker/plugins/setup.ts +++ b/plugin-server/src/worker/plugins/setup.ts @@ -1,8 +1,8 @@ import { Gauge, Summary } from 'prom-client' -import { Hub, StatelessVmMap } from '../../types' +import { Hub, StatelessInstanceMap } from '../../types' import { status } from '../../utils/status' -import { LazyPluginVM } from '../vm/lazy' +import { constructPluginInstance } from '../vm/lazy' import { loadPlugin } from './loadPlugin' import { loadPluginsFromDB } from './loadPluginsFromDB' import { loadSchedule } from './loadSchedule' @@ -24,7 +24,7 @@ export async function setupPlugins(hub: Hub): Promise { status.info('🔁', `Loading plugin configs...`) const { plugins, pluginConfigs, pluginConfigsPerTeam } = await loadPluginsFromDB(hub) const pluginVMLoadPromises: Array> = [] - const statelessVms = {} as StatelessVmMap + const statelessInstances = {} as StatelessInstanceMap const timer = new Date() @@ -37,11 +37,11 @@ export async function setupPlugins(hub: Hub): Promise { const pluginChanged = plugin?.updated_at !== prevPlugin?.updated_at if (!pluginConfigChanged && !pluginChanged) { - pluginConfig.vm = prevConfig.vm - } else if (plugin?.is_stateless && statelessVms[plugin.id]) { - pluginConfig.vm = statelessVms[plugin.id] + pluginConfig.instance = prevConfig.instance + } else if (plugin?.is_stateless && statelessInstances[plugin.id]) { + pluginConfig.instance = statelessInstances[plugin.id] } else { - pluginConfig.vm = new LazyPluginVM(hub, pluginConfig) + pluginConfig.instance = constructPluginInstance(hub, pluginConfig) if (hub.PLUGIN_LOAD_SEQUENTIALLY) { await loadPlugin(hub, pluginConfig) } else { @@ -52,7 +52,7 @@ export async function setupPlugins(hub: Hub): Promise { } if (plugin?.is_stateless) { - statelessVms[plugin.id] = pluginConfig.vm + statelessInstances[plugin.id] = pluginConfig.instance } } } @@ -67,7 +67,7 @@ export async function setupPlugins(hub: Hub): Promise { importUsedGauge.reset() const seenPlugins = new Set() for (const pluginConfig of pluginConfigs.values()) { - const usedImports = pluginConfig.vm?.usedImports + const usedImports = pluginConfig.instance?.usedImports if (usedImports && !seenPlugins.has(pluginConfig.plugin_id)) { seenPlugins.add(pluginConfig.plugin_id) for (const importName of usedImports) { diff --git a/plugin-server/src/worker/plugins/teardown.ts b/plugin-server/src/worker/plugins/teardown.ts index 8d465a7644369..4fe1a4f52c19e 100644 --- a/plugin-server/src/worker/plugins/teardown.ts +++ b/plugin-server/src/worker/plugins/teardown.ts @@ -6,9 +6,9 @@ export async function teardownPlugins(server: Hub, pluginConfig?: PluginConfig): const teardownPromises: Promise[] = [] for (const pluginConfig of pluginConfigs) { - if (pluginConfig.vm) { - pluginConfig.vm.clearRetryTimeoutIfExists() - const teardownPlugin = await pluginConfig.vm.getTeardownPlugin() + if (pluginConfig.instance) { + pluginConfig.instance.clearRetryTimeoutIfExists() + const teardownPlugin = await pluginConfig.instance.getTeardown() if (teardownPlugin) { teardownPromises.push( (async () => { diff --git a/plugin-server/src/worker/vm/capabilities.ts b/plugin-server/src/worker/vm/capabilities.ts index 5c4fa2e90386e..daa12444eb9be 100644 --- a/plugin-server/src/worker/vm/capabilities.ts +++ b/plugin-server/src/worker/vm/capabilities.ts @@ -1,17 +1,17 @@ -import { PluginCapabilities, PluginTask, PluginTaskType, VMMethods } from '../../types' +import { PluginCapabilities, PluginMethods, PluginTask, PluginTaskType } from '../../types' import { PluginServerCapabilities } from './../../types' const PROCESS_EVENT_CAPABILITIES = new Set(['ingestion', 'ingestionOverflow', 'ingestionHistorical']) export function getVMPluginCapabilities( - methods: VMMethods, + methods: PluginMethods, tasks: Record> ): PluginCapabilities { const capabilities: Required = { scheduled_tasks: [], jobs: [], methods: [] } if (methods) { for (const [key, value] of Object.entries(methods)) { - if (value as VMMethods[keyof VMMethods] | undefined) { + if (value as PluginMethods[keyof PluginMethods] | undefined) { capabilities.methods.push(key) } } diff --git a/plugin-server/src/worker/vm/extensions/jobs.ts b/plugin-server/src/worker/vm/extensions/jobs.ts index cdeaa9c1ff45b..3d9ffac9a35b9 100644 --- a/plugin-server/src/worker/vm/extensions/jobs.ts +++ b/plugin-server/src/worker/vm/extensions/jobs.ts @@ -64,7 +64,7 @@ export function createJobs(server: Hub, pluginConfig: PluginConfig): Jobs { pluginJobEnqueueCounter.labels(String(pluginConfig.plugin?.id)).inc() await server.enqueuePluginJob(job) } catch (e) { - await pluginConfig.vm?.createLogEntry( + await pluginConfig.instance?.createLogEntry( `Failed to enqueue job ${type} with error: ${e.message}`, PluginLogEntryType.Error ) diff --git a/plugin-server/src/worker/vm/inline/inline.ts b/plugin-server/src/worker/vm/inline/inline.ts new file mode 100644 index 0000000000000..42a90248c5c4b --- /dev/null +++ b/plugin-server/src/worker/vm/inline/inline.ts @@ -0,0 +1,92 @@ +import { PluginConfigSchema } from '@posthog/plugin-scaffold' + +import { Hub, PluginCapabilities, PluginConfig, PluginLogLevel } from '../../../types' +import { upsertInlinePlugin } from '../../../utils/db/sql' +import { status } from '../../../utils/status' +import { PluginInstance } from '../lazy' +import { NoopInlinePlugin } from './noop' +import { SEMVER_FLATTENER_CONFIG_SCHEMA, SemverFlattener } from './semver-flattener' + +export function constructInlinePluginInstance(hub: Hub, pluginConfig: PluginConfig): PluginInstance { + const url = pluginConfig.plugin?.url + + if (!INLINE_PLUGIN_URLS.includes(url as InlinePluginId)) { + throw new Error(`Invalid inline plugin URL: ${url}`) + } + const plugin = INLINE_PLUGIN_MAP[url as InlinePluginId] + + return plugin.constructor(hub, pluginConfig) +} + +export interface RegisteredInlinePlugin { + constructor: (hub: Hub, config: PluginConfig) => PluginInstance + description: Readonly +} + +export const INLINE_PLUGIN_URLS = ['inline://noop', 'inline://semver-flattener'] as const +type InlinePluginId = (typeof INLINE_PLUGIN_URLS)[number] + +// TODO - add all inline plugins here +export const INLINE_PLUGIN_MAP: Record = { + 'inline://noop': { + constructor: (hub: Hub, config: PluginConfig) => new NoopInlinePlugin(hub, config), + description: { + name: 'Noop Plugin', + description: 'A plugin that does nothing', + is_global: false, + is_preinstalled: false, + url: 'inline://noop', + config_schema: {}, + tag: 'noop', + capabilities: {}, + is_stateless: true, + log_level: PluginLogLevel.Info, + }, + }, + + 'inline://semver-flattener': { + constructor: (hub: Hub, config: PluginConfig) => new SemverFlattener(hub, config), + description: { + name: 'posthog-semver-flattener', + description: + 'Processes specified properties to flatten sematic versions. Assumes any property contains a string which matches [the SemVer specification](https://semver.org/#backusnaur-form-grammar-for-valid-semver-versions)', + is_global: false, + is_preinstalled: false, + url: 'inline://semver-flattener', + config_schema: SEMVER_FLATTENER_CONFIG_SCHEMA, + tag: 'semver-flattener', + capabilities: { + jobs: [], + scheduled_tasks: [], + methods: ['processEvent'], + }, + is_stateless: false, // TODO - this plugin /could/ be stateless, but right now we cache config parsing, which is stateful + log_level: PluginLogLevel.Info, + }, + }, +} + +// Inline plugins are uniquely identified by their /url/, not their ID, and do +// not have most of the standard plugin properties. This reduced interface is +// the "canonical" description of an inline plugin, but can be mapped to a region +// specific Plugin object by url. +export interface InlinePluginDescription { + name: string + description: string + is_global: boolean + is_preinstalled: boolean + url: string + config_schema: Record | PluginConfigSchema[] + tag: string + capabilities: PluginCapabilities + is_stateless: boolean + log_level: PluginLogLevel +} + +export async function syncInlinePlugins(hub: Hub): Promise { + status.info('⚡', 'Syncing inline plugins') + for (const url of INLINE_PLUGIN_URLS) { + const plugin = INLINE_PLUGIN_MAP[url] + await upsertInlinePlugin(hub, plugin.description) + } +} diff --git a/plugin-server/src/worker/vm/inline/noop.ts b/plugin-server/src/worker/vm/inline/noop.ts new file mode 100644 index 0000000000000..aaa80d8b1007f --- /dev/null +++ b/plugin-server/src/worker/vm/inline/noop.ts @@ -0,0 +1,68 @@ +import { PluginEvent } from '@posthog/plugin-scaffold' + +import { + Hub, + PluginConfig, + PluginLogEntrySource, + PluginLogEntryType, + PluginMethods, + PluginTask, + PluginTaskType, +} from '../../../types' +import { PluginInstance } from '../lazy' + +export class NoopInlinePlugin implements PluginInstance { + // The noop plugin has no initialization behavior, or imports + initialize = async () => {} + failInitialization = () => {} + usedImports: Set | undefined + methods: PluginMethods + + hub: Hub + config: PluginConfig + + constructor(hub: Hub, pluginConfig: PluginConfig) { + this.hub = hub + this.config = pluginConfig + this.usedImports = new Set() + + this.methods = { + processEvent: (event: PluginEvent) => { + return Promise.resolve(event) + }, + } + } + + public getTeardown(): Promise { + return Promise.resolve(null) + } + + public getTask(_name: string, _type: PluginTaskType): Promise { + return Promise.resolve(null) + } + + public getScheduledTasks(): Promise> { + return Promise.resolve({}) + } + + public getPluginMethod(method_name: T): Promise { + return Promise.resolve(this.methods[method_name] as PluginMethods[T]) + } + + public clearRetryTimeoutIfExists = () => {} + + public setupPluginIfNeeded(): Promise { + return Promise.resolve(true) + } + + public async createLogEntry(message: string, logType = PluginLogEntryType.Info): Promise { + // TODO - this will be identical across all plugins, so figure out a better place to put it. + await this.hub.db.queuePluginLogEntry({ + message, + pluginConfig: this.config, + source: PluginLogEntrySource.System, + type: logType, + instanceId: this.hub.instanceId, + }) + } +} diff --git a/plugin-server/src/worker/vm/inline/semver-flattener.ts b/plugin-server/src/worker/vm/inline/semver-flattener.ts new file mode 100644 index 0000000000000..50290c6f5066e --- /dev/null +++ b/plugin-server/src/worker/vm/inline/semver-flattener.ts @@ -0,0 +1,135 @@ +import { PluginEvent } from '@posthog/plugin-scaffold' + +import { + Hub, + PluginConfig, + PluginLogEntrySource, + PluginLogEntryType, + PluginMethods, + PluginTask, + PluginTaskType, +} from '../../../types' +import { PluginInstance } from '../lazy' + +export class SemverFlattener implements PluginInstance { + initialize = async () => {} + failInitialization = async () => {} + clearRetryTimeoutIfExists = () => {} + usedImports: Set | undefined + methods: PluginMethods + + hub: Hub + config: PluginConfig + targetProps: string[] + + constructor(hub: Hub, pluginConfig: PluginConfig) { + this.hub = hub + this.config = pluginConfig + this.usedImports = new Set() + + this.targetProps = (this.config.config.properties as string)?.split(',').map((s) => s.trim()) + if (!this.targetProps) { + this.targetProps = [] + } + + this.methods = { + processEvent: (event: PluginEvent) => { + return Promise.resolve(this.flattenSemver(event)) + }, + } + } + + public getTeardown(): Promise { + return Promise.resolve(null) + } + + public getTask(_name: string, _type: PluginTaskType): Promise { + return Promise.resolve(null) + } + + public getScheduledTasks(): Promise> { + return Promise.resolve({}) + } + + public getPluginMethod(method_name: T): Promise { + return Promise.resolve(this.methods[method_name] as PluginMethods[T]) + } + + public setupPluginIfNeeded(): Promise { + return Promise.resolve(true) + } + + public async createLogEntry(message: string, logType = PluginLogEntryType.Info): Promise { + // TODO - this will be identical across all plugins, so figure out a better place to put it. + await this.hub.db.queuePluginLogEntry({ + message, + pluginConfig: this.config, + source: PluginLogEntrySource.System, + type: logType, + instanceId: this.hub.instanceId, + }) + } + + flattenSemver(event: PluginEvent): PluginEvent { + if (!event.properties) { + return event + } + + for (const target of this.targetProps) { + const candidate = event.properties[target] + + if (candidate) { + const { major, minor, patch, preRelease, build } = splitVersion(candidate) + event.properties[`${target}__major`] = major + event.properties[`${target}__minor`] = minor + if (patch !== undefined) { + event.properties[`${target}__patch`] = patch + } + if (preRelease !== undefined) { + event.properties[`${target}__preRelease`] = preRelease + } + if (build !== undefined) { + event.properties[`${target}__build`] = build + } + } + } + + return event + } +} + +export interface VersionParts { + major: number + minor: number + patch?: number + preRelease?: string + build?: string +} + +const splitVersion = (candidate: string): VersionParts => { + const [head, build] = candidate.split('+') + const [version, ...preRelease] = head.split('-') + const [major, minor, patch] = version.split('.') + return { + major: Number(major), + minor: Number(minor), + patch: patch ? Number(patch) : undefined, + preRelease: preRelease.join('-') || undefined, + build, + } +} + +export const SEMVER_FLATTENER_CONFIG_SCHEMA = [ + { + markdown: + 'Processes specified properties to flatten sematic versions. Assumes any property contains a string which matches [the SemVer specification](https://semver.org/#backusnaur-form-grammar-for-valid-semver-versions)', + }, + { + key: 'properties', + name: 'comma separated properties to explode version number from', + type: 'string' as const, + hint: 'my_version_number,app_version', + default: '', + required: true, + }, +] diff --git a/plugin-server/src/worker/vm/lazy.ts b/plugin-server/src/worker/vm/lazy.ts index 9c1964a792269..c873c4a437c7e 100644 --- a/plugin-server/src/worker/vm/lazy.ts +++ b/plugin-server/src/worker/vm/lazy.ts @@ -9,9 +9,9 @@ import { PluginConfigVMResponse, PluginLogEntrySource, PluginLogEntryType, + PluginMethods, PluginTask, PluginTaskType, - VMMethods, } from '../../types' import { processError } from '../../utils/db/error' import { disablePlugin, getPlugin, setPluginCapabilities } from '../../utils/db/sql' @@ -20,6 +20,7 @@ import { getNextRetryMs } from '../../utils/retries' import { status } from '../../utils/status' import { pluginDigest } from '../../utils/utils' import { getVMPluginCapabilities, shouldSetupPluginInServer } from '../vm/capabilities' +import { constructInlinePluginInstance } from './inline/inline' import { createPluginConfigVM } from './vm' export const VM_INIT_MAX_RETRIES = 5 @@ -44,7 +45,33 @@ const pluginDisabledBySystemCounter = new Counter({ labelNames: ['plugin_id'], }) -export class LazyPluginVM { +export function constructPluginInstance(hub: Hub, pluginConfig: PluginConfig): PluginInstance { + if (pluginConfig.plugin?.plugin_type == 'inline') { + return constructInlinePluginInstance(hub, pluginConfig) + } + return new LazyPluginVM(hub, pluginConfig) +} + +export interface PluginInstance { + // These are "optional", but if they're not set, loadPlugin will fail + initialize?: (indexJs: string, logInfo: string) => Promise + failInitialization?: () => void + + getTeardown: () => Promise + getTask: (name: string, type: PluginTaskType) => Promise + getScheduledTasks: () => Promise> + getPluginMethod: (method_name: T) => Promise + clearRetryTimeoutIfExists: () => void + setupPluginIfNeeded: () => Promise + + createLogEntry: (message: string, logType?: PluginLogEntryType) => Promise + + // This is only used for metrics, and can probably be dropped as we start to care less about + // what imports are used by plugins (or as inlining more plugins makes imports irrelevant) + usedImports: Set | undefined +} + +export class LazyPluginVM implements PluginInstance { initialize?: (indexJs: string, logInfo: string) => Promise failInitialization?: () => void resolveInternalVm!: Promise @@ -68,15 +95,7 @@ export class LazyPluginVM { this.initVm() } - public async getOnEvent(): Promise { - return await this.getVmMethod('onEvent') - } - - public async getProcessEvent(): Promise { - return await this.getVmMethod('processEvent') - } - - public async getTeardownPlugin(): Promise { + public async getTeardown(): Promise { // if we never ran `setupPlugin`, there's no reason to run `teardownPlugin` - it's essentially "tore down" already if (!this.ready) { return null @@ -112,15 +131,15 @@ export class LazyPluginVM { return tasks || {} } - public async getVmMethod(method: T): Promise { - let vmMethod = (await this.resolveInternalVm)?.methods[method] || null - if (!this.ready && vmMethod) { + public async getPluginMethod(method_name: T): Promise { + let method = (await this.resolveInternalVm)?.methods[method_name] || null + if (!this.ready && method) { const pluginReady = await this.setupPluginIfNeeded() if (!pluginReady) { - vmMethod = null + method = null } } - return vmMethod + return method } public clearRetryTimeoutIfExists(): void { @@ -207,6 +226,7 @@ export class LazyPluginVM { return true } + // TODO - this is only called in tests, try to remove at some point. public async _setupPlugin(vm?: VM): Promise { const logInfo = this.pluginConfig.plugin ? pluginDigest(this.pluginConfig.plugin) diff --git a/plugin-server/tests/helpers/sqlMock.ts b/plugin-server/tests/helpers/sqlMock.ts index 378c6bf6273e9..a323d0fd18cb7 100644 --- a/plugin-server/tests/helpers/sqlMock.ts +++ b/plugin-server/tests/helpers/sqlMock.ts @@ -2,7 +2,9 @@ import * as s from '../../src/utils/db/sql' // mock functions that get data from postgres and give them the right types type UnPromisify = F extends (...args: infer A) => Promise ? (...args: A) => T : never -export const getPluginRows = s.getPluginRows as unknown as jest.MockedFunction> +export const getPluginRows = s.getActivePluginRows as unknown as jest.MockedFunction< + UnPromisify +> export const getPluginAttachmentRows = s.getPluginAttachmentRows as unknown as jest.MockedFunction< UnPromisify > diff --git a/plugin-server/tests/server.test.ts b/plugin-server/tests/server.test.ts index 3f497be03703c..52fe0b989bf40 100644 --- a/plugin-server/tests/server.test.ts +++ b/plugin-server/tests/server.test.ts @@ -58,6 +58,7 @@ describe('server', () => { ingestionHistorical: true, appManagementSingleton: true, preflightSchedules: true, + syncInlinePlugins: true, } ) }) @@ -73,6 +74,7 @@ describe('server', () => { { http: true, eventsIngestionPipelines: true, + syncInlinePlugins: true, } ) }) @@ -95,6 +97,7 @@ describe('server', () => { cdpProcessedEvents: true, cdpFunctionCallbacks: true, cdpFunctionOverflow: true, + syncInlinePlugins: true, } ) }) @@ -112,6 +115,7 @@ describe('server', () => { http: true, sessionRecordingBlobIngestion: true, sessionRecordingBlobOverflowIngestion: true, + syncInlinePlugins: true, } ) }) @@ -126,6 +130,7 @@ describe('server', () => { pluginScheduledTasks: true, processAsyncWebhooksHandlers: true, preflightSchedules: true, + syncInlinePlugins: true, } ) @@ -141,7 +146,7 @@ describe('server', () => { test('starts graphile for scheduled tasks capability', async () => { pluginsServer = await createPluginServer( {}, - { ingestion: true, pluginScheduledTasks: true, processPluginJobs: true } + { ingestion: true, pluginScheduledTasks: true, processPluginJobs: true, syncInlinePlugins: true } ) expect(startGraphileWorker).toHaveBeenCalled() diff --git a/plugin-server/tests/sql.test.ts b/plugin-server/tests/sql.test.ts index 24c294a6a97c2..d23b133b4c5bf 100644 --- a/plugin-server/tests/sql.test.ts +++ b/plugin-server/tests/sql.test.ts @@ -1,7 +1,7 @@ import { Hub } from '../src/types' import { createHub } from '../src/utils/db/hub' import { PostgresUse } from '../src/utils/db/postgres' -import { disablePlugin, getPluginAttachmentRows, getPluginConfigRows, getPluginRows } from '../src/utils/db/sql' +import { disablePlugin, getActivePluginRows, getPluginAttachmentRows, getPluginConfigRows } from '../src/utils/db/sql' import { commonOrganizationId } from './helpers/plugins' import { resetTestDatabase } from './helpers/sql' @@ -66,7 +66,7 @@ describe('sql', () => { expect(rows1).toEqual([expectedRow]) }) - test('getPluginRows', async () => { + test('getActivePluginRows', async () => { const rowsExpected = [ { error: null, @@ -92,7 +92,7 @@ describe('sql', () => { }, ] - const rows1 = await getPluginRows(hub) + const rows1 = await getActivePluginRows(hub) expect(rows1).toEqual(rowsExpected) await hub.db.postgres.query( PostgresUse.COMMON_WRITE, @@ -100,7 +100,7 @@ describe('sql', () => { undefined, 'testTag' ) - const rows2 = await getPluginRows(hub) + const rows2 = await getActivePluginRows(hub) expect(rows2).toEqual(rowsExpected) }) diff --git a/plugin-server/tests/worker/plugins.test.ts b/plugin-server/tests/worker/plugins.test.ts index e43dd0a628ec0..286f289e46cd4 100644 --- a/plugin-server/tests/worker/plugins.test.ts +++ b/plugin-server/tests/worker/plugins.test.ts @@ -8,6 +8,7 @@ import { loadPlugin } from '../../src/worker/plugins/loadPlugin' import { loadSchedule } from '../../src/worker/plugins/loadSchedule' import { runProcessEvent } from '../../src/worker/plugins/run' import { setupPlugins } from '../../src/worker/plugins/setup' +import { LazyPluginVM } from '../../src/worker/vm/lazy' import { commonOrganizationId, mockPluginSourceCode, @@ -64,7 +65,6 @@ describe('plugins', () => { expect(pluginConfig.enabled).toEqual(pluginConfig39.enabled) expect(pluginConfig.order).toEqual(pluginConfig39.order) expect(pluginConfig.config).toEqual(pluginConfig39.config) - expect(pluginConfig.error).toEqual(pluginConfig39.error) expect(pluginConfig.plugin).toEqual({ ...plugin60, @@ -78,16 +78,15 @@ describe('plugins', () => { contents: pluginAttachment1.contents, }, }) - expect(pluginConfig.vm).toBeDefined() - const vm = await pluginConfig.vm!.resolveInternalVm - expect(Object.keys(vm!.methods).sort()).toEqual([ - 'composeWebhook', - 'getSettings', - 'onEvent', - 'processEvent', - 'setupPlugin', - 'teardownPlugin', - ]) + expect(pluginConfig.instance).toBeDefined() + const instance = pluginConfig.instance! + + expect(instance.getPluginMethod('composeWebhook')).toBeDefined() + expect(instance.getPluginMethod('getSettings')).toBeDefined() + expect(instance.getPluginMethod('onEvent')).toBeDefined() + expect(instance.getPluginMethod('processEvent')).toBeDefined() + expect(instance.getPluginMethod('setupPlugin')).toBeDefined() + expect(instance.getPluginMethod('teardownPlugin')).toBeDefined() // async loading of capabilities expect(setPluginCapabilities).toHaveBeenCalled() @@ -101,7 +100,7 @@ describe('plugins', () => { ], ]) - const processEvent = vm!.methods['processEvent']! + const processEvent = await instance.getPluginMethod('processEvent') const event = { event: '$test', properties: {}, team_id: 2 } as PluginEvent await processEvent(event) @@ -135,10 +134,10 @@ describe('plugins', () => { expect(pluginConfigTeam1.plugin).toEqual(plugin) expect(pluginConfigTeam2.plugin).toEqual(plugin) - expect(pluginConfigTeam1.vm).toBeDefined() - expect(pluginConfigTeam2.vm).toBeDefined() + expect(pluginConfigTeam1.instance).toBeDefined() + expect(pluginConfigTeam2.instance).toBeDefined() - expect(pluginConfigTeam1.vm).toEqual(pluginConfigTeam2.vm) + expect(pluginConfigTeam1.instance).toEqual(pluginConfigTeam2.instance) }) test('plugin returns null', async () => { @@ -211,9 +210,11 @@ describe('plugins', () => { const { pluginConfigs } = hub const pluginConfig = pluginConfigs.get(39)! - pluginConfig.vm!.totalInitAttemptsCounter = 20 // prevent more retries + expect(pluginConfig.instance).toBeInstanceOf(LazyPluginVM) + const vm = pluginConfig.instance as LazyPluginVM + vm.totalInitAttemptsCounter = 20 // prevent more retries await delay(4000) // processError is called at end of retries - expect(await pluginConfig.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfig.instance!.getScheduledTasks()).toEqual({}) const event = { event: '$test', properties: {}, team_id: 2 } as PluginEvent const returnedEvent = await runProcessEvent(hub, { ...event }) @@ -238,9 +239,11 @@ describe('plugins', () => { const { pluginConfigs } = hub const pluginConfig = pluginConfigs.get(39)! - pluginConfig.vm!.totalInitAttemptsCounter = 20 // prevent more retries + expect(pluginConfig.instance).toBeInstanceOf(LazyPluginVM) + const vm = pluginConfig.instance as LazyPluginVM + vm!.totalInitAttemptsCounter = 20 // prevent more retries await delay(4000) // processError is called at end of retries - expect(await pluginConfig.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfig.instance!.getScheduledTasks()).toEqual({}) const event = { event: '$test', properties: {}, team_id: 2 } as PluginEvent const returnedEvent = await runProcessEvent(hub, { ...event }) @@ -308,7 +311,7 @@ describe('plugins', () => { await setupPlugins(hub) const { pluginConfigs } = hub - expect(await pluginConfigs.get(39)!.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfigs.get(39)!.instance!.getScheduledTasks()).toEqual({}) const event = { event: '$test', properties: {}, team_id: 2 } as PluginEvent const returnedEvent = await runProcessEvent(hub, { ...event }) @@ -341,7 +344,7 @@ describe('plugins', () => { await setupPlugins(hub) const { pluginConfigs } = hub - expect(await pluginConfigs.get(39)!.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfigs.get(39)!.instance!.getScheduledTasks()).toEqual({}) const event = { event: '$test', properties: {}, team_id: 2 } as PluginEvent const returnedEvent = await runProcessEvent(hub, { ...event }) @@ -379,7 +382,7 @@ describe('plugins', () => { `Could not load "plugin.json" for plugin test-maxmind-plugin ID ${plugin60.id} (organization ID ${commonOrganizationId})` ) - expect(await pluginConfigs.get(39)!.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfigs.get(39)!.instance!.getScheduledTasks()).toEqual({}) }) test('local plugin with broken plugin.json does not do much', async () => { @@ -403,7 +406,7 @@ describe('plugins', () => { pluginConfigs.get(39)!, expect.stringContaining('Could not load "plugin.json" for plugin ') ) - expect(await pluginConfigs.get(39)!.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfigs.get(39)!.instance!.getScheduledTasks()).toEqual({}) unlink() }) @@ -426,7 +429,7 @@ describe('plugins', () => { pluginConfigs.get(39)!, `Could not load source code for plugin test-maxmind-plugin ID 60 (organization ID ${commonOrganizationId}). Tried: index.js` ) - expect(await pluginConfigs.get(39)!.vm!.getScheduledTasks()).toEqual({}) + expect(await pluginConfigs.get(39)!.instance!.getScheduledTasks()).toEqual({}) }) test('plugin config order', async () => { @@ -499,7 +502,7 @@ describe('plugins', () => { const pluginConfig = pluginConfigs.get(39)! - await pluginConfig.vm?.resolveInternalVm + await (pluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(pluginConfig.plugin!.capabilities!.methods!.sort()).toEqual(['processEvent', 'setupPlugin']) @@ -529,7 +532,7 @@ describe('plugins', () => { const pluginConfig = pluginConfigs.get(39)! - await pluginConfig.vm?.resolveInternalVm + await (pluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(pluginConfig.plugin!.capabilities!.methods!.sort()).toEqual(['onEvent', 'processEvent']) @@ -553,7 +556,7 @@ describe('plugins', () => { const pluginConfig = pluginConfigs.get(39)! - await pluginConfig.vm?.resolveInternalVm + await (pluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(pluginConfig.plugin!.capabilities!.methods!.sort()).toEqual(['onEvent', 'processEvent']) @@ -581,7 +584,7 @@ describe('plugins', () => { const pluginConfig = pluginConfigs.get(39)! - await pluginConfig.vm?.resolveInternalVm + await (pluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(pluginConfig.plugin!.capabilities!.methods!.sort()).toEqual(['onEvent', 'processEvent']) @@ -675,7 +678,7 @@ describe('plugins', () => { await setupPlugins(hub) const pluginConfig = hub.pluginConfigs.get(39)! - await pluginConfig.vm?.resolveInternalVm + await (pluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(setPluginCapabilities.mock.calls.length).toBe(1) @@ -685,7 +688,7 @@ describe('plugins', () => { await setupPlugins(hub) const newPluginConfig = hub.pluginConfigs.get(39)! - await newPluginConfig.vm?.resolveInternalVm + await (newPluginConfig.instance as LazyPluginVM)?.resolveInternalVm // async loading of capabilities expect(newPluginConfig.plugin).not.toBe(pluginConfig.plugin) @@ -694,7 +697,7 @@ describe('plugins', () => { }) describe('loadSchedule()', () => { - const mockConfig = (tasks: any) => ({ vm: { getScheduledTasks: () => Promise.resolve(tasks) } }) + const mockConfig = (tasks: any) => ({ instance: { getScheduledTasks: () => Promise.resolve(tasks) } }) const hub = { pluginConfigs: new Map( diff --git a/plugin-server/tests/worker/plugins/inline.test.ts b/plugin-server/tests/worker/plugins/inline.test.ts new file mode 100644 index 0000000000000..d03d66b357552 --- /dev/null +++ b/plugin-server/tests/worker/plugins/inline.test.ts @@ -0,0 +1,167 @@ +import { PluginEvent } from '@posthog/plugin-scaffold' + +import { Hub, LogLevel, Plugin, PluginConfig } from '../../../src/types' +import { createHub } from '../../../src/utils/db/hub' +import { PostgresUse } from '../../../src/utils/db/postgres' +import { + constructInlinePluginInstance, + INLINE_PLUGIN_MAP, + INLINE_PLUGIN_URLS, + syncInlinePlugins, +} from '../../../src/worker/vm/inline/inline' +import { VersionParts } from '../../../src/worker/vm/inline/semver-flattener' +import { PluginInstance } from '../../../src/worker/vm/lazy' +import { resetTestDatabase } from '../../helpers/sql' + +describe('Inline plugin', () => { + let hub: Hub + let closeHub: () => Promise + + beforeAll(async () => { + console.info = jest.fn() as any + console.warn = jest.fn() as any + ;[hub, closeHub] = await createHub({ LOG_LEVEL: LogLevel.Log }) + await resetTestDatabase() + }) + + afterAll(async () => { + await closeHub() + }) + + // Sync all the inline plugins, then assert that for each plugin URL, a + // plugin exists in the database with the correct properties. + test('syncInlinePlugins', async () => { + await syncInlinePlugins(hub) + + const { rows }: { rows: Plugin[] } = await hub.postgres.query( + PostgresUse.COMMON_WRITE, + 'SELECT * FROM posthog_plugin', + undefined, + 'getPluginRows' + ) + for (const url of INLINE_PLUGIN_URLS) { + const plugin = INLINE_PLUGIN_MAP[url] + const row = rows.find((row) => row.url === url)! + // All the inline plugin properties should align + expect(row).not.toBeUndefined() + expect(row.name).toEqual(plugin.description.name) + expect(row.description).toEqual(plugin.description.description) + expect(row.is_global).toEqual(plugin.description.is_global) + expect(row.is_preinstalled).toEqual(plugin.description.is_preinstalled) + expect(row.config_schema).toEqual(plugin.description.config_schema) + expect(row.tag).toEqual(plugin.description.tag) + expect(row.capabilities).toEqual(plugin.description.capabilities) + expect(row.is_stateless).toEqual(plugin.description.is_stateless) + expect(row.log_level).toEqual(plugin.description.log_level) + + // These non-inline plugin properties should be fixed across all inline plugins + // (in true deployments some of these would not be the case, as they're leftovers from + // before inlining, but in tests the inline plugins are always newly created) + expect(row.plugin_type).toEqual('inline') + expect(row.from_json).toEqual(false) + expect(row.from_web).toEqual(false) + expect(row.source__plugin_json).toBeUndefined() + expect(row.source__index_ts).toBeUndefined() + expect(row.source__frontend_tsx).toBeUndefined() + expect(row.source__site_ts).toBeUndefined() + expect(row.error).toBeNull() + expect(row.organization_id).toBeNull() + expect(row.metrics).toBeNull() + expect(row.public_jobs).toBeNull() + } + }) + + test('semver-flattener', async () => { + interface SemanticVersionTestCase { + versionString: string + expected: VersionParts + } + + const config: PluginConfig = { + plugin: { + id: null, + organization_id: null, + plugin_type: null, + name: null, + is_global: null, + url: 'inline://semver-flattener', + }, + config: { + properties: 'version,version2', + }, + id: null, + plugin_id: null, + enabled: null, + team_id: null, + order: null, + created_at: null, + } + + const instance: PluginInstance = constructInlinePluginInstance(hub, config) + + const versionExamples: SemanticVersionTestCase[] = [ + { + versionString: '1.2.3', + expected: { major: 1, minor: 2, patch: 3, build: undefined }, + }, + { + versionString: '22.7', + expected: { major: 22, minor: 7, preRelease: undefined, build: undefined }, + }, + { + versionString: '22.7-pre-release', + expected: { major: 22, minor: 7, patch: undefined, preRelease: 'pre-release', build: undefined }, + }, + { + versionString: '1.0.0-alpha+001', + expected: { major: 1, minor: 0, patch: 0, preRelease: 'alpha', build: '001' }, + }, + { + versionString: '1.0.0+20130313144700', + expected: { major: 1, minor: 0, patch: 0, build: '20130313144700' }, + }, + { + versionString: '1.2.3-beta+exp.sha.5114f85', + expected: { major: 1, minor: 2, patch: 3, preRelease: 'beta', build: 'exp.sha.5114f85' }, + }, + { + versionString: '1.0.0+21AF26D3—-117B344092BD', + expected: { major: 1, minor: 0, patch: 0, preRelease: undefined, build: '21AF26D3—-117B344092BD' }, + }, + ] + + const test_event: PluginEvent = { + distinct_id: '', + ip: null, + site_url: '', + team_id: 0, + now: '', + event: '', + uuid: '', + properties: {}, + } + + const method = await instance.getPluginMethod('processEvent') + + for (const { versionString, expected } of versionExamples) { + test_event.properties.version = versionString + test_event.properties.version2 = versionString + const flattened = await method(test_event) + + expect(flattened.properties.version__major).toEqual(expected.major) + expect(flattened.properties.version__minor).toEqual(expected.minor) + expect(flattened.properties.version__patch).toEqual(expected.patch) + expect(flattened.properties.version__preRelease).toEqual(expected.preRelease) + expect(flattened.properties.version__build).toEqual(expected.build) + + expect(flattened.properties.version2__major).toEqual(expected.major) + expect(flattened.properties.version2__minor).toEqual(expected.minor) + expect(flattened.properties.version2__patch).toEqual(expected.patch) + expect(flattened.properties.version2__preRelease).toEqual(expected.preRelease) + expect(flattened.properties.version2__build).toEqual(expected.build) + + // reset the event for the next iteration + test_event.properties = {} + } + }) +}) diff --git a/plugin-server/tests/worker/plugins/run.test.ts b/plugin-server/tests/worker/plugins/run.test.ts index aa48e0b8451a1..928b31ee7ab00 100644 --- a/plugin-server/tests/worker/plugins/run.test.ts +++ b/plugin-server/tests/worker/plugins/run.test.ts @@ -20,7 +20,7 @@ describe('runPluginTask()', () => { { team_id: 2, enabled: true, - vm: { + instance: { getTask, }, }, @@ -30,7 +30,7 @@ describe('runPluginTask()', () => { { team_id: 2, enabled: false, - vm: { + instance: { getTask, }, }, @@ -142,8 +142,8 @@ describe('runOnEvent', () => { plugin_id: 100, team_id: 2, enabled: false, - vm: { - getVmMethod: () => onEvent, + instance: { + getPluginMethod: () => onEvent, }, }, @@ -151,8 +151,8 @@ describe('runOnEvent', () => { plugin_id: 101, team_id: 2, enabled: false, - vm: { - getVmMethod: () => onEvent, + instance: { + getPluginMethod: () => onEvent, }, }, ], @@ -264,8 +264,8 @@ describe('runComposeWebhook', () => { plugin_id: 100, team_id: 2, enabled: false, - vm: { - getVmMethod: () => composeWebhook, + instance: { + getPluginMethod: () => composeWebhook, } as any, } mockActionManager = { diff --git a/plugin-server/tests/worker/vm.extra-lazy.test.ts b/plugin-server/tests/worker/vm.extra-lazy.test.ts index e571b2f809b59..78bcc0da60f6c 100644 --- a/plugin-server/tests/worker/vm.extra-lazy.test.ts +++ b/plugin-server/tests/worker/vm.extra-lazy.test.ts @@ -33,7 +33,7 @@ describe('VMs are extra lazy 💤', () => { const pluginConfig = { ...pluginConfig39, plugin: plugin60 } const lazyVm = new LazyPluginVM(hub, pluginConfig) - pluginConfig.vm = lazyVm + pluginConfig.instance = lazyVm jest.spyOn(lazyVm, 'setupPluginIfNeeded') await lazyVm.initialize!(indexJs, pluginDigest(plugin60)) @@ -58,7 +58,7 @@ describe('VMs are extra lazy 💤', () => { const pluginConfig = { ...pluginConfig39, plugin: plugin60 } const lazyVm = new LazyPluginVM(hub, pluginConfig) - pluginConfig.vm = lazyVm + pluginConfig.instance = lazyVm jest.spyOn(lazyVm, 'setupPluginIfNeeded') await lazyVm.initialize!(indexJs, pluginDigest(plugin60)) @@ -80,7 +80,7 @@ describe('VMs are extra lazy 💤', () => { await resetTestDatabase(indexJs) const pluginConfig = { ...pluginConfig39, plugin: plugin60 } const lazyVm = new LazyPluginVM(hub, pluginConfig) - pluginConfig.vm = lazyVm + pluginConfig.instance = lazyVm jest.spyOn(lazyVm, 'setupPluginIfNeeded') await lazyVm.initialize!(indexJs, pluginDigest(plugin60)) @@ -88,7 +88,7 @@ describe('VMs are extra lazy 💤', () => { expect(lazyVm.setupPluginIfNeeded).not.toHaveBeenCalled() expect(fetch).not.toHaveBeenCalled() - await lazyVm.getOnEvent() + await lazyVm.getPluginMethod('onEvent') expect(lazyVm.ready).toEqual(true) expect(lazyVm.setupPluginIfNeeded).toHaveBeenCalled() expect(fetch).toHaveBeenCalledWith('https://onevent.com/', undefined) @@ -107,14 +107,14 @@ describe('VMs are extra lazy 💤', () => { await resetTestDatabase(indexJs) const pluginConfig = { ...pluginConfig39, plugin: plugin60 } const lazyVm = new LazyPluginVM(hub, pluginConfig) - pluginConfig.vm = lazyVm + pluginConfig.instance = lazyVm jest.spyOn(lazyVm, 'setupPluginIfNeeded') await lazyVm.initialize!(indexJs, pluginDigest(plugin60)) lazyVm.ready = false lazyVm.inErroredState = true - const onEvent = await lazyVm.getOnEvent() + const onEvent = await lazyVm.getPluginMethod('onEvent') expect(onEvent).toBeNull() expect(lazyVm.ready).toEqual(false) expect(lazyVm.setupPluginIfNeeded).toHaveBeenCalled() diff --git a/plugin-server/tests/worker/vm.lazy.test.ts b/plugin-server/tests/worker/vm.lazy.test.ts index fc77c5c9f3582..cfe13bc628902 100644 --- a/plugin-server/tests/worker/vm.lazy.test.ts +++ b/plugin-server/tests/worker/vm.lazy.test.ts @@ -65,7 +65,7 @@ describe('LazyPluginVM', () => { const vm = createVM() void initializeVm(vm) - expect(await vm.getProcessEvent()).toEqual('processEvent') + expect(await vm.getPluginMethod('processEvent')).toEqual('processEvent') expect(await vm.getTask('someTask', PluginTaskType.Schedule)).toEqual(null) expect(await vm.getTask('runEveryMinute', PluginTaskType.Schedule)).toEqual('runEveryMinute') expect(await vm.getScheduledTasks()).toEqual(mockVM.tasks.schedule) @@ -109,7 +109,7 @@ describe('LazyPluginVM', () => { void initializeVm(vm) - expect(await vm.getProcessEvent()).toEqual(null) + expect(await vm.getPluginMethod('processEvent')).toEqual(null) expect(await vm.getTask('runEveryMinute', PluginTaskType.Schedule)).toEqual(null) expect(await vm.getScheduledTasks()).toEqual({}) }) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 13add50e919ca..307e5fc0965ff 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -263,8 +263,8 @@ dependencies: specifier: ^9.3.0 version: 9.3.0(postcss@8.4.31) posthog-js: - specifier: 1.149.0 - version: 1.149.0 + specifier: 1.149.1 + version: 1.149.1 posthog-js-lite: specifier: 3.0.0 version: 3.0.0 @@ -13913,7 +13913,7 @@ packages: hogan.js: 3.0.2 htm: 3.1.1 instantsearch-ui-components: 0.3.0 - preact: 10.22.1 + preact: 10.23.0 qs: 6.9.7 search-insights: 2.13.0 dev: false @@ -17721,11 +17721,11 @@ packages: resolution: {integrity: sha512-dyajjnfzZD1tht4N7p7iwf7nBnR1MjVaVu+MKr+7gBgA39bn28wizCIJZztZPtHy4PY0YwtSGgwfBCuG/hnHgA==} dev: false - /posthog-js@1.149.0: - resolution: {integrity: sha512-uIknyqxv5uDAToPaYVBzGqWwTiuga56cHs+3OeiXKZgjkm97yWh9VA5/gRD/3LEq3iszxHEOU4I5pVIaUrMNtg==} + /posthog-js@1.149.1: + resolution: {integrity: sha512-n3mkDlV0vJ1QhkDkWwUzY9RIFTPbzDzbKRyjzRE4D6H2PoH3vsrR05DNujoCr3t0hqgsaO4RLXO3VlctpdkGKQ==} dependencies: fflate: 0.4.8 - preact: 10.22.1 + preact: 10.23.0 web-vitals: 4.2.2 dev: false @@ -17733,8 +17733,8 @@ packages: resolution: {integrity: sha512-Q+/tYsFU9r7xoOJ+y/ZTtdVQwTWfzjbiXBDMM/JKUux3+QPP02iUuIoeBQ+Ot6oEDlC+/PGjB/5A3K7KKb7hcw==} dev: false - /preact@10.22.1: - resolution: {integrity: sha512-jRYbDDgMpIb5LHq3hkI0bbl+l/TQ9UnkdQ0ww+lp+4MMOdqaUYdFc5qeyP+IV8FAd/2Em7drVPeKdQxsiWCf/A==} + /preact@10.23.0: + resolution: {integrity: sha512-Pox0jeY4q6PGkFB5AsXni+zHxxx/sAYFIFZzukW4nIpoJLRziRX0xC4WjZENlkSrDQvqVgZcaZzZ/NL8/A+H/w==} dev: false /prelude-ls@1.2.1: diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index 6ef347008e7fb..4473a35ecaab8 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -23,6 +23,7 @@ comments, dead_letter_queue, early_access_feature, + error_tracking, event_definition, exports, feature_flag, @@ -132,7 +133,7 @@ def api_not_found(request): "project_early_access_feature", ["team_id"], ) -project_surveys_router = projects_router.register(r"surveys", survey.SurveyViewSet, "project_surveys", ["team_id"]) +projects_router.register(r"surveys", survey.SurveyViewSet, "project_surveys", ["team_id"]) projects_router.register( r"dashboard_templates", @@ -396,6 +397,13 @@ def api_not_found(request): ["team_id"], ) +projects_router.register( + r"error_tracking", + error_tracking.ErrorTrackingGroupViewSet, + "project_error_tracking", + ["team_id"], +) + projects_router.register( r"comments", comments.CommentViewSet, diff --git a/posthog/api/error_tracking.py b/posthog/api/error_tracking.py new file mode 100644 index 0000000000000..bdef3cc42ee03 --- /dev/null +++ b/posthog/api/error_tracking.py @@ -0,0 +1,23 @@ +from django.db.models import QuerySet +from rest_framework import serializers, viewsets + +from posthog.api.forbid_destroy_model import ForbidDestroyModel +from posthog.api.routing import TeamAndOrgViewSetMixin +from posthog.models.error_tracking import ErrorTrackingGroup + + +class ErrorTrackingGroupSerializer(serializers.ModelSerializer): + class Meta: + model = ErrorTrackingGroup + fields = ["assignee"] + + +class ErrorTrackingGroupViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.ModelViewSet): + scope_object = "INTERNAL" + queryset = ErrorTrackingGroup.objects.all() + serializer_class = ErrorTrackingGroupSerializer + + def safely_get_object(self, queryset) -> QuerySet: + fingerprint = self.kwargs["pk"] + group, _ = queryset.get_or_create(fingerprint=fingerprint, team=self.team) + return group diff --git a/posthog/api/organization_invite.py b/posthog/api/organization_invite.py index 3dc6d979af801..b3a22fe1bf599 100644 --- a/posthog/api/organization_invite.py +++ b/posthog/api/organization_invite.py @@ -25,6 +25,7 @@ class OrganizationInviteSerializer(serializers.ModelSerializer): created_by = UserBasicSerializer(read_only=True) + send_email = serializers.BooleanField(write_only=True, default=True) class Meta: model = OrganizationInvite @@ -40,6 +41,7 @@ class Meta: "updated_at", "message", "private_project_access", + "send_email", ] read_only_fields = [ "id", @@ -96,12 +98,13 @@ def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> O user__email=validated_data["target_email"], ).exists(): raise exceptions.ValidationError("A user with this email address already belongs to the organization.") + send_email = validated_data.pop("send_email", True) invite: OrganizationInvite = OrganizationInvite.objects.create( organization_id=self.context["organization_id"], created_by=self.context["request"].user, **validated_data, ) - if is_email_available(with_absolute_urls=True): + if is_email_available(with_absolute_urls=True) and send_email: invite.emailing_attempt_made = True send_invite(invite_id=invite.id) invite.save() diff --git a/posthog/api/plugin.py b/posthog/api/plugin.py index 481b63476f10e..04578f5e64eba 100644 --- a/posthog/api/plugin.py +++ b/posthog/api/plugin.py @@ -290,7 +290,10 @@ def get_latest_tag(self, plugin: Plugin) -> Optional[str]: return None def get_organization_name(self, plugin: Plugin) -> str: - return plugin.organization.name + if plugin.organization: + return plugin.organization.name + else: + return "posthog-inline" def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Plugin: validated_data["url"] = self.initial_data.get("url", None) diff --git a/posthog/api/signup.py b/posthog/api/signup.py index 13ea4f9c20625..7cda79d66195d 100644 --- a/posthog/api/signup.py +++ b/posthog/api/signup.py @@ -407,9 +407,9 @@ def process_social_invite_signup(strategy: DjangoStrategy, invite_id: str, email def process_social_domain_jit_provisioning_signup( - email: str, full_name: str, user: Optional[User] = None + strategy: DjangoStrategy, email: str, full_name: str, user: Optional[User] = None ) -> Optional[User]: - # Check if the user is on a allowed domain + # Check if the user is on an allowed domain domain = email.split("@")[-1] try: logger.info(f"process_social_domain_jit_provisioning_signup", domain=domain) @@ -429,19 +429,37 @@ def process_social_domain_jit_provisioning_signup( ) if domain_instance.is_verified and domain_instance.jit_provisioning_enabled: if not user: - user = User.objects.create_and_join( - organization=domain_instance.organization, - email=email, - password=None, - first_name=full_name, - is_email_verified=True, - ) - logger.info( - f"process_social_domain_jit_provisioning_join_complete", - domain=domain, - user=user.email, - organization=domain_instance.organization_id, - ) + try: + invite: OrganizationInvite = OrganizationInvite.objects.get( + target_email=email, organization=domain_instance.organization + ) + invite.validate(user=None, email=email) + + try: + user = strategy.create_user( + email=email, first_name=full_name, password=None, is_email_verified=True + ) + assert isinstance(user, User) # type hinting + invite.use(user, prevalidated=True) + except Exception as e: + capture_exception(e) + message = "Account unable to be created. This account may already exist. Please try again or use different credentials." + raise ValidationError(message, code="unknown", params={"source": "social_create_user"}) + + except OrganizationInvite.DoesNotExist: + user = User.objects.create_and_join( + organization=domain_instance.organization, + email=email, + password=None, + first_name=full_name, + is_email_verified=True, + ) + logger.info( + f"process_social_domain_jit_provisioning_join_complete", + domain=domain, + user=user.email, + organization=domain_instance.organization_id, + ) if not user.organizations.filter(pk=domain_instance.organization_id).exists(): user.join(organization=domain_instance.organization) logger.info( @@ -471,7 +489,7 @@ def social_create_user( user.set_unusable_password() user.is_email_verified = True user.save() - process_social_domain_jit_provisioning_signup(user.email, user.first_name, user) + process_social_domain_jit_provisioning_signup(strategy, user.email, user.first_name, user) return {"is_new": False} backend_processor = "social_create_user" @@ -501,7 +519,7 @@ def social_create_user( else: # JIT Provisioning? - user = process_social_domain_jit_provisioning_signup(email, full_name) + user = process_social_domain_jit_provisioning_signup(strategy, email, full_name) logger.info( f"social_create_user_jit_user", full_name_len=len(full_name), diff --git a/posthog/api/survey.py b/posthog/api/survey.py index 8c18fe6170032..cf4e965b59276 100644 --- a/posthog/api/survey.py +++ b/posthog/api/survey.py @@ -379,7 +379,11 @@ def update(self, instance: Survey, validated_data): instance.targeting_flag.save() iteration_count = validated_data.get("iteration_count") - if instance.current_iteration is not None and instance.current_iteration > iteration_count > 0: + if ( + instance.current_iteration is not None + and iteration_count is not None + and instance.current_iteration > iteration_count > 0 + ): raise serializers.ValidationError( f"Cannot change survey recurrence to {iteration_count}, should be at least {instance.current_iteration}" ) diff --git a/posthog/api/test/__snapshots__/test_plugin.ambr b/posthog/api/test/__snapshots__/test_plugin.ambr index d658a166f5858..e424770da1794 100644 --- a/posthog/api/test/__snapshots__/test_plugin.ambr +++ b/posthog/api/test/__snapshots__/test_plugin.ambr @@ -141,7 +141,7 @@ "posthog_organization"."personalization", "posthog_organization"."domain_whitelist" FROM "posthog_plugin" - INNER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") + LEFT OUTER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") WHERE ("posthog_plugin"."organization_id" = '00000000-0000-0000-0000-000000000000'::uuid OR "posthog_plugin"."is_global" OR "posthog_plugin"."id" IN @@ -329,7 +329,7 @@ "posthog_organization"."personalization", "posthog_organization"."domain_whitelist" FROM "posthog_plugin" - INNER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") + LEFT OUTER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") WHERE ("posthog_plugin"."organization_id" = '00000000-0000-0000-0000-000000000000'::uuid OR "posthog_plugin"."is_global" OR "posthog_plugin"."id" IN @@ -542,7 +542,7 @@ "posthog_organization"."personalization", "posthog_organization"."domain_whitelist" FROM "posthog_plugin" - INNER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") + LEFT OUTER JOIN "posthog_organization" ON ("posthog_plugin"."organization_id" = "posthog_organization"."id") WHERE ("posthog_plugin"."organization_id" = '00000000-0000-0000-0000-000000000000'::uuid OR "posthog_plugin"."is_global" OR "posthog_plugin"."id" IN diff --git a/posthog/api/test/test_error_tracking.py b/posthog/api/test/test_error_tracking.py new file mode 100644 index 0000000000000..c2d6f74b45785 --- /dev/null +++ b/posthog/api/test/test_error_tracking.py @@ -0,0 +1,39 @@ +from posthog.test.base import APIBaseTest +from posthog.models import Team, ErrorTrackingGroup + + +class TestErrorTracking(APIBaseTest): + def test_reuses_existing_group_for_team(self): + fingerprint = "CustomFingerprint" + ErrorTrackingGroup.objects.create(fingerprint=fingerprint, team=self.team) + + self.assertEqual(ErrorTrackingGroup.objects.count(), 1) + self.client.patch( + f"/api/projects/{self.team.id}/error_tracking/{fingerprint}", + data={"assignee": self.user.id}, + ) + self.assertEqual(ErrorTrackingGroup.objects.count(), 1) + + def test_creates_group_if_not_already_existing_for_team(self): + fingerprint = "CustomFingerprint" + other_team = Team.objects.create(organization=self.organization) + ErrorTrackingGroup.objects.create(fingerprint=fingerprint, team=other_team) + + self.assertEqual(ErrorTrackingGroup.objects.count(), 1) + self.client.patch( + f"/api/projects/{self.team.id}/error_tracking/{fingerprint}", + data={"assignee": self.user.id}, + ) + self.assertEqual(ErrorTrackingGroup.objects.count(), 2) + + def test_can_only_update_allowed_fields(self): + fingerprint = "CustomFingerprint" + other_team = Team.objects.create(organization=self.organization) + group = ErrorTrackingGroup.objects.create(fingerprint=fingerprint, team=other_team) + + self.client.patch( + f"/api/projects/{self.team.id}/error_tracking/{fingerprint}", + data={"fingerprint": "NewFingerprint", "assignee": self.user.id}, + ) + group.refresh_from_db() + self.assertEqual(group.fingerprint, "CustomFingerprint") diff --git a/posthog/api/test/test_organization_invites.py b/posthog/api/test/test_organization_invites.py index 4733e32099333..351d2b21aabc9 100644 --- a/posthog/api/test/test_organization_invites.py +++ b/posthog/api/test/test_organization_invites.py @@ -134,6 +134,25 @@ def test_add_organization_invite_with_email(self, mock_capture): self.assertListEqual(mail.outbox[0].to, [email]) self.assertEqual(mail.outbox[0].reply_to, [self.user.email]) # Reply-To is set to the inviting user + @patch("posthoganalytics.capture") + def test_add_organization_invite_with_email_on_instance_but_send_email_prop_false(self, mock_capture): + """ + Email is available on the instance, but the user creating the invite does not want to send an email to the invitee. + """ + set_instance_setting("EMAIL_HOST", "localhost") + email = "x@x.com" + + with self.settings(EMAIL_ENABLED=True, SITE_URL="http://test.posthog.com"): + response = self.client.post( + "/api/organizations/@current/invites/", {"target_email": email, "send_email": False} + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertTrue(OrganizationInvite.objects.exists()) + + # Assert invite email is not sent + self.assertEqual(len(mail.outbox), 0) + def test_can_create_invites_for_the_same_email_multiple_times(self): email = "x@posthog.com" count = OrganizationInvite.objects.count() diff --git a/posthog/api/test/test_plugin.py b/posthog/api/test/test_plugin.py index 968a18faa8b98..0176cc6077739 100644 --- a/posthog/api/test/test_plugin.py +++ b/posthog/api/test/test_plugin.py @@ -885,6 +885,8 @@ def test_plugin_unused(self, mock_get, mock_reload): ) def test_install_plugin_on_multiple_orgs(self, mock_get, mock_reload): + # Expectation: since plugins are url-unique, installing the same plugin on a second orgs should + # return a 400 response, as the plugin is already installed on the first org my_org = self.organization other_org = Organization.objects.create( name="FooBar2", plugins_access_level=Organization.PluginsAccessLevel.INSTALL @@ -914,6 +916,7 @@ def test_install_plugin_on_multiple_orgs(self, mock_get, mock_reload): f"/api/organizations/{other_org.id}/plugins/", {"url": "https://github.com/PostHog/helloworldplugin"}, ) + # Fails due to org membership self.assertEqual(response.status_code, 403) self.assertEqual(Plugin.objects.count(), 1) @@ -923,14 +926,9 @@ def test_install_plugin_on_multiple_orgs(self, mock_get, mock_reload): f"/api/organizations/{other_org.id}/plugins/", {"url": "https://github.com/PostHog/helloworldplugin"}, ) - self.assertEqual(response.status_code, 201) - self.assertEqual(Plugin.objects.count(), 2) - response = self.client.post( - f"/api/organizations/{other_org.id}/plugins/", - {"url": "https://github.com/PostHog/helloworldplugin"}, - ) + # Fails since the plugin already exists self.assertEqual(response.status_code, 400) - self.assertEqual(Plugin.objects.count(), 2) + self.assertEqual(Plugin.objects.count(), 1) def test_cannot_access_others_orgs_plugins(self, mock_get, mock_reload): other_org = Organization.objects.create( diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py index 2e9ccfe3f5e7f..919f029787483 100644 --- a/posthog/api/test/test_signup.py +++ b/posthog/api/test/test_signup.py @@ -615,7 +615,7 @@ def test_api_social_login_cannot_create_second_organization(self, mock_sso_provi response, "/login?error_code=no_new_organizations" ) # show the user an error; operation not permitted - def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_capture): + def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_capture, use_invite: bool = False): # Make sure Google Auth is valid for this test instance mock_sso_providers.return_value = {"google-oauth2": True} @@ -627,6 +627,18 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap organization=new_org, ) new_project = Team.objects.create(organization=new_org, name="My First Project") + + if use_invite: + private_project: Team = Team.objects.create( + organization=new_org, name="Private Project", access_control=True + ) + OrganizationInvite.objects.create( + target_email="jane@hogflix.posthog.com", + organization=new_org, + first_name="Jane", + level=OrganizationMembership.Level.MEMBER, + private_project_access=[{"id": private_project.id, "level": ExplicitTeamMembership.Level.ADMIN}], + ) user_count = User.objects.count() response = self.client.get(reverse("social:begin", kwargs={"backend": "google-oauth2"})) self.assertEqual(response.status_code, status.HTTP_302_FOUND) @@ -655,6 +667,23 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap ) self.assertFalse(mock_capture.call_args.kwargs["properties"]["is_organization_first_user"]) + if use_invite: + # make sure the org invite no longer exists + self.assertEqual( + OrganizationInvite.objects.filter( + organization=new_org, target_email="jane@hogflix.posthog.com" + ).count(), + 0, + ) + teams = user.teams.all() + # make sure user has access to the private project specified in the invite + self.assertTrue(teams.filter(pk=private_project.pk).exists()) + org_membership = OrganizationMembership.objects.get(organization=new_org, user=user) + explicit_team_membership = ExplicitTeamMembership.objects.get( + team=private_project, parent_membership=org_membership + ) + assert explicit_team_membership.level == ExplicitTeamMembership.Level.ADMIN + @patch("posthoganalytics.capture") @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @@ -689,6 +718,30 @@ def test_social_signup_with_allowed_domain_on_cloud( assert mock_update_billing_customer_email.called_once() assert mock_update_billing_admin_emails.called_once() + @patch("posthoganalytics.capture") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_admin_emails") + @mock.patch("social_core.backends.base.BaseAuth.request") + @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") + @mock.patch("posthog.tasks.user_identify.identify_task") + @pytest.mark.ee + def test_social_signup_with_allowed_domain_on_cloud_with_existing_invite( + self, + mock_identify, + mock_sso_providers, + mock_request, + mock_update_distinct_ids, + mock_update_billing_customer_email, + mock_update_billing_admin_emails, + mock_capture, + ): + with self.is_cloud(True): + self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture, use_invite=True) + assert mock_update_distinct_ids.called_once() + assert mock_update_billing_customer_email.called_once() + assert mock_update_billing_admin_emails.called_once() + @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @pytest.mark.ee diff --git a/posthog/api/test/test_survey.py b/posthog/api/test/test_survey.py index 6fd90dfbfb26c..e90ebea3b41b1 100644 --- a/posthog/api/test/test_survey.py +++ b/posthog/api/test/test_survey.py @@ -958,7 +958,6 @@ def test_can_list_surveys(self): response_data = list.json() assert list.status_code == status.HTTP_200_OK, response_data survey = Survey.objects.get(team_id=self.team.id) - assert response_data == { "count": 1, "next": None, @@ -1019,7 +1018,7 @@ def test_can_list_surveys(self): "responses_limit": None, "iteration_count": None, "iteration_frequency_days": None, - "iteration_start_dates": None, + "iteration_start_dates": [], "current_iteration": None, "current_iteration_start_date": None, } @@ -2233,6 +2232,27 @@ def _create_recurring_survey(self) -> Survey: survey = Survey.objects.get(id=response_data["id"]) return survey + def _create_non_recurring_survey(self) -> Survey: + random_id = generate("1234567890abcdef", 10) + response = self.client.post( + f"/api/projects/{self.team.id}/surveys/", + data={ + "name": f"Recurring NPS Survey {random_id}", + "description": "Get feedback on the new notebooks feature", + "type": "popover", + "questions": [ + { + "type": "open", + "question": "What's a survey?", + } + ], + }, + ) + + response_data = response.json() + survey = Survey.objects.get(id=response_data["id"]) + return survey + def test_can_create_recurring_survey(self): survey = self._create_recurring_survey() response = self.client.patch( @@ -2346,6 +2366,41 @@ def test_cannot_reduce_iterations_lt_current_iteration(self): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json()["detail"] == "Cannot change survey recurrence to 1, should be at least 2" + def test_can_handle_non_nil_current_iteration(self): + survey = self._create_non_recurring_survey() + survey.current_iteration = 2 + survey.save() + response = self.client.patch( + f"/api/projects/{self.team.id}/surveys/{survey.id}/", + data={ + "start_date": datetime.now() - timedelta(days=1), + }, + ) + assert response.status_code == status.HTTP_200_OK + + def test_guards_for_nil_iteration_count(self): + survey = self._create_recurring_survey() + survey.current_iteration = 2 + survey.save() + response = self.client.patch( + f"/api/projects/{self.team.id}/surveys/{survey.id}/", + data={ + "start_date": datetime.now() - timedelta(days=1), + }, + ) + assert response.status_code == status.HTTP_200_OK + survey.refresh_from_db() + self.assertIsNone(survey.current_iteration) + response = self.client.patch( + f"/api/projects/{self.team.id}/surveys/{survey.id}/", + data={ + "start_date": datetime.now() - timedelta(days=1), + "iteration_count": 3, + "iteration_frequency_days": 30, + }, + ) + assert response.status_code == status.HTTP_200_OK + def test_can_turn_off_recurring_schedule(self): survey = self._create_recurring_survey() response = self.client.patch( diff --git a/posthog/cdp/templates/customerio/template_customerio.py b/posthog/cdp/templates/customerio/template_customerio.py index 420ee678cca26..e82b234b313af 100644 --- a/posthog/cdp/templates/customerio/template_customerio.py +++ b/posthog/cdp/templates/customerio/template_customerio.py @@ -3,7 +3,7 @@ # Based off of https://customer.io/docs/api/track/#operation/entity template: HogFunctionTemplate = HogFunctionTemplate( - status="alpha", + status="beta", id="template-customerio", name="Update persons in Customer.io", description="Updates persons in Customer.io", diff --git a/posthog/cdp/templates/hubspot/template_hubspot.py b/posthog/cdp/templates/hubspot/template_hubspot.py index a046866007c84..8f2f6d7ab96ea 100644 --- a/posthog/cdp/templates/hubspot/template_hubspot.py +++ b/posthog/cdp/templates/hubspot/template_hubspot.py @@ -2,7 +2,7 @@ template: HogFunctionTemplate = HogFunctionTemplate( - status="alpha", + status="beta", id="template-hubspot", name="Create Hubspot contact", description="Creates a new contact in Hubspot whenever an event is triggered.", diff --git a/posthog/cdp/templates/intercom/template_intercom.py b/posthog/cdp/templates/intercom/template_intercom.py index 4cf3d1100926d..f110d10bbb72e 100644 --- a/posthog/cdp/templates/intercom/template_intercom.py +++ b/posthog/cdp/templates/intercom/template_intercom.py @@ -2,7 +2,7 @@ template: HogFunctionTemplate = HogFunctionTemplate( - status="alpha", + status="beta", id="template-Intercom", name="Send data to Intercom", description="Send events and contact information to Intercom", diff --git a/posthog/cdp/templates/sendgrid/template_sendgrid.py b/posthog/cdp/templates/sendgrid/template_sendgrid.py index 9dfbe9415cb86..7f08728e8dc4a 100644 --- a/posthog/cdp/templates/sendgrid/template_sendgrid.py +++ b/posthog/cdp/templates/sendgrid/template_sendgrid.py @@ -3,7 +3,7 @@ # Based off of https://www.twilio.com/docs/sendgrid/api-reference/contacts/add-or-update-a-contact template: HogFunctionTemplate = HogFunctionTemplate( - status="alpha", + status="beta", id="template-sendgrid", name="Update marketing contacts in Sendgrid", description="Update marketing contacts in Sendgrid", diff --git a/posthog/hogql_queries/insights/trends/having.sql b/posthog/hogql_queries/insights/trends/having.sql new file mode 100644 index 0000000000000..ed8045610a9a7 --- /dev/null +++ b/posthog/hogql_queries/insights/trends/having.sql @@ -0,0 +1,11 @@ +SELECT + toStartOfDay(min(timestamp)) as day_start, + argMin(ifNull(nullIf(toString(person.properties.email), ''), '$$_posthog_breakdown_null_$$'), timestamp) AS breakdown_value +FROM + events AS e SAMPLE 1 +WHERE + lessOrEquals(timestamp, assumeNotNull(toDateTime('2024-07-23 23:59:59'))) and event = '$pageview' +GROUP BY + person_id +HAVING + equals(properties.$browser, 'Safari') diff --git a/posthog/management/commands/setup_test_environment.py b/posthog/management/commands/setup_test_environment.py index 39549ec864e6d..07c39f6ce6414 100644 --- a/posthog/management/commands/setup_test_environment.py +++ b/posthog/management/commands/setup_test_environment.py @@ -26,6 +26,12 @@ class Command(BaseCommand): help = "Set up databases for non-Python tests that depend on the Django server" + # has optional arg to only run postgres setup + def add_arguments(self, parser): + parser.add_argument( + "--only-postgres", action="store_true", help="Only set up the Postgres database", default=False + ) + def handle(self, *args, **options): if not TEST: raise ValueError("TEST environment variable needs to be set for this command to function") @@ -36,6 +42,10 @@ def handle(self, *args, **options): test_runner.setup_databases() test_runner.setup_test_environment() + if options["only_postgres"]: + print("Only setting up Postgres database") # noqa: T201 + return + print("\nCreating test ClickHouse database...") # noqa: T201 database = Database( CLICKHOUSE_DATABASE, diff --git a/posthog/management/commands/test/test_create_batch_export_from_app.py b/posthog/management/commands/test/test_create_batch_export_from_app.py index a5c8fffc5f4d4..9357920f909a5 100644 --- a/posthog/management/commands/test/test_create_batch_export_from_app.py +++ b/posthog/management/commands/test/test_create_batch_export_from_app.py @@ -3,6 +3,7 @@ import datetime as dt import json import typing +import uuid import pytest import temporalio.client @@ -36,11 +37,17 @@ def team(organization): team.delete() +# Used to randomize plugin URLs, to prevent tests stepping on each other, since +# plugin urls are constrained to be unique. +def append_random(url: str) -> str: + return f"{url}?random={uuid.uuid4()}" + + @pytest.fixture def snowflake_plugin(organization) -> typing.Generator[Plugin, None, None]: plugin = Plugin.objects.create( name="Snowflake Export", - url="https://github.com/PostHog/snowflake-export-plugin", + url=append_random("https://github.com/PostHog/snowflake-export-plugin"), plugin_type="custom", organization=organization, ) @@ -52,7 +59,7 @@ def snowflake_plugin(organization) -> typing.Generator[Plugin, None, None]: def s3_plugin(organization) -> typing.Generator[Plugin, None, None]: plugin = Plugin.objects.create( name="S3 Export Plugin", - url="https://github.com/PostHog/s3-export-plugin", + url=append_random("https://github.com/PostHog/s3-export-plugin"), plugin_type="custom", organization=organization, ) @@ -64,7 +71,7 @@ def s3_plugin(organization) -> typing.Generator[Plugin, None, None]: def bigquery_plugin(organization) -> typing.Generator[Plugin, None, None]: plugin = Plugin.objects.create( name="BigQuery Export", - url="https://github.com/PostHog/bigquery-plugin", + url=append_random("https://github.com/PostHog/bigquery-plugin"), plugin_type="custom", organization=organization, ) @@ -76,7 +83,7 @@ def bigquery_plugin(organization) -> typing.Generator[Plugin, None, None]: def postgres_plugin(organization) -> typing.Generator[Plugin, None, None]: plugin = Plugin.objects.create( name="PostgreSQL Export Plugin", - url="https://github.com/PostHog/postgres-plugin", + url=append_random("https://github.com/PostHog/postgres-plugin"), plugin_type="custom", organization=organization, ) @@ -88,7 +95,7 @@ def postgres_plugin(organization) -> typing.Generator[Plugin, None, None]: def redshift_plugin(organization) -> typing.Generator[Plugin, None, None]: plugin = Plugin.objects.create( name="Redshift Export Plugin", - url="https://github.com/PostHog/postgres-plugin", + url=append_random("https://github.com/PostHog/postgres-plugin"), plugin_type="custom", organization=organization, ) diff --git a/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py b/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py new file mode 100644 index 0000000000000..b1cc746856069 --- /dev/null +++ b/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py @@ -0,0 +1,27 @@ +# Generated by Django 4.2.11 on 2024-06-05 17:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0447_alter_integration_kind"), + ] + + operations = [ + migrations.AlterField( + model_name="externaldatasource", + name="source_type", + field=models.CharField( + choices=[ + ("Stripe", "Stripe"), + ("Hubspot", "Hubspot"), + ("Postgres", "Postgres"), + ("Zendesk", "Zendesk"), + ("Snowflake", "Snowflake"), + ("MySQL", "MySQL"), + ], + max_length=128, + ), + ), + ] diff --git a/posthog/migrations/0449_alter_plugin_organization_alter_plugin_plugin_type_and_more.py b/posthog/migrations/0449_alter_plugin_organization_alter_plugin_plugin_type_and_more.py new file mode 100644 index 0000000000000..acbeebaac82f1 --- /dev/null +++ b/posthog/migrations/0449_alter_plugin_organization_alter_plugin_plugin_type_and_more.py @@ -0,0 +1,90 @@ +# Generated by Django 4.2.14 on 2024-07-22 08:04 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + atomic = False # Added to support concurrent index creation + dependencies = [ + ("posthog", "0448_add_mysql_externaldatasource_source_type"), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AlterField( + model_name="plugin", + name="organization", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="plugins", + related_query_name="plugin", + to="posthog.organization", + ), + ), + ], + database_operations=[ + migrations.RunSQL( + """ + SET CONSTRAINTS "posthog_plugin_organization_id_d040b9a9_fk_posthog_o" IMMEDIATE; -- existing-table-constraint-ignore + ALTER TABLE "posthog_plugin" DROP CONSTRAINT "posthog_plugin_organization_id_d040b9a9_fk_posthog_o"; -- existing-table-constraint-ignore + ALTER TABLE "posthog_plugin" ALTER COLUMN "organization_id" DROP NOT NULL; + ALTER TABLE "posthog_plugin" ADD CONSTRAINT "posthog_plugin_organization_id_d040b9a9_fk_posthog_o" FOREIGN KEY ("organization_id") REFERENCES "posthog_organization" ("id") DEFERRABLE INITIALLY DEFERRED; -- existing-table-constraint-ignore + """, + reverse_sql=""" + SET CONSTRAINTS "posthog_plugin_organization_id_d040b9a9_fk_posthog_o" IMMEDIATE; -- existing-table-constraint-ignore + ALTER TABLE "posthog_plugin" DROP CONSTRAINT "posthog_plugin_organization_id_d040b9a9_fk_posthog_o"; -- existing-table-constraint-ignore + ALTER TABLE "posthog_plugin" ALTER COLUMN "organization_id" SET NOT NULL; + ALTER TABLE "posthog_plugin" ADD CONSTRAINT "posthog_plugin_organization_id_d040b9a9_fk_posthog_o" FOREIGN KEY ("organization_id") REFERENCES "posthog_organization" ("id") DEFERRABLE INITIALLY DEFERRED; -- existing-table-constraint-ignore + """, + ), + ], + ), + migrations.AlterField( + model_name="plugin", + name="plugin_type", + field=models.CharField( + blank=True, + choices=[ + ("local", "local"), + ("custom", "custom"), + ("repository", "repository"), + ("source", "source"), + ("inline", "inline"), + ], + default=None, + max_length=200, + null=True, + ), + ), + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AlterField( + model_name="plugin", + name="url", + field=models.CharField(blank=True, max_length=800, null=True, unique=True), + ) + ], + database_operations=[ + migrations.RunSQL( + """ + ALTER TABLE "posthog_plugin" ADD CONSTRAINT "posthog_plugin_url_bccac89d_uniq" UNIQUE ("url"); -- existing-table-constraint-ignore + """, + reverse_sql=""" + ALTER TABLE "posthog_plugin" DROP CONSTRAINT IF EXISTS "posthog_plugin_url_bccac89d_uniq"; + """, + ), + # We add the index seperately + migrations.RunSQL( + """ + CREATE INDEX CONCURRENTLY "posthog_plugin_url_bccac89d_like" ON "posthog_plugin" ("url" varchar_pattern_ops); + """, + reverse_sql=""" + DROP INDEX IF EXISTS "posthog_plugin_url_bccac89d_like"; + """, + ), + ], + ), + ] diff --git a/posthog/migrations/0450_externaldataschema_sync_frequency_interval_and_more.py b/posthog/migrations/0450_externaldataschema_sync_frequency_interval_and_more.py new file mode 100644 index 0000000000000..0456b88e6dca0 --- /dev/null +++ b/posthog/migrations/0450_externaldataschema_sync_frequency_interval_and_more.py @@ -0,0 +1,53 @@ +# Generated by Django 4.2.14 on 2024-07-24 10:13 + +import datetime +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0449_alter_plugin_organization_alter_plugin_plugin_type_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="externaldataschema", + name="sync_frequency_interval", + field=models.DurationField(blank=True, null=True, default=datetime.timedelta(seconds=21600)), + ), + migrations.RunSQL( + sql=""" + UPDATE posthog_externaldataschema + SET sync_frequency_interval = interval '24 hour' + WHERE sync_frequency = 'day'; + """, + reverse_sql=migrations.RunSQL.noop, + ), + migrations.RunSQL( + sql=""" + UPDATE posthog_externaldataschema + SET sync_frequency_interval = interval '7 day' + WHERE sync_frequency = 'week'; + """, + reverse_sql=migrations.RunSQL.noop, + ), + migrations.RunSQL( + sql=""" + UPDATE posthog_externaldataschema + SET sync_frequency_interval = interval '30 day' + WHERE sync_frequency = 'month'; + """, + reverse_sql=migrations.RunSQL.noop, + ), + migrations.AlterField( + model_name="externaldataschema", + name="sync_frequency", + field=models.CharField( + blank=True, + choices=[("day", "Daily"), ("week", "Weekly"), ("month", "Monthly")], + default="day", + max_length=128, + null=True, + ), + ), + ] diff --git a/posthog/migrations/0451_datawarehousetable_updated_at_and_more.py b/posthog/migrations/0451_datawarehousetable_updated_at_and_more.py new file mode 100644 index 0000000000000..b5ba9d16d99d5 --- /dev/null +++ b/posthog/migrations/0451_datawarehousetable_updated_at_and_more.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.14 on 2024-07-24 11:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0450_externaldataschema_sync_frequency_interval_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="datawarehousetable", + name="updated_at", + field=models.DateTimeField(auto_now=True, null=True, blank=True), + ), + migrations.AddField( + model_name="externaldatajob", + name="updated_at", + field=models.DateTimeField(auto_now=True, null=True, blank=True), + ), + migrations.AddField( + model_name="externaldataschema", + name="updated_at", + field=models.DateTimeField(auto_now=True, null=True, blank=True), + ), + migrations.AddField( + model_name="externaldatasource", + name="updated_at", + field=models.DateTimeField(auto_now=True, null=True, blank=True), + ), + ] diff --git a/posthog/models/async_deletion/delete_events.py b/posthog/models/async_deletion/delete_events.py index 7997f492738fc..9c94af5203c6e 100644 --- a/posthog/models/async_deletion/delete_events.py +++ b/posthog/models/async_deletion/delete_events.py @@ -91,7 +91,7 @@ def _verify_by_column(self, distinct_columns: str, async_deletions: list[AsyncDe WHERE {" OR ".join(conditions)} """, args, - settings={"max_query_size": MAX_QUERY_SIZE}, + settings={"max_query_size": MAX_QUERY_SIZE, "max_execution_time": 30 * 60}, ) return {tuple(row) for row in clickhouse_result} diff --git a/posthog/models/error_tracking/error_tracking.py b/posthog/models/error_tracking/error_tracking.py index 58e0ad06808b7..9ec2fa296c13d 100644 --- a/posthog/models/error_tracking/error_tracking.py +++ b/posthog/models/error_tracking/error_tracking.py @@ -2,7 +2,7 @@ from django.contrib.postgres.fields import ArrayField from posthog.models.utils import UUIDModel from django.db import transaction -from django.db.models import Q +from django.db.models import Q, QuerySet class ErrorTrackingGroup(UUIDModel): @@ -27,7 +27,7 @@ class Status(models.TextChoices): ) @classmethod - def filter_fingerprints(cls, queryset, fingerprints: list[str]): + def filter_fingerprints(cls, queryset, fingerprints: list[str]) -> QuerySet: query = Q(fingerprint__in=fingerprints) for fp in fingerprints: diff --git a/posthog/models/feedback/survey.py b/posthog/models/feedback/survey.py index 40aabf0aa96dc..d7d2008868c90 100644 --- a/posthog/models/feedback/survey.py +++ b/posthog/models/feedback/survey.py @@ -176,7 +176,12 @@ def update_survey_iterations(sender, instance, *args, **kwargs): iteration_count = 0 if instance.iteration_count is None else instance.iteration_count iteration_frequency_dates = 0 if instance.iteration_frequency_days is None else instance.iteration_frequency_days - if instance.iteration_count == 0 or instance.iteration_frequency_days == 0: + if ( + instance.iteration_count is None + or instance.iteration_frequency_days is None + or instance.iteration_count == 0 + or instance.iteration_frequency_days == 0 + ): instance.iteration_start_dates = [] instance.current_iteration = None instance.current_iteration_start_date = None diff --git a/posthog/models/plugin.py b/posthog/models/plugin.py index 26b3cdde676ca..19d07578cf4a5 100644 --- a/posthog/models/plugin.py +++ b/posthog/models/plugin.py @@ -38,15 +38,11 @@ pass -def raise_if_plugin_installed(url: str, organization_id: str): +def raise_if_plugin_installed(url: str): url_without_private_key = url.split("?")[0] - if ( - Plugin.objects.filter( - models.Q(url=url_without_private_key) | models.Q(url__startswith=f"{url_without_private_key}?") - ) - .filter(organization_id=organization_id) - .exists() - ): + if Plugin.objects.filter( + models.Q(url=url_without_private_key) | models.Q(url__startswith=f"{url_without_private_key}?") + ).exists(): raise ValidationError(f'Plugin from URL "{url_without_private_key}" already installed!') @@ -125,7 +121,7 @@ def install(self, **kwargs) -> "Plugin": plugin_json: Optional[dict[str, Any]] = None if kwargs.get("plugin_type", None) != Plugin.PluginType.SOURCE: plugin_json = update_validated_data_from_url(kwargs, kwargs["url"]) - raise_if_plugin_installed(kwargs["url"], kwargs["organization_id"]) + raise_if_plugin_installed(kwargs["url"]) plugin = Plugin.objects.create(**kwargs) if plugin_json: PluginSourceFile.objects.sync_from_plugin_archive(plugin, plugin_json) @@ -149,12 +145,18 @@ class PluginType(models.TextChoices): "source", "source", ) # coded inside the browser (versioned via plugin_source_version) + INLINE = ( + "inline", + "inline", + ) # Code checked into plugin_server, url starts with "inline:" + # DEPRECATED: plugin-server will own all plugin code, org relations don't make sense organization: models.ForeignKey = models.ForeignKey( "posthog.Organization", on_delete=models.CASCADE, related_name="plugins", related_query_name="plugin", + null=True, ) plugin_type: models.CharField = models.CharField( max_length=200, null=True, blank=True, choices=PluginType.choices, default=None @@ -167,7 +169,7 @@ class PluginType(models.TextChoices): name: models.CharField = models.CharField(max_length=200, null=True, blank=True) description: models.TextField = models.TextField(null=True, blank=True) - url: models.CharField = models.CharField(max_length=800, null=True, blank=True) + url: models.CharField = models.CharField(max_length=800, null=True, blank=True, unique=True) icon: models.CharField = models.CharField(max_length=800, null=True, blank=True) # Describe the fields to ask in the interface; store answers in PluginConfig->config # - config_schema = { [fieldKey]: { name: 'api key', type: 'string', default: '', required: true } } diff --git a/posthog/models/utils.py b/posthog/models/utils.py index f16bd09984e13..6bfcc81c30825 100644 --- a/posthog/models/utils.py +++ b/posthog/models/utils.py @@ -146,6 +146,13 @@ class Meta: abstract = True +class UpdatedMetaFields(models.Model): + updated_at: models.DateTimeField = models.DateTimeField(auto_now=True, null=True, blank=True) + + class Meta: + abstract = True + + class DeletedMetaFields(models.Model): deleted: models.BooleanField = models.BooleanField(null=True, blank=True) diff --git a/posthog/session_recordings/test/__snapshots__/test_session_recordings.ambr b/posthog/session_recordings/test/__snapshots__/test_session_recordings.ambr index e1e2bd01a820a..16bc71a7219e4 100644 --- a/posthog/session_recordings/test/__snapshots__/test_session_recordings.ambr +++ b/posthog/session_recordings/test/__snapshots__/test_session_recordings.ambr @@ -474,6 +474,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -518,6 +519,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -652,6 +654,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -696,6 +699,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -1408,6 +1412,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -1452,6 +1457,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -1816,6 +1822,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -1860,6 +1867,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -1931,6 +1939,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -1975,6 +1984,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -2394,6 +2404,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -2438,6 +2449,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -2509,6 +2521,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -2553,6 +2566,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -2935,6 +2949,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -2979,6 +2994,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -3106,6 +3122,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -3150,6 +3167,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -3533,6 +3551,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -3577,6 +3596,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -3648,6 +3668,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -3692,6 +3713,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -3864,6 +3886,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -3908,6 +3931,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -4010,6 +4034,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -4054,6 +4079,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -4415,6 +4441,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -4459,6 +4486,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -4530,6 +4558,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -4574,6 +4603,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -4912,6 +4942,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -4956,6 +4987,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -4992,6 +5024,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -5036,6 +5069,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -5107,6 +5141,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -5151,6 +5186,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -5508,6 +5544,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -5552,6 +5589,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -5641,6 +5679,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -5685,6 +5724,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -6039,6 +6079,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -6083,6 +6124,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -6154,6 +6196,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -6198,6 +6241,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -6257,6 +6301,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -6301,6 +6346,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", @@ -6610,6 +6656,7 @@ ''' SELECT "posthog_datawarehousetable"."created_by_id", "posthog_datawarehousetable"."created_at", + "posthog_datawarehousetable"."updated_at", "posthog_datawarehousetable"."deleted", "posthog_datawarehousetable"."id", "posthog_datawarehousetable"."name", @@ -6654,6 +6701,7 @@ "posthog_datawarehousecredential"."team_id", "posthog_externaldatasource"."created_by_id", "posthog_externaldatasource"."created_at", + "posthog_externaldatasource"."updated_at", "posthog_externaldatasource"."id", "posthog_externaldatasource"."source_id", "posthog_externaldatasource"."connection_id", diff --git a/posthog/templates/email/batch_export_run_failure.html b/posthog/templates/email/batch_export_run_failure.html index 04cf2021e342c..58244e49a5bc8 100644 --- a/posthog/templates/email/batch_export_run_failure.html +++ b/posthog/templates/email/batch_export_run_failure.html @@ -3,7 +3,7 @@ {% block heading %}PostHog batch export {{ name }} has failed{% endblock %} {% block section %}

    - There's been a fatal error with your batch export {{ name }} at {{ time }}. Due to the nature of the error, it cannot be retried automatically and requires manual intervention. + There's been a fatal error with your batch export {{ name }} at {{ time }}. Due to the nature of the error, we could not automatically recover from the failure and it requires manual intervention. We recommend reviewing the batch export logs for error details:

    @@ -14,7 +14,9 @@

    - After reviewing the logs, and addressing any errors in them, you can retry the batch export run manually. If the batch export continues to fail we will disable it. + In the logs you may find configuration errors that can be addressed by yourself, like an incorrect credential, or an unreachable warehouse. If you’re feeling extra-adventurous, sometimes a manual retry can fix things! + + If you can't diagnose the issue, please contact us for help. Keep in mind that if the batch export continues to fail we will have to disable it.

    {% endblock %} diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index f2b04f3cac869..e0d629c54e35c 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -434,15 +434,16 @@ async def finish_batch_export_run(inputs: FinishBatchExportRunInputs) -> None: logger.error("Batch export failed with error: %s", batch_export_run.latest_error) elif batch_export_run.status == BatchExportRun.Status.FAILED: - logger.error("Batch export failed with non-retryable error: %s", batch_export_run.latest_error) + logger.error("Batch export failed with non-recoverable error: %s", batch_export_run.latest_error) from posthog.tasks.email import send_batch_export_run_failure try: - logger.info("Sending failure notification email for run %s", inputs.id) await database_sync_to_async(send_batch_export_run_failure)(inputs.id) except Exception: logger.exception("Failure email notification could not be sent") + else: + logger.info("Failure notification email for run %s has been sent", inputs.id) is_over_failure_threshold = await check_if_over_failure_threshold( inputs.batch_export_id, diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index f1e7b7c5b157d..735df5b7f3505 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -428,11 +428,12 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records async def flush_to_bigquery( local_results_file, - records_since_last_flush, - bytes_since_last_flush, + records_since_last_flush: int, + bytes_since_last_flush: int, flush_counter: int, last_inserted_at, - last, + last: bool, + error: Exception | None, ): logger.debug( "Loading %s records of size %s bytes", diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 5535b82a251b9..75093fe444986 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -522,6 +522,7 @@ async def flush_to_postgres( flush_counter: int, last_inserted_at, last: bool, + error: Exception | None, ): logger.debug( "Copying %s records of size %s bytes", diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index f371a5421ce42..7150d116feb71 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -7,6 +7,7 @@ import typing import aioboto3 +import botocore.exceptions import pyarrow as pa from django.conf import settings from temporalio import activity, workflow @@ -117,6 +118,18 @@ def __init__(self): super().__init__("No multi-part upload is in progress. Call 'create' to start one.") +class IntermittentUploadPartTimeoutError(Exception): + """Exception raised when an S3 upload part times out. + + This is generally a transient or intermittent error that can be handled by a retry. + However, it's wrapped by a `botocore.exceptions.ClientError` that generally includes + non-retryable errors. So, we can re-raise our own exception in those cases. + """ + + def __init__(self, part_number: int): + super().__init__(f"An intermittent `RequestTimeout` was raised while attempting to upload part {part_number}") + + class S3MultiPartUploadState(typing.NamedTuple): upload_id: str parts: list[dict[str, str | int]] @@ -274,13 +287,22 @@ async def upload_part(self, body: BatchExportTemporaryFile, rewind: bool = True) reader = io.BufferedReader(body) # type: ignore async with self.s3_client() as s3_client: - response = await s3_client.upload_part( - Bucket=self.bucket_name, - Key=self.key, - PartNumber=next_part_number, - UploadId=self.upload_id, - Body=reader, - ) + try: + response = await s3_client.upload_part( + Bucket=self.bucket_name, + Key=self.key, + PartNumber=next_part_number, + UploadId=self.upload_id, + Body=reader, + ) + except botocore.exceptions.ClientError as err: + error_code = err.response.get("Error", {}).get("Code", None) + + if error_code is not None and error_code == "RequestTimeout": + raise IntermittentUploadPartTimeoutError(part_number=next_part_number) from err + else: + raise + reader.detach() # BufferedReader closes the file otherwise. self.parts.append({"PartNumber": next_part_number, "ETag": response["ETag"]}) @@ -485,7 +507,16 @@ async def flush_to_s3( flush_counter: int, last_inserted_at: dt.datetime, last: bool, + error: Exception | None, ): + if error is not None: + logger.debug("Error while writing part %d", s3_upload.part_number + 1, exc_info=error) + logger.warn( + "An error was detected while writing part %d. Partial part will not be uploaded in case it can be retried.", + s3_upload.part_number + 1, + ) + return + logger.debug( "Uploading %s part %s containing %s records with size %s bytes", "last " if last else "", diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 0d69335836399..e7fc189b3fc28 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -627,6 +627,7 @@ async def flush_to_snowflake( flush_counter: int, last_inserted_at, last: bool, + error: Exception | None, ): logger.info( "Putting %sfile %s containing %s records with size %s bytes", diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index d69e41edfa495..39651f3560d72 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -243,7 +243,15 @@ def reset(self): BytesSinceLastFlush = int FlushCounter = int FlushCallable = collections.abc.Callable[ - [BatchExportTemporaryFile, RecordsSinceLastFlush, BytesSinceLastFlush, FlushCounter, LastInsertedAt, IsLast], + [ + BatchExportTemporaryFile, + RecordsSinceLastFlush, + BytesSinceLastFlush, + FlushCounter, + LastInsertedAt, + IsLast, + Exception | None, + ], collections.abc.Awaitable[None], ] @@ -306,6 +314,7 @@ def reset_writer_tracking(self): self.bytes_total = 0 self.bytes_since_last_flush = 0 self.flush_counter = 0 + self.error = None @contextlib.asynccontextmanager async def open_temporary_file(self, current_flush_counter: int = 0): @@ -325,6 +334,9 @@ async def open_temporary_file(self, current_flush_counter: int = 0): try: yield + except Exception as e: + self.error = e + raise finally: self.track_bytes_written(temp_file) @@ -401,6 +413,7 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N self.flush_counter, last_inserted_at, is_last, + self.error, ) self.batch_export_file.reset() diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 4fc6e10200866..76ca85db9be5f 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -32,6 +32,14 @@ ExternalDataSource, ) from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.warehouse.models.external_data_schema import aupdate_should_sync + + +Non_Retryable_Schema_Errors = [ + "NoSuchTableError", + "401 Client Error: Unauthorized for url: https://api.stripe.com", + "403 Client Error: Forbidden for url: https://api.stripe.com", +] @dataclasses.dataclass @@ -54,6 +62,11 @@ async def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInpu f"External data job failed for external data schema {inputs.schema_id} with error: {inputs.internal_error}" ) + has_non_retryable_error = any(error in inputs.internal_error for error in Non_Retryable_Schema_Errors) + if has_non_retryable_error: + logger.info("Schema has a non-retryable error - turning off syncing") + await aupdate_should_sync(schema_id=inputs.schema_id, team_id=inputs.team_id, should_sync=False) + await sync_to_async(update_external_job_status)( run_id=uuid.UUID(inputs.id), status=inputs.status, @@ -177,7 +190,7 @@ async def run(self, inputs: ExternalDataWorkflowInputs): await workflow.execute_activity( import_data_activity, job_inputs, - heartbeat_timeout=dt.timedelta(minutes=1), + heartbeat_timeout=dt.timedelta(minutes=2), **timeout_params, ) # type: ignore diff --git a/posthog/temporal/data_imports/pipelines/schemas.py b/posthog/temporal/data_imports/pipelines/schemas.py index 7dccb65eca59b..8c0355b34d6ed 100644 --- a/posthog/temporal/data_imports/pipelines/schemas.py +++ b/posthog/temporal/data_imports/pipelines/schemas.py @@ -21,6 +21,7 @@ ), ExternalDataSource.Type.POSTGRES: (), ExternalDataSource.Type.SNOWFLAKE: (), + ExternalDataSource.Type.MYSQL: (), } PIPELINE_TYPE_INCREMENTAL_ENDPOINTS_MAPPING = { @@ -29,6 +30,7 @@ ExternalDataSource.Type.ZENDESK: ZENDESK_INCREMENTAL_ENDPOINTS, ExternalDataSource.Type.POSTGRES: (), ExternalDataSource.Type.SNOWFLAKE: (), + ExternalDataSource.Type.MYSQL: (), } PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING: dict[ExternalDataSource.Type, dict[str, list[IncrementalField]]] = { @@ -37,4 +39,5 @@ ExternalDataSource.Type.ZENDESK: ZENDESK_INCREMENTAL_FIELDS, ExternalDataSource.Type.POSTGRES: {}, ExternalDataSource.Type.SNOWFLAKE: {}, + ExternalDataSource.Type.MYSQL: {}, } diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index 700e3af65b99e..6f9ec4c1162b7 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -16,6 +16,8 @@ from urllib.parse import quote from posthog.warehouse.types import IncrementalFieldType +from posthog.warehouse.models.external_data_source import ExternalDataSource +from sqlalchemy.sql import text from .helpers import ( table_rows, @@ -34,7 +36,8 @@ def incremental_type_to_initial_value(field_type: IncrementalFieldType) -> Any: return date(1970, 1, 1) -def postgres_source( +def sql_source_for_type( + source_type: ExternalDataSource.Type, host: str, port: int, user: str, @@ -52,10 +55,6 @@ def postgres_source( database = quote(database) sslmode = quote(sslmode) - credentials = ConnectionStringCredentials( - f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" - ) - if incremental_field is not None and incremental_field_type is not None: incremental: dlt.sources.incremental | None = dlt.sources.incremental( cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) @@ -63,6 +62,15 @@ def postgres_source( else: incremental = None + if source_type == ExternalDataSource.Type.POSTGRES: + credentials = ConnectionStringCredentials( + f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" + ) + elif source_type == ExternalDataSource.Type.MYSQL: + credentials = ConnectionStringCredentials(f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}") + else: + raise Exception("Unsupported source_type") + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source @@ -78,7 +86,7 @@ def snowflake_source( table_names: list[str], role: Optional[str] = None, incremental_field: Optional[str] = None, - incremental_field_type: Optional[str] = None, + incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: account_id = quote(account_id) user = quote(user) @@ -87,10 +95,17 @@ def snowflake_source( warehouse = quote(warehouse) role = quote(role) if role else None + if incremental_field is not None and incremental_field_type is not None: + incremental: dlt.sources.incremental | None = dlt.sources.incremental( + cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) + ) + else: + incremental = None + credentials = ConnectionStringCredentials( f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" ) - db_source = sql_database(credentials, schema=schema, table_names=table_names) + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source @@ -137,7 +152,12 @@ def sql_database( name=table.name, primary_key=get_primary_key(table), merge_key=get_primary_key(table), - write_disposition="merge" if incremental else "replace", + write_disposition={ + "disposition": "merge", + "strategy": "upsert", + } + if incremental + else "replace", spec=SqlDatabaseTableConfiguration, table_format="delta", columns=get_column_hints(engine, schema or "", table.name), @@ -150,14 +170,13 @@ def sql_database( def get_column_hints(engine: Engine, schema_name: str, table_name: str) -> dict[str, TColumnSchema]: with engine.connect() as conn: - execute_result: CursorResult | None = conn.execute( - "SELECT column_name, data_type, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = %(schema_name)s AND table_name = %(table_name)s", + execute_result: CursorResult = conn.execute( + text( + "SELECT column_name, data_type, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = :schema_name AND table_name = :table_name" + ), {"schema_name": schema_name, "table_name": table_name}, ) - if execute_result is None: - return {} - cursor_result = cast(CursorResult, execute_result) results = cursor_result.fetchall() diff --git a/posthog/temporal/data_imports/pipelines/stripe/__init__.py b/posthog/temporal/data_imports/pipelines/stripe/__init__.py index c6ff35e971a74..7bb66a63d1b06 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/__init__.py +++ b/posthog/temporal/data_imports/pipelines/stripe/__init__.py @@ -14,7 +14,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "BalanceTransaction", "table_name": "balance_transaction", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_balancetransaction"), # type: ignore "endpoint": { "data_selector": "data", @@ -44,7 +49,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Charge", "table_name": "charge", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_charge"), # type: ignore "endpoint": { "data_selector": "data", @@ -73,7 +83,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Customer", "table_name": "customer", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_customer"), # type: ignore "endpoint": { "data_selector": "data", @@ -101,7 +116,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Invoice", "table_name": "invoice", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_invoice"), # type: ignore "endpoint": { "data_selector": "data", @@ -132,7 +152,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Price", "table_name": "price", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_price"), # type: ignore "endpoint": { "data_selector": "data", @@ -164,7 +189,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Product", "table_name": "product", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_product"), # type: ignore "endpoint": { "data_selector": "data", @@ -194,7 +224,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "Subscription", "table_name": "subscription", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("stripe_subscription"), # type: ignore "endpoint": { "data_selector": "data", @@ -274,7 +309,12 @@ def stripe_source( }, "resource_defaults": { "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", }, "resources": [get_resource(endpoint, is_incremental)], } diff --git a/posthog/temporal/data_imports/pipelines/zendesk/__init__.py b/posthog/temporal/data_imports/pipelines/zendesk/__init__.py index 2e7859935f37d..0bf2510cce8f3 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/__init__.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/__init__.py @@ -1,6 +1,6 @@ import base64 import dlt -from dlt.sources.helpers.rest_client.paginators import BasePaginator +from dlt.sources.helpers.rest_client.paginators import BasePaginator, JSONLinkPaginator from dlt.sources.helpers.requests import Response, Request import requests from posthog.temporal.data_imports.pipelines.rest_source import RESTAPIConfig, rest_api_resources @@ -14,15 +14,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "brands", "table_name": "brands", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_brands"), # type: ignore "endpoint": { "data_selector": "brands", "path": "/api/v2/brands", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), "params": { "page[size]": 100, }, @@ -33,15 +35,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "organizations", "table_name": "organizations", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_organizations"), # type: ignore "endpoint": { "data_selector": "organizations", "path": "/api/v2/organizations", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), "params": { "page[size]": 100, }, @@ -52,15 +56,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "groups", "table_name": "groups", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_groups"), # type: ignore "endpoint": { "data_selector": "groups", "path": "/api/v2/groups", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), "params": { # the parameters below can optionally be configured # "exclude_deleted": "OPTIONAL_CONFIG", @@ -73,15 +79,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "sla_policies", "table_name": "sla_policies", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_sla_policies"), # type: ignore "endpoint": { "data_selector": "sla_policies", "path": "/api/v2/slas/policies", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), }, "table_format": "delta", }, @@ -89,15 +97,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "users", "table_name": "users", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_users"), # type: ignore "endpoint": { "data_selector": "users", "path": "/api/v2/users", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), "params": { # the parameters below can optionally be configured # "role": "OPTIONAL_CONFIG", @@ -113,15 +123,17 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "ticket_fields", "table_name": "ticket_fields", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_ticket_fields"), # type: ignore "endpoint": { "data_selector": "ticket_fields", "path": "/api/v2/ticket_fields", - "paginator": { - "type": "json_response", - "next_url_path": "links.next", - }, + "paginator": JSONLinkPaginator(next_url_path="links.next"), "params": { # the parameters below can optionally be configured # "locale": "OPTIONAL_CONFIG", @@ -135,7 +147,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "ticket_events", "table_name": "ticket_events", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_ticket_events"), # type: ignore "endpoint": { "data_selector": "ticket_events", @@ -159,7 +176,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "tickets", "table_name": "tickets", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_tickets"), # type: ignore "endpoint": { "data_selector": "tickets", @@ -182,7 +204,12 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource: "name": "ticket_metric_events", "table_name": "ticket_metric_events", "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", "columns": get_dlt_mapping_for_external_table("zendesk_ticket_metric_events"), # type: ignore "endpoint": { "data_selector": "ticket_metric_events", @@ -274,7 +301,12 @@ def zendesk_source( }, "resource_defaults": { "primary_key": "id", - "write_disposition": "merge" if is_incremental else "replace", + "write_disposition": { + "disposition": "merge", + "strategy": "upsert", + } + if is_incremental + else "replace", }, "resources": [get_resource(endpoint, is_incremental)], } diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index a35bb1667e7b0..21f5e046d1a28 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -13,7 +13,7 @@ from posthog.warehouse.models import sync_old_schemas_with_new_schemas, ExternalDataSource, aget_schema_by_id from posthog.warehouse.models.external_data_schema import ( ExternalDataSchema, - get_postgres_schemas, + get_sql_schemas_for_source_type, get_snowflake_schemas, ) from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -46,7 +46,7 @@ async def create_external_data_job_model_activity(inputs: CreateExternalDataJobM source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=inputs.source_id) - if source.source_type == ExternalDataSource.Type.POSTGRES: + if source.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: host = source.job_inputs.get("host") port = source.job_inputs.get("port") user = source.job_inputs.get("user") @@ -74,8 +74,8 @@ async def create_external_data_job_model_activity(inputs: CreateExternalDataJobM private_key=ssh_tunnel_auth_type_private_key, ) - schemas_to_sync = await sync_to_async(get_postgres_schemas)( - host, port, database, user, password, db_schema, ssh_tunnel + schemas_to_sync = await sync_to_async(get_sql_schemas_for_source_type)( + source.source_type, host, port, database, user, password, db_schema, ssh_tunnel ) elif source.source_type == ExternalDataSource.Type.SNOWFLAKE: account_id = source.job_inputs.get("account_id") diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py index 2cba1697ef44e..103408db92bd4 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data.py @@ -4,6 +4,7 @@ from temporalio import activity +from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.data_imports.pipelines.helpers import aremove_reset_pipeline, aupdate_job_count from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs @@ -13,7 +14,6 @@ get_external_data_job, ) from posthog.temporal.common.logger import bind_temporal_worker_logger -import asyncio from structlog.typing import FilteringBoundLogger from posthog.warehouse.models.external_data_schema import ExternalDataSchema, aget_schema_by_id from posthog.warehouse.models.ssh_tunnel import SSHTunnel @@ -102,8 +102,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): schema=schema, reset_pipeline=reset_pipeline, ) - elif model.pipeline.source_type == ExternalDataSource.Type.POSTGRES: - from posthog.temporal.data_imports.pipelines.sql_database import postgres_source + elif model.pipeline.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: + from posthog.temporal.data_imports.pipelines.sql_database import sql_source_for_type host = model.pipeline.job_inputs.get("host") port = model.pipeline.job_inputs.get("port") @@ -137,7 +137,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): if tunnel is None: raise Exception("Can't open tunnel to SSH server") - source = postgres_source( + source = sql_source_for_type( + source_type=model.pipeline.source_type, host=tunnel.local_bind_host, port=tunnel.local_bind_port, user=user, @@ -163,7 +164,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): reset_pipeline=reset_pipeline, ) - source = postgres_source( + source = sql_source_for_type( + source_type=model.pipeline.source_type, host=host, port=port, user=user, @@ -206,6 +208,10 @@ async def import_data_activity(inputs: ImportDataActivityInputs): warehouse=warehouse, role=role, table_names=endpoints, + incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None, + incremental_field_type=schema.sync_type_config.get("incremental_field_type") + if schema.is_incremental + else None, ) return await _run( @@ -250,15 +256,7 @@ async def _run( schema: ExternalDataSchema, reset_pipeline: bool, ): - # Temp background heartbeat for now - async def heartbeat() -> None: - while True: - await asyncio.sleep(10) - activity.heartbeat() - - heartbeat_task = asyncio.create_task(heartbeat()) - - try: + async with Heartbeater(): table_row_counts = await DataImportPipeline( job_inputs, source, logger, reset_pipeline, schema.is_incremental ).run() @@ -266,6 +264,3 @@ async def heartbeat() -> None: await aupdate_job_count(inputs.run_id, inputs.team_id, total_rows_synced) await aremove_reset_pipeline(inputs.source_id) - finally: - heartbeat_task.cancel() - await asyncio.wait([heartbeat_task]) diff --git a/posthog/temporal/tests/batch_exports/test_import_data.py b/posthog/temporal/tests/batch_exports/test_import_data.py index 2b743102056d0..935781c3bdf34 100644 --- a/posthog/temporal/tests/batch_exports/test_import_data.py +++ b/posthog/temporal/tests/batch_exports/test_import_data.py @@ -70,12 +70,13 @@ async def test_postgres_source_without_ssh_tunnel(activity_environment, team, ** activity_inputs = await _setup(team, job_inputs) with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="host.com", port="5432", user="Username", @@ -107,12 +108,13 @@ async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, te activity_inputs = await _setup(team, job_inputs) with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="host.com", port="5432", user="Username", @@ -160,13 +162,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback): return MockedTunnel() with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), mock.patch.object(SSHTunnel, "get_tunnel", mock_get_tunnel), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="other-host.com", port=55550, user="Username", diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index 624708410a220..4c959170dba9e 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -1,9 +1,12 @@ import asyncio +import contextlib import datetime as dt import functools +import io import json import os import uuid +from unittest import mock import aioboto3 import botocore.exceptions @@ -26,9 +29,11 @@ from posthog.temporal.batch_exports.s3_batch_export import ( FILE_FORMAT_EXTENSIONS, HeartbeatDetails, + IntermittentUploadPartTimeoutError, S3BatchExportInputs, S3BatchExportWorkflow, S3InsertInputs, + S3MultiPartUpload, get_s3_key, insert_into_s3_activity, s3_default_fields, @@ -1251,3 +1256,178 @@ def track_hearbeat_details(*details): data_interval_start=data_interval_start, data_interval_end=data_interval_end, ) + + +async def test_s3_multi_part_upload_raises_retryable_exception(bucket_name, s3_key_prefix): + """Test a retryable exception is raised instead of a `RequestTimeout`. + + Even though they should be retryable, `RequestTimeout`s are wrapped by `ClientError`, which + are all non-retryable. So, we assert our own exception is raised instead. + """ + s3_upload = S3MultiPartUpload( + bucket_name=bucket_name, + key=s3_key_prefix, + encryption=None, + kms_key_id=None, + region_name="us-east-1", + aws_access_key_id="object_storage_root_user", + aws_secret_access_key="object_storage_root_password", + endpoint_url=settings.OBJECT_STORAGE_ENDPOINT, + ) + + async def faulty_upload_part(*args, **kwargs): + raise botocore.exceptions.ClientError( + error_response={ + "Error": {"Code": "RequestTimeout", "Message": "Oh no!"}, + "ResponseMetadata": {"MaxAttemptsReached": True, "RetryAttempts": 2}, # type: ignore + }, + operation_name="UploadPart", + ) + + class FakeSession(aioboto3.Session): + @contextlib.asynccontextmanager + async def client(self, *args, **kwargs): + client = self._session.create_client(*args, **kwargs) + client.upload_part = faulty_upload_part + + yield client + + s3_upload._session = FakeSession() + + with pytest.raises(IntermittentUploadPartTimeoutError): + await s3_upload.upload_part(io.BytesIO(b"1010"), rewind=False) # type: ignore + + +@pytest.mark.parametrize("model", [TEST_S3_MODELS[1], TEST_S3_MODELS[2], None]) +async def test_s3_export_workflow_with_request_timeouts( + clickhouse_client, + ateam, + minio_client, + bucket_name, + interval, + s3_batch_export, + s3_key_prefix, + data_interval_end, + data_interval_start, + model: BatchExportModel | BatchExportSchema | None, + generate_test_data, +): + """Test the S3BatchExport Workflow end-to-end when a `RequestTimeout` occurs. + + We run the S3 batch export workflow with a mocked session that will raise a `ClientError` due + to a `RequestTimeout` on the first run of the batch export. The second run should work normally. + """ + batch_export_schema: BatchExportSchema | None = None + batch_export_model: BatchExportModel | None = None + if isinstance(model, BatchExportModel): + batch_export_model = model + elif model is not None: + batch_export_schema = model + + raised = False + + class FakeSession(aioboto3.Session): + @contextlib.asynccontextmanager + async def client(self, *args, **kwargs): + client = self._session.create_client(*args, **kwargs) + + async with client as client: + original_upload_part = client.upload_part + + async def faulty_upload_part(*args, **kwargs): + nonlocal raised + + if not raised: + raised = True + raise botocore.exceptions.ClientError( + error_response={ + "Error": {"Code": "RequestTimeout", "Message": "Oh no!"}, + "ResponseMetadata": {"MaxAttemptsReached": True, "RetryAttempts": 2}, # type: ignore + }, + operation_name="UploadPart", + ) + else: + return await original_upload_part(*args, **kwargs) + + client.upload_part = faulty_upload_part + + yield client + + workflow_id = str(uuid.uuid4()) + inputs = S3BatchExportInputs( + team_id=ateam.pk, + batch_export_id=str(s3_batch_export.id), + data_interval_end=data_interval_end.isoformat(), + batch_export_model=batch_export_model, + batch_export_schema=batch_export_schema, + interval=interval, + **s3_batch_export.destination.config, + ) + + async with await WorkflowEnvironment.start_time_skipping() as activity_environment: + async with Worker( + activity_environment.client, + task_queue=settings.TEMPORAL_TASK_QUEUE, + workflows=[S3BatchExportWorkflow], + activities=[ + start_batch_export_run, + insert_into_s3_activity, + finish_batch_export_run, + ], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + with mock.patch("posthog.temporal.batch_exports.s3_batch_export.aioboto3.Session", FakeSession): + await activity_environment.client.execute_workflow( + S3BatchExportWorkflow.run, + inputs, + id=workflow_id, + task_queue=settings.TEMPORAL_TASK_QUEUE, + retry_policy=RetryPolicy(maximum_attempts=2), + execution_timeout=dt.timedelta(seconds=10), + ) + + runs = await afetch_batch_export_runs(batch_export_id=s3_batch_export.id) + assert len(runs) == 2 + # Sort by `last_updated_at` as earlier run should be the failed run. + runs.sort(key=lambda r: r.last_updated_at) + + run = runs[0] + (events_to_export_created, persons_to_export_created) = generate_test_data + assert run.status == "FailedRetryable" + assert run.records_completed is None + + run = runs[1] + (events_to_export_created, persons_to_export_created) = generate_test_data + assert run.status == "Completed" + assert run.records_completed == len(events_to_export_created) or run.records_completed == len( + persons_to_export_created + ) + + assert runs[0].data_interval_end == runs[1].data_interval_end + + expected_key_prefix = s3_key_prefix.format( + table=batch_export_model.name if batch_export_model is not None else "events", + year=data_interval_end.year, + # All of these must include leading 0s. + month=data_interval_end.strftime("%m"), + day=data_interval_end.strftime("%d"), + hour=data_interval_end.strftime("%H"), + minute=data_interval_end.strftime("%M"), + second=data_interval_end.strftime("%S"), + ) + + objects = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=expected_key_prefix) + key = objects["Contents"][0].get("Key") + assert len(objects.get("Contents", [])) == 1 + assert key.startswith(expected_key_prefix) + + await assert_clickhouse_records_in_s3( + s3_compatible_client=minio_client, + clickhouse_client=clickhouse_client, + bucket_name=bucket_name, + key_prefix=expected_key_prefix, + team_id=ateam.pk, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + batch_export_model=model, + ) diff --git a/posthog/temporal/tests/batch_exports/test_temporary_file.py b/posthog/temporal/tests/batch_exports/test_temporary_file.py index e9e70579acbd6..4f6ffc8ad0569 100644 --- a/posthog/temporal/tests/batch_exports/test_temporary_file.py +++ b/posthog/temporal/tests/batch_exports/test_temporary_file.py @@ -226,7 +226,13 @@ async def test_jsonl_writer_writes_record_batches(record_batch): inserted_ats_seen: list[LastInsertedAt] = [] async def store_in_memory_on_flush( - batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, last_inserted_at, is_last + batch_export_file, + records_since_last_flush, + bytes_since_last_flush, + flush_counter, + last_inserted_at, + is_last, + error, ): assert writer.records_since_last_flush == record_batch.num_rows in_memory_file_obj.write(batch_export_file.read()) @@ -264,7 +270,13 @@ async def test_csv_writer_writes_record_batches(record_batch): inserted_ats_seen = [] async def store_in_memory_on_flush( - batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, last_inserted_at, is_last + batch_export_file, + records_since_last_flush, + bytes_since_last_flush, + flush_counter, + last_inserted_at, + is_last, + error, ): in_memory_file_obj.write(batch_export_file.read().decode("utf-8")) inserted_ats_seen.append(last_inserted_at) @@ -304,7 +316,13 @@ async def test_parquet_writer_writes_record_batches(record_batch): inserted_ats_seen = [] async def store_in_memory_on_flush( - batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, last_inserted_at, is_last + batch_export_file, + records_since_last_flush, + bytes_since_last_flush, + flush_counter, + last_inserted_at, + is_last, + error, ): in_memory_file_obj.write(batch_export_file.read()) inserted_ats_seen.append(last_inserted_at) diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index 4734740cb4b47..aa0a83d9941a6 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -262,6 +262,99 @@ async def test_update_external_job_activity(activity_environment, team, **kwargs assert schema.status == ExternalDataJob.Status.COMPLETED +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_update_external_job_activity_with_retryable_error(activity_environment, team, **kwargs): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + ) + + schema = await sync_to_async(ExternalDataSchema.objects.create)( + name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], + team_id=team.id, + source_id=new_source.pk, + should_sync=True, + ) + + new_job = await sync_to_async(create_external_data_job)( + team_id=team.id, + external_data_source_id=new_source.pk, + workflow_id=activity_environment.info.workflow_id, + workflow_run_id=activity_environment.info.workflow_run_id, + external_data_schema_id=schema.id, + ) + + inputs = UpdateExternalDataJobStatusInputs( + id=str(new_job.id), + run_id=str(new_job.id), + status=ExternalDataJob.Status.COMPLETED, + latest_error=None, + internal_error="Some other retryable error", + schema_id=str(schema.pk), + team_id=team.id, + ) + + await activity_environment.run(update_external_data_job_model, inputs) + await sync_to_async(new_job.refresh_from_db)() + await sync_to_async(schema.refresh_from_db)() + + assert new_job.status == ExternalDataJob.Status.COMPLETED + assert schema.status == ExternalDataJob.Status.COMPLETED + assert schema.should_sync is True + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_update_external_job_activity_with_non_retryable_error(activity_environment, team, **kwargs): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Stripe", + ) + + schema = await sync_to_async(ExternalDataSchema.objects.create)( + name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0], + team_id=team.id, + source_id=new_source.pk, + should_sync=True, + ) + + new_job = await sync_to_async(create_external_data_job)( + team_id=team.id, + external_data_source_id=new_source.pk, + workflow_id=activity_environment.info.workflow_id, + workflow_run_id=activity_environment.info.workflow_run_id, + external_data_schema_id=schema.id, + ) + + inputs = UpdateExternalDataJobStatusInputs( + id=str(new_job.id), + run_id=str(new_job.id), + status=ExternalDataJob.Status.COMPLETED, + latest_error=None, + internal_error="NoSuchTableError: TableA", + schema_id=str(schema.pk), + team_id=team.id, + ) + with mock.patch("posthog.warehouse.models.external_data_schema.external_data_workflow_exists", return_value=False): + await activity_environment.run(update_external_data_job_model, inputs) + + await sync_to_async(new_job.refresh_from_db)() + await sync_to_async(schema.refresh_from_db)() + + assert new_job.status == ExternalDataJob.Status.COMPLETED + assert schema.status == ExternalDataJob.Status.COMPLETED + assert schema.should_sync is False + + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_run_stripe_job(activity_environment, team, minio_client, **kwargs): diff --git a/posthog/warehouse/api/external_data_schema.py b/posthog/warehouse/api/external_data_schema.py index eaba69507f392..116455992060d 100644 --- a/posthog/warehouse/api/external_data_schema.py +++ b/posthog/warehouse/api/external_data_schema.py @@ -24,10 +24,13 @@ delete_data_import_folder, ) from posthog.warehouse.models.external_data_schema import ( + filter_mysql_incremental_fields, filter_postgres_incremental_fields, filter_snowflake_incremental_fields, - get_postgres_schemas, get_snowflake_schemas, + get_sql_schemas_for_source_type, + sync_frequency_interval_to_sync_frequency, + sync_frequency_to_sync_frequency_interval, ) from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.models.ssh_tunnel import SSHTunnel @@ -42,6 +45,7 @@ class ExternalDataSchemaSerializer(serializers.ModelSerializer): sync_type = serializers.SerializerMethodField(read_only=True) incremental_field = serializers.SerializerMethodField(read_only=True) incremental_field_type = serializers.SerializerMethodField(read_only=True) + sync_frequency = serializers.SerializerMethodField(read_only=True) class Meta: model = ExternalDataSchema @@ -91,6 +95,9 @@ def get_table(self, schema: ExternalDataSchema) -> Optional[dict]: return SimpleTableSerializer(schema.table, context={"database": hogql_context}).data or None + def get_sync_frequency(self, schema: ExternalDataSchema): + return sync_frequency_interval_to_sync_frequency(schema) + def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) -> ExternalDataSchema: data = self.context["request"].data @@ -133,7 +140,16 @@ def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) - validated_data["sync_type_config"] = payload should_sync = validated_data.get("should_sync", None) - sync_frequency = validated_data.get("sync_frequency", None) + sync_frequency = data.get("sync_frequency", None) + was_sync_frequency_updated = False + + if sync_frequency: + sync_frequency_interval = sync_frequency_to_sync_frequency_interval(sync_frequency) + + if sync_frequency_interval != instance.sync_frequency_interval: + was_sync_frequency_updated = True + validated_data["sync_frequency_interval"] = sync_frequency_interval + instance.sync_frequency_interval = sync_frequency_interval if should_sync is True and sync_type is None and instance.sync_type is None: raise ValidationError("Sync type must be set up first before enabling schema") @@ -149,7 +165,7 @@ def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) - if should_sync is True: sync_external_data_job_workflow(instance, create=True) - if sync_frequency: + if was_sync_frequency_updated: sync_external_data_job_workflow(instance, create=False) if trigger_refresh: @@ -253,7 +269,7 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): source: ExternalDataSource = instance.source incremental_columns: list[IncrementalField] = [] - if source.source_type == ExternalDataSource.Type.POSTGRES: + if source.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: # TODO(@Gilbert09): Move all this into a util and replace elsewhere host = source.job_inputs.get("host") port = source.job_inputs.get("port") @@ -282,7 +298,8 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): private_key=ssh_tunnel_auth_type_private_key, ) - pg_schemas = get_postgres_schemas( + db_schemas = get_sql_schemas_for_source_type( + source.source_type, host=host, port=port, database=database, @@ -292,10 +309,15 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): ssh_tunnel=ssh_tunnel, ) - columns = pg_schemas.get(instance.name, []) + columns = db_schemas.get(instance.name, []) + if source.source_type == ExternalDataSource.Type.POSTGRES: + incremental_fields_func = filter_postgres_incremental_fields + else: + incremental_fields_func = filter_mysql_incremental_fields + incremental_columns = [ {"field": name, "field_type": field_type, "label": name, "type": field_type} - for name, field_type in filter_postgres_incremental_fields(columns) + for name, field_type in incremental_fields_func(columns) ] elif source.source_type == ExternalDataSource.Type.SNOWFLAKE: # TODO(@Gilbert09): Move all this into a util and replace elsewhere diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index af3df2ec43ee2..059e1fe271154 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -34,7 +34,7 @@ from posthog.warehouse.models.external_data_schema import ( filter_postgres_incremental_fields, filter_snowflake_incremental_fields, - get_postgres_schemas, + get_sql_schemas_for_source_type, get_snowflake_schemas, ) @@ -50,7 +50,16 @@ logger = structlog.get_logger(__name__) -GenericPostgresError = "Could not connect to Postgres. Please check all connection details are valid." + +def get_generic_sql_error(source_type: ExternalDataSource.Type): + if source_type == ExternalDataSource.Type.MYSQL: + name = "MySQL" + else: + name = "Postgres" + + return f"Could not connect to {name}. Please check all connection details are valid." + + GenericSnowflakeError = "Could not connect to Snowflake. Please check all connection details are valid." PostgresErrors = { "password authentication failed for user": "Invalid user or password", @@ -248,9 +257,9 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: new_source_model = self._handle_hubspot_source(request, *args, **kwargs) elif source_type == ExternalDataSource.Type.ZENDESK: new_source_model = self._handle_zendesk_source(request, *args, **kwargs) - elif source_type == ExternalDataSource.Type.POSTGRES: + elif source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: try: - new_source_model, postgres_schemas = self._handle_postgres_source(request, *args, **kwargs) + new_source_model, sql_schemas = self._handle_sql_source(request, *args, **kwargs) except InternalPostgresError: return Response( status=status.HTTP_400_BAD_REQUEST, data={"message": "Cannot use internal Postgres database"} @@ -264,8 +273,8 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: payload = request.data["payload"] schemas = payload.get("schemas", None) - if source_type == ExternalDataSource.Type.POSTGRES: - default_schemas = postgres_schemas + if source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: + default_schemas = sql_schemas elif source_type == ExternalDataSource.Type.SNOWFLAKE: default_schemas = snowflake_schemas else: @@ -408,9 +417,7 @@ def _handle_hubspot_source(self, request: Request, *args: Any, **kwargs: Any) -> return new_source_model - def _handle_postgres_source( - self, request: Request, *args: Any, **kwargs: Any - ) -> tuple[ExternalDataSource, list[Any]]: + def _handle_sql_source(self, request: Request, *args: Any, **kwargs: Any) -> tuple[ExternalDataSource, list[Any]]: payload = request.data["payload"] prefix = request.data.get("prefix", None) source_type = request.data["source_type"] @@ -474,7 +481,16 @@ def _handle_postgres_source( private_key=ssh_tunnel_auth_type_private_key, ) - schemas = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + schemas = get_sql_schemas_for_source_type( + source_type, + host, + port, + database, + user, + password, + schema, + ssh_tunnel, + ) return new_source_model, schemas @@ -609,7 +625,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): ) # Get schemas and validate SQL credentials - if source_type == ExternalDataSource.Type.POSTGRES: + if source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: host = request.data.get("host", None) port = request.data.get("port", None) database = request.data.get("dbname", None) @@ -677,11 +693,20 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): ) try: - result = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + result = get_sql_schemas_for_source_type( + source_type, + host, + port, + database, + user, + password, + schema, + ssh_tunnel, + ) if len(result.keys()) == 0: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": "Postgres schema doesn't exist"}, + data={"message": "Schema doesn't exist"}, ) except OperationalError as e: exposed_error = self._expose_postgres_error(e) @@ -691,12 +716,12 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": exposed_error or GenericPostgresError}, + data={"message": exposed_error or get_generic_sql_error(source_type)}, ) except BaseSSHTunnelForwarderError as e: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": e.value or GenericPostgresError}, + data={"message": e.value or get_generic_sql_error(source_type)}, ) except Exception as e: capture_exception(e) @@ -704,7 +729,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": GenericPostgresError}, + data={"message": get_generic_sql_error(source_type)}, ) filtered_results = [ diff --git a/posthog/warehouse/api/table.py b/posthog/warehouse/api/table.py index 58d6296cb2ae3..d5929a9a04315 100644 --- a/posthog/warehouse/api/table.py +++ b/posthog/warehouse/api/table.py @@ -1,6 +1,6 @@ from typing import Any -from rest_framework import filters, request, response, serializers, status, viewsets +from rest_framework import exceptions, filters, request, response, serializers, status, viewsets from rest_framework.decorators import action from posthog.api.routing import TeamAndOrgViewSetMixin @@ -89,11 +89,17 @@ def get_external_schema(self, instance: DataWarehouseTable): return SimpleExternalDataSchemaSerializer(instance.externaldataschema_set.first(), read_only=True).data or None def create(self, validated_data): - validated_data["team_id"] = self.context["team_id"] + team_id = self.context["team_id"] + + table_name_exists = DataWarehouseTable.objects.filter(team_id=team_id, name=validated_data["name"]).exists() + if table_name_exists: + raise exceptions.ValidationError("Table name already exists.") + + validated_data["team_id"] = team_id validated_data["created_by"] = self.context["request"].user if validated_data.get("credential"): validated_data["credential"] = DataWarehouseCredential.objects.create( - team_id=self.context["team_id"], + team_id=team_id, access_key=validated_data["credential"]["access_key"], access_secret=validated_data["credential"]["access_secret"], ) diff --git a/posthog/warehouse/api/test/test_external_data_schema.py b/posthog/warehouse/api/test/test_external_data_schema.py index 44554f4dce17d..b63f3f1dfab4f 100644 --- a/posthog/warehouse/api/test/test_external_data_schema.py +++ b/posthog/warehouse/api/test/test_external_data_schema.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest import mock import uuid import psycopg @@ -370,7 +371,7 @@ def test_update_schema_sync_frequency(self): should_sync=True, status=ExternalDataSchema.Status.COMPLETED, sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, - sync_frequency=ExternalDataSchema.SyncFrequency.DAILY, + sync_frequency_interval=timedelta(hours=24), ) with ( @@ -385,11 +386,11 @@ def test_update_schema_sync_frequency(self): response = self.client.patch( f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", - data={"sync_frequency": "week"}, + data={"sync_frequency": "7day"}, ) assert response.status_code == 200 mock_sync_external_data_job_workflow.assert_called_once() schema.refresh_from_db() - assert schema.sync_frequency == ExternalDataSchema.SyncFrequency.WEEKLY + assert schema.sync_frequency_interval == timedelta(days=7) diff --git a/posthog/warehouse/api/test/test_external_data_source.py b/posthog/warehouse/api/test/test_external_data_source.py index 84517cafb32af..85fbec2cef0f7 100644 --- a/posthog/warehouse/api/test/test_external_data_source.py +++ b/posthog/warehouse/api/test/test_external_data_source.py @@ -15,6 +15,7 @@ from posthog.warehouse.models.external_data_job import ExternalDataJob +from posthog.warehouse.models.external_data_schema import sync_frequency_interval_to_sync_frequency class TestExternalDataSource(APIBaseTest): @@ -415,7 +416,7 @@ def test_get_external_data_source_with_schema(self): "status": schema.status, "sync_type": schema.sync_type, "table": schema.table, - "sync_frequency": schema.sync_frequency, + "sync_frequency": sync_frequency_interval_to_sync_frequency(schema), } ], ) @@ -584,9 +585,10 @@ def test_database_schema_non_postgres_source(self): assert table in table_names @patch( - "posthog.warehouse.api.external_data_source.get_postgres_schemas", return_value={"table_1": [("id", "integer")]} + "posthog.warehouse.api.external_data_source.get_sql_schemas_for_source_type", + return_value={"table_1": [("id", "integer")]}, ) - def test_internal_postgres(self, patch_get_postgres_schemas): + def test_internal_postgres(self, patch_get_sql_schemas_for_source_type): # This test checks handling of project ID 2 in Cloud US and project ID 1 in Cloud EU, # so let's make sure there are no projects with these IDs in the test DB Project.objects.filter(id__in=[1, 2]).delete() diff --git a/posthog/warehouse/api/test/test_table.py b/posthog/warehouse/api/test/test_table.py index 885f5c7267aaa..f8c451c35ab5b 100644 --- a/posthog/warehouse/api/test/test_table.py +++ b/posthog/warehouse/api/test/test_table.py @@ -232,3 +232,49 @@ def test_update_schema_400_with_invalid_type(self): assert response.status_code == 400 assert response.json()["message"] == "Can not parse type another_type for column id - type does not exist" assert table.columns == columns + + @patch( + "posthog.warehouse.models.table.DataWarehouseTable.get_columns", + return_value={ + "id": {"clickhouse": "Nullable(String)", "hogql": "StringDatabaseField", "valid": True}, + "a_column": {"clickhouse": "Nullable(String)", "hogql": "StringDatabaseField", "valid": True}, + }, + ) + @patch( + "posthog.warehouse.models.table.DataWarehouseTable.validate_column_type", + return_value=True, + ) + def test_table_name_duplicate(self, patch_get_columns, patch_validate_column_type): + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_tables/", + { + "name": "whatever", + "url_pattern": "https://your-org.s3.amazonaws.com/bucket/whatever.pqt", + "credential": { + "access_key": "_accesskey", + "access_secret": "_accesssecret", + }, + "format": "Parquet", + }, + ) + assert response.status_code == 201 + data: dict[str, Any] = response.json() + + table = DataWarehouseTable.objects.get(id=data["id"]) + + assert table is not None + + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_tables/", + { + "name": "whatever", + "url_pattern": "https://your-org.s3.amazonaws.com/bucket/whatever.pqt", + "credential": { + "access_key": "_accesskey", + "access_secret": "_accesssecret", + }, + "format": "Parquet", + }, + ) + assert response.status_code == 400 + assert DataWarehouseTable.objects.count() == 1 diff --git a/posthog/warehouse/data_load/service.py b/posthog/warehouse/data_load/service.py index 2425f186b5fa7..9ee54510227bc 100644 --- a/posthog/warehouse/data_load/service.py +++ b/posthog/warehouse/data_load/service.py @@ -1,5 +1,6 @@ from dataclasses import asdict from datetime import timedelta +from typing import TYPE_CHECKING from temporalio.client import ( Schedule, @@ -28,7 +29,6 @@ unpause_schedule, ) from posthog.temporal.utils import ExternalDataWorkflowInputs -from posthog.warehouse.models import ExternalDataSource import temporalio from temporalio.client import Client as TemporalClient from asgiref.sync import async_to_sync @@ -36,17 +36,19 @@ from django.conf import settings import s3fs -from posthog.warehouse.models.external_data_schema import ExternalDataSchema +if TYPE_CHECKING: + from posthog.warehouse.models import ExternalDataSource + from posthog.warehouse.models.external_data_schema import ExternalDataSchema -def get_sync_schedule(external_data_schema: ExternalDataSchema): +def get_sync_schedule(external_data_schema: "ExternalDataSchema"): inputs = ExternalDataWorkflowInputs( team_id=external_data_schema.team_id, external_data_schema_id=external_data_schema.id, external_data_source_id=external_data_schema.source_id, ) - sync_frequency = get_sync_frequency(external_data_schema) + sync_frequency, jitter = get_sync_frequency(external_data_schema) return Schedule( action=ScheduleActionStartWorkflow( @@ -56,30 +58,26 @@ def get_sync_schedule(external_data_schema: ExternalDataSchema): task_queue=str(DATA_WAREHOUSE_TASK_QUEUE), ), spec=ScheduleSpec( - intervals=[ - ScheduleIntervalSpec(every=sync_frequency, offset=timedelta(hours=external_data_schema.created_at.hour)) - ], - jitter=timedelta(hours=2), + intervals=[ScheduleIntervalSpec(every=sync_frequency)], + jitter=jitter, ), state=ScheduleState(note=f"Schedule for external data source: {external_data_schema.pk}"), policy=SchedulePolicy(overlap=ScheduleOverlapPolicy.SKIP), ) -def get_sync_frequency(external_data_schema: ExternalDataSchema): - if external_data_schema.sync_frequency == ExternalDataSchema.SyncFrequency.DAILY: - return timedelta(days=1) - elif external_data_schema.sync_frequency == ExternalDataSchema.SyncFrequency.WEEKLY: - return timedelta(weeks=1) - elif external_data_schema.sync_frequency == ExternalDataSchema.SyncFrequency.MONTHLY: - return timedelta(days=30) - else: - raise ValueError(f"Unknown sync frequency: {external_data_schema.source.sync_frequency}") +def get_sync_frequency(external_data_schema: "ExternalDataSchema") -> tuple[timedelta, timedelta]: + if external_data_schema.sync_frequency_interval <= timedelta(hours=1): + return (external_data_schema.sync_frequency_interval, timedelta(minutes=1)) + if external_data_schema.sync_frequency_interval <= timedelta(hours=12): + return (external_data_schema.sync_frequency_interval, timedelta(minutes=30)) + + return (external_data_schema.sync_frequency_interval, timedelta(hours=1)) def sync_external_data_job_workflow( - external_data_schema: ExternalDataSchema, create: bool = False -) -> ExternalDataSchema: + external_data_schema: "ExternalDataSchema", create: bool = False +) -> "ExternalDataSchema": temporal = sync_connect() schedule = get_sync_schedule(external_data_schema) @@ -93,8 +91,8 @@ def sync_external_data_job_workflow( async def a_sync_external_data_job_workflow( - external_data_schema: ExternalDataSchema, create: bool = False -) -> ExternalDataSchema: + external_data_schema: "ExternalDataSchema", create: bool = False +) -> "ExternalDataSchema": temporal = await async_connect() schedule = get_sync_schedule(external_data_schema) @@ -107,17 +105,17 @@ async def a_sync_external_data_job_workflow( return external_data_schema -def trigger_external_data_source_workflow(external_data_source: ExternalDataSource): +def trigger_external_data_source_workflow(external_data_source: "ExternalDataSource"): temporal = sync_connect() trigger_schedule(temporal, schedule_id=str(external_data_source.id)) -def trigger_external_data_workflow(external_data_schema: ExternalDataSchema): +def trigger_external_data_workflow(external_data_schema: "ExternalDataSchema"): temporal = sync_connect() trigger_schedule(temporal, schedule_id=str(external_data_schema.id)) -async def a_trigger_external_data_workflow(external_data_schema: ExternalDataSchema): +async def a_trigger_external_data_workflow(external_data_schema: "ExternalDataSchema"): temporal = await async_connect() await a_trigger_schedule(temporal, schedule_id=str(external_data_schema.id)) @@ -153,7 +151,7 @@ def delete_external_data_schedule(schedule_id: str): raise -async def a_delete_external_data_schedule(external_data_source: ExternalDataSource): +async def a_delete_external_data_schedule(external_data_source: "ExternalDataSource"): temporal = await async_connect() try: await a_delete_schedule(temporal, schedule_id=str(external_data_source.id)) @@ -185,4 +183,6 @@ def delete_data_import_folder(folder_path: str): def is_any_external_data_job_paused(team_id: int) -> bool: + from posthog.warehouse.models import ExternalDataSource + return ExternalDataSource.objects.filter(team_id=team_id, status=ExternalDataSource.Status.PAUSED).exists() diff --git a/posthog/warehouse/models/external_data_job.py b/posthog/warehouse/models/external_data_job.py index 488f7ba0f9212..7b7f1cc15e1a8 100644 --- a/posthog/warehouse/models/external_data_job.py +++ b/posthog/warehouse/models/external_data_job.py @@ -2,14 +2,14 @@ from django.db.models import Prefetch from django.conf import settings from posthog.models.team import Team -from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr +from posthog.models.utils import CreatedMetaFields, UUIDModel, UpdatedMetaFields, sane_repr from posthog.settings import TEST from posthog.warehouse.s3 import get_s3_client from uuid import UUID from posthog.warehouse.util import database_sync_to_async -class ExternalDataJob(CreatedMetaFields, UUIDModel): +class ExternalDataJob(CreatedMetaFields, UpdatedMetaFields, UUIDModel): class Status(models.TextChoices): RUNNING = "Running", "Running" FAILED = "Failed", "Failed" diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index b2c68fb9e8cab..95bfb94a2f624 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,17 +1,27 @@ from collections import defaultdict +from datetime import timedelta from typing import Optional from django.db import models +from django_deprecate_fields import deprecate_field import snowflake.connector from posthog.models.team import Team -from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr +from posthog.models.utils import CreatedMetaFields, UUIDModel, UpdatedMetaFields, sane_repr import uuid import psycopg2 +import pymysql +from .external_data_source import ExternalDataSource +from posthog.warehouse.data_load.service import ( + external_data_workflow_exists, + pause_external_data_schedule, + sync_external_data_job_workflow, + unpause_external_data_schedule, +) from posthog.warehouse.types import IncrementalFieldType from posthog.warehouse.models.ssh_tunnel import SSHTunnel from posthog.warehouse.util import database_sync_to_async -class ExternalDataSchema(CreatedMetaFields, UUIDModel): +class ExternalDataSchema(CreatedMetaFields, UpdatedMetaFields, UUIDModel): class Status(models.TextChoices): RUNNING = "Running", "Running" PAUSED = "Paused", "Paused" @@ -47,8 +57,12 @@ class SyncFrequency(models.TextChoices): default=dict, blank=True, ) - sync_frequency: models.CharField = models.CharField( - max_length=128, choices=SyncFrequency.choices, default=SyncFrequency.DAILY, blank=True + # Deprecated in favour of `sync_frequency_interval` + sync_frequency = deprecate_field( + models.CharField(max_length=128, choices=SyncFrequency.choices, default=SyncFrequency.DAILY, blank=True) + ) + sync_frequency_interval: models.DurationField = models.DurationField( + default=timedelta(hours=6), null=True, blank=True ) __repr__ = sane_repr("name") @@ -78,6 +92,26 @@ def aget_schema_by_id(schema_id: str, team_id: int) -> ExternalDataSchema | None return ExternalDataSchema.objects.prefetch_related("source").get(id=schema_id, team_id=team_id) +@database_sync_to_async +def aupdate_should_sync(schema_id: str, team_id: int, should_sync: bool) -> ExternalDataSchema | None: + schema = ExternalDataSchema.objects.get(id=schema_id, team_id=team_id) + schema.should_sync = should_sync + schema.save() + + schedule_exists = external_data_workflow_exists(schema_id) + + if schedule_exists: + if should_sync is False: + pause_external_data_schedule(schema_id) + elif should_sync is True: + unpause_external_data_schedule(schema_id) + else: + if should_sync is True: + sync_external_data_job_workflow(schema, create=True) + + return schema + + @database_sync_to_async def get_active_schemas_for_source_id(source_id: uuid.UUID, team_id: int): return list(ExternalDataSchema.objects.filter(team_id=team_id, source_id=source_id, should_sync=True).all()) @@ -97,6 +131,48 @@ def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, t ExternalDataSchema.objects.create(name=schema, team_id=team_id, source_id=source_id, should_sync=False) +def sync_frequency_to_sync_frequency_interval(frequency: str) -> timedelta: + if frequency == "5min": + return timedelta(minutes=5) + if frequency == "30min": + return timedelta(minutes=30) + if frequency == "1hour": + return timedelta(hours=1) + if frequency == "6hour": + return timedelta(hours=6) + if frequency == "12hour": + return timedelta(hours=12) + if frequency == "24hour": + return timedelta(hours=24) + if frequency == "7day": + return timedelta(days=7) + if frequency == "30day": + return timedelta(days=30) + + raise ValueError(f"Frequency {frequency} is not supported") + + +def sync_frequency_interval_to_sync_frequency(schema: ExternalDataSchema) -> str: + if schema.sync_frequency_interval == timedelta(minutes=5): + return "5min" + if schema.sync_frequency_interval == timedelta(minutes=30): + return "30min" + if schema.sync_frequency_interval == timedelta(hours=1): + return "1hour" + if schema.sync_frequency_interval == timedelta(hours=6): + return "6hour" + if schema.sync_frequency_interval == timedelta(hours=12): + return "12hour" + if schema.sync_frequency_interval == timedelta(hours=24): + return "24hour" + if schema.sync_frequency_interval == timedelta(days=7): + return "7day" + if schema.sync_frequency_interval == timedelta(days=30): + return "30day" + + raise ValueError(f"Frequency interval {schema.sync_frequency_interval} is not supported") + + def filter_snowflake_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: results: list[tuple[str, IncrementalFieldType]] = [] for column_name, type in columns: @@ -196,3 +272,83 @@ def get_schemas(postgres_host: str, postgres_port: int): return get_schemas(tunnel.local_bind_host, tunnel.local_bind_port) return get_schemas(host, int(port)) + + +def filter_mysql_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: + results: list[tuple[str, IncrementalFieldType]] = [] + for column_name, type in columns: + type = type.lower() + if type.startswith("timestamp"): + results.append((column_name, IncrementalFieldType.Timestamp)) + elif type == "date": + results.append((column_name, IncrementalFieldType.Date)) + elif type == "datetime": + results.append((column_name, IncrementalFieldType.DateTime)) + elif type == "tinyint" or type == "smallint" or type == "mediumint" or type == "int" or type == "bigint": + results.append((column_name, IncrementalFieldType.Integer)) + + return results + + +def get_mysql_schemas( + host: str, + port: str, + database: str, + user: str, + password: str, + schema: str, + ssh_tunnel: SSHTunnel, +) -> dict[str, list[tuple[str, str]]]: + def get_schemas(mysql_host: str, mysql_port: int): + connection = pymysql.connect( + host=mysql_host, + port=mysql_port, + database=database, + user=user, + password=password, + connect_timeout=5, + ) + + with connection.cursor() as cursor: + cursor.execute( + "SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = %(schema)s ORDER BY table_name ASC", + {"schema": schema}, + ) + result = cursor.fetchall() + + schema_list = defaultdict(list) + for row in result: + schema_list[row[0]].append((row[1], row[2])) + + connection.close() + + return schema_list + + if ssh_tunnel.enabled: + with ssh_tunnel.get_tunnel(host, int(port)) as tunnel: + if tunnel is None: + raise Exception("Can't open tunnel to SSH server") + + return get_schemas(tunnel.local_bind_host, tunnel.local_bind_port) + + return get_schemas(host, int(port)) + + +def get_sql_schemas_for_source_type( + source_type: ExternalDataSource.Type, + host: str, + port: str, + database: str, + user: str, + password: str, + schema: str, + ssh_tunnel: SSHTunnel, +) -> dict[str, list[tuple[str, str]]]: + if source_type == ExternalDataSource.Type.POSTGRES: + schemas = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + elif source_type == ExternalDataSource.Type.MYSQL: + schemas = get_mysql_schemas(host, port, database, user, password, schema, ssh_tunnel) + else: + raise Exception("Unsupported source_type") + + return schemas diff --git a/posthog/warehouse/models/external_data_source.py b/posthog/warehouse/models/external_data_source.py index dc21af8db26ad..0919362b80a2c 100644 --- a/posthog/warehouse/models/external_data_source.py +++ b/posthog/warehouse/models/external_data_source.py @@ -2,7 +2,7 @@ from django.db import models from posthog.models.team import Team -from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr +from posthog.models.utils import CreatedMetaFields, UUIDModel, UpdatedMetaFields, sane_repr from posthog.warehouse.util import database_sync_to_async from uuid import UUID @@ -12,13 +12,14 @@ logger = structlog.get_logger(__name__) -class ExternalDataSource(CreatedMetaFields, UUIDModel): +class ExternalDataSource(CreatedMetaFields, UpdatedMetaFields, UUIDModel): class Type(models.TextChoices): STRIPE = "Stripe", "Stripe" HUBSPOT = "Hubspot", "Hubspot" POSTGRES = "Postgres", "Postgres" ZENDESK = "Zendesk", "Zendesk" SNOWFLAKE = "Snowflake", "Snowflake" + MYSQL = "MySQL", "MySQL" class Status(models.TextChoices): RUNNING = "Running", "Running" @@ -39,7 +40,7 @@ class SyncFrequency(models.TextChoices): destination_id: models.CharField = models.CharField(max_length=400, null=True, blank=True) team: models.ForeignKey = models.ForeignKey(Team, on_delete=models.CASCADE) - # Deprecated, use `ExternalDataSchema.sync_frequency` + # Deprecated, use `ExternalDataSchema.sync_frequency_interval` sync_frequency: models.CharField = models.CharField( max_length=128, choices=SyncFrequency.choices, default=SyncFrequency.DAILY, blank=True ) diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index a99ed5bcdf0b2..b6454ea379d80 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -13,6 +13,7 @@ CreatedMetaFields, DeletedMetaFields, UUIDModel, + UpdatedMetaFields, sane_repr, ) from posthog.schema import DatabaseSerializedFieldType, HogQLQueryModifiers @@ -65,7 +66,7 @@ def get_queryset(self): ) -class DataWarehouseTable(CreatedMetaFields, UUIDModel, DeletedMetaFields): +class DataWarehouseTable(CreatedMetaFields, UpdatedMetaFields, UUIDModel, DeletedMetaFields): # loading external_data_source and credentials is easily N+1, # so we have a custom object manager meaning people can't forget to load them # this also means we _always_ have two joins whenever we load tables diff --git a/requirements-dev.in b/requirements-dev.in index 9ab0252aecf67..5ca5431dbaf1c 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -26,6 +26,7 @@ packaging==23.1 black~=23.9.1 boto3-stubs[s3] types-markdown==3.3.9 +types-PyMySQL==1.1.0.20240524 types-PyYAML==6.0.1 types-freezegun==1.1.10 types-paramiko==3.4.0.20240423 diff --git a/requirements-dev.txt b/requirements-dev.txt index a528eb65d50a7..938eaead5395c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -37,7 +37,7 @@ black==23.9.1 # -r requirements-dev.in # datamodel-code-generator # inline-snapshot -boto3-stubs[s3]==1.34.84 +boto3-stubs==1.34.84 # via -r requirements-dev.in botocore-stubs==1.34.84 # via boto3-stubs @@ -62,7 +62,7 @@ click==8.1.7 # inline-snapshot colorama==0.4.4 # via pytest-watch -coverage[toml]==5.5 +coverage==5.5 # via pytest-cov cryptography==39.0.2 # via @@ -98,6 +98,7 @@ executing==2.0.1 faker==17.5.0 # via -r requirements-dev.in fakeredis==2.23.3 + # via -r requirements-dev.in flaky==3.7.0 # via -r requirements-dev.in freezegun==1.2.2 @@ -197,7 +198,7 @@ pycparser==2.20 # via # -c requirements.txt # cffi -pydantic[email]==2.5.3 +pydantic==2.5.3 # via # -c requirements.txt # datamodel-code-generator @@ -313,6 +314,8 @@ types-markdown==3.3.9 # via -r requirements-dev.in types-paramiko==3.4.0.20240423 # via -r requirements-dev.in +types-pymysql==1.1.0.20240524 + # via -r requirements-dev.in types-python-dateutil==2.8.3 # via -r requirements-dev.in types-pytz==2023.3.0.0 diff --git a/requirements.in b/requirements.in index 3d586910f5cd5..004661063be43 100644 --- a/requirements.in +++ b/requirements.in @@ -33,8 +33,8 @@ djangorestframework==3.15.1 djangorestframework-csv==2.1.1 djangorestframework-dataclasses==1.2.0 django-fernet-encrypted-fields==0.1.3 -dlt==0.5.1 -dlt[deltalake]==0.5.1 +dlt==0.5.2a1 +dlt[deltalake]==0.5.2a1 dnspython==2.2.1 drf-exceptions-hog==0.4.0 drf-extensions==0.7.0 @@ -56,6 +56,7 @@ paramiko==3.4.0 Pillow==10.2.0 posthoganalytics==3.5.0 psycopg2-binary==2.9.7 +PyMySQL==1.1.1 psycopg[binary]==3.1.18 pyarrow==15.0.0 pydantic==2.5.3 @@ -74,10 +75,10 @@ semantic_version==2.8.5 scikit-learn==1.5.0 slack_sdk==3.17.1 snowflake-connector-python==3.6.0 -snowflake-sqlalchemy==1.5.3 +snowflake-sqlalchemy==1.6.1 social-auth-app-django==5.0.0 social-auth-core==4.3.0 -sqlalchemy==1.4.52 +sqlalchemy==2.0.31 sshtunnel==0.4.0 statshog==1.0.6 structlog==23.2.0 diff --git a/requirements.txt b/requirements.txt index d5df2bf41f3e5..c4e07d9924e01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -204,7 +204,7 @@ djangorestframework-csv==2.1.1 # via -r requirements.in djangorestframework-dataclasses==1.2.0 # via -r requirements.in -dlt==0.5.1 +dlt==0.5.2a1 # via -r requirements.in dnspython==2.2.1 # via -r requirements.in @@ -414,6 +414,7 @@ psycopg-binary==3.1.18 # via psycopg psycopg2-binary==2.9.7 # via -r requirements.in +PyMySQL==1.1.1 py==1.11.0 # via retry pyarrow==15.0.0 @@ -570,7 +571,7 @@ snowflake-connector-python==3.6.0 # via # -r requirements.in # snowflake-sqlalchemy -snowflake-sqlalchemy==1.5.3 +snowflake-sqlalchemy==1.6.1 # via -r requirements.in social-auth-app-django==5.0.0 # via -r requirements.in @@ -582,7 +583,7 @@ sortedcontainers==2.4.0 # via # snowflake-connector-python # trio -sqlalchemy==1.4.52 +sqlalchemy==2.0.31 # via # -r requirements.in # snowflake-sqlalchemy @@ -641,6 +642,7 @@ typing-extensions==4.7.1 # pydantic-core # qrcode # snowflake-connector-python + # sqlalchemy # stripe # temporalio tzdata==2023.3 diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3617bee588a41..15c2210f61fb5 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1049,10 +1049,12 @@ dependencies = [ "serde-pickle", "serde_json", "sha1", + "sqlx", "thiserror", "tokio", "tracing", "tracing-subscriber", + "uuid", ] [[package]] diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 08ff21eaed0d8..e4d51dc308d34 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -15,6 +15,7 @@ tokio = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } bytes = { workspace = true } +once_cell = "1.18.0" rand = { workspace = true } redis = { version = "0.23.3", features = [ "tokio-comp", @@ -27,12 +28,13 @@ thiserror = { workspace = true } serde-pickle = { version = "1.1.1"} sha1 = "0.10.6" regex = "1.10.4" +sqlx = { workspace = true } +uuid = { workspace = true } [lints] workspace = true [dev-dependencies] assert-json-diff = { workspace = true } -once_cell = "1.18.0" reqwest = { workspace = true } diff --git a/rust/feature-flags/README.md b/rust/feature-flags/README.md index 1c9500900aade..efce036124524 100644 --- a/rust/feature-flags/README.md +++ b/rust/feature-flags/README.md @@ -1,6 +1,23 @@ # Testing +First, make sure docker compose is running (from main posthog repo), and test database exists: + +``` +docker compose -f ../docker-compose.dev.yml up -d +``` + +``` +TEST=1 python manage.py setup_test_environment --only-postgres +``` + +We only need to run the above once, when the test database is created. + +TODO: Would be nice to make the above automatic. + + +Then, run the tests: + ``` cargo test --package feature-flags ``` diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index ccf4735e5b04a..2caae80bf9af6 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -5,6 +5,9 @@ use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; use thiserror::Error; +use crate::database::CustomDatabaseError; +use crate::redis::CustomRedisError; + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] pub enum FlagsResponseCode { Ok = 1, @@ -42,6 +45,14 @@ pub enum FlagError { DataParsingError, #[error("redis unavailable")] RedisUnavailable, + #[error("database unavailable")] + DatabaseUnavailable, + #[error("Timed out while fetching data")] + TimeoutError, + // TODO: Consider splitting top-level errors (that are returned to the client) + // and FlagMatchingError, like timeouterror which we can gracefully handle. + // This will make the `into_response` a lot clearer as well, since it wouldn't + // have arbitrary errors that actually never make it to the client. } impl IntoResponse for FlagError { @@ -58,10 +69,53 @@ impl IntoResponse for FlagError { FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), - FlagError::DataParsingError | FlagError::RedisUnavailable => { - (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) - } + FlagError::DataParsingError + | FlagError::RedisUnavailable + | FlagError::DatabaseUnavailable + | FlagError::TimeoutError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), } .into_response() } } + +impl From for FlagError { + fn from(e: CustomRedisError) -> Self { + match e { + CustomRedisError::NotFound => FlagError::TokenValidationError, + CustomRedisError::PickleError(e) => { + tracing::error!("failed to fetch data: {}", e); + FlagError::DataParsingError + } + CustomRedisError::Timeout(_) => FlagError::TimeoutError, + CustomRedisError::Other(e) => { + tracing::error!("Unknown redis error: {}", e); + FlagError::RedisUnavailable + } + } + } +} + +impl From for FlagError { + fn from(e: CustomDatabaseError) -> Self { + match e { + CustomDatabaseError::NotFound => FlagError::TokenValidationError, + CustomDatabaseError::Other(_) => { + tracing::error!("failed to get connection: {}", e); + FlagError::DatabaseUnavailable + } + CustomDatabaseError::Timeout(_) => FlagError::TimeoutError, + } + } +} + +impl From for FlagError { + fn from(e: sqlx::Error) -> Self { + // TODO: Be more precise with error handling here + tracing::error!("sqlx error: {}", e); + println!("sqlx error: {}", e); + match e { + sqlx::Error::RowNotFound => FlagError::TokenValidationError, + _ => FlagError::DatabaseUnavailable, + } + } +} diff --git a/rust/feature-flags/src/config.rs b/rust/feature-flags/src/config.rs index cc7ad37bf72c1..d9e1bf06b1ee3 100644 --- a/rust/feature-flags/src/config.rs +++ b/rust/feature-flags/src/config.rs @@ -1,16 +1,17 @@ -use std::net::SocketAddr; - use envconfig::Envconfig; +use once_cell::sync::Lazy; +use std::net::SocketAddr; +use std::str::FromStr; -#[derive(Envconfig, Clone)] +#[derive(Envconfig, Clone, Debug)] pub struct Config { #[envconfig(default = "127.0.0.1:3001")] pub address: SocketAddr, - #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] pub write_database_url: String, - #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] pub read_database_url: String, #[envconfig(default = "1024")] @@ -21,4 +22,83 @@ pub struct Config { #[envconfig(default = "redis://localhost:6379/")] pub redis_url: String, + + #[envconfig(default = "1")] + pub acquire_timeout_secs: u64, +} + +impl Config { + pub fn default_test_config() -> Self { + Self { + address: SocketAddr::from_str("127.0.0.1:0").unwrap(), + redis_url: "redis://localhost:6379/".to_string(), + write_database_url: "postgres://posthog:posthog@localhost:5432/test_posthog" + .to_string(), + read_database_url: "postgres://posthog:posthog@localhost:5432/test_posthog".to_string(), + max_concurrent_jobs: 1024, + max_pg_connections: 100, + acquire_timeout_secs: 1, + } + } +} + +pub static DEFAULT_TEST_CONFIG: Lazy = Lazy::new(Config::default_test_config); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = Config::init_from_env().unwrap(); + assert_eq!( + config.address, + SocketAddr::from_str("127.0.0.1:3001").unwrap() + ); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } + + #[test] + fn test_default_test_config() { + let config = Config::default_test_config(); + assert_eq!(config.address, SocketAddr::from_str("127.0.0.1:0").unwrap()); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } + + #[test] + fn test_default_test_config_static() { + let config = &*DEFAULT_TEST_CONFIG; + assert_eq!(config.address, SocketAddr::from_str("127.0.0.1:0").unwrap()); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } } diff --git a/rust/feature-flags/src/database.rs b/rust/feature-flags/src/database.rs new file mode 100644 index 0000000000000..29360d22b9444 --- /dev/null +++ b/rust/feature-flags/src/database.rs @@ -0,0 +1,98 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use sqlx::{ + pool::PoolConnection, + postgres::{PgPoolOptions, PgRow}, + Postgres, +}; +use thiserror::Error; +use tokio::time::timeout; + +use crate::config::Config; + +const DATABASE_TIMEOUT_MILLISECS: u64 = 1000; + +#[derive(Error, Debug)] +pub enum CustomDatabaseError { + #[error("Not found in database")] + NotFound, + + #[error("Pg error: {0}")] + Other(#[from] sqlx::Error), + + #[error("Timeout error")] + Timeout(#[from] tokio::time::error::Elapsed), +} + +/// A simple db wrapper +/// Supports running any arbitrary query with a timeout. +/// TODO: Make sqlx prepared statements work with pgbouncer, potentially by setting pooling mode to session. +#[async_trait] +pub trait Client { + async fn get_connection(&self) -> Result, CustomDatabaseError>; + async fn run_query( + &self, + query: String, + parameters: Vec, + timeout_ms: Option, + ) -> Result, CustomDatabaseError>; +} + +pub struct PgClient { + pool: sqlx::PgPool, +} + +impl PgClient { + pub async fn new_read_client(config: &Config) -> Result { + let pool = PgPoolOptions::new() + .max_connections(config.max_pg_connections) + .acquire_timeout(Duration::from_secs(1)) + .test_before_acquire(true) + .connect(&config.read_database_url) + .await?; + + Ok(PgClient { pool }) + } + + pub async fn new_write_client(config: &Config) -> Result { + let pool = PgPoolOptions::new() + .max_connections(config.max_pg_connections) + .acquire_timeout(Duration::from_secs(1)) + .test_before_acquire(true) + .connect(&config.write_database_url) + .await?; + + Ok(PgClient { pool }) + } +} + +#[async_trait] +impl Client for PgClient { + async fn run_query( + &self, + query: String, + parameters: Vec, + timeout_ms: Option, + ) -> Result, CustomDatabaseError> { + let built_query = sqlx::query(&query); + let built_query = parameters + .iter() + .fold(built_query, |acc, param| acc.bind(param)); + let query_results = built_query.fetch_all(&self.pool); + + let timeout_ms = match timeout_ms { + Some(ms) => ms, + None => DATABASE_TIMEOUT_MILLISECS, + }; + + let fut = timeout(Duration::from_secs(timeout_ms), query_results).await?; + + Ok(fut?) + } + + async fn get_connection(&self) -> Result, CustomDatabaseError> { + Ok(self.pool.acquire().await?) + } +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index fbbd0445b5998..cc208ae8b073f 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,11 +1,8 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; -use crate::{ - api::FlagError, - redis::{Client, CustomRedisError}, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; // TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. // TODO: Add integration tests across repos to ensure this doesn't happen. @@ -46,7 +43,7 @@ pub struct PropertyFilter { pub operator: Option, #[serde(rename = "type")] pub prop_type: String, - pub group_type_index: Option, + pub group_type_index: Option, } #[derive(Debug, Clone, Deserialize)] @@ -74,15 +71,15 @@ pub struct MultivariateFlagOptions { pub struct FlagFilters { pub groups: Vec, pub multivariate: Option, - pub aggregation_group_type_index: Option, + pub aggregation_group_type_index: Option, pub payloads: Option, pub super_groups: Option>, } #[derive(Debug, Clone, Deserialize)] pub struct FeatureFlag { - pub id: i64, - pub team_id: i64, + pub id: i32, + pub team_id: i32, pub name: Option, pub key: String, pub filters: FlagFilters, @@ -94,8 +91,20 @@ pub struct FeatureFlag { pub ensure_experience_continuity: bool, } +#[derive(Debug, Serialize, sqlx::FromRow)] +pub struct FeatureFlagRow { + pub id: i32, + pub team_id: i32, + pub name: Option, + pub key: String, + pub filters: serde_json::Value, + pub deleted: bool, + pub active: bool, + pub ensure_experience_continuity: bool, +} + impl FeatureFlag { - pub fn get_group_type_index(&self) -> Option { + pub fn get_group_type_index(&self) -> Option { self.filters.aggregation_group_type_index } @@ -121,27 +130,13 @@ impl FeatureFlagList { /// Returns feature flags from redis given a team_id #[instrument(skip_all)] pub async fn from_redis( - client: Arc, - team_id: i64, + client: Arc, + team_id: i32, ) -> Result { // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_flags = client .get(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id)) - .await - .map_err(|e| match e { - CustomRedisError::NotFound => FlagError::TokenValidationError, - CustomRedisError::PickleError(_) => { - // TODO: Implement From trait for FlagError so we don't need to map - // CustomRedisError ourselves - tracing::error!("failed to fetch data: {}", e); - println!("failed to fetch data: {}", e); - FlagError::DataParsingError - } - _ => { - tracing::error!("Unknown redis error: {}", e); - FlagError::RedisUnavailable - } - })?; + .await?; let flags_list: Vec = serde_json::from_str(&serialized_flags).map_err(|e| { @@ -153,13 +148,45 @@ impl FeatureFlagList { Ok(FeatureFlagList { flags: flags_list }) } + + /// Returns feature flags from postgres given a team_id + #[instrument(skip_all)] + pub async fn from_pg( + client: Arc, + team_id: i32, + ) -> Result { + let mut conn = client.get_connection().await?; + // TODO: Clean up error handling here + + let query = "SELECT id, team_id, name, key, filters, deleted, active, ensure_experience_continuity FROM posthog_featureflag WHERE team_id = $1"; + let flags_row = sqlx::query_as::<_, FeatureFlagRow>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await?; + + let serialized_flags = serde_json::to_string(&flags_row).map_err(|e| { + tracing::error!("failed to serialize flags: {}", e); + println!("failed to serialize flags: {}", e); + FlagError::DataParsingError + })?; + + let flags_list: Vec = + serde_json::from_str(&serialized_flags).map_err(|e| { + tracing::error!("failed to parse data to flags list: {}", e); + println!("failed to parse data: {}", e); + + FlagError::DataParsingError + })?; + Ok(FeatureFlagList { flags: flags_list }) + } } #[cfg(test)] mod tests { use super::*; use crate::test_utils::{ - insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client, + insert_flags_for_team_in_pg, insert_flags_for_team_in_redis, insert_new_team_in_pg, + insert_new_team_in_redis, setup_pg_client, setup_redis_client, }; #[tokio::test] @@ -211,4 +238,64 @@ mod tests { _ => panic!("Expected RedisUnavailable"), }; } + + #[tokio::test] + async fn test_fetch_flags_from_pg() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + insert_flags_for_team_in_pg(client.clone(), team.id, None) + .await + .expect("Failed to insert flags"); + + let flags_from_pg = FeatureFlagList::from_pg(client.clone(), team.id) + .await + .expect("Failed to fetch flags from pg"); + + assert_eq!(flags_from_pg.flags.len(), 1); + let flag = flags_from_pg.flags.get(0).expect("Flags should be in pg"); + + assert_eq!(flag.key, "flag1"); + assert_eq!(flag.team_id, team.id); + assert_eq!(flag.filters.groups.len(), 1); + assert_eq!( + flag.filters.groups[0] + .properties + .as_ref() + .expect("Properties don't exist on flag") + .len(), + 1 + ); + let property_filter = &flag.filters.groups[0] + .properties + .as_ref() + .expect("Properties don't exist on flag")[0]; + + assert_eq!(property_filter.key, "email"); + assert_eq!(property_filter.value, "a@b.com"); + assert_eq!(property_filter.operator, None); + assert_eq!(property_filter.prop_type, "person"); + assert_eq!(property_filter.group_type_index, None); + assert_eq!(flag.filters.groups[0].rollout_percentage, Some(50.0)); + } + + // TODO: Add more tests to validate deserialization of flags. + // TODO: Also make sure old flag data is handled, or everything is migrated to new style in production + + #[tokio::test] + async fn test_fetch_empty_team_from_pg() { + let client = setup_pg_client(None).await; + + match FeatureFlagList::from_pg(client.clone(), 1234) + .await + .expect("Failed to fetch flags from pg") + { + FeatureFlagList { flags } => { + assert_eq!(flags.len(), 0); + } + } + } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 510fc153dc87a..485d8a646e823 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,6 +1,12 @@ -use crate::flag_definitions::{FeatureFlag, FlagGroupType}; +use crate::{ + api::FlagError, + database::Client as DatabaseClient, + flag_definitions::{FeatureFlag, FlagGroupType}, + property_matching::match_property, +}; +use serde_json::Value; use sha1::{Digest, Sha1}; -use std::fmt::Write; +use std::{collections::HashMap, fmt::Write, sync::Arc}; #[derive(Debug, PartialEq, Eq)] pub struct FeatureFlagMatch { @@ -11,6 +17,11 @@ pub struct FeatureFlagMatch { //payload } +#[derive(Debug, sqlx::FromRow)] +pub struct Person { + pub properties: sqlx::types::Json>, +} + // TODO: Rework FeatureFlagMatcher - python has a pretty awkward interface, where we pass in all flags, and then again // the flag to match. I don't think there's any reason anymore to store the flags in the matcher, since we can just // pass the flag to match directly to the get_match method. This will also make the matcher more stateless. @@ -21,23 +32,30 @@ pub struct FeatureFlagMatch { // for all teams. If not, we can have a LRU cache, or a cache that stores only the most recent N keys. // But, this can be a future refactor, for now just focusing on getting the basic matcher working, write lots and lots of tests // and then we can easily refactor stuff around. -#[derive(Debug)] +// #[derive(Debug)] pub struct FeatureFlagMatcher { // pub flags: Vec, pub distinct_id: String, + pub database_client: Option>, + cached_properties: Option>, } const LONG_SCALE: u64 = 0xfffffffffffffff; impl FeatureFlagMatcher { - pub fn new(distinct_id: String) -> Self { + pub fn new( + distinct_id: String, + database_client: Option>, + ) -> Self { FeatureFlagMatcher { // flags, distinct_id, + database_client, + cached_properties: None, } } - pub fn get_match(&self, feature_flag: &FeatureFlag) -> FeatureFlagMatch { + pub async fn get_match(&mut self, feature_flag: &FeatureFlag) -> FeatureFlagMatch { if self.hashed_identifier(feature_flag).is_none() { return FeatureFlagMatch { matches: false, @@ -49,8 +67,9 @@ impl FeatureFlagMatcher { // TODO: Variant overrides condition sort for (index, condition) in feature_flag.get_conditions().iter().enumerate() { - let (is_match, _evaluation_reason) = - self.is_condition_match(feature_flag, condition, index); + let (is_match, _evaluation_reason) = self + .is_condition_match(feature_flag, condition, index) + .await; if is_match { // TODO: This is a bit awkward, we should handle overrides only when variants exist. @@ -82,20 +101,33 @@ impl FeatureFlagMatcher { } } - pub fn is_condition_match( - &self, + // TODO: Making all this mutable just to store a cached value is annoying. Can I refactor this to be non-mutable? + // Leaning a bit more towards a separate cache store for this. + pub async fn is_condition_match( + &mut self, feature_flag: &FeatureFlag, condition: &FlagGroupType, _index: usize, ) -> (bool, String) { let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0); let mut condition_match = true; - if condition.properties.is_some() { - // TODO: Handle matching conditions - if !condition.properties.as_ref().unwrap().is_empty() { - condition_match = false; + + if let Some(ref properties) = condition.properties { + if properties.is_empty() { + condition_match = true; + } else { + // TODO: First handle given override properties before going to db + let target_properties = self + .get_person_properties(feature_flag.team_id, self.distinct_id.clone()) + .await + .unwrap_or_default(); + // TODO: Handle db issues / person not found + + condition_match = properties.iter().all(|property| { + match_property(property, &target_properties, false).unwrap_or(false) + }); } - } + }; if !condition_match { return (false, "NO_CONDITION_MATCH".to_string()); @@ -157,4 +189,133 @@ impl FeatureFlagMatcher { } None } + + pub async fn get_person_properties( + &mut self, + team_id: i32, + distinct_id: String, + ) -> Result, FlagError> { + // TODO: Do we even need to cache here anymore? + // Depends on how often we're calling this function + // to match all flags for a single person + + if let Some(cached_props) = self.cached_properties.clone() { + // TODO: Maybe we don't want to copy around all user properties, this will by far be the largest chunk + // of data we're copying around. Can we work with references here? + // Worst case, just use a Rc. + return Ok(cached_props); + } + + if self.database_client.is_none() { + return Err(FlagError::DatabaseUnavailable); + } + + let mut conn = self + .database_client + .as_ref() + .expect("client should exist here") + .get_connection() + .await?; + + let query = r#" + SELECT "posthog_person"."properties" + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $3) + LIMIT 1; + "#; + + let row = sqlx::query_as::<_, Person>(query) + .bind(&distinct_id) + .bind(team_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await?; + + let props = match row { + Some(row) => row.properties.0, + None => HashMap::new(), + }; + + self.cached_properties = Some(props.clone()); + + Ok(props) + } +} + +#[cfg(test)] +mod tests { + + use serde_json::json; + + use super::*; + use crate::test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}; + + #[tokio::test] + async fn test_fetch_properties_from_pg_to_match() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + let distinct_id = "user_distinct_id".to_string(); + insert_person_for_team_in_pg(client.clone(), team.id, distinct_id.clone(), None) + .await + .expect("Failed to insert person"); + + let not_matching_distinct_id = "not_matching_distinct_id".to_string(); + insert_person_for_team_in_pg( + client.clone(), + team.id, + not_matching_distinct_id.clone(), + Some(json!({ "email": "a@x.com"})), + ) + .await + .expect("Failed to insert person"); + + let flag = serde_json::from_value(json!( + { + "id": 1, + "team_id": team.id, + "name": "flag1", + "key": "flag1", + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "a@b.com", + "type": "person" + } + ], + "rollout_percentage": 100 + } + ] + } + } + )) + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new(distinct_id, Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, true); + assert_eq!(match_result.variant, None); + + // property value is different + let mut matcher = FeatureFlagMatcher::new(not_matching_distinct_id, Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, false); + assert_eq!(match_result.variant, None); + + // person does not exist + let mut matcher = + FeatureFlagMatcher::new("other_distinct_id".to_string(), Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, false); + assert_eq!(match_result.variant, None); + } } diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 7f03747b9ee6d..7784bd7bf1b8d 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; pub mod config; +pub mod database; pub mod flag_definitions; pub mod flag_matching; pub mod property_matching; diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 8824d44efdbde..2fbc87c870930 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -2,18 +2,59 @@ use std::sync::Arc; use axum::{routing::post, Router}; -use crate::{redis::Client, v0_endpoint}; +use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; #[derive(Clone)] pub struct State { - pub redis: Arc, + pub redis: Arc, // TODO: Add pgClient when ready + pub postgres: Arc, } -pub fn router(redis: Arc) -> Router { - let state = State { redis }; +pub fn router(redis: Arc, postgres: Arc) -> Router +where + R: RedisClient + Send + Sync + 'static, + D: DatabaseClient + Send + Sync + 'static, +{ + let state = State { redis, postgres }; Router::new() .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) .with_state(state) } + +// TODO, eventually we can differentiate read and write postgres clients, if needed +// I _think_ everything is read-only, but I'm not 100% sure yet +// here's how that client would look +// use std::sync::Arc; + +// use axum::{routing::post, Router}; + +// use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; + +// #[derive(Clone)] +// pub struct State { +// pub redis: Arc, +// pub postgres_read: Arc, +// pub postgres_write: Arc, +// } + +// pub fn router( +// redis: Arc, +// postgres_read: Arc, +// postgres_write: Arc, +// ) -> Router +// where +// R: RedisClient + Send + Sync + 'static, +// D: DatabaseClient + Send + Sync + 'static, +// { +// let state = State { +// redis, +// postgres_read, +// postgres_write, +// }; + +// Router::new() +// .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) +// .with_state(state) +// } diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index ffe6b0efb7068..37bd721a9a51f 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use tokio::net::TcpListener; use crate::config::Config; - +use crate::database::PgClient; use crate::redis::RedisClient; use crate::router; @@ -13,13 +13,25 @@ pub async fn serve(config: Config, listener: TcpListener, shutdown: F) where F: Future + Send + 'static, { - let redis_client = - Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client")); + let redis_client = match RedisClient::new(config.redis_url.clone()) { + Ok(client) => Arc::new(client), + Err(e) => { + tracing::error!("Failed to create Redis client: {}", e); + return; + } + }; + + let read_postgres_client = match PgClient::new_read_client(&config).await { + Ok(client) => Arc::new(client), + Err(e) => { + tracing::error!("Failed to create read Postgres client: {}", e); + return; + } + }; - let app = router::router(redis_client); + // You can decide which client to pass to the router, or pass both if needed + let app = router::router(redis_client, read_postgres_client); - // run our app with hyper - // `axum::Server` is a re-export of `hyper::Server` tracing::info!("listening on {:?}", listener.local_addr().unwrap()); axum::serve( listener, diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index e872aa477968f..7c7cfd9547bbf 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -2,18 +2,15 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; -use crate::{ - api::FlagError, - redis::{Client, CustomRedisError}, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; // TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)] pub struct Team { - pub id: i64, + pub id: i32, pub name: String, pub api_token: String, } @@ -23,24 +20,13 @@ impl Team { #[instrument(skip_all)] pub async fn from_redis( - client: Arc, + client: Arc, token: String, ) -> Result { // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_team = client .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) - .await - .map_err(|e| match e { - CustomRedisError::NotFound => FlagError::TokenValidationError, - CustomRedisError::PickleError(_) => { - tracing::error!("failed to fetch data: {}", e); - FlagError::DataParsingError - } - _ => { - tracing::error!("Unknown redis error: {}", e); - FlagError::RedisUnavailable - } - })?; + .await?; // TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { @@ -50,6 +36,21 @@ impl Team { Ok(team) } + + pub async fn from_pg( + client: Arc, + token: String, + ) -> Result { + let mut conn = client.get_connection().await?; + + let query = "SELECT id, name, api_token FROM posthog_team WHERE api_token = $1"; + let row = sqlx::query_as::<_, Team>(query) + .bind(&token) + .fetch_one(&mut *conn) + .await?; + + Ok(row) + } } #[cfg(test)] @@ -60,14 +61,19 @@ mod tests { use super::*; use crate::{ team, - test_utils::{insert_new_team_in_redis, random_string, setup_redis_client}, + test_utils::{ + insert_new_team_in_pg, insert_new_team_in_redis, random_string, setup_pg_client, + setup_redis_client, + }, }; #[tokio::test] async fn test_fetch_team_from_redis() { let client = setup_redis_client(None); - let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let team = insert_new_team_in_redis(client.clone()) + .await + .expect("Failed to insert team in redis"); let target_token = team.api_token; @@ -137,4 +143,39 @@ mod tests { Ok(_) => panic!("Expected DataParsingError"), }; } + + #[tokio::test] + async fn test_fetch_team_from_pg() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + let target_token = team.api_token; + + let team_from_pg = Team::from_pg(client.clone(), target_token.clone()) + .await + .expect("Failed to fetch team from pg"); + + assert_eq!(team_from_pg.api_token, target_token); + assert_eq!(team_from_pg.id, team.id); + assert_eq!(team_from_pg.name, team.name); + } + + #[tokio::test] + async fn test_fetch_team_from_pg_with_invalid_token() { + // TODO: Figure out a way such that `run_database_migrations` is called only once, and already called + // before running these tests. + + let client = setup_pg_client(None).await; + let target_token = "xxxx".to_string(); + + match Team::from_pg(client.clone(), target_token.clone()).await { + Err(FlagError::TokenValidationError) => (), + _ => panic!("Expected TokenValidationError"), + }; + } + + // TODO: Handle cases where db connection fails. } diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 92bc8a4ff4494..9d1f5970d46b6 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,10 +1,13 @@ use anyhow::Error; -use serde_json::json; +use serde_json::{json, Value}; use std::sync::Arc; +use uuid::Uuid; use crate::{ - flag_definitions::{self, FeatureFlag}, - redis::{Client, RedisClient}, + config::{Config, DEFAULT_TEST_CONFIG}, + database::{Client as DatabaseClientTrait, PgClient}, + flag_definitions::{self, FeatureFlag, FeatureFlagRow}, + redis::{Client as RedisClientTrait, RedisClient}, team::{self, Team}, }; use rand::{distributions::Alphanumeric, Rng}; @@ -44,7 +47,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result, - team_id: i64, + team_id: i32, json_value: Option, ) -> Result<(), Error> { let payload = match json_value { @@ -124,3 +127,149 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { serde_json::from_str(&payload).expect("Failed to parse data to flags list"); flags } + +pub async fn setup_pg_client(config: Option<&Config>) -> Arc { + let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); + Arc::new( + PgClient::new_read_client(config) + .await + .expect("Failed to create pg read client"), + ) +} + +pub async fn insert_new_team_in_pg(client: Arc) -> Result { + const ORG_ID: &str = "019026a4be8000005bf3171d00629163"; + + client.run_query( + r#"INSERT INTO posthog_organization + (id, name, slug, created_at, updated_at, plugins_access_level, for_internal_metrics, is_member_join_email_enabled, enforce_2fa, is_hipaa, customer_id, available_product_features, personalization, setup_section_2_completed, domain_whitelist) + VALUES + ($1::uuid, 'Test Organization', 'test-organization', '2024-06-17 14:40:49.298579+00:00', '2024-06-17 14:40:49.298593+00:00', 9, false, true, NULL, false, NULL, '{}', '{}', true, '{}') + ON CONFLICT DO NOTHING"#.to_string(), + vec![ORG_ID.to_string()], + Some(2000), + ).await?; + + client + .run_query( + r#"INSERT INTO posthog_project + (id, organization_id, name, created_at) + VALUES + (1, $1::uuid, 'Test Team', '2024-06-17 14:40:51.329772+00:00') + ON CONFLICT DO NOTHING"# + .to_string(), + vec![ORG_ID.to_string()], + Some(2000), + ) + .await?; + + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id, + name: "team".to_string(), + api_token: token, + }; + let uuid = Uuid::now_v7(); + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_team + (id, uuid, organization_id, project_id, api_token, name, created_at, updated_at, app_urls, anonymize_ips, completed_snippet_onboarding, ingested_event, session_recording_opt_in, is_demo, access_control, test_account_filters, timezone, data_attributes, plugins_opt_in, opt_out_capture, event_names, event_names_with_usage, event_properties, event_properties_with_usage, event_properties_numerical) VALUES + ($1, $5, $2::uuid, 1, $3, $4, '2024-06-17 14:40:51.332036+00:00', '2024-06-17', '{}', false, false, false, false, false, false, '{}', 'UTC', '["data-attr"]', false, false, '[]', '[]', '[]', '[]', '[]')"# + ).bind(team.id).bind(ORG_ID).bind(&team.api_token).bind(&team.name).bind(uuid).execute(&mut *conn).await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(team) +} + +pub async fn insert_flags_for_team_in_pg( + client: Arc, + team_id: i32, + flag: Option, +) -> Result { + let id = rand::thread_rng().gen_range(0..10_000_000); + + let payload_flag = match flag { + Some(value) => value, + None => FeatureFlagRow { + id, + key: "flag1".to_string(), + name: Some("flag1 description".to_string()), + active: true, + deleted: false, + ensure_experience_continuity: false, + team_id, + filters: json!({ + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "a@b.com", + "type": "person", + }, + ], + "rollout_percentage": 50, + }, + ], + }), + }, + }; + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_featureflag + (id, team_id, name, key, filters, deleted, active, ensure_experience_continuity, created_at) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, '2024-06-17')"# + ).bind(payload_flag.id).bind(team_id).bind(&payload_flag.name).bind(&payload_flag.key).bind(&payload_flag.filters).bind(payload_flag.deleted).bind(payload_flag.active).bind(payload_flag.ensure_experience_continuity).execute(&mut *conn).await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(payload_flag) +} + +pub async fn insert_person_for_team_in_pg( + client: Arc, + team_id: i32, + distinct_id: String, + properties: Option, +) -> Result<(), Error> { + let payload = match properties { + Some(value) => value, + None => json!({ + "email": "a@b.com", + "name": "Alice", + }), + }; + + let uuid = Uuid::now_v7(); + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#" + WITH inserted_person AS ( + INSERT INTO posthog_person ( + created_at, properties, properties_last_updated_at, + properties_last_operation, team_id, is_user_id, is_identified, uuid, version + ) + VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0) + RETURNING * + ) + INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version) + VALUES ($4, (SELECT id FROM inserted_person), $5, 0) + "#, + ) + .bind(&payload) + .bind(team_id) + .bind(uuid) + .bind(&distinct_id) + .bind(team_id) + .execute(&mut *conn) + .await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(()) +} diff --git a/rust/feature-flags/tests/common/mod.rs b/rust/feature-flags/tests/common/mod.rs index c8644fe1f4542..2b14292e0fda3 100644 --- a/rust/feature-flags/tests/common/mod.rs +++ b/rust/feature-flags/tests/common/mod.rs @@ -1,9 +1,6 @@ use std::net::SocketAddr; -use std::str::FromStr; -use std::string::ToString; use std::sync::Arc; -use once_cell::sync::Lazy; use reqwest::header::CONTENT_TYPE; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -11,15 +8,6 @@ use tokio::sync::Notify; use feature_flags::config::Config; use feature_flags::server::serve; -pub static DEFAULT_CONFIG: Lazy = Lazy::new(|| Config { - address: SocketAddr::from_str("127.0.0.1:0").unwrap(), - redis_url: "redis://localhost:6379/".to_string(), - write_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), - read_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), - max_concurrent_jobs: 1024, - max_pg_connections: 100, -}); - pub struct ServerHandle { pub addr: SocketAddr, shutdown: Arc, diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 4a24b0e16d50e..d4b55ed4e9001 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -5,8 +5,8 @@ use feature_flags::flag_matching::{FeatureFlagMatch, FeatureFlagMatcher}; use feature_flags::test_utils::create_flag_from_json; use serde_json::json; -#[test] -fn it_is_consistent_with_rollout_calculation_for_simple_flags() { +#[tokio::test] +async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { let flags = create_flag_from_json(Some( json!([{ "id": 1, @@ -107,7 +107,9 @@ fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id).get_match(&flags[0]); + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + .get_match(&flags[0]) + .await; if results[i] { assert_eq!( @@ -129,8 +131,8 @@ fn it_is_consistent_with_rollout_calculation_for_simple_flags() { } } -#[test] -fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { +#[tokio::test] +async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { let flags = create_flag_from_json(Some( json!([{ "id": 1, @@ -1186,7 +1188,9 @@ fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id).get_match(&flags[0]); + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + .get_match(&flags[0]) + .await; if results[i].is_some() { assert_eq!( diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index 2ceba24efd712..f9a46e1c543af 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -6,13 +6,14 @@ use serde_json::{json, Value}; use crate::common::*; +use feature_flags::config::DEFAULT_TEST_CONFIG; use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client}; pub mod common; #[tokio::test] async fn it_sends_flag_request() -> Result<()> { - let config = DEFAULT_CONFIG.clone(); + let config = DEFAULT_TEST_CONFIG.clone(); let distinct_id = "user_distinct_id".to_string(); @@ -50,7 +51,7 @@ async fn it_sends_flag_request() -> Result<()> { #[tokio::test] async fn it_rejects_invalid_headers_flag_request() -> Result<()> { - let config = DEFAULT_CONFIG.clone(); + let config = DEFAULT_TEST_CONFIG.clone(); let distinct_id = "user_distinct_id".to_string();