diff --git a/frontend/public/postgres-logo.svg b/frontend/public/postgres-logo.svg new file mode 100644 index 0000000000000..6b65997a98d5e --- /dev/null +++ b/frontend/public/postgres-logo.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 017a8ee9e3894..7f2b0f2e2ca03 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -1623,7 +1623,7 @@ const api = { async list(): Promise> { return await new ApiRequest().externalDataSources().get() }, - async create(data: ExternalDataSourceCreatePayload): Promise { + async create(data: ExternalDataSourceCreatePayload): Promise { return await new ApiRequest().externalDataSources().create({ data }) }, async delete(sourceId: ExternalDataSource['id']): Promise { diff --git a/frontend/src/scenes/data-warehouse/external/SourceModal.tsx b/frontend/src/scenes/data-warehouse/external/SourceModal.tsx index 2308ec31b3140..69afb95960889 100644 --- a/frontend/src/scenes/data-warehouse/external/SourceModal.tsx +++ b/frontend/src/scenes/data-warehouse/external/SourceModal.tsx @@ -1,10 +1,9 @@ import { LemonButton, LemonDivider, LemonInput, LemonModal, LemonModalProps } from '@posthog/lemon-ui' import { Form } from 'kea-forms' -import { ConnectorConfigType, sourceModalLogic } from './sourceModalLogic' +import { ConnectorConfigType, FORM_PAYLOAD_TYPES, FormPayloadType, sourceModalLogic } from './sourceModalLogic' import { useActions, useValues } from 'kea' import { DatawarehouseTableForm } from '../new_table/DataWarehouseTableForm' import { Field } from 'lib/forms/Field' -import stripeLogo from 'public/stripe-logo.svg' interface SourceModalProps extends LemonModalProps {} @@ -21,7 +20,7 @@ export default function SourceModal(props: SourceModalProps): JSX.Element { return ( - {`stripe + {`stripe ) } @@ -37,38 +36,40 @@ export default function SourceModal(props: SourceModalProps): JSX.Element { toggleManualLinkFormVisible(true) } - const formToShow = (): JSX.Element => { + const formPayloadTypeToField = (formPayloadType: FormPayloadType): JSX.Element => { + return ( + + + + ) + } + + const buildPayloadTypeForm = (payloadType: string): JSX.Element => { + return ( +
+ {FORM_PAYLOAD_TYPES[payloadType].map(formPayloadTypeToField)} + +
+ + Back + + + Link + +
+ + ) + } + + const formToShow = (selectedConnector: ConnectorConfigType): JSX.Element => { if (selectedConnector) { - return ( -
- - - - - - - -
- - Back - - - Link - -
- - ) + return buildPayloadTypeForm(selectedConnector.name) } if (isManualLinkFormVisible) { @@ -104,16 +105,7 @@ export default function SourceModal(props: SourceModalProps): JSX.Element { ) } - return ( -
- {connectors.map((config, index) => ( - - ))} - - Manual Link - -
- ) + return <> } return ( @@ -123,7 +115,18 @@ export default function SourceModal(props: SourceModalProps): JSX.Element { title="Data Sources" description={selectedConnector ? selectedConnector.caption : null} > - {formToShow()} + {selectedConnector ? ( + formToShow(selectedConnector) + ) : ( +
+ {connectors.map((config, index) => ( + + ))} + + Manual Link + +
+ )} ) } diff --git a/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts b/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts index 4aa757303d708..4051c61d77ff4 100644 --- a/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts +++ b/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts @@ -10,24 +10,82 @@ import { dataWarehouseSceneLogic } from './dataWarehouseSceneLogic' import { router } from 'kea-router' import { urls } from 'scenes/urls' import { dataWarehouseSettingsLogic } from '../settings/dataWarehouseSettingsLogic' +import stripeLogo from 'public/stripe-logo.svg' +import postgresLogo from 'public/postgres-logo.svg' export interface ConnectorConfigType { name: string - fields: string[] caption: string disabledReason: string | null + icon: string } // TODO: add icon export const CONNECTORS: ConnectorConfigType[] = [ { - name: 'Stripe', - fields: ['accound_id', 'client_secret'], + name: 'stripe', caption: 'Enter your Stripe credentials to link your Stripe to PostHog', disabledReason: null, + icon: stripeLogo, + }, + { + name: 'postgres', + caption: 'Enter your Postgres credentials to link your Postgres database to PostHog', + disabledReason: null, + icon: postgresLogo, }, ] +type FormTypes = 'input' | 'select' + +export interface FormPayloadType { + name: string + type: FormTypes + label: string +} + +export const FORM_PAYLOAD_TYPES: Record = { + stripe: [ + { + name: 'account_id', + type: 'input', + label: 'Account Id', + }, + { + name: 'client_secret', + type: 'input', + label: 'Client Secret', + }, + ], + postgres: [ + { + name: 'host', + type: 'input', + label: 'Host', + }, + { + name: 'port', + type: 'input', + label: 'Port', + }, + { + name: 'database', + type: 'input', + label: 'Database', + }, + { + name: 'username', + type: 'input', + label: 'Username', + }, + { + name: 'password', + type: 'input', + label: 'Password', + }, + ], +} + export const sourceModalLogic = kea([ path(['scenes', 'data-warehouse', 'external', 'sourceModalLogic']), actions({ @@ -79,16 +137,19 @@ export const sourceModalLogic = kea([ }), forms(() => ({ externalDataSource: { - defaults: { account_id: '', client_secret: '' } as ExternalDataSourceCreatePayload, + defaults: { account_id: '', client_secret: '' }, errors: ({ account_id, client_secret }) => { return { account_id: !account_id && 'Please enter an account id.', client_secret: !client_secret && 'Please enter a client secret.', } }, - submit: async (payload: ExternalDataSourceCreatePayload) => { - const newResource = await api.externalDataSources.create(payload) - return newResource + submit: async (payload) => { + await api.externalDataSources.create({ + payload, + payload_type: 'stripe', + } as ExternalDataSourceCreatePayload) + return payload }, }, })), diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 2de4b3c837364..8d0271a3b7519 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -3232,10 +3232,9 @@ export interface DataWarehouseViewLink { } export interface ExternalDataSourceCreatePayload { - account_id: string - client_secret: string + payload_type: string + payload: Record } - export interface ExternalDataSource { id: string source_id: string diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index a9737ed2a87ea..6c14ec0dc2644 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -7,7 +7,8 @@ from rest_framework import filters, serializers, viewsets from posthog.warehouse.models import ExternalDataSource from posthog.warehouse.external_data_source.workspace import get_or_create_workspace -from posthog.warehouse.external_data_source.source import StripeSourcePayload, create_stripe_source, delete_source +from posthog.warehouse.external_data_source.source import create_source, delete_source +from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING from posthog.warehouse.external_data_source.connection import ( create_connection, start_sync, @@ -83,16 +84,17 @@ def get_queryset(self): return self.queryset.filter(team_id=self.team_id).prefetch_related("created_by").order_by(self.ordering) def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: - account_id = request.data["account_id"] - client_secret = request.data["client_secret"] + payload = request.data["payload"] + payload_type = request.data["payload_type"] - workspace_id = get_or_create_workspace(self.team_id) + if payload_type not in SOURCE_TYPE_MAPPING.keys(): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"detail": f"Payload type {payload_type} is not supported."}, + ) - stripe_payload = StripeSourcePayload( - account_id=account_id, - client_secret=client_secret, - ) - new_source = create_stripe_source(stripe_payload, workspace_id) + workspace_id = get_or_create_workspace(self.team_id) + new_source = create_source(payload_type, payload, workspace_id) try: new_destination = create_destination(self.team_id, workspace_id) @@ -101,7 +103,7 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: raise e try: - new_connection = create_connection(new_source.source_id, new_destination.destination_id) + new_connection = create_connection(payload_type, new_source.source_id, new_destination.destination_id) except Exception as e: delete_source(new_source.source_id) delete_destination(new_destination.destination_id) @@ -113,7 +115,7 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: destination_id=new_destination.destination_id, team=self.team, status="running", - source_type="Stripe", + source_type=payload_type, ) start_sync(new_connection.connection_id) diff --git a/posthog/warehouse/external_data_source/connection.py b/posthog/warehouse/external_data_source/connection.py index c080907b3cc65..7095c4b3a2d8f 100644 --- a/posthog/warehouse/external_data_source/connection.py +++ b/posthog/warehouse/external_data_source/connection.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from posthog.warehouse.external_data_source.client import send_request from posthog.warehouse.models import ExternalDataSource +from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING import structlog from typing import List @@ -19,17 +20,24 @@ class ExternalDataConnection(BaseModel): workspace_id: str -def create_connection(source_id: str, destination_id: str) -> ExternalDataConnection: +def create_connection(source_type: str, source_id: str, destination_id: str) -> ExternalDataConnection: + default_streams_by_type = SOURCE_TYPE_MAPPING[source_type]["default_streams"] payload = { "schedule": {"scheduleType": "cron", "cronExpression": "0 0 0 * * ?"}, "namespaceFormat": None, "sourceId": source_id, "destinationId": destination_id, + "prefix": f"{source_type}_", } - response = send_request(AIRBYTE_CONNECTION_URL, method="POST", payload=payload) + if default_streams_by_type: + payload["configurations"] = { + "streams": [ + {"name": streamName, "syncMode": "full_refresh_overwrite"} for streamName in default_streams_by_type + ] + } - update_connection_stream(response["connectionId"], ["customers"]) + response = send_request(AIRBYTE_CONNECTION_URL, method="POST", payload=payload) return ExternalDataConnection( source_id=response["sourceId"], @@ -70,14 +78,12 @@ def update_connection_status_by_id(connection_id: str, status: str): def update_connection_stream(connection_id: str, streams: List): connection_id_url = f"{AIRBYTE_CONNECTION_URL}/{connection_id}" - # TODO: hardcoded to stripe stream right now payload = { "configurations": { "streams": [{"name": streamName, "syncMode": "full_refresh_overwrite"} for streamName in streams] }, "schedule": {"scheduleType": "cron", "cronExpression": "0 0 0 * * ?"}, "namespaceFormat": None, - "prefix": "stripe_", } send_request(connection_id_url, method="PATCH", payload=payload) diff --git a/posthog/warehouse/external_data_source/source.py b/posthog/warehouse/external_data_source/source.py index bacd99e812197..a00b67de99e23 100644 --- a/posthog/warehouse/external_data_source/source.py +++ b/posthog/warehouse/external_data_source/source.py @@ -1,43 +1,11 @@ -from posthog.models.utils import UUIDT -from pydantic import BaseModel, field_validator -from typing import Dict, Optional -import datetime as dt +from pydantic import BaseModel +from typing import Dict from posthog.warehouse.external_data_source.client import send_request +from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING AIRBYTE_SOURCE_URL = "https://api.airbyte.com/v1/sources" -class StripeSourcePayload(BaseModel): - account_id: str - client_secret: str - start_date: Optional[dt.datetime] = None - lookback_window_days: Optional[int] = None - slice_range: Optional[int] = None - - @field_validator("account_id") - @classmethod - def account_id_is_valid_uuid(cls, v: str) -> str: - try: - UUIDT.is_valid_uuid(v) - except ValueError: - raise ValueError("account_id must be a valid UUID.") - return v - - @field_validator("start_date") - @classmethod - def valid_iso_start_date(cls, v: Optional[str]) -> Optional[str]: - from posthog.batch_exports.http import validate_date_input - - if not v: - return v - - try: - validate_date_input(v) - except ValueError: - raise ValueError("start_date must be a valid ISO date string.") - return v - - class ExternalDataSource(BaseModel): source_id: str name: str @@ -45,28 +13,22 @@ class ExternalDataSource(BaseModel): workspace_id: str -def create_stripe_source(payload: StripeSourcePayload, workspace_id: str) -> ExternalDataSource: - optional_config = {} - if payload.start_date: - optional_config["start_date"] = payload.start_date.isoformat() - - if payload.lookback_window_days: - optional_config["lookback_window_days"] = payload.lookback_window_days +def create_source(source_type: str, payload: Dict, workspace_id: str) -> ExternalDataSource: + try: + source_payload = SOURCE_TYPE_MAPPING[source_type]["payload_type"](**payload) + except Exception as e: + raise ValueError(f"Invalid payload for source type {source_type}: {e}") - if payload.slice_range: - optional_config["slice_range"] = payload.slice_range - - payload = { + request_payload = { "configuration": { - "sourceType": "stripe", - "account_id": payload.account_id, - "client_secret": payload.client_secret, - **optional_config, + "sourceType": source_type, + **source_payload.dict(), }, - "name": "stripe source", + "name": f"{source_type} source", "workspaceId": workspace_id, } - return _create_source(payload) + + return _create_source(request_payload) def _create_source(payload: Dict) -> ExternalDataSource: diff --git a/posthog/warehouse/external_data_source/source_definitions.py b/posthog/warehouse/external_data_source/source_definitions.py new file mode 100644 index 0000000000000..2e60126089905 --- /dev/null +++ b/posthog/warehouse/external_data_source/source_definitions.py @@ -0,0 +1,67 @@ +from posthog.models.utils import UUIDT +from pydantic import BaseModel, field_validator +from typing import Optional +import datetime as dt + + +class PostgresSourcePayload(BaseModel): + host: str + port: int + database: str + schemas: Optional[list[str]] = None + username: str + password: Optional[str] = None + + +class SalesforceSourcePayload(BaseModel): + client_id: str + client_secret: str + refresh_token: str + sourceType: str = "salesforce" + + +class StripeSourcePayload(BaseModel): + account_id: str + client_secret: str + start_date: Optional[dt.datetime] = None + lookback_window_days: Optional[int] = None + slice_range: Optional[int] = None + + @field_validator("account_id") + @classmethod + def account_id_is_valid_uuid(cls, v: str) -> str: + try: + UUIDT.is_valid_uuid(v) + except ValueError: + raise ValueError("account_id must be a valid UUID.") + return v + + @field_validator("start_date") + @classmethod + def valid_iso_start_date(cls, v: Optional[str]) -> Optional[str]: + from posthog.batch_exports.http import validate_date_input + + if not v: + return v + + try: + validate_date_input(v) + except ValueError: + raise ValueError("start_date must be a valid ISO date string.") + return v + + +SOURCE_TYPE_MAPPING = { + "stripe": { + "payload_type": StripeSourcePayload, + "default_streams": ["customers"], + }, + "salesforce": { + "payload_type": SalesforceSourcePayload, + "default_streams": ["accounts"], + }, + "postgres": { + "payload_type": PostgresSourcePayload, + "default_streams": None, + }, +} diff --git a/posthog/warehouse/external_data_source/test/__init__.py b/posthog/warehouse/external_data_source/test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/posthog/warehouse/external_data_source/test/test_source.py b/posthog/warehouse/external_data_source/test/test_source.py new file mode 100644 index 0000000000000..ffa4d6e416e10 --- /dev/null +++ b/posthog/warehouse/external_data_source/test/test_source.py @@ -0,0 +1,69 @@ +from posthog.test.base import ( + APIBaseTest, +) +from unittest.mock import patch +from posthog.warehouse.external_data_source.source import create_source + + +class TestSource(APIBaseTest): + @patch("posthog.warehouse.external_data_source.source.send_request") + def test_create_stripe_source(self, send_request_mock): + send_request_mock.return_value = { + "sourceId": "123", + "name": "stripe source", + "sourceType": "stripe", + "workspaceId": "456", + } + + source_payload = { + "account_id": "some_account_id", + "client_secret": "some_secret", + } + + data_source = create_source("stripe", source_payload, "456") + + self.assertEqual(data_source.source_id, "123") + self.assertEqual(data_source.name, "stripe source") + self.assertEqual(data_source.source_type, "stripe") + self.assertEqual(data_source.workspace_id, "456") + + @patch("posthog.warehouse.external_data_source.source.send_request") + def test_create_salesforce_source(self, send_request_mock): + send_request_mock.return_value = { + "sourceId": "123", + "name": "salesforce source", + "sourceType": "salesforce", + "workspaceId": "456", + } + + source_payload = {"client_id": "some_account_id", "client_secret": "some_secret", "refresh_token": "some_token"} + + data_source = create_source("salesforce", source_payload, "456") + + self.assertEqual(data_source.source_id, "123") + self.assertEqual(data_source.name, "salesforce source") + self.assertEqual(data_source.source_type, "salesforce") + self.assertEqual(data_source.workspace_id, "456") + + @patch("posthog.warehouse.external_data_source.source.send_request") + def test_create_postgres_source(self, send_request_mock): + send_request_mock.return_value = { + "sourceId": "123", + "name": "postgres source", + "sourceType": "postgres", + "workspaceId": "456", + } + + source_payload = { + "host": "localhost", + "port": 5432, + "database": "test-db", + "username": "posthog", + } + + data_source = create_source("postgres", source_payload, "456") + + self.assertEqual(data_source.source_id, "123") + self.assertEqual(data_source.name, "postgres source") + self.assertEqual(data_source.source_type, "postgres") + self.assertEqual(data_source.workspace_id, "456")