diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 720c3c3642887..16a42c609e16c 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -73,6 +73,7 @@ import { RolesListParams, RoleType, ScheduledChangeType, + SchemaIncrementalFieldsResponse, SearchListParams, SearchResponse, SessionRecordingPlaylistType, @@ -2080,6 +2081,9 @@ const api = { async resync(schemaId: ExternalDataSourceSchema['id']): Promise { await new ApiRequest().externalDataSourceSchema(schemaId).withAction('resync').create() }, + async incremental_fields(schemaId: ExternalDataSourceSchema['id']): Promise { + return await new ApiRequest().externalDataSourceSchema(schemaId).withAction('incremental_fields').create() + }, }, dataWarehouseViewLinks: { diff --git a/frontend/src/lib/lemon-ui/LemonRadio/LemonRadio.tsx b/frontend/src/lib/lemon-ui/LemonRadio/LemonRadio.tsx index fba801de2a8e2..686255984134a 100644 --- a/frontend/src/lib/lemon-ui/LemonRadio/LemonRadio.tsx +++ b/frontend/src/lib/lemon-ui/LemonRadio/LemonRadio.tsx @@ -15,6 +15,7 @@ export interface LemonRadioProps { onChange: (newValue: T) => void options: LemonRadioOption[] className?: string + radioPosition?: 'center' | 'top' } /** Single choice radio. */ @@ -23,6 +24,7 @@ export function LemonRadio({ onChange, options, className, + radioPosition, }: LemonRadioProps): JSX.Element { return (
@@ -32,7 +34,11 @@ export function LemonRadio({ key={value} className={clsx( 'grid items-center gap-x-2 grid-cols-[min-content_auto] text-sm', - disabledReason ? 'text-muted cursor-not-allowed' : 'cursor-pointer' + disabledReason ? 'text-muted cursor-not-allowed' : 'cursor-pointer', + { + 'items-baseline': radioPosition === 'top', + 'items-center': radioPosition === 'center' || !radioPosition, + } )} > [] => { - const options: LemonSelectOptionLeaf[] = [] - - if (schema.sync_types.full_refresh) { - options.push({ value: 'full_refresh', label: 'Full refresh' }) - } - - if (schema.sync_types.incremental) { - options.push({ value: 'incremental', label: 'Incremental' }) - } - - return options -} +import { SyncMethodForm } from './SyncMethodForm' export default function PostgresSchemaForm(): JSX.Element { - const { toggleSchemaShouldSync, updateSchemaSyncType } = useActions(sourceWizardLogic) + const { toggleSchemaShouldSync, openSyncMethodModal } = useActions(sourceWizardLogic) const { databaseSchema } = useValues(sourceWizardLogic) - const [toggleAllState, setToggleAllState] = useState(false) - - const toggleAllSwitches = (): void => { - databaseSchema.forEach((schema) => { - toggleSchemaShouldSync(schema, toggleAllState) - }) - - setToggleAllState(!toggleAllState) - } return ( -
-
- +
+
+ { + return ( + { + toggleSchemaShouldSync(schema, checked) + }} + /> + ) + }, }, - }, - { - title: ( - <> - Sync - toggleAllSwitches()} - > - {toggleAllState ? 'Enable' : 'Disable'} all - - - ), - key: 'should_sync', - render: function RenderShouldSync(_, schema) { - return ( - { - toggleSchemaShouldSync(schema, checked) - }} - /> - ) + { + title: 'Table', + key: 'table', + render: function RenderTable(_, schema) { + return schema.table + }, }, - }, - { - key: 'sync_type', - title: 'Sync type', - tooltip: - 'Full refresh will refresh the full table on every sync, whereas incremental will only sync new and updated rows since the last sync', - render: (_, schema) => { - const options = syncTypesToOptions(schema) + { + key: 'sync_type', + title: 'Sync method', + align: 'right', + tooltip: + 'Full refresh will refresh the full table on every sync, whereas incremental will only sync new and updated rows since the last sync', + render: (_, schema) => { + if (!schema.sync_type) { + return ( +
+ openSyncMethodModal(schema)} + > + Set up + +
+ ) + } - return ( - updateSchemaSyncType(schema, newValue)} - /> - ) + return ( +
+ openSyncMethodModal(schema)} + > + {schema.sync_type === 'full_refresh' ? 'Full refresh' : 'Incremental'} + +
+ ) + }, }, - }, - ]} - /> + ]} + /> +
-
+ + + ) +} + +const SyncMethodModal = (): JSX.Element => { + const { cancelSyncMethodModal, updateSchemaSyncType, toggleSchemaShouldSync } = useActions(sourceWizardLogic) + const { syncMethodModalOpen, currentSyncMethodModalSchema } = useValues(sourceWizardLogic) + + if (!currentSyncMethodModalSchema) { + return <> + } + + return ( + + { + if (syncType === 'incremental') { + updateSchemaSyncType( + currentSyncMethodModalSchema, + syncType, + incrementalField, + incrementalFieldType + ) + } else { + updateSchemaSyncType(currentSyncMethodModalSchema, syncType ?? null, null, null) + } + + toggleSchemaShouldSync(currentSyncMethodModalSchema, true) + cancelSyncMethodModal() + }} + /> + ) } diff --git a/frontend/src/scenes/data-warehouse/external/forms/SyncMethodForm.tsx b/frontend/src/scenes/data-warehouse/external/forms/SyncMethodForm.tsx new file mode 100644 index 0000000000000..8d4bafe96632a --- /dev/null +++ b/frontend/src/scenes/data-warehouse/external/forms/SyncMethodForm.tsx @@ -0,0 +1,197 @@ +import { LemonButton, LemonSelect, LemonTag, lemonToast } from '@posthog/lemon-ui' +import { LemonRadio } from 'lib/lemon-ui/LemonRadio' +import { useEffect, useState } from 'react' + +import { ExternalDataSourceSyncSchema } from '~/types' + +const getIncrementalSyncSupported = ( + schema: ExternalDataSourceSyncSchema +): { disabled: true; disabledReason: string } | { disabled: false } => { + if (!schema.incremental_available) { + return { + disabled: true, + disabledReason: "Incremental replication isn't supported on this table", + } + } + + if (schema.incremental_fields.length === 0) { + return { + disabled: true, + disabledReason: 'No incremental fields found on table', + } + } + + return { + disabled: false, + } +} + +interface SyncMethodFormProps { + schema: ExternalDataSourceSyncSchema + onClose: () => void + onSave: ( + syncType: ExternalDataSourceSyncSchema['sync_type'], + incrementalField: string | null, + incrementalFieldType: string | null + ) => void + saveButtonIsLoading?: boolean + showRefreshMessageOnChange?: boolean +} + +const hasInputChanged = ( + newSchemaSyncType: ExternalDataSourceSyncSchema['sync_type'], + newSchemaIncrementalField: string | null, + originalSchemaSyncType: ExternalDataSourceSyncSchema['sync_type'], + originalSchemaIncrementalField: string | null +): boolean => { + if (originalSchemaSyncType !== newSchemaSyncType) { + return true + } + + if (newSchemaSyncType === 'incremental' && newSchemaIncrementalField !== originalSchemaIncrementalField) { + return true + } + + return false +} + +const getSaveDisabledReason = ( + syncType: 'full_refresh' | 'incremental' | undefined, + incrementalField: string | null +): string | undefined => { + if (!syncType) { + return 'You must select a sync method before saving' + } + + if (syncType === 'incremental' && !incrementalField) { + return 'You must select an incremental field' + } +} + +export const SyncMethodForm = ({ + schema, + onClose, + onSave, + saveButtonIsLoading, + showRefreshMessageOnChange, +}: SyncMethodFormProps): JSX.Element => { + const [originalSchemaSyncType] = useState(schema.sync_type ?? null) + const [originalSchemaIncrementalField] = useState(schema.incremental_field ?? null) + + const [radioValue, setRadioValue] = useState(schema.sync_type ?? undefined) + const [incrementalFieldValue, setIncrementalFieldValue] = useState(schema.incremental_field ?? null) + + useEffect(() => { + setRadioValue(schema.sync_type ?? undefined) + setIncrementalFieldValue(schema.incremental_field ?? null) + }, [schema.table]) + + const incrementalSyncSupported = getIncrementalSyncSupported(schema) + + const inputChanged = hasInputChanged( + radioValue ?? null, + incrementalFieldValue, + originalSchemaSyncType, + originalSchemaIncrementalField + ) + const showRefreshMessage = inputChanged && showRefreshMessageOnChange + + return ( + <> + +
+

Incremental replication

+ {!incrementalSyncSupported.disabled && ( + Recommended + )} +
+

+ When using incremental replication, we'll store the max value of the below field on + each sync and only sync rows with greater or equal value on the next run. +

+

+ You should pick a field that increments or updates each time the row is updated, + such as a updated_at timestamp. +

+ setIncrementalFieldValue(newValue)} + options={ + schema.incremental_fields.map((n) => ({ + value: n.field, + label: ( + <> + {n.label} + + {n.type} + + + ), + })) ?? [] + } + disabledReason={incrementalSyncSupported.disabled ? '' : undefined} + /> +
+ ), + }, + { + value: 'full_refresh', + label: ( +
+
+

Full table replication

+
+

+ We'll replicate the whole table on every sync. This can take longer to sync and + increase your monthly billing. +

+
+ ), + }, + ]} + onChange={(newValue) => setRadioValue(newValue)} + /> + {showRefreshMessage && ( +

+ Note: Changing the sync type or incremental replication field will trigger a full table refresh +

+ )} +
+ + Close + + { + if (radioValue === 'incremental') { + const fieldSelected = schema.incremental_fields.find( + (n) => n.field === incrementalFieldValue + ) + if (!fieldSelected) { + lemonToast.error('Selected field for incremental replication not found') + return + } + + onSave('incremental', incrementalFieldValue, fieldSelected.field_type) + } else { + onSave('full_refresh', null, null) + } + }} + > + Save + +
+ + ) +} diff --git a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx index e1409477f62ac..bbbed3d43a777 100644 --- a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx +++ b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx @@ -351,10 +351,14 @@ export const sourceWizardLogic = kea([ toggleSchemaShouldSync: (schema: ExternalDataSourceSyncSchema, shouldSync: boolean) => ({ schema, shouldSync }), updateSchemaSyncType: ( schema: ExternalDataSourceSyncSchema, - sync_type: ExternalDataSourceSyncSchema['sync_type'] + syncType: ExternalDataSourceSyncSchema['sync_type'], + incrementalField: string | null, + incrementalFieldType: string | null ) => ({ schema, - sync_type, + syncType, + incrementalField, + incrementalFieldType, }), clearSource: true, updateSource: (source: Partial) => ({ source }), @@ -366,6 +370,8 @@ export const sourceWizardLogic = kea([ setStep: (step: number) => ({ step }), getDatabaseSchemas: true, setManualLinkingProvider: (provider: ManualLinkSourceType) => ({ provider }), + openSyncMethodModal: (schema: ExternalDataSourceSyncSchema) => ({ schema }), + cancelSyncMethodModal: true, }), connect({ values: [ @@ -422,10 +428,13 @@ export const sourceWizardLogic = kea([ })) return newSchema }, - updateSchemaSyncType: (state, { schema, sync_type }) => { + updateSchemaSyncType: (state, { schema, syncType, incrementalField, incrementalFieldType }) => { const newSchema = state.map((s) => ({ ...s, - sync_type: s.table === schema.table ? sync_type : s.sync_type, + sync_type: s.table === schema.table ? syncType : s.sync_type, + incremental_field: s.table === schema.table ? incrementalField : s.incremental_field, + incremental_field_type: + s.table === schema.table ? incrementalFieldType : s.incremental_field_type, })) return newSchema }, @@ -462,6 +471,26 @@ export const sourceWizardLogic = kea([ setSourceId: (_, { sourceId }) => sourceId, }, ], + syncMethodModalOpen: [ + false as boolean, + { + openSyncMethodModal: () => true, + cancelSyncMethodModal: () => false, + }, + ], + currentSyncMethodModalSchema: [ + null as ExternalDataSourceSyncSchema | null, + { + openSyncMethodModal: (_, { schema }) => schema, + cancelSyncMethodModal: () => null, + updateSchemaSyncType: (_, { schema, syncType, incrementalField, incrementalFieldType }) => ({ + ...schema, + sync_type: syncType, + incremental_field: incrementalField, + incremental_field_type: incrementalFieldType, + }), + }, + ], }), selectors({ isManualLinkingSelected: [(s) => [s.selectedConnector], (selectedConnector): boolean => !selectedConnector], @@ -472,12 +501,20 @@ export const sourceWizardLogic = kea([ }, ], canGoNext: [ - (s) => [s.currentStep, s.isManualLinkingSelected], - (currentStep, isManualLinkingSelected): boolean => { - if (isManualLinkingSelected && currentStep == 1) { + (s) => [s.currentStep, s.isManualLinkingSelected, s.databaseSchema], + (currentStep, isManualLinkingSelected, databaseSchema): boolean => { + if (isManualLinkingSelected && currentStep === 1) { return false } + if (!isManualLinkingSelected && currentStep === 3) { + if (databaseSchema.filter((n) => n.should_sync).length === 0) { + return false + } + + return databaseSchema.filter((n) => n.should_sync && !n.sync_type).length === 0 + } + return true }, ], @@ -638,6 +675,8 @@ export const sourceWizardLogic = kea([ name: schema.table, should_sync: schema.should_sync, sync_type: schema.sync_type, + incremental_field: schema.incremental_field, + incremental_field_type: schema.incremental_field_type, })), }, }) diff --git a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx index c4b797426c6dc..f0e8b4f1943f6 100644 --- a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx +++ b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx @@ -2,7 +2,9 @@ import { TZLabel } from '@posthog/apps-common' import { LemonButton, LemonDialog, + LemonModal, LemonSelect, + LemonSkeleton, LemonSwitch, LemonTable, LemonTag, @@ -22,6 +24,7 @@ import s3Logo from 'public/s3-logo.png' import snowflakeLogo from 'public/snowflake-logo.svg' import stripeLogo from 'public/stripe-logo.svg' import zendeskLogo from 'public/zendesk-logo.svg' +import { useEffect } from 'react' import { urls } from 'scenes/urls' import { DataTableNode, NodeKind } from '~/queries/schema' @@ -33,7 +36,9 @@ import { ProductKey, } from '~/types' +import { SyncMethodForm } from '../external/forms/SyncMethodForm' import { dataWarehouseSettingsLogic } from './dataWarehouseSettingsLogic' +import { dataWarehouseSourcesTableSyncMethodModalLogic } from './dataWarehouseSourcesTableSyncMethodModalLogic' const StatusTagSetting = { Running: 'primary', @@ -272,154 +277,275 @@ const SchemaTable = ({ schemas }: SchemaTableProps): JSX.Element => { const { schemaReloadingById } = useValues(dataWarehouseSettingsLogic) return ( - {schema.name} - }, - }, - { - title: 'Refresh Type', - key: 'incremental', - render: function RenderIncremental(_, schema) { - return schema.incremental ? ( - - Incremental - - ) : ( - - Full Refresh - - ) - }, - }, - { - title: 'Enabled', - key: 'should_sync', - render: function RenderShouldSync(_, schema) { - return ( - { - updateSchema({ ...schema, should_sync: active }) - }} - /> - ) + <> + {schema.name} + }, }, - }, - { - title: 'Synced Table', - key: 'table', - render: function RenderTable(_, schema) { - if (schema.table) { - const query: DataTableNode = { - kind: NodeKind.DataTableNode, - full: true, - source: { - kind: NodeKind.HogQLQuery, - // TODO: Use `hogql` tag? - query: `SELECT ${schema.table.columns - .filter( - ({ table, fields, chain, schema_valid }) => - !table && !fields && !chain && schema_valid - ) - .map(({ name }) => name)} FROM ${ - schema.table.name === 'numbers' ? 'numbers(0, 10)' : schema.table.name - } LIMIT 100`, - }, + { + title: 'Sync method', + key: 'incremental', + render: function RenderIncremental(_, schema) { + const { openSyncMethodModal } = useActions( + dataWarehouseSourcesTableSyncMethodModalLogic({ schema }) + ) + + if (!schema.sync_type) { + return ( + <> + openSyncMethodModal(schema)} + > + Set up + + + + ) } + return ( - - {schema.table.name} - + <> + openSyncMethodModal(schema)} + > + {schema.sync_type == 'incremental' ? 'Incremental' : 'Full refresh'} + + + ) - } - return
Not yet synced
+ }, }, - }, - { - title: 'Last Synced At', - key: 'last_synced_at', - render: function Render(_, schema) { - return schema.last_synced_at ? ( - <> - - - ) : null + { + title: 'Enabled', + key: 'should_sync', + render: function RenderShouldSync(_, schema) { + return ( + { + updateSchema({ ...schema, should_sync: active }) + }} + /> + ) + }, }, - }, - { - title: 'Rows Synced', - key: 'rows_synced', - render: function Render(_, schema) { - return schema.table?.row_count ?? '' + { + title: 'Synced Table', + key: 'table', + render: function RenderTable(_, schema) { + if (schema.table) { + const query: DataTableNode = { + kind: NodeKind.DataTableNode, + full: true, + source: { + kind: NodeKind.HogQLQuery, + // TODO: Use `hogql` tag? + query: `SELECT ${schema.table.columns + .filter( + ({ table, fields, chain, schema_valid }) => + !table && !fields && !chain && schema_valid + ) + .map(({ name }) => name)} FROM ${ + schema.table.name === 'numbers' ? 'numbers(0, 10)' : schema.table.name + } LIMIT 100`, + }, + } + return ( + + {schema.table.name} + + ) + } + return
Not yet synced
+ }, }, - }, - { - title: 'Status', - key: 'status', - render: function RenderStatus(_, schema) { - if (!schema.status) { - return null - } - - return {schema.status} + { + title: 'Last Synced At', + key: 'last_synced_at', + render: function Render(_, schema) { + return schema.last_synced_at ? ( + <> + + + ) : null + }, }, - }, - { - key: 'actions', - width: 0, - render: function RenderActions(_, schema) { - if (schemaReloadingById[schema.id]) { + { + title: 'Rows Synced', + key: 'rows_synced', + render: function Render(_, schema) { + return schema.table?.row_count ?? '' + }, + }, + { + title: 'Status', + key: 'status', + render: function RenderStatus(_, schema) { + if (!schema.status) { + return null + } + return ( -
- -
+ {schema.status} ) - } + }, + }, + { + key: 'actions', + width: 0, + render: function RenderActions(_, schema) { + if (schemaReloadingById[schema.id]) { + return ( +
+ +
+ ) + } - return ( -
-
- - { - reloadSchema(schema) - }} - > - Reload - - {schema.incremental && ( - - { - resyncSchema(schema) - }} - status="danger" - > - Resync - - - )} - - } - /> + return ( +
+
+ + { + reloadSchema(schema) + }} + > + Reload + + {schema.incremental && ( + + { + resyncSchema(schema) + }} + status="danger" + > + Resync + + + )} + + } + /> +
-
- ) + ) + }, }, - }, - ]} - /> + ]} + /> + + ) +} + +const SyncMethodModal = ({ schema }: { schema: ExternalDataSourceSchema }): JSX.Element => { + const { + syncMethodModalIsOpen, + currentSyncMethodModalSchema, + schemaIncrementalFields, + schemaIncrementalFieldsLoading, + saveButtonIsLoading, + } = useValues(dataWarehouseSourcesTableSyncMethodModalLogic({ schema })) + const { closeSyncMethodModal, loadSchemaIncrementalFields, resetSchemaIncrementalFields, updateSchema } = + useActions(dataWarehouseSourcesTableSyncMethodModalLogic({ schema })) + + useEffect(() => { + if (currentSyncMethodModalSchema?.id) { + resetSchemaIncrementalFields() + loadSchemaIncrementalFields(currentSyncMethodModalSchema.id) + } + }, [currentSyncMethodModalSchema?.id]) + + const schemaLoading = schemaIncrementalFieldsLoading || !schemaIncrementalFields + const showForm = !schemaLoading && schemaIncrementalFields + + if (!currentSyncMethodModalSchema) { + return <> + } + + return ( + + + + + ) + } + > + {schemaLoading && ( +
+ + +
+ )} + {showForm && ( + { + resetSchemaIncrementalFields() + closeSyncMethodModal() + }} + onSave={(syncType, incrementalField, incrementalFieldType) => { + if (syncType === 'full_refresh') { + updateSchema({ + ...currentSyncMethodModalSchema, + should_sync: true, + sync_type: syncType, + incremental_field: null, + incremental_field_type: null, + }) + } else { + updateSchema({ + ...currentSyncMethodModalSchema, + should_sync: true, + sync_type: syncType, + incremental_field: incrementalField, + incremental_field_type: incrementalFieldType, + }) + } + }} + /> + )} +
) } diff --git a/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts index 41423920778dc..312c23bb582bc 100644 --- a/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts +++ b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSettingsLogic.ts @@ -38,7 +38,6 @@ export const dataWarehouseSettingsLogic = kea([ resyncSchema: (schema: ExternalDataSourceSchema) => ({ schema }), sourceLoadingFinished: (source: ExternalDataStripeSource) => ({ source }), schemaLoadingFinished: (schema: ExternalDataSourceSchema) => ({ schema }), - updateSchema: (schema: ExternalDataSourceSchema) => ({ schema }), abortAnyRunningQuery: true, setCurrentTab: (tab: DataWarehouseSettingsTab = DataWarehouseSettingsTab.Managed) => ({ tab }), deleteSelfManagedTable: (tableId: string) => ({ tableId }), @@ -74,6 +73,30 @@ export const dataWarehouseSettingsLogic = kea([ }, }, ], + schemas: [ + null, + { + updateSchema: async (schema: ExternalDataSourceSchema) => { + // Optimistic UI updates before sending updates to the backend + const clonedSources = JSON.parse( + JSON.stringify(values.dataWarehouseSources?.results ?? []) + ) as ExternalDataStripeSource[] + const sourceIndex = clonedSources.findIndex((n) => n.schemas.find((m) => m.id === schema.id)) + const schemaIndex = clonedSources[sourceIndex].schemas.findIndex((n) => n.id === schema.id) + clonedSources[sourceIndex].schemas[schemaIndex] = schema + + actions.loadSourcesSuccess({ + ...values.dataWarehouseSources, + results: clonedSources, + }) + + await api.externalDataSchemas.update(schema.id, schema) + actions.loadSources(null) + + return null + }, + }, + ], })), reducers(({ cache }) => ({ dataWarehouseSourcesLoading: [ @@ -261,23 +284,6 @@ export const dataWarehouseSettingsLogic = kea([ } } }, - updateSchema: async ({ schema }) => { - // Optimistic UI updates before sending updates to the backend - const clonedSources = JSON.parse( - JSON.stringify(values.dataWarehouseSources?.results ?? []) - ) as ExternalDataStripeSource[] - const sourceIndex = clonedSources.findIndex((n) => n.schemas.find((m) => m.id === schema.id)) - const schemaIndex = clonedSources[sourceIndex].schemas.findIndex((n) => n.id === schema.id) - clonedSources[sourceIndex].schemas[schemaIndex] = schema - - actions.loadSourcesSuccess({ - ...values.dataWarehouseSources, - results: clonedSources, - }) - - await api.externalDataSchemas.update(schema.id, schema) - actions.loadSources(null) - }, abortAnyRunningQuery: () => { if (cache.abortController) { cache.abortController.abort() diff --git a/frontend/src/scenes/data-warehouse/settings/dataWarehouseSourcesTableSyncMethodModalLogic.ts b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSourcesTableSyncMethodModalLogic.ts new file mode 100644 index 0000000000000..d76f22a3583ad --- /dev/null +++ b/frontend/src/scenes/data-warehouse/settings/dataWarehouseSourcesTableSyncMethodModalLogic.ts @@ -0,0 +1,66 @@ +import { actions, connect, kea, key, listeners, path, props, reducers } from 'kea' +import { loaders } from 'kea-loaders' +import api from 'lib/api' + +import { ExternalDataSourceSchema, SchemaIncrementalFieldsResponse } from '~/types' + +import { dataWarehouseSettingsLogic } from './dataWarehouseSettingsLogic' +import type { dataWarehouseSourcesTableSyncMethodModalLogicType } from './dataWarehouseSourcesTableSyncMethodModalLogicType' + +export interface DataWarehouseSourcesTableSyncMethodModalLogicProps { + schema: ExternalDataSourceSchema +} + +export const dataWarehouseSourcesTableSyncMethodModalLogic = kea([ + path(['scenes', 'data-warehouse', 'settings', 'DataWarehouseSourcesTableSyncMethodModalLogic']), + props({ schema: {} } as DataWarehouseSourcesTableSyncMethodModalLogicProps), + key((props) => props.schema.id), + connect(() => ({ + actions: [dataWarehouseSettingsLogic, ['updateSchema', 'updateSchemaSuccess', 'updateSchemaFailure']], + })), + actions({ + openSyncMethodModal: (schema: ExternalDataSourceSchema) => ({ schema }), + closeSyncMethodModal: true, + }), + loaders({ + schemaIncrementalFields: [ + null as SchemaIncrementalFieldsResponse | null, + { + loadSchemaIncrementalFields: async (schemaId: string) => { + return await api.externalDataSchemas.incremental_fields(schemaId) + }, + resetSchemaIncrementalFields: () => null, + }, + ], + }), + reducers({ + syncMethodModalIsOpen: [ + false as boolean, + { + openSyncMethodModal: () => true, + closeSyncMethodModal: () => false, + }, + ], + currentSyncMethodModalSchema: [ + null as ExternalDataSourceSchema | null, + { + openSyncMethodModal: (_, { schema }) => schema, + closeSyncMethodModal: () => null, + }, + ], + saveButtonIsLoading: [ + false as boolean, + { + updateSchema: () => true, + updateSchemaFailure: () => false, + updateSchemaSuccess: () => false, + }, + ], + }), + listeners(({ actions }) => ({ + updateSchemaSuccess: () => { + actions.resetSchemaIncrementalFields() + actions.closeSyncMethodModal() + }, + })), +]) diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 3e5b2d6500099..4ce45fa13fc91 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -3841,20 +3841,32 @@ export interface SimpleExternalDataSourceSchema { last_synced_at?: Dayjs } +export type SchemaIncrementalFieldsResponse = IncrementalField[] + +export interface IncrementalField { + label: string + type: string + field: string + field_type: string +} + export interface ExternalDataSourceSyncSchema { table: string should_sync: boolean - sync_type: 'full_refresh' | 'incremental' - sync_types: { - full_refresh: boolean - incremental: boolean - } + incremental_field: string | null + incremental_field_type: string | null + sync_type: 'full_refresh' | 'incremental' | null + incremental_fields: IncrementalField[] + incremental_available: boolean } export interface ExternalDataSourceSchema extends SimpleExternalDataSourceSchema { table?: SimpleDataWarehouseTable - incremental?: boolean + incremental: boolean + sync_type: 'incremental' | 'full_refresh' | null status?: string + incremental_field: string | null + incremental_field_type: string | null } export interface SimpleDataWarehouseTable { diff --git a/latest_migrations.manifest b/latest_migrations.manifest index 0e208217e0853..4e4b0e6c63d9d 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0430_batchexport_model +posthog: 0431_externaldataschema_sync_type_payload sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 8d95d21224471..bf1903dfd0c01 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -2,10 +2,13 @@ posthog/temporal/common/utils.py:0: error: Argument 1 to "abstractclassmethod" h posthog/temporal/common/utils.py:0: note: This is likely because "from_activity" has named arguments: "cls". Consider marking them positional-only posthog/temporal/common/utils.py:0: error: Argument 2 to "__get__" of "classmethod" has incompatible type "type[HeartbeatType]"; expected "type[Never]" [arg-type] posthog/warehouse/models/ssh_tunnel.py:0: error: Incompatible types in assignment (expression has type "NoEncryption", variable has type "BestAvailableEncryption") [assignment] +posthog/temporal/data_imports/pipelines/sql_database/helpers.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Dict entry 2 has incompatible type "Literal['auto']": "None"; expected "Literal['json_response', 'header_link', 'auto', 'single_page', 'cursor', 'offset', 'page_number']": "type[BasePaginator]" [dict-item] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Incompatible types in assignment (expression has type "None", variable has type "AuthConfigBase") [assignment] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Argument 1 to "get_auth_class" has incompatible type "Literal['bearer', 'api_key', 'http_basic'] | None"; expected "Literal['bearer', 'api_key', 'http_basic']" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Need type annotation for "dependency_graph" [var-annotated] +posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Value of type variable "TDict" of "update_dict_nested" cannot be "Mapping[str, object]" [type-var] +posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: note: Error code "type-var" not covered by "type: ignore" comment posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Incompatible types in assignment (expression has type "None", target has type "ResolvedParam") [assignment] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Incompatible return value type (got "tuple[TopologicalSorter[Any], dict[str, EndpointResource], dict[str, ResolvedParam]]", expected "tuple[Any, dict[str, EndpointResource], dict[str, ResolvedParam | None]]") [return-value] posthog/temporal/data_imports/pipelines/rest_source/config_setup.py:0: error: Unsupported right operand type for in ("str | Endpoint | None") [operator] @@ -546,7 +549,11 @@ posthog/api/test/test_exports.py:0: error: Incompatible types in assignment (exp posthog/api/test/test_exports.py:0: error: Incompatible types in assignment (expression has type "None", variable has type "Insight") [assignment] posthog/api/notebook.py:0: error: Incompatible types in assignment (expression has type "int", variable has type "str | None") [assignment] posthog/warehouse/data_load/validate_schema.py:0: error: Incompatible types in assignment (expression has type "dict[str, dict[str, str | bool]] | dict[str, str]", variable has type "dict[str, dict[str, str]]") [assignment] +posthog/warehouse/api/external_data_source.py:0: error: Incompatible return value type (got "tuple[ExternalDataSource, dict[str, list[tuple[str, str]]]]", expected "tuple[ExternalDataSource, list[Any]]") [return-value] +posthog/warehouse/api/external_data_source.py:0: error: Incompatible return value type (got "tuple[ExternalDataSource, dict[str, list[tuple[str, str]]]]", expected "tuple[ExternalDataSource, list[Any]]") [return-value] posthog/warehouse/api/table.py:0: error: Unsupported target for indexed assignment ("dict[str, str | bool] | str") [index] +posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: Incompatible types in assignment (expression has type "list[str]", variable has type "dict[str, list[tuple[str, str]]]") [assignment] +posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: Argument 1 has incompatible type "dict[str, list[tuple[str, str]]]"; expected "list[Any]" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Not all union combinations were tried because there are too many unions [misc] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 2 to "source" has incompatible type "str | None"; expected "str" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 3 to "source" has incompatible type "str | None"; expected "str" [arg-type] diff --git a/posthog/migrations/0431_externaldataschema_sync_type_payload.py b/posthog/migrations/0431_externaldataschema_sync_type_payload.py new file mode 100644 index 0000000000000..4ae377a4a5836 --- /dev/null +++ b/posthog/migrations/0431_externaldataschema_sync_type_payload.py @@ -0,0 +1,45 @@ +# Generated by Django 4.2.11 on 2024-06-19 15:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0430_batchexport_model"), + ] + + operations = [ + migrations.AddField( + model_name="externaldataschema", + name="sync_type_config", + field=models.JSONField(blank=True, default=dict), + ), + migrations.AlterField( + model_name="externaldataschema", + name="sync_type", + field=models.CharField( + blank=True, + choices=[("full_refresh", "full_refresh"), ("incremental", "incremental")], + max_length=128, + null=True, + ), + ), + migrations.RunSQL( # Update Stripe schemas + sql=""" + UPDATE posthog_externaldataschema AS schema + SET sync_type_config = '{"incremental_field": "created"}' + FROM posthog_externaldatasource AS source + WHERE schema.source_id = source.id AND source.source_type = 'Stripe' AND schema.sync_type = 'incremental' + """, + reverse_sql=migrations.RunSQL.noop, + ), + migrations.RunSQL( # Update Zendesk schemas + sql=""" + UPDATE posthog_externaldataschema AS schema + SET sync_type_config = '{"incremental_field": "generated_timestamp"}' + FROM posthog_externaldatasource AS source + WHERE schema.source_id = source.id AND source.source_type = 'Zendesk' AND schema.sync_type = 'incremental' + """, + reverse_sql=migrations.RunSQL.noop, + ), + ] diff --git a/posthog/temporal/data_imports/pipelines/helpers.py b/posthog/temporal/data_imports/pipelines/helpers.py index 9e38be0fd919e..776b7f8dd0582 100644 --- a/posthog/temporal/data_imports/pipelines/helpers.py +++ b/posthog/temporal/data_imports/pipelines/helpers.py @@ -1,5 +1,7 @@ +import uuid from posthog.warehouse.models import ExternalDataJob from django.db.models import F +from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.util import database_sync_to_async @@ -20,3 +22,10 @@ def aget_external_data_job(team_id, job_id): @database_sync_to_async def aupdate_job_count(job_id: str, team_id: int, count: int): ExternalDataJob.objects.filter(id=job_id, team_id=team_id).update(rows_synced=F("rows_synced") + count) + + +@database_sync_to_async +def aremove_reset_pipeline(source_id: uuid.UUID): + source = ExternalDataSource.objects.get(id=source_id) + source.job_inputs.pop("reset_pipeline", None) + source.save() diff --git a/posthog/temporal/data_imports/pipelines/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline.py index 18da8f563cd04..29983da8c68ef 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline.py +++ b/posthog/temporal/data_imports/pipelines/pipeline.py @@ -14,6 +14,7 @@ from collections import Counter from posthog.warehouse.data_load.validate_schema import validate_schema_and_update_table +from posthog.warehouse.models.external_data_source import ExternalDataSource @dataclass @@ -22,7 +23,7 @@ class PipelineInputs: run_id: str schema_id: UUID dataset_name: str - job_type: str + job_type: ExternalDataSource.Type team_id: int @@ -30,7 +31,12 @@ class DataImportPipeline: loader_file_format: Literal["parquet"] = "parquet" def __init__( - self, inputs: PipelineInputs, source: DltSource, logger: FilteringBoundLogger, incremental: bool = False + self, + inputs: PipelineInputs, + source: DltSource, + logger: FilteringBoundLogger, + reset_pipeline: bool, + incremental: bool = False, ): self.inputs = inputs self.logger = logger @@ -41,6 +47,12 @@ def __init__( self.source = source self._incremental = incremental + self.refresh_dlt = reset_pipeline + self.should_chunk_pipeline = ( + incremental + and inputs.job_type != ExternalDataSource.Type.POSTGRES + and inputs.job_type != ExternalDataSource.Type.SNOWFLAKE + ) def _get_pipeline_name(self): return f"{self.inputs.job_type}_pipeline_{self.inputs.team_id}_run_{self.inputs.schema_id}" @@ -75,16 +87,27 @@ def _create_pipeline(self): ) def _run(self) -> dict[str, int]: + if self.refresh_dlt: + self.logger.info("Pipeline getting a full refresh due to reset_pipeline being set") + pipeline = self._create_pipeline() total_counts: Counter[str] = Counter({}) - if self._incremental: + # Do chunking for incremental syncing on API based endpoints (e.g. not sql databases) + if self.should_chunk_pipeline: # will get overwritten counts: Counter[str] = Counter({"start": 1}) + pipeline_runs = 0 while counts: - pipeline.run(self.source, loader_file_format=self.loader_file_format) + self.logger.info(f"Running incremental (non-sql) pipeline, run ${pipeline_runs}") + + pipeline.run( + self.source, + loader_file_format=self.loader_file_format, + refresh="drop_sources" if self.refresh_dlt and pipeline_runs == 0 else None, + ) row_counts = pipeline.last_trace.last_normalize_info.row_counts # Remove any DLT tables from the counts @@ -99,8 +122,16 @@ def _run(self) -> dict[str, int]: table_schema=self.source.schema.tables, row_count=total_counts.total(), ) + + pipeline_runs = pipeline_runs + 1 else: - pipeline.run(self.source, loader_file_format=self.loader_file_format) + self.logger.info("Running standard pipeline") + + pipeline.run( + self.source, + loader_file_format=self.loader_file_format, + refresh="drop_sources" if self.refresh_dlt else None, + ) row_counts = pipeline.last_trace.last_normalize_info.row_counts filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items())) counts = Counter(filtered_rows) diff --git a/posthog/temporal/data_imports/pipelines/schemas.py b/posthog/temporal/data_imports/pipelines/schemas.py index 8089d1204e8b1..7dccb65eca59b 100644 --- a/posthog/temporal/data_imports/pipelines/schemas.py +++ b/posthog/temporal/data_imports/pipelines/schemas.py @@ -1,12 +1,15 @@ +from posthog.warehouse.types import IncrementalField from posthog.temporal.data_imports.pipelines.zendesk.settings import ( BASE_ENDPOINTS, SUPPORT_ENDPOINTS, INCREMENTAL_ENDPOINTS as ZENDESK_INCREMENTAL_ENDPOINTS, + INCREMENTAL_FIELDS as ZENDESK_INCREMENTAL_FIELDS, ) from posthog.warehouse.models import ExternalDataSource from posthog.temporal.data_imports.pipelines.stripe.settings import ( ENDPOINTS as STRIPE_ENDPOINTS, INCREMENTAL_ENDPOINTS as STRIPE_INCREMENTAL_ENDPOINTS, + INCREMENTAL_FIELDS as STRIPE_INCREMENTAL_FIELDS, ) from posthog.temporal.data_imports.pipelines.hubspot.settings import ENDPOINTS as HUBSPOT_ENDPOINTS @@ -27,3 +30,11 @@ ExternalDataSource.Type.POSTGRES: (), ExternalDataSource.Type.SNOWFLAKE: (), } + +PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING: dict[ExternalDataSource.Type, dict[str, list[IncrementalField]]] = { + ExternalDataSource.Type.STRIPE: STRIPE_INCREMENTAL_FIELDS, + ExternalDataSource.Type.HUBSPOT: {}, + ExternalDataSource.Type.ZENDESK: ZENDESK_INCREMENTAL_FIELDS, + ExternalDataSource.Type.POSTGRES: {}, + ExternalDataSource.Type.SNOWFLAKE: {}, +} diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index b376d4d2da8db..023c4203a62fc 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -1,7 +1,9 @@ """Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" -from typing import Optional, Union, List # noqa: UP035 +from datetime import datetime, date +from typing import Any, Optional, Union, List # noqa: UP035 from collections.abc import Iterable +from zoneinfo import ZoneInfo from sqlalchemy import MetaData, Table from sqlalchemy.engine import Engine @@ -12,6 +14,8 @@ from dlt.sources.credentials import ConnectionStringCredentials from urllib.parse import quote +from posthog.warehouse.types import IncrementalFieldType + from .helpers import ( table_rows, engine_from_credentials, @@ -20,8 +24,26 @@ ) +def incremental_type_to_initial_value(field_type: IncrementalFieldType) -> Any: + if field_type == IncrementalFieldType.Integer or field_type == IncrementalFieldType.Numeric: + return 0 + if field_type == IncrementalFieldType.DateTime or field_type == IncrementalFieldType.Timestamp: + return datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=ZoneInfo("UTC")) + if field_type == IncrementalFieldType.Date: + return date(1970, 1, 1) + + def postgres_source( - host: str, port: int, user: str, password: str, database: str, sslmode: str, schema: str, table_names: list[str] + host: str, + port: int, + user: str, + password: str, + database: str, + sslmode: str, + schema: str, + table_names: list[str], + incremental_field: Optional[str] = None, + incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: host = quote(host) user = quote(user) @@ -32,7 +54,15 @@ def postgres_source( credentials = ConnectionStringCredentials( f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" ) - db_source = sql_database(credentials, schema=schema, table_names=table_names) + + if incremental_field is not None and incremental_field_type is not None: + incremental: dlt.sources.incremental | None = dlt.sources.incremental( + cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) + ) + else: + incremental = None + + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source @@ -46,6 +76,8 @@ def snowflake_source( schema: str, table_names: list[str], role: Optional[str] = None, + incremental_field: Optional[str] = None, + incremental_field_type: Optional[str] = None, ) -> DltSource: account_id = quote(account_id) user = quote(user) @@ -68,6 +100,7 @@ def sql_database( schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, table_names: Optional[List[str]] = dlt.config.value, # noqa: UP006 + incremental: Optional[dlt.sources.incremental] = None, ) -> Iterable[DltResource]: """ A DLT source which loads data from an SQL database using SQLAlchemy. @@ -96,9 +129,17 @@ def sql_database( tables = list(metadata.tables.values()) for table in tables: + # TODO(@Gilbert09): Read column types, convert them to DLT types + # and pass them in here to get empty table materialization yield dlt.resource( table_rows, name=table.name, primary_key=get_primary_key(table), + merge_key=get_primary_key(table), + write_disposition="merge" if incremental else "replace", spec=SqlDatabaseTableConfiguration, - )(engine, table) + )( + engine=engine, + table=table, + incremental=incremental, + ) diff --git a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py index ba93b4246abad..653f93392466c 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py @@ -51,10 +51,10 @@ def make_query(self) -> Select[Any]: last_value_func = self.incremental.last_value_func if last_value_func is max: # Query ordered and filtered according to last_value function order_by = self.cursor_column.asc() # type: ignore - filter_op = operator.ge + filter_op = operator.gt elif last_value_func is min: order_by = self.cursor_column.desc() # type: ignore - filter_op = operator.le + filter_op = operator.lt else: # Custom last_value, load everything and let incremental handle filtering return query query = query.order_by(order_by) @@ -89,6 +89,8 @@ def table_rows( Returns: Iterable[DltResource]: A list of DLT resources for each table to be loaded. """ + yield dlt.mark.materialize_table_schema() # type: ignore + loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size) yield from loader.load_rows() @@ -104,7 +106,15 @@ def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engi def get_primary_key(table: Table) -> list[str]: - return [c.name for c in table.primary_key] + primary_keys = [c.name for c in table.primary_key] + if len(primary_keys) > 0: + return primary_keys + + column_names = [c.name for c in table.columns] + if "id" in column_names: + return ["id"] + + return [] @configspec diff --git a/posthog/temporal/data_imports/pipelines/stripe/settings.py b/posthog/temporal/data_imports/pipelines/stripe/settings.py index c280391c24c96..3eefb464967f0 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/settings.py +++ b/posthog/temporal/data_imports/pipelines/stripe/settings.py @@ -3,6 +3,69 @@ # the most popular endpoints # Full list of the Stripe API endpoints you can find here: https://stripe.com/docs/api. # These endpoints are converted into ExternalDataSchema objects when a source is linked. + +from posthog.warehouse.types import IncrementalField, IncrementalFieldType + + ENDPOINTS = ("BalanceTransaction", "Subscription", "Customer", "Product", "Price", "Invoice", "Charge") INCREMENTAL_ENDPOINTS = ("BalanceTransaction", "Subscription", "Customer", "Product", "Price", "Invoice", "Charge") + +INCREMENTAL_FIELDS: dict[str, list[IncrementalField]] = { + "BalanceTransaction": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Subscription": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Customer": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Product": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Price": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Invoice": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], + "Charge": [ + { + "label": "created_at", + "type": IncrementalFieldType.DateTime, + "field": "created", + "field_type": IncrementalFieldType.Integer, + } + ], +} diff --git a/posthog/temporal/data_imports/pipelines/test/test_pipeline.py b/posthog/temporal/data_imports/pipelines/test/test_pipeline.py index 435bcda33a9c5..de4f2df7e8694 100644 --- a/posthog/temporal/data_imports/pipelines/test/test_pipeline.py +++ b/posthog/temporal/data_imports/pipelines/test/test_pipeline.py @@ -45,8 +45,8 @@ async def _create_pipeline(self, schema_name: str, incremental: bool): source_id=source.pk, run_id=str(job.pk), schema_id=schema.pk, - dataset_name=job.folder_path, - job_type="Stripe", + dataset_name=job.folder_path(), + job_type=ExternalDataSource.Type.STRIPE, team_id=self.team.pk, ), source=stripe_source( @@ -59,6 +59,7 @@ async def _create_pipeline(self, schema_name: str, incremental: bool): ), logger=structlog.get_logger(), incremental=incremental, + reset_pipeline=False, ) return pipeline diff --git a/posthog/temporal/data_imports/pipelines/zendesk/settings.py b/posthog/temporal/data_imports/pipelines/zendesk/settings.py index ddd75aaafaf41..935bc714ac7cd 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/settings.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/settings.py @@ -1,6 +1,7 @@ """Zendesk source settings and constants""" from dlt.common import pendulum +from posthog.warehouse.types import IncrementalField, IncrementalFieldType DEFAULT_START_DATE = pendulum.datetime(year=2000, month=1, day=1) PAGE_SIZE = 100 @@ -12,6 +13,16 @@ # Resources that will always get pulled BASE_ENDPOINTS = ["ticket_fields", "ticket_events", "tickets", "ticket_metric_events"] INCREMENTAL_ENDPOINTS = ["tickets"] +INCREMENTAL_FIELDS: dict[str, list[IncrementalField]] = { + "tickets": [ + { + "label": "generated_timestamp", + "type": IncrementalFieldType.Integer, + "field": "generated_timestamp", + "field_type": IncrementalFieldType.Integer, + } + ], +} # Tuples of (Resource name, endpoint URL, data_key, supports pagination) # data_key is the key which data list is nested under in responses diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py index dcd1840622589..f293a329d68c1 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data.py @@ -5,8 +5,7 @@ from dlt.common.schema.typing import TSchemaTables from temporalio import activity -# TODO: remove dependency -from posthog.temporal.data_imports.pipelines.helpers import aupdate_job_count +from posthog.temporal.data_imports.pipelines.helpers import aremove_reset_pipeline, aupdate_job_count from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline, PipelineInputs from posthog.warehouse.models import ( @@ -43,9 +42,11 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem run_id=inputs.run_id, team_id=inputs.team_id, job_type=model.pipeline.source_type, - dataset_name=model.folder_path, + dataset_name=model.folder_path(), ) + reset_pipeline = model.pipeline.job_inputs.get("reset_pipeline", "False") == "True" + schema: ExternalDataSchema = await aget_schema_by_id(inputs.schema_id, inputs.team_id) endpoints = [schema.name] @@ -70,7 +71,14 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem is_incremental=schema.is_incremental, ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) elif model.pipeline.source_type == ExternalDataSource.Type.HUBSPOT: from posthog.temporal.data_imports.pipelines.hubspot.auth import refresh_access_token from posthog.temporal.data_imports.pipelines.hubspot import hubspot @@ -89,7 +97,14 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem endpoints=tuple(endpoints), ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) elif model.pipeline.source_type == ExternalDataSource.Type.POSTGRES: from posthog.temporal.data_imports.pipelines.sql_database import postgres_source @@ -134,9 +149,22 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem sslmode="prefer", schema=pg_schema, table_names=endpoints, + incremental_field=schema.sync_type_config.get("incremental_field") + if schema.is_incremental + else None, + incremental_field_type=schema.sync_type_config.get("incremental_field_type") + if schema.is_incremental + else None, ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) source = postgres_source( host=host, @@ -147,9 +175,20 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem sslmode="prefer", schema=pg_schema, table_names=endpoints, + incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None, + incremental_field_type=schema.sync_type_config.get("incremental_field_type") + if schema.is_incremental + else None, ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) elif model.pipeline.source_type == ExternalDataSource.Type.SNOWFLAKE: from posthog.temporal.data_imports.pipelines.sql_database import snowflake_source @@ -172,7 +211,14 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem table_names=endpoints, ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) elif model.pipeline.source_type == ExternalDataSource.Type.ZENDESK: from posthog.temporal.data_imports.pipelines.zendesk import zendesk_source @@ -187,7 +233,14 @@ async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchem is_incremental=schema.is_incremental, ) - return await _run(job_inputs=job_inputs, source=source, logger=logger, inputs=inputs, schema=schema) + return await _run( + job_inputs=job_inputs, + source=source, + logger=logger, + inputs=inputs, + schema=schema, + reset_pipeline=reset_pipeline, + ) else: raise ValueError(f"Source type {model.pipeline.source_type} not supported") @@ -198,6 +251,7 @@ async def _run( logger: FilteringBoundLogger, inputs: ImportDataActivityInputs, schema: ExternalDataSchema, + reset_pipeline: bool, ) -> tuple[TSchemaTables, dict[str, int]]: # Temp background heartbeat for now async def heartbeat() -> None: @@ -208,10 +262,13 @@ async def heartbeat() -> None: heartbeat_task = asyncio.create_task(heartbeat()) try: - table_row_counts = await DataImportPipeline(job_inputs, source, logger, schema.is_incremental).run() + table_row_counts = await DataImportPipeline( + job_inputs, source, logger, reset_pipeline, schema.is_incremental + ).run() total_rows_synced = sum(table_row_counts.values()) await aupdate_job_count(inputs.run_id, inputs.team_id, total_rows_synced) + await aremove_reset_pipeline(inputs.source_id) finally: heartbeat_task.cancel() await asyncio.wait([heartbeat_task]) diff --git a/posthog/temporal/tests/batch_exports/test_import_data.py b/posthog/temporal/tests/batch_exports/test_import_data.py index 39dba5199190b..2b743102056d0 100644 --- a/posthog/temporal/tests/batch_exports/test_import_data.py +++ b/posthog/temporal/tests/batch_exports/test_import_data.py @@ -84,6 +84,8 @@ async def test_postgres_source_without_ssh_tunnel(activity_environment, team, ** sslmode="prefer", schema="schema", table_names=["table_1"], + incremental_field=None, + incremental_field_type=None, ) @@ -119,6 +121,8 @@ async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, te sslmode="prefer", schema="schema", table_names=["table_1"], + incremental_field=None, + incremental_field_type=None, ) @@ -171,4 +175,6 @@ def __exit__(self, exc_type, exc_value, exc_traceback): sslmode="prefer", schema="schema", table_names=["table_1"], + incremental_field=None, + incremental_field_type=None, ) diff --git a/posthog/temporal/tests/data_imports/test_end_to_end.py b/posthog/temporal/tests/data_imports/test_end_to_end.py index c0fce812b2196..b02ef3997be3c 100644 --- a/posthog/temporal/tests/data_imports/test_end_to_end.py +++ b/posthog/temporal/tests/data_imports/test_end_to_end.py @@ -104,6 +104,9 @@ def mock_paginate( continue assert name in (res.columns or []) + await sync_to_async(source.refresh_from_db)() + assert source.job_inputs.get("reset_pipeline", None) is None + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio @@ -347,3 +350,16 @@ async def test_zendesk_ticket_metric_events(team, zendesk_ticket_metric_events): }, mock_data_response=zendesk_ticket_metric_events["ticket_metric_events"], ) + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_reset_pipeline(team, stripe_balance_transaction): + await _run( + team=team, + schema_name="BalanceTransaction", + table_name="stripe_balancetransaction", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id", "reset_pipeline": "True"}, + mock_data_response=stripe_balance_transaction["data"], + ) diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index bbaccf7308f9b..469ca2fac1845 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -386,8 +386,9 @@ def mock_charges_paginate( activity_environment.run(import_data_activity, job_1_inputs), ) + folder_path = await sync_to_async(job_1.folder_path)() job_1_customer_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/customer/" + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/" ) assert len(job_1_customer_objects["Contents"]) == 1 @@ -409,7 +410,7 @@ def mock_charges_paginate( ) job_2_charge_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path}/charge/" + Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path()}/charge/" ) assert len(job_2_charge_objects["Contents"]) == 1 @@ -491,8 +492,9 @@ def mock_customers_paginate( activity_environment.run(import_data_activity, job_1_inputs), ) + folder_path = await sync_to_async(job_1.folder_path)() job_1_customer_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/customer/" + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/" ) # if job was not canceled, this job would run indefinitely @@ -577,8 +579,9 @@ def mock_customers_paginate( activity_environment.run(import_data_activity, job_1_inputs), ) + folder_path = await sync_to_async(job_1.folder_path)() job_1_customer_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/customer/" + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/" ) assert len(job_1_customer_objects["Contents"]) == 1 @@ -716,8 +719,9 @@ async def setup_job_1(): activity_environment.run(import_data_activity, job_1_inputs), ) + folder_path = await sync_to_async(job_1.folder_path)() job_1_team_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/posthog_test/" + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/" ) assert len(job_1_team_objects["Contents"]) == 1 diff --git a/posthog/warehouse/api/external_data_schema.py b/posthog/warehouse/api/external_data_schema.py index 2375182aa8b18..35669885f6683 100644 --- a/posthog/warehouse/api/external_data_schema.py +++ b/posthog/warehouse/api/external_data_schema.py @@ -1,11 +1,13 @@ from rest_framework import serializers import structlog import temporalio +from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING from posthog.warehouse.models import ExternalDataSchema, ExternalDataJob from typing import Optional, Any from posthog.api.routing import TeamAndOrgViewSetMixin from rest_framework import viewsets, filters, status from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError from rest_framework.request import Request from rest_framework.response import Response from posthog.hogql.database.database import create_hogql_database @@ -20,6 +22,15 @@ cancel_external_data_workflow, delete_data_import_folder, ) +from posthog.warehouse.models.external_data_schema import ( + filter_postgres_incremental_fields, + filter_snowflake_incremental_fields, + get_postgres_schemas, + get_snowflake_schemas, +) +from posthog.warehouse.models.external_data_source import ExternalDataSource +from posthog.warehouse.models.ssh_tunnel import SSHTunnel +from posthog.warehouse.types import IncrementalField logger = structlog.get_logger(__name__) @@ -28,6 +39,8 @@ class ExternalDataSchemaSerializer(serializers.ModelSerializer): table = serializers.SerializerMethodField(read_only=True) incremental = serializers.SerializerMethodField(read_only=True) sync_type = serializers.SerializerMethodField(read_only=True) + incremental_field = serializers.SerializerMethodField(read_only=True) + incremental_field_type = serializers.SerializerMethodField(read_only=True) class Meta: model = ExternalDataSchema @@ -42,13 +55,30 @@ class Meta: "incremental", "status", "sync_type", + "incremental_field", + "incremental_field_type", + ] + + read_only_fields = [ + "id", + "name", + "table", + "last_synced_at", + "latest_error", + "status", ] def get_incremental(self, schema: ExternalDataSchema) -> bool: return schema.is_incremental - def get_sync_type(self, schema: ExternalDataSchema) -> ExternalDataSchema.SyncType: - return schema.sync_type or ExternalDataSchema.SyncType.FULL_REFRESH + def get_incremental_field(self, schema: ExternalDataSchema) -> str | None: + return schema.sync_type_config.get("incremental_field") + + def get_incremental_field_type(self, schema: ExternalDataSchema) -> str | None: + return schema.sync_type_config.get("incremental_field_type") + + def get_sync_type(self, schema: ExternalDataSchema) -> ExternalDataSchema.SyncType | None: + return schema.sync_type def get_table(self, schema: ExternalDataSchema) -> Optional[dict]: from posthog.warehouse.api.table import SimpleTableSerializer @@ -60,7 +90,51 @@ def get_table(self, schema: ExternalDataSchema) -> Optional[dict]: return SimpleTableSerializer(schema.table, context={"database": hogql_context}).data or None def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) -> ExternalDataSchema: + data = self.context["request"].data + + sync_type = data.get("sync_type") + + if ( + sync_type is not None + and sync_type != ExternalDataSchema.SyncType.FULL_REFRESH + and sync_type != ExternalDataSchema.SyncType.INCREMENTAL + ): + raise ValidationError("Invalid sync type") + + validated_data["sync_type"] = sync_type + + # Check whether we need a full table refresh + trigger_refresh = False + if instance.sync_type is not None and sync_type is not None: + # If sync type changes + if instance.sync_type != sync_type: + trigger_refresh = True + + # If sync type is incremental and the incremental field changes + if sync_type == ExternalDataSchema.SyncType.INCREMENTAL and instance.sync_type_config.get( + "incremental_field" + ) != data.get("incremental_field"): + trigger_refresh = True + + # Update the validated_data with incremental fields + if sync_type == "incremental": + payload = instance.sync_type_config + payload["incremental_field"] = data.get("incremental_field") + payload["incremental_field_type"] = data.get("incremental_field_type") + + validated_data["sync_type_config"] = payload + else: + payload = instance.sync_type_config + payload.pop("incremental_field", None) + payload.pop("incremental_field_type", None) + + validated_data["sync_type_config"] = payload + should_sync = validated_data.get("should_sync", None) + + if should_sync is True and sync_type is None and instance.sync_type is None: + raise ValidationError("Sync type must be set up first before enabling schema") + schedule_exists = external_data_workflow_exists(str(instance.id)) if schedule_exists: @@ -72,6 +146,12 @@ def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) - if should_sync is True: sync_external_data_job_workflow(instance, create=True) + if trigger_refresh: + source: ExternalDataSource = instance.source + source.job_inputs.update({"reset_pipeline": True}) + source.save() + trigger_external_data_workflow(instance) + return super().update(instance, validated_data) @@ -146,9 +226,9 @@ def resync(self, request: Request, *args: Any, **kwargs: Any): # Unnecessary to iterate for incremental jobs since they'll all by identified by the schema_id. Be over eager just to clear remnants for job in all_jobs: try: - delete_data_import_folder(job.folder_path) + delete_data_import_folder(job.folder_path()) except Exception as e: - logger.exception(f"Could not clean up data import folder: {job.folder_path}", exc_info=e) + logger.exception(f"Could not clean up data import folder: {job.folder_path()}", exc_info=e) pass try: @@ -159,3 +239,96 @@ def resync(self, request: Request, *args: Any, **kwargs: Any): instance.status = ExternalDataSchema.Status.RUNNING instance.save() return Response(status=status.HTTP_200_OK) + + @action(methods=["POST"], detail=True) + def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): + instance: ExternalDataSchema = self.get_object() + source: ExternalDataSource = instance.source + incremental_columns: list[IncrementalField] = [] + + if source.source_type == ExternalDataSource.Type.POSTGRES: + # TODO(@Gilbert09): Move all this into a util and replace elsewhere + host = source.job_inputs.get("host") + port = source.job_inputs.get("port") + user = source.job_inputs.get("user") + password = source.job_inputs.get("password") + database = source.job_inputs.get("database") + pg_schema = source.job_inputs.get("schema") + + using_ssh_tunnel = str(source.job_inputs.get("ssh_tunnel_enabled", False)) == "True" + ssh_tunnel_host = source.job_inputs.get("ssh_tunnel_host") + ssh_tunnel_port = source.job_inputs.get("ssh_tunnel_port") + ssh_tunnel_auth_type = source.job_inputs.get("ssh_tunnel_auth_type") + ssh_tunnel_auth_type_username = source.job_inputs.get("ssh_tunnel_auth_type_username") + ssh_tunnel_auth_type_password = source.job_inputs.get("ssh_tunnel_auth_type_password") + ssh_tunnel_auth_type_passphrase = source.job_inputs.get("ssh_tunnel_auth_type_passphrase") + ssh_tunnel_auth_type_private_key = source.job_inputs.get("ssh_tunnel_auth_type_private_key") + + ssh_tunnel = SSHTunnel( + enabled=using_ssh_tunnel, + host=ssh_tunnel_host, + port=ssh_tunnel_port, + auth_type=ssh_tunnel_auth_type, + username=ssh_tunnel_auth_type_username, + password=ssh_tunnel_auth_type_password, + passphrase=ssh_tunnel_auth_type_passphrase, + private_key=ssh_tunnel_auth_type_private_key, + ) + + pg_schemas = get_postgres_schemas( + host=host, + port=port, + database=database, + user=user, + password=password, + schema=pg_schema, + ssh_tunnel=ssh_tunnel, + ) + + columns = pg_schemas.get(instance.name, []) + incremental_columns = [ + {"field": name, "field_type": field_type, "label": name, "type": field_type} + for name, field_type in filter_postgres_incremental_fields(columns) + ] + elif source.source_type == ExternalDataSource.Type.SNOWFLAKE: + # TODO(@Gilbert09): Move all this into a util and replace elsewhere + account_id = source.job_inputs.get("account_id") + user = source.job_inputs.get("user") + password = source.job_inputs.get("password") + database = source.job_inputs.get("database") + warehouse = source.job_inputs.get("warehouse") + sf_schema = source.job_inputs.get("schema") + role = source.job_inputs.get("role") + + sf_schemas = get_snowflake_schemas( + account_id=account_id, + database=database, + warehouse=warehouse, + user=user, + password=password, + schema=sf_schema, + role=role, + ) + + columns = sf_schemas.get(instance.name, []) + incremental_columns = [ + {"field": name, "field_type": field_type, "label": name, "type": field_type} + for name, field_type in filter_snowflake_incremental_fields(columns) + ] + else: + mapping = PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING.get(source.source_type) + if not mapping: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f'Source type "{source.source_type}" not found'}, + ) + mapping_fields = mapping.get(instance.name) + if not mapping_fields: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f'Incremental fields for "{source.source_type}.{instance.name}" can\'t be found'}, + ) + + incremental_columns = mapping_fields + + return Response(status=status.HTTP_200_OK, data=incremental_columns) diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index 36a705dfc1849..ae0487d58f66d 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -23,12 +23,18 @@ from posthog.hogql.database.database import create_hogql_database from posthog.temporal.data_imports.pipelines.schemas import ( PIPELINE_TYPE_INCREMENTAL_ENDPOINTS_MAPPING, + PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING, PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING, ) from posthog.temporal.data_imports.pipelines.hubspot.auth import ( get_access_token_from_code, ) -from posthog.warehouse.models.external_data_schema import get_postgres_schemas, get_snowflake_schemas +from posthog.warehouse.models.external_data_schema import ( + filter_postgres_incremental_fields, + filter_snowflake_incremental_fields, + get_postgres_schemas, + get_snowflake_schemas, +) import temporalio @@ -233,12 +239,37 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: active_schemas: list[ExternalDataSchema] = [] for schema in schemas: + sync_type = schema.get("sync_type") + is_incremental = sync_type == "incremental" + incremental_field = schema.get("incremental_field") + incremental_field_type = schema.get("incremental_field_type") + + if is_incremental and incremental_field is None: + new_source_model.delete() + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Incremental schemas given do not have an incremental field set"}, + ) + + if is_incremental and incremental_field_type is None: + new_source_model.delete() + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Incremental schemas given do not have an incremental field type set"}, + ) + schema_model = ExternalDataSchema.objects.create( name=schema.get("name"), team=self.team, source=new_source_model, should_sync=schema.get("should_sync"), - sync_type=schema.get("sync_type"), + sync_type=sync_type, + sync_type_config={ + "incremental_field": incremental_field, + "incremental_field_type": incremental_field_type, + } + if is_incremental + else {}, ) if schema.get("should_sync"): @@ -461,9 +492,9 @@ def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response: ).all() for job in all_jobs: try: - delete_data_import_folder(job.folder_path) + delete_data_import_folder(job.folder_path()) except Exception as e: - logger.exception(f"Could not clean up data import folder: {job.folder_path}", exc_info=e) + logger.exception(f"Could not clean up data import folder: {job.folder_path()}", exc_info=e) pass for schema in ExternalDataSchema.objects.filter( @@ -578,7 +609,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): try: result = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) - if len(result) == 0: + if len(result.keys()) == 0: return Response( status=status.HTTP_400_BAD_REQUEST, data={"message": "Postgres schema doesn't exist"}, @@ -607,14 +638,23 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): data={"message": GenericPostgresError}, ) + filtered_results = [ + (table_name, filter_postgres_incremental_fields(columns)) for table_name, columns in result.items() + ] + result_mapped_to_options = [ { - "table": row, - "should_sync": True, - "sync_types": {"full_refresh": True, "incremental": False}, - "sync_type": "full_refresh", + "table": table_name, + "should_sync": False, + "incremental_fields": [ + {"label": column_name, "type": column_type, "field": column_name, "field_type": column_type} + for column_name, column_type in columns + ], + "incremental_available": True, + "incremental_field": columns[0][0] if len(columns) > 0 and len(columns[0]) > 0 else None, + "sync_type": None, } - for row in result + for table_name, columns in filtered_results ] return Response(status=status.HTTP_200_OK, data=result_mapped_to_options) elif source_type == ExternalDataSource.Type.SNOWFLAKE: @@ -636,7 +676,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): try: result = get_snowflake_schemas(account_id, database, warehouse, user, password, schema, role) - if len(result) == 0: + if len(result.keys()) == 0: return Response( status=status.HTTP_400_BAD_REQUEST, data={"message": "Snowflake schema doesn't exist"}, @@ -659,20 +699,31 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): status=status.HTTP_400_BAD_REQUEST, data={"message": GenericSnowflakeError}, ) + + filtered_results = [ + (table_name, filter_snowflake_incremental_fields(columns)) for table_name, columns in result.items() + ] + result_mapped_to_options = [ { - "table": row, - "should_sync": True, - "sync_types": {"full_refresh": True, "incremental": False}, - "sync_type": "full_refresh", + "table": table_name, + "should_sync": False, + "incremental_fields": [ + {"label": column_name, "type": column_type, "field": column_name, "field_type": column_type} + for column_name, column_type in columns + ], + "incremental_available": True, + "incremental_field": columns[0][0] if len(columns) > 0 and len(columns[0]) > 0 else None, + "sync_type": None, } - for row in result + for table_name, columns in filtered_results ] return Response(status=status.HTTP_200_OK, data=result_mapped_to_options) # Return the possible endpoints for all other source types schemas = PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING.get(source_type, None) incremental_schemas = PIPELINE_TYPE_INCREMENTAL_ENDPOINTS_MAPPING.get(source_type, ()) + incremental_fields = PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING.get(source_type, {}) if schemas is None: return Response( @@ -683,9 +734,21 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): options = [ { "table": row, - "should_sync": True, - "sync_types": {"full_refresh": True, "incremental": row in incremental_schemas}, - "sync_type": "incremental" if row in incremental_schemas else "full_refresh", + "should_sync": False, + "incremental_fields": [ + { + "label": field["label"], + "type": field["type"], + "field": field["field"], + "field_type": field["field_type"], + } + for field in incremental_fields.get(row, []) + ], + "incremental_available": row in incremental_schemas, + "incremental_field": incremental_fields.get(row, [])[0]["field"] + if row in incremental_schemas + else None, + "sync_type": None, } for row in schemas ] diff --git a/posthog/warehouse/api/test/conftest.py b/posthog/warehouse/api/test/conftest.py new file mode 100644 index 0000000000000..fd75bceb4cc0e --- /dev/null +++ b/posthog/warehouse/api/test/conftest.py @@ -0,0 +1,71 @@ +import psycopg +import pytest_asyncio +from psycopg import sql + + +@pytest_asyncio.fixture +async def setup_postgres_test_db(postgres_config): + """Fixture to manage a database for Postgres export testing. + + Managing a test database involves the following steps: + 1. Creating a test database. + 2. Initializing a connection to that database. + 3. Creating a test schema. + 4. Yielding the connection to be used in tests. + 5. After tests, drop the test schema and any tables in it. + 6. Drop the test database. + """ + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + (postgres_config["database"],), + ) + + if await cursor.fetchone() is None: + await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() + + # We need a new connection to connect to the database we just created. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + dbname=postgres_config["database"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) + ) + + yield + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) + + await connection.close() + + # We need a new connection to drop the database, as we cannot drop the current database. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() diff --git a/posthog/warehouse/api/test/test_external_data_schema.py b/posthog/warehouse/api/test/test_external_data_schema.py new file mode 100644 index 0000000000000..4babec0833ab0 --- /dev/null +++ b/posthog/warehouse/api/test/test_external_data_schema.py @@ -0,0 +1,358 @@ +from unittest import mock +import uuid +import psycopg +import pytest +from asgiref.sync import sync_to_async +import pytest_asyncio +from posthog.test.base import APIBaseTest +from posthog.warehouse.models.external_data_schema import ExternalDataSchema +from posthog.warehouse.models.external_data_source import ExternalDataSource +from django.conf import settings + + +@pytest.fixture +def postgres_config(): + return { + "user": settings.PG_USER, + "password": settings.PG_PASSWORD, + "database": "external_data_database", + "schema": "external_data_schema", + "host": settings.PG_HOST, + "port": int(settings.PG_PORT), + } + + +@pytest_asyncio.fixture +async def postgres_connection(postgres_config, setup_postgres_test_db): + if setup_postgres_test_db: + await anext(setup_postgres_test_db) + + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + dbname=postgres_config["database"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + + yield connection + + await connection.close() + + +@pytest.mark.usefixtures("postgres_connection", "postgres_config") +class TestExternalDataSchema(APIBaseTest): + @pytest.fixture(autouse=True) + def _setup(self, postgres_connection, postgres_config): + self.postgres_connection = postgres_connection + self.postgres_config = postgres_config + + def test_incremental_fields_stripe(self): + soruce = ExternalDataSource.objects.create( + team=self.team, + source_type=ExternalDataSource.Type.STRIPE, + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=soruce, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}/incremental_fields", + ) + payload = response.json() + + assert payload == [{"label": "created_at", "type": "datetime", "field": "created", "field_type": "integer"}] + + def test_incremental_fields_missing_source_type(self): + soruce = ExternalDataSource.objects.create( + team=self.team, + source_type="bad_source", + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=soruce, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}/incremental_fields", + ) + + assert response.status_code == 400 + + def test_incremental_fields_missing_table_name(self): + soruce = ExternalDataSource.objects.create( + team=self.team, + source_type=ExternalDataSource.Type.STRIPE, + ) + schema = ExternalDataSchema.objects.create( + name="Some_other_non_existent_table", + team=self.team, + source=soruce, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}/incremental_fields", + ) + + assert response.status_code == 400 + + @pytest.mark.django_db(transaction=True) + @pytest.mark.asyncio + async def test_incremental_fields_postgres(self): + if not isinstance(self.postgres_connection, psycopg.AsyncConnection): + postgres_connection: psycopg.AsyncConnection = await anext(self.postgres_connection) + else: + postgres_connection = self.postgres_connection + + await postgres_connection.execute( + "CREATE TABLE IF NOT EXISTS {schema}.posthog_test (id integer)".format( + schema=self.postgres_config["schema"] + ) + ) + await postgres_connection.execute( + "INSERT INTO {schema}.posthog_test (id) VALUES (1)".format(schema=self.postgres_config["schema"]) + ) + await postgres_connection.commit() + + source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=self.team, + status="running", + source_type="Postgres", + job_inputs={ + "host": self.postgres_config["host"], + "port": self.postgres_config["port"], + "database": self.postgres_config["database"], + "user": self.postgres_config["user"], + "password": self.postgres_config["password"], + "schema": self.postgres_config["schema"], + "ssh_tunnel_enabled": False, + }, + ) + + schema = await sync_to_async(ExternalDataSchema.objects.create)( + name="posthog_test", + team=self.team, + source=source, + ) + + response = await sync_to_async(self.client.post)( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}/incremental_fields", + ) + payload = response.json() + + assert payload == [{"label": "id", "type": "integer", "field": "id", "field_type": "integer"}] + + def test_update_schema_change_sync_type(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.INCREMENTAL, + ) + + with mock.patch( + "posthog.warehouse.api.external_data_schema.trigger_external_data_workflow" + ) as mock_trigger_external_data_workflow: + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"sync_type": "full_refresh"}, + ) + + assert response.status_code == 200 + mock_trigger_external_data_workflow.assert_called_once() + source.refresh_from_db() + assert source.job_inputs.get("reset_pipeline") == "True" + + def test_update_schema_change_sync_type_incremental_field(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.INCREMENTAL, + sync_type_config={"incremental_field": "some_other_field", "incremental_field_type": "datetime"}, + ) + + with mock.patch( + "posthog.warehouse.api.external_data_schema.trigger_external_data_workflow" + ) as mock_trigger_external_data_workflow: + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"sync_type": "incremental", "incremental_field": "field", "incremental_field_type": "integer"}, + ) + + assert response.status_code == 200 + mock_trigger_external_data_workflow.assert_called_once() + + source.refresh_from_db() + assert source.job_inputs.get("reset_pipeline") == "True" + + schema.refresh_from_db() + assert schema.sync_type_config.get("incremental_field") == "field" + assert schema.sync_type_config.get("incremental_field_type") == "integer" + + def test_update_schema_change_should_sync_off(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=True, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + with ( + mock.patch( + "posthog.warehouse.api.external_data_schema.external_data_workflow_exists" + ) as mock_external_data_workflow_exists, + mock.patch( + "posthog.warehouse.api.external_data_schema.pause_external_data_schedule" + ) as mock_pause_external_data_schedule, + ): + mock_external_data_workflow_exists.return_value = True + + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"should_sync": False}, + ) + + assert response.status_code == 200 + mock_pause_external_data_schedule.assert_called_once() + + schema.refresh_from_db() + assert schema.should_sync is False + + def test_update_schema_change_should_sync_on_with_existing_schedule(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=False, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + with ( + mock.patch( + "posthog.warehouse.api.external_data_schema.external_data_workflow_exists" + ) as mock_external_data_workflow_exists, + mock.patch( + "posthog.warehouse.api.external_data_schema.unpause_external_data_schedule" + ) as mock_unpause_external_data_schedule, + ): + mock_external_data_workflow_exists.return_value = True + + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"should_sync": True}, + ) + + assert response.status_code == 200 + mock_unpause_external_data_schedule.assert_called_once() + + schema.refresh_from_db() + assert schema.should_sync is True + + def test_update_schema_change_should_sync_on_without_existing_schedule(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=False, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, + ) + + with ( + mock.patch( + "posthog.warehouse.api.external_data_schema.external_data_workflow_exists" + ) as mock_external_data_workflow_exists, + mock.patch( + "posthog.warehouse.api.external_data_schema.sync_external_data_job_workflow" + ) as mock_sync_external_data_job_workflow, + ): + mock_external_data_workflow_exists.return_value = False + + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"should_sync": True}, + ) + + assert response.status_code == 200 + mock_sync_external_data_job_workflow.assert_called_once() + + schema.refresh_from_db() + assert schema.should_sync is True + + def test_update_schema_change_should_sync_on_without_sync_type(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=False, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=None, + ) + + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"should_sync": True}, + ) + + assert response.status_code == 400 + + def test_update_schema_change_sync_type_with_invalid_type(self): + source = ExternalDataSource.objects.create( + team=self.team, source_type=ExternalDataSource.Type.STRIPE, job_inputs={} + ) + schema = ExternalDataSchema.objects.create( + name="BalanceTransaction", + team=self.team, + source=source, + should_sync=False, + status=ExternalDataSchema.Status.COMPLETED, + sync_type=None, + ) + + response = self.client.patch( + f"/api/projects/{self.team.pk}/external_data_schemas/{schema.id}", + data={"sync_type": "blah"}, + ) + + assert response.status_code == 400 diff --git a/posthog/warehouse/api/test/test_external_data_source.py b/posthog/warehouse/api/test/test_external_data_source.py index 1c8e53d99b6a7..69a3ae2248205 100644 --- a/posthog/warehouse/api/test/test_external_data_source.py +++ b/posthog/warehouse/api/test/test_external_data_source.py @@ -16,7 +16,7 @@ import datetime -class TestSavedQuery(APIBaseTest): +class TestExternalDataSource(APIBaseTest): def _create_external_data_source(self) -> ExternalDataSource: return ExternalDataSource.objects.create( team_id=self.team.pk, @@ -38,7 +38,7 @@ def _create_external_data_schema(self, source_id) -> ExternalDataSchema: def test_create_external_data_source(self): response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -66,7 +66,7 @@ def test_create_external_data_source(self): def test_create_external_data_source_delete_on_missing_schemas(self): response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -81,7 +81,7 @@ def test_create_external_data_source_delete_on_missing_schemas(self): def test_create_external_data_source_delete_on_bad_schema(self): response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -100,7 +100,7 @@ def test_prefix_external_data_source(self): # Create no prefix response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -122,7 +122,7 @@ def test_prefix_external_data_source(self): # Try to create same type without prefix again response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -145,7 +145,7 @@ def test_prefix_external_data_source(self): # Create with prefix response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -168,7 +168,7 @@ def test_prefix_external_data_source(self): # Try to create same type with same prefix again response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/", + f"/api/projects/{self.team.pk}/external_data_sources/", data={ "source_type": "Stripe", "payload": { @@ -190,11 +190,188 @@ def test_prefix_external_data_source(self): self.assertEqual(response.status_code, 400) self.assertEqual(response.json(), {"message": "Prefix already exists"}) + def test_create_external_data_source_incremental(self): + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_sources/", + data={ + "source_type": "Stripe", + "payload": { + "client_secret": "sk_test_123", + "schemas": [ + { + "name": "BalanceTransaction", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Subscription", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Customer", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Product", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Price", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Invoice", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + { + "name": "Charge", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + "incremental_field_type": "integer", + }, + ], + }, + }, + ) + self.assertEqual(response.status_code, 201) + + def test_create_external_data_source_incremental_missing_field(self): + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_sources/", + data={ + "source_type": "Stripe", + "payload": { + "client_secret": "sk_test_123", + "schemas": [ + { + "name": "BalanceTransaction", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Subscription", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Customer", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Product", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Price", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Invoice", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + { + "name": "Charge", + "should_sync": True, + "sync_type": "incremental", + "incremental_field_type": "integer", + }, + ], + }, + }, + ) + assert response.status_code == 400 + assert len(ExternalDataSource.objects.all()) == 0 + + def test_create_external_data_source_incremental_missing_type(self): + response = self.client.post( + f"/api/projects/{self.team.pk}/external_data_sources/", + data={ + "source_type": "Stripe", + "payload": { + "client_secret": "sk_test_123", + "schemas": [ + { + "name": "BalanceTransaction", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Subscription", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Customer", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Product", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Price", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Invoice", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + { + "name": "Charge", + "should_sync": True, + "sync_type": "incremental", + "incremental_field": "created", + }, + ], + }, + }, + ) + assert response.status_code == 400 + assert len(ExternalDataSource.objects.all()) == 0 + def test_list_external_data_source(self): self._create_external_data_source() self._create_external_data_source() - response = self.client.get(f"/api/projects/{self.team.id}/external_data_sources/") + response = self.client.get(f"/api/projects/{self.team.pk}/external_data_sources/") payload = response.json() self.assertEqual(response.status_code, 200) @@ -204,7 +381,7 @@ def test_get_external_data_source_with_schema(self): source = self._create_external_data_source() schema = self._create_external_data_schema(source.pk) - response = self.client.get(f"/api/projects/{self.team.id}/external_data_sources/{source.pk}") + response = self.client.get(f"/api/projects/{self.team.pk}/external_data_sources/{source.pk}") payload = response.json() self.assertEqual(response.status_code, 200) @@ -228,6 +405,8 @@ def test_get_external_data_source_with_schema(self): { "id": str(schema.pk), "incremental": False, + "incremental_field": None, + "incremental_field_type": None, "last_synced_at": schema.last_synced_at, "name": schema.name, "should_sync": schema.should_sync, @@ -243,7 +422,7 @@ def test_delete_external_data_source(self): source = self._create_external_data_source() schema = self._create_external_data_schema(source.pk) - response = self.client.delete(f"/api/projects/{self.team.id}/external_data_sources/{source.pk}") + response = self.client.delete(f"/api/projects/{self.team.pk}/external_data_sources/{source.pk}") self.assertEqual(response.status_code, 204) @@ -255,7 +434,7 @@ def test_delete_external_data_source(self): def test_reload_external_data_source(self, mock_trigger): source = self._create_external_data_source() - response = self.client.post(f"/api/projects/{self.team.id}/external_data_sources/{source.pk}/reload/") + response = self.client.post(f"/api/projects/{self.team.pk}/external_data_sources/{source.pk}/reload/") source.refresh_from_db() @@ -285,7 +464,7 @@ def test_database_schema(self): postgres_connection.commit() response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/database_schema/", + f"/api/projects/{self.team.pk}/external_data_sources/database_schema/", data={ "source_type": "Postgres", "host": settings.PG_HOST, @@ -315,7 +494,7 @@ def test_database_schema(self): def test_database_schema_non_postgres_source(self): response = self.client.post( - f"/api/projects/{self.team.id}/external_data_sources/database_schema/", + f"/api/projects/{self.team.pk}/external_data_sources/database_schema/", data={ "source_type": "Stripe", }, @@ -330,7 +509,7 @@ def test_database_schema_non_postgres_source(self): @patch("posthog.warehouse.api.external_data_source.get_postgres_schemas") def test_internal_postgres(self, patch_get_postgres_schemas): - patch_get_postgres_schemas.return_value = ["table_1"] + patch_get_postgres_schemas.return_value = {"table_1": [("id", "integer")]} with override_settings(CLOUD_DEPLOYMENT="US"): team_2, _ = Team.objects.get_or_create(id=2, organization=self.team.organization) @@ -349,17 +528,19 @@ def test_internal_postgres(self, patch_get_postgres_schemas): assert response.status_code == 200 assert response.json() == [ { - "should_sync": True, "table": "table_1", - "sync_type": "full_refresh", - "sync_types": {"full_refresh": True, "incremental": False}, + "should_sync": False, + "incremental_fields": [{"label": "id", "type": "integer", "field": "id", "field_type": "integer"}], + "incremental_available": True, + "incremental_field": "id", + "sync_type": None, } ] new_team = Team.objects.create(name="new_team", organization=self.team.organization) response = self.client.post( - f"/api/projects/{new_team.id}/external_data_sources/database_schema/", + f"/api/projects/{new_team.pk}/external_data_sources/database_schema/", data={ "source_type": "Postgres", "host": "172.16.0.0", @@ -391,17 +572,19 @@ def test_internal_postgres(self, patch_get_postgres_schemas): assert response.status_code == 200 assert response.json() == [ { - "should_sync": True, "table": "table_1", - "sync_type": "full_refresh", - "sync_types": {"full_refresh": True, "incremental": False}, + "should_sync": False, + "incremental_fields": [{"label": "id", "type": "integer", "field": "id", "field_type": "integer"}], + "incremental_available": True, + "incremental_field": "id", + "sync_type": None, } ] new_team = Team.objects.create(name="new_team", organization=self.team.organization) response = self.client.post( - f"/api/projects/{new_team.id}/external_data_sources/database_schema/", + f"/api/projects/{new_team.pk}/external_data_sources/database_schema/", data={ "source_type": "Postgres", "host": "172.16.0.0", @@ -430,7 +613,7 @@ def test_update_source_sync_frequency(self, _patch_sync_external_data_job_workfl # test api response = self.client.patch( - f"/api/projects/{self.team.id}/external_data_sources/{source.pk}/", + f"/api/projects/{self.team.pk}/external_data_sources/{source.pk}/", data={"sync_frequency": ExternalDataSource.SyncFrequency.WEEKLY}, ) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/posthog/warehouse/models/external_data_job.py b/posthog/warehouse/models/external_data_job.py index 0b40c41f11069..08a0e1c0fcd5f 100644 --- a/posthog/warehouse/models/external_data_job.py +++ b/posthog/warehouse/models/external_data_job.py @@ -31,24 +31,23 @@ class Status(models.TextChoices): __repr__ = sane_repr("id") - @property def folder_path(self) -> str: if self.schema and self.schema.is_incremental: - return f"team_{self.team_id}_{self.pipeline.source_type}_{str(self.schema.pk)}".lower().replace("-", "_") + return f"team_{self.team_id}_{self.pipeline.source_type}_{str(self.schema_id)}".lower().replace("-", "_") return f"team_{self.team_id}_{self.pipeline.source_type}_{str(self.pk)}".lower().replace("-", "_") def url_pattern_by_schema(self, schema: str) -> str: if TEST: return ( - f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/test-pipeline/{self.folder_path}/{schema.lower()}/*.parquet" + f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/test-pipeline/{self.folder_path()}/{schema.lower()}/*.parquet" ) - return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path}/{schema.lower()}/*.parquet" + return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}/*.parquet" def delete_data_in_bucket(self) -> None: s3 = get_s3_client() - s3.delete(f"{settings.BUCKET_URL}/{self.folder_path}", recursive=True) + s3.delete(f"{settings.BUCKET_URL}/{self.folder_path()}", recursive=True) @database_sync_to_async diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index 314410f4071bf..109c440eccc2c 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,10 +1,12 @@ -from typing import Any, Optional +from collections import defaultdict +from typing import Optional from django.db import models import snowflake.connector from posthog.models.team import Team from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr import uuid import psycopg2 +from posthog.warehouse.types import IncrementalFieldType from posthog.warehouse.models.ssh_tunnel import SSHTunnel from posthog.warehouse.util import database_sync_to_async @@ -35,8 +37,10 @@ class SyncType(models.TextChoices): ) status: models.CharField = models.CharField(max_length=400, null=True, blank=True) last_synced_at: models.DateTimeField = models.DateTimeField(null=True, blank=True) - sync_type: models.CharField = models.CharField( - max_length=128, choices=SyncType.choices, default=SyncType.FULL_REFRESH, blank=True + sync_type: models.CharField = models.CharField(max_length=128, choices=SyncType.choices, null=True, blank=True) + sync_type_config: models.JSONField = models.JSONField( + default=dict, + blank=True, ) __repr__ = sane_repr("name") @@ -85,9 +89,25 @@ def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, t ExternalDataSchema.objects.create(name=schema, team_id=team_id, source_id=source_id, should_sync=False) +def filter_snowflake_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: + results: list[tuple[str, IncrementalFieldType]] = [] + for column_name, type in columns: + type = type.lower() + if type.startswith("timestamp"): + results.append((column_name, IncrementalFieldType.Timestamp)) + elif type == "date": + results.append((column_name, IncrementalFieldType.Date)) + elif type == "datetime": + results.append((column_name, IncrementalFieldType.DateTime)) + elif type == "numeric": + results.append((column_name, IncrementalFieldType.Numeric)) + + return results + + def get_snowflake_schemas( account_id: str, database: str, warehouse: str, user: str, password: str, schema: str, role: Optional[str] = None -) -> list[Any]: +) -> dict[str, list[tuple[str, str]]]: with snowflake.connector.connect( user=user, password=password, @@ -102,17 +122,35 @@ def get_snowflake_schemas( raise Exception("Can't create cursor to Snowflake") cursor.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = %(schema)s", {"schema": schema} + "SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = %(schema)s ORDER BY table_name ASC", + {"schema": schema}, ) - results = cursor.fetchall() - results = [row[0] for row in results] + result = cursor.fetchall() + + schema_list = defaultdict(list) + for row in result: + schema_list[row[0]].append((row[1], row[2])) - return results + return schema_list + + +def filter_postgres_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: + results: list[tuple[str, IncrementalFieldType]] = [] + for column_name, type in columns: + type = type.lower() + if type.startswith("timestamp"): + results.append((column_name, IncrementalFieldType.Timestamp)) + elif type == "date": + results.append((column_name, IncrementalFieldType.Date)) + elif type == "integer" or type == "smallint" or type == "bigint": + results.append((column_name, IncrementalFieldType.Integer)) + + return results def get_postgres_schemas( host: str, port: str, database: str, user: str, password: str, schema: str, ssh_tunnel: SSHTunnel -) -> list[Any]: +) -> dict[str, list[tuple[str, str]]]: def get_schemas(postgres_host: str, postgres_port: int): connection = psycopg2.connect( host=postgres_host, @@ -129,14 +167,18 @@ def get_schemas(postgres_host: str, postgres_port: int): with connection.cursor() as cursor: cursor.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = %(schema)s", {"schema": schema} + "SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = %(schema)s ORDER BY table_name ASC", + {"schema": schema}, ) result = cursor.fetchall() - result = [row[0] for row in result] + + schema_list = defaultdict(list) + for row in result: + schema_list[row[0]].append((row[1], row[2])) connection.close() - return result + return schema_list if ssh_tunnel.enabled: with ssh_tunnel.get_tunnel(host, int(port)) as tunnel: diff --git a/posthog/warehouse/types.py b/posthog/warehouse/types.py new file mode 100644 index 0000000000000..57455ac361232 --- /dev/null +++ b/posthog/warehouse/types.py @@ -0,0 +1,17 @@ +from enum import Enum +from typing import TypedDict + + +class IncrementalFieldType(Enum): + Integer = "integer" + Numeric = "numeric" # For snowflake + DateTime = "datetime" + Date = "date" + Timestamp = "timestamp" + + +class IncrementalField(TypedDict): + label: str # Label shown in the UI + type: IncrementalFieldType # Field type shown in the UI + field: str # Actual DB field accessed + field_type: IncrementalFieldType # Actual DB type of the field diff --git a/requirements.in b/requirements.in index 6b59c0ad8f13a..fb24969950a0f 100644 --- a/requirements.in +++ b/requirements.in @@ -33,7 +33,7 @@ djangorestframework==3.15.1 djangorestframework-csv==2.1.1 djangorestframework-dataclasses==1.2.0 django-fernet-encrypted-fields==0.1.3 -dlt==0.4.12 +dlt==0.4.13a0 dnspython==2.2.1 drf-exceptions-hog==0.4.0 drf-extensions==0.7.0 diff --git a/requirements.txt b/requirements.txt index 13796a6133fd7..49d060f506970 100644 --- a/requirements.txt +++ b/requirements.txt @@ -207,7 +207,7 @@ djangorestframework-csv==2.1.1 # via -r requirements.in djangorestframework-dataclasses==1.2.0 # via -r requirements.in -dlt==0.4.12 +dlt==0.4.13a0 # via -r requirements.in dnspython==2.2.1 # via -r requirements.in