{
diff --git a/frontend/src/scenes/data-warehouse/external/SourceModal.tsx b/frontend/src/scenes/data-warehouse/external/SourceModal.tsx
index 2308ec31b3140..69afb95960889 100644
--- a/frontend/src/scenes/data-warehouse/external/SourceModal.tsx
+++ b/frontend/src/scenes/data-warehouse/external/SourceModal.tsx
@@ -1,10 +1,9 @@
import { LemonButton, LemonDivider, LemonInput, LemonModal, LemonModalProps } from '@posthog/lemon-ui'
import { Form } from 'kea-forms'
-import { ConnectorConfigType, sourceModalLogic } from './sourceModalLogic'
+import { ConnectorConfigType, FORM_PAYLOAD_TYPES, FormPayloadType, sourceModalLogic } from './sourceModalLogic'
import { useActions, useValues } from 'kea'
import { DatawarehouseTableForm } from '../new_table/DataWarehouseTableForm'
import { Field } from 'lib/forms/Field'
-import stripeLogo from 'public/stripe-logo.svg'
interface SourceModalProps extends LemonModalProps {}
@@ -21,7 +20,7 @@ export default function SourceModal(props: SourceModalProps): JSX.Element {
return (
-
+
)
}
@@ -37,38 +36,40 @@ export default function SourceModal(props: SourceModalProps): JSX.Element {
toggleManualLinkFormVisible(true)
}
- const formToShow = (): JSX.Element => {
+ const formPayloadTypeToField = (formPayloadType: FormPayloadType): JSX.Element => {
+ return (
+
+
+
+ )
+ }
+
+ const buildPayloadTypeForm = (payloadType: string): JSX.Element => {
+ return (
+
+ )
+ }
+
+ const formToShow = (selectedConnector: ConnectorConfigType): JSX.Element => {
if (selectedConnector) {
- return (
-
- )
+ return buildPayloadTypeForm(selectedConnector.name)
}
if (isManualLinkFormVisible) {
@@ -104,16 +105,7 @@ export default function SourceModal(props: SourceModalProps): JSX.Element {
)
}
- return (
-
- {connectors.map((config, index) => (
-
- ))}
-
- Manual Link
-
-
- )
+ return <>>
}
return (
@@ -123,7 +115,18 @@ export default function SourceModal(props: SourceModalProps): JSX.Element {
title="Data Sources"
description={selectedConnector ? selectedConnector.caption : null}
>
- {formToShow()}
+ {selectedConnector ? (
+ formToShow(selectedConnector)
+ ) : (
+
+ {connectors.map((config, index) => (
+
+ ))}
+
+ Manual Link
+
+
+ )}
)
}
diff --git a/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts b/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts
index 4aa757303d708..4051c61d77ff4 100644
--- a/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts
+++ b/frontend/src/scenes/data-warehouse/external/sourceModalLogic.ts
@@ -10,24 +10,82 @@ import { dataWarehouseSceneLogic } from './dataWarehouseSceneLogic'
import { router } from 'kea-router'
import { urls } from 'scenes/urls'
import { dataWarehouseSettingsLogic } from '../settings/dataWarehouseSettingsLogic'
+import stripeLogo from 'public/stripe-logo.svg'
+import postgresLogo from 'public/postgres-logo.svg'
export interface ConnectorConfigType {
name: string
- fields: string[]
caption: string
disabledReason: string | null
+ icon: string
}
// TODO: add icon
export const CONNECTORS: ConnectorConfigType[] = [
{
- name: 'Stripe',
- fields: ['accound_id', 'client_secret'],
+ name: 'stripe',
caption: 'Enter your Stripe credentials to link your Stripe to PostHog',
disabledReason: null,
+ icon: stripeLogo,
+ },
+ {
+ name: 'postgres',
+ caption: 'Enter your Postgres credentials to link your Postgres database to PostHog',
+ disabledReason: null,
+ icon: postgresLogo,
},
]
+type FormTypes = 'input' | 'select'
+
+export interface FormPayloadType {
+ name: string
+ type: FormTypes
+ label: string
+}
+
+export const FORM_PAYLOAD_TYPES: Record = {
+ stripe: [
+ {
+ name: 'account_id',
+ type: 'input',
+ label: 'Account Id',
+ },
+ {
+ name: 'client_secret',
+ type: 'input',
+ label: 'Client Secret',
+ },
+ ],
+ postgres: [
+ {
+ name: 'host',
+ type: 'input',
+ label: 'Host',
+ },
+ {
+ name: 'port',
+ type: 'input',
+ label: 'Port',
+ },
+ {
+ name: 'database',
+ type: 'input',
+ label: 'Database',
+ },
+ {
+ name: 'username',
+ type: 'input',
+ label: 'Username',
+ },
+ {
+ name: 'password',
+ type: 'input',
+ label: 'Password',
+ },
+ ],
+}
+
export const sourceModalLogic = kea([
path(['scenes', 'data-warehouse', 'external', 'sourceModalLogic']),
actions({
@@ -79,16 +137,19 @@ export const sourceModalLogic = kea([
}),
forms(() => ({
externalDataSource: {
- defaults: { account_id: '', client_secret: '' } as ExternalDataSourceCreatePayload,
+ defaults: { account_id: '', client_secret: '' },
errors: ({ account_id, client_secret }) => {
return {
account_id: !account_id && 'Please enter an account id.',
client_secret: !client_secret && 'Please enter a client secret.',
}
},
- submit: async (payload: ExternalDataSourceCreatePayload) => {
- const newResource = await api.externalDataSources.create(payload)
- return newResource
+ submit: async (payload) => {
+ await api.externalDataSources.create({
+ payload,
+ payload_type: 'stripe',
+ } as ExternalDataSourceCreatePayload)
+ return payload
},
},
})),
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index 2de4b3c837364..8d0271a3b7519 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -3232,10 +3232,9 @@ export interface DataWarehouseViewLink {
}
export interface ExternalDataSourceCreatePayload {
- account_id: string
- client_secret: string
+ payload_type: string
+ payload: Record
}
-
export interface ExternalDataSource {
id: string
source_id: string
diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py
index a9737ed2a87ea..6c14ec0dc2644 100644
--- a/posthog/warehouse/api/external_data_source.py
+++ b/posthog/warehouse/api/external_data_source.py
@@ -7,7 +7,8 @@
from rest_framework import filters, serializers, viewsets
from posthog.warehouse.models import ExternalDataSource
from posthog.warehouse.external_data_source.workspace import get_or_create_workspace
-from posthog.warehouse.external_data_source.source import StripeSourcePayload, create_stripe_source, delete_source
+from posthog.warehouse.external_data_source.source import create_source, delete_source
+from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING
from posthog.warehouse.external_data_source.connection import (
create_connection,
start_sync,
@@ -83,16 +84,17 @@ def get_queryset(self):
return self.queryset.filter(team_id=self.team_id).prefetch_related("created_by").order_by(self.ordering)
def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
- account_id = request.data["account_id"]
- client_secret = request.data["client_secret"]
+ payload = request.data["payload"]
+ payload_type = request.data["payload_type"]
- workspace_id = get_or_create_workspace(self.team_id)
+ if payload_type not in SOURCE_TYPE_MAPPING.keys():
+ return Response(
+ status=status.HTTP_400_BAD_REQUEST,
+ data={"detail": f"Payload type {payload_type} is not supported."},
+ )
- stripe_payload = StripeSourcePayload(
- account_id=account_id,
- client_secret=client_secret,
- )
- new_source = create_stripe_source(stripe_payload, workspace_id)
+ workspace_id = get_or_create_workspace(self.team_id)
+ new_source = create_source(payload_type, payload, workspace_id)
try:
new_destination = create_destination(self.team_id, workspace_id)
@@ -101,7 +103,7 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
raise e
try:
- new_connection = create_connection(new_source.source_id, new_destination.destination_id)
+ new_connection = create_connection(payload_type, new_source.source_id, new_destination.destination_id)
except Exception as e:
delete_source(new_source.source_id)
delete_destination(new_destination.destination_id)
@@ -113,7 +115,7 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
destination_id=new_destination.destination_id,
team=self.team,
status="running",
- source_type="Stripe",
+ source_type=payload_type,
)
start_sync(new_connection.connection_id)
diff --git a/posthog/warehouse/external_data_source/connection.py b/posthog/warehouse/external_data_source/connection.py
index c080907b3cc65..7095c4b3a2d8f 100644
--- a/posthog/warehouse/external_data_source/connection.py
+++ b/posthog/warehouse/external_data_source/connection.py
@@ -1,6 +1,7 @@
from pydantic import BaseModel
from posthog.warehouse.external_data_source.client import send_request
from posthog.warehouse.models import ExternalDataSource
+from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING
import structlog
from typing import List
@@ -19,17 +20,24 @@ class ExternalDataConnection(BaseModel):
workspace_id: str
-def create_connection(source_id: str, destination_id: str) -> ExternalDataConnection:
+def create_connection(source_type: str, source_id: str, destination_id: str) -> ExternalDataConnection:
+ default_streams_by_type = SOURCE_TYPE_MAPPING[source_type]["default_streams"]
payload = {
"schedule": {"scheduleType": "cron", "cronExpression": "0 0 0 * * ?"},
"namespaceFormat": None,
"sourceId": source_id,
"destinationId": destination_id,
+ "prefix": f"{source_type}_",
}
- response = send_request(AIRBYTE_CONNECTION_URL, method="POST", payload=payload)
+ if default_streams_by_type:
+ payload["configurations"] = {
+ "streams": [
+ {"name": streamName, "syncMode": "full_refresh_overwrite"} for streamName in default_streams_by_type
+ ]
+ }
- update_connection_stream(response["connectionId"], ["customers"])
+ response = send_request(AIRBYTE_CONNECTION_URL, method="POST", payload=payload)
return ExternalDataConnection(
source_id=response["sourceId"],
@@ -70,14 +78,12 @@ def update_connection_status_by_id(connection_id: str, status: str):
def update_connection_stream(connection_id: str, streams: List):
connection_id_url = f"{AIRBYTE_CONNECTION_URL}/{connection_id}"
- # TODO: hardcoded to stripe stream right now
payload = {
"configurations": {
"streams": [{"name": streamName, "syncMode": "full_refresh_overwrite"} for streamName in streams]
},
"schedule": {"scheduleType": "cron", "cronExpression": "0 0 0 * * ?"},
"namespaceFormat": None,
- "prefix": "stripe_",
}
send_request(connection_id_url, method="PATCH", payload=payload)
diff --git a/posthog/warehouse/external_data_source/source.py b/posthog/warehouse/external_data_source/source.py
index bacd99e812197..a00b67de99e23 100644
--- a/posthog/warehouse/external_data_source/source.py
+++ b/posthog/warehouse/external_data_source/source.py
@@ -1,43 +1,11 @@
-from posthog.models.utils import UUIDT
-from pydantic import BaseModel, field_validator
-from typing import Dict, Optional
-import datetime as dt
+from pydantic import BaseModel
+from typing import Dict
from posthog.warehouse.external_data_source.client import send_request
+from posthog.warehouse.external_data_source.source_definitions import SOURCE_TYPE_MAPPING
AIRBYTE_SOURCE_URL = "https://api.airbyte.com/v1/sources"
-class StripeSourcePayload(BaseModel):
- account_id: str
- client_secret: str
- start_date: Optional[dt.datetime] = None
- lookback_window_days: Optional[int] = None
- slice_range: Optional[int] = None
-
- @field_validator("account_id")
- @classmethod
- def account_id_is_valid_uuid(cls, v: str) -> str:
- try:
- UUIDT.is_valid_uuid(v)
- except ValueError:
- raise ValueError("account_id must be a valid UUID.")
- return v
-
- @field_validator("start_date")
- @classmethod
- def valid_iso_start_date(cls, v: Optional[str]) -> Optional[str]:
- from posthog.batch_exports.http import validate_date_input
-
- if not v:
- return v
-
- try:
- validate_date_input(v)
- except ValueError:
- raise ValueError("start_date must be a valid ISO date string.")
- return v
-
-
class ExternalDataSource(BaseModel):
source_id: str
name: str
@@ -45,28 +13,22 @@ class ExternalDataSource(BaseModel):
workspace_id: str
-def create_stripe_source(payload: StripeSourcePayload, workspace_id: str) -> ExternalDataSource:
- optional_config = {}
- if payload.start_date:
- optional_config["start_date"] = payload.start_date.isoformat()
-
- if payload.lookback_window_days:
- optional_config["lookback_window_days"] = payload.lookback_window_days
+def create_source(source_type: str, payload: Dict, workspace_id: str) -> ExternalDataSource:
+ try:
+ source_payload = SOURCE_TYPE_MAPPING[source_type]["payload_type"](**payload)
+ except Exception as e:
+ raise ValueError(f"Invalid payload for source type {source_type}: {e}")
- if payload.slice_range:
- optional_config["slice_range"] = payload.slice_range
-
- payload = {
+ request_payload = {
"configuration": {
- "sourceType": "stripe",
- "account_id": payload.account_id,
- "client_secret": payload.client_secret,
- **optional_config,
+ "sourceType": source_type,
+ **source_payload.dict(),
},
- "name": "stripe source",
+ "name": f"{source_type} source",
"workspaceId": workspace_id,
}
- return _create_source(payload)
+
+ return _create_source(request_payload)
def _create_source(payload: Dict) -> ExternalDataSource:
diff --git a/posthog/warehouse/external_data_source/source_definitions.py b/posthog/warehouse/external_data_source/source_definitions.py
new file mode 100644
index 0000000000000..2e60126089905
--- /dev/null
+++ b/posthog/warehouse/external_data_source/source_definitions.py
@@ -0,0 +1,67 @@
+from posthog.models.utils import UUIDT
+from pydantic import BaseModel, field_validator
+from typing import Optional
+import datetime as dt
+
+
+class PostgresSourcePayload(BaseModel):
+ host: str
+ port: int
+ database: str
+ schemas: Optional[list[str]] = None
+ username: str
+ password: Optional[str] = None
+
+
+class SalesforceSourcePayload(BaseModel):
+ client_id: str
+ client_secret: str
+ refresh_token: str
+ sourceType: str = "salesforce"
+
+
+class StripeSourcePayload(BaseModel):
+ account_id: str
+ client_secret: str
+ start_date: Optional[dt.datetime] = None
+ lookback_window_days: Optional[int] = None
+ slice_range: Optional[int] = None
+
+ @field_validator("account_id")
+ @classmethod
+ def account_id_is_valid_uuid(cls, v: str) -> str:
+ try:
+ UUIDT.is_valid_uuid(v)
+ except ValueError:
+ raise ValueError("account_id must be a valid UUID.")
+ return v
+
+ @field_validator("start_date")
+ @classmethod
+ def valid_iso_start_date(cls, v: Optional[str]) -> Optional[str]:
+ from posthog.batch_exports.http import validate_date_input
+
+ if not v:
+ return v
+
+ try:
+ validate_date_input(v)
+ except ValueError:
+ raise ValueError("start_date must be a valid ISO date string.")
+ return v
+
+
+SOURCE_TYPE_MAPPING = {
+ "stripe": {
+ "payload_type": StripeSourcePayload,
+ "default_streams": ["customers"],
+ },
+ "salesforce": {
+ "payload_type": SalesforceSourcePayload,
+ "default_streams": ["accounts"],
+ },
+ "postgres": {
+ "payload_type": PostgresSourcePayload,
+ "default_streams": None,
+ },
+}
diff --git a/posthog/warehouse/external_data_source/test/__init__.py b/posthog/warehouse/external_data_source/test/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/posthog/warehouse/external_data_source/test/test_source.py b/posthog/warehouse/external_data_source/test/test_source.py
new file mode 100644
index 0000000000000..ffa4d6e416e10
--- /dev/null
+++ b/posthog/warehouse/external_data_source/test/test_source.py
@@ -0,0 +1,69 @@
+from posthog.test.base import (
+ APIBaseTest,
+)
+from unittest.mock import patch
+from posthog.warehouse.external_data_source.source import create_source
+
+
+class TestSource(APIBaseTest):
+ @patch("posthog.warehouse.external_data_source.source.send_request")
+ def test_create_stripe_source(self, send_request_mock):
+ send_request_mock.return_value = {
+ "sourceId": "123",
+ "name": "stripe source",
+ "sourceType": "stripe",
+ "workspaceId": "456",
+ }
+
+ source_payload = {
+ "account_id": "some_account_id",
+ "client_secret": "some_secret",
+ }
+
+ data_source = create_source("stripe", source_payload, "456")
+
+ self.assertEqual(data_source.source_id, "123")
+ self.assertEqual(data_source.name, "stripe source")
+ self.assertEqual(data_source.source_type, "stripe")
+ self.assertEqual(data_source.workspace_id, "456")
+
+ @patch("posthog.warehouse.external_data_source.source.send_request")
+ def test_create_salesforce_source(self, send_request_mock):
+ send_request_mock.return_value = {
+ "sourceId": "123",
+ "name": "salesforce source",
+ "sourceType": "salesforce",
+ "workspaceId": "456",
+ }
+
+ source_payload = {"client_id": "some_account_id", "client_secret": "some_secret", "refresh_token": "some_token"}
+
+ data_source = create_source("salesforce", source_payload, "456")
+
+ self.assertEqual(data_source.source_id, "123")
+ self.assertEqual(data_source.name, "salesforce source")
+ self.assertEqual(data_source.source_type, "salesforce")
+ self.assertEqual(data_source.workspace_id, "456")
+
+ @patch("posthog.warehouse.external_data_source.source.send_request")
+ def test_create_postgres_source(self, send_request_mock):
+ send_request_mock.return_value = {
+ "sourceId": "123",
+ "name": "postgres source",
+ "sourceType": "postgres",
+ "workspaceId": "456",
+ }
+
+ source_payload = {
+ "host": "localhost",
+ "port": 5432,
+ "database": "test-db",
+ "username": "posthog",
+ }
+
+ data_source = create_source("postgres", source_payload, "456")
+
+ self.assertEqual(data_source.source_id, "123")
+ self.assertEqual(data_source.name, "postgres source")
+ self.assertEqual(data_source.source_type, "postgres")
+ self.assertEqual(data_source.workspace_id, "456")