Skip to content

Commit

Permalink
fix(data-warehouse): add schema id for validation step and use snake …
Browse files Browse the repository at this point in the history
…case (#20840)

* add schema id for validation step and use snake case

* fix tests

* more typing
  • Loading branch information
EDsCODE authored Mar 12, 2024
1 parent 5f22a34 commit 989c065
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 38 deletions.
14 changes: 8 additions & 6 deletions posthog/temporal/data_imports/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CreateExternalDataJobInputs:


@activity.defn
async def create_external_data_job_model(inputs: CreateExternalDataJobInputs) -> Tuple[str, list[str]]:
async def create_external_data_job_model(inputs: CreateExternalDataJobInputs) -> Tuple[str, list[Tuple[str, str]]]:
run = await sync_to_async(create_external_data_job)(
team_id=inputs.team_id,
external_data_source_id=inputs.external_data_source_id,
Expand Down Expand Up @@ -105,7 +105,7 @@ async def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInpu
class ValidateSchemaInputs:
run_id: str
team_id: int
schemas: list[str]
schemas: list[Tuple[str, str]]


@activity.defn
Expand Down Expand Up @@ -133,7 +133,7 @@ class ExternalDataJobInputs:
team_id: int
source_id: uuid.UUID
run_id: str
schemas: list[str]
schemas: list[Tuple[str, str]]


@activity.defn
Expand All @@ -153,6 +153,8 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None:
dataset_name=model.folder_path,
)

endpoints = [schema[1] for schema in inputs.schemas]

source = None
if model.pipeline.source_type == ExternalDataSource.Type.STRIPE:
from posthog.temporal.data_imports.pipelines.stripe.helpers import stripe_source
Expand All @@ -162,7 +164,7 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None:
raise ValueError(f"Stripe secret key not found for job {model.id}")
source = stripe_source(
api_key=stripe_secret_key,
endpoints=tuple(inputs.schemas),
endpoints=tuple(endpoints),
team_id=inputs.team_id,
job_id=inputs.run_id,
)
Expand All @@ -181,7 +183,7 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None:
source = hubspot(
api_key=hubspot_access_code,
refresh_token=refresh_token,
endpoints=tuple(inputs.schemas),
endpoints=tuple(endpoints),
)
elif model.pipeline.source_type == ExternalDataSource.Type.POSTGRES:
from posthog.temporal.data_imports.pipelines.postgres import postgres_source
Expand All @@ -201,7 +203,7 @@ async def run_external_data_job(inputs: ExternalDataJobInputs) -> None:
database=database,
sslmode="prefer" if settings.TEST or settings.DEBUG else "require",
schema=schema,
table_names=inputs.schemas,
table_names=endpoints,
)

else:
Expand Down
2 changes: 1 addition & 1 deletion posthog/temporal/data_imports/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class PipelineInputs:
source_id: UUID
run_id: str
schemas: list[str]
schemas: list[tuple[str, str]]
dataset_name: str
job_type: str
team_id: int
Expand Down
89 changes: 72 additions & 17 deletions posthog/temporal/tests/external_data/test_external_data_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from unittest import mock

from typing import Optional
import pytest
from asgiref.sync import sync_to_async
from django.test import override_settings
Expand Down Expand Up @@ -32,6 +32,7 @@
from posthog.temporal.data_imports.pipelines.schemas import (
PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING,
)
from posthog.models import Team
from posthog.temporal.data_imports.pipelines.pipeline import DataImportPipeline
from temporalio.testing import WorkflowEnvironment
from temporalio.common import RetryPolicy
Expand Down Expand Up @@ -118,6 +119,15 @@ async def postgres_connection(postgres_config, setup_postgres_test_db):
await connection.close()


async def _create_schema(schema_name: str, source: ExternalDataSource, team: Team, table_id: Optional[str] = None):
return await sync_to_async(ExternalDataSchema.objects.create)(
name=schema_name,
team_id=team.id,
source_id=source.pk,
table_id=table_id,
)


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_create_external_job_activity(activity_environment, team, **kwargs):
Expand Down Expand Up @@ -232,7 +242,9 @@ async def setup_job_1():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)()

schemas = ["Customer"]
customer_schema = await _create_schema("Customer", new_source, team)
schemas = [(customer_schema.id, "Customer")]

inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand Down Expand Up @@ -262,7 +274,9 @@ async def setup_job_2():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)()

schemas = ["Customer", "Invoice"]
customer_schema = await _create_schema("Customer", new_source, team)
invoice_schema = await _create_schema("Invoice", new_source, team)
schemas = [(customer_schema.id, "Customer"), (invoice_schema.id, "Invoice")]
inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand Down Expand Up @@ -350,7 +364,8 @@ async def setup_job_1():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)()

schemas = ["Customer"]
customer_schema = await _create_schema("Customer", new_source, team)
schemas = [(customer_schema.id, "Customer")]
inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand Down Expand Up @@ -414,7 +429,8 @@ async def setup_job_1():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)()

schemas = ["Customer"]
customer_schema = await _create_schema("Customer", new_source, team)
schemas = [(customer_schema.id, "Customer")]
inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand Down Expand Up @@ -476,15 +492,26 @@ async def test_validate_schema_and_update_table_activity(activity_environment, t
rows_synced=0,
)

test_1_schema = await _create_schema("test-1", new_source, team)
test_2_schema = await _create_schema("test-2", new_source, team)
test_3_schema = await _create_schema("test-3", new_source, team)
test_4_schema = await _create_schema("test-4", new_source, team)
test_5_schema = await _create_schema("test-5", new_source, team)
schemas = [
(test_1_schema.id, "test-1"),
(test_2_schema.id, "test-2"),
(test_3_schema.id, "test-3"),
(test_4_schema.id, "test-4"),
(test_5_schema.id, "test-5"),
]

with mock.patch(
"posthog.warehouse.models.table.DataWarehouseTable.get_columns"
) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS):
mock_get_columns.return_value = {"id": "string"}
await activity_environment.run(
validate_schema_activity,
ValidateSchemaInputs(
run_id=new_job.pk, team_id=team.id, schemas=["test-1", "test-2", "test-3", "test-4", "test-5"]
),
ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=schemas),
)

assert mock_get_columns.call_count == 10
Expand All @@ -504,6 +531,7 @@ async def test_validate_schema_and_update_table_activity_with_existing(activity_
status="running",
source_type="Stripe",
job_inputs={"stripe_secret_key": "test-key"},
prefix="stripe_",
)

old_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)(
Expand All @@ -521,7 +549,7 @@ async def test_validate_schema_and_update_table_activity_with_existing(activity_

url_pattern = await sync_to_async(old_job.url_pattern_by_schema)("test-1")

await sync_to_async(DataWarehouseTable.objects.create)(
existing_table = await sync_to_async(DataWarehouseTable.objects.create)(
credential=old_credential,
name="stripe_test-1",
format="Parquet",
Expand All @@ -537,15 +565,26 @@ async def test_validate_schema_and_update_table_activity_with_existing(activity_
rows_synced=0,
)

test_1_schema = await _create_schema("test-1", new_source, team, table_id=existing_table.id)
test_2_schema = await _create_schema("test-2", new_source, team)
test_3_schema = await _create_schema("test-3", new_source, team)
test_4_schema = await _create_schema("test-4", new_source, team)
test_5_schema = await _create_schema("test-5", new_source, team)
schemas = [
(test_1_schema.id, "test-1"),
(test_2_schema.id, "test-2"),
(test_3_schema.id, "test-3"),
(test_4_schema.id, "test-4"),
(test_5_schema.id, "test-5"),
]

with mock.patch(
"posthog.warehouse.models.table.DataWarehouseTable.get_columns"
) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS):
mock_get_columns.return_value = {"id": "string"}
await activity_environment.run(
validate_schema_activity,
ValidateSchemaInputs(
run_id=new_job.pk, team_id=team.id, schemas=["test-1", "test-2", "test-3", "test-4", "test-5"]
),
ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=schemas),
)

assert mock_get_columns.call_count == 10
Expand Down Expand Up @@ -595,9 +634,13 @@ async def test_validate_schema_and_update_table_activity_half_run(activity_envir
},
]

broken_schema = await _create_schema("broken_schema", new_source, team)
test_schema = await _create_schema("test_schema", new_source, team)
schemas = [(broken_schema.id, "broken_schema"), (test_schema.id, "test_schema")]

await activity_environment.run(
validate_schema_activity,
ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=["broken_schema", "test_schema"]),
ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=schemas),
)

assert mock_get_columns.call_count == 1
Expand Down Expand Up @@ -626,15 +669,26 @@ async def test_create_schema_activity(activity_environment, team, **kwargs):
rows_synced=0,
)

test_1_schema = await _create_schema("test-1", new_source, team)
test_2_schema = await _create_schema("test-2", new_source, team)
test_3_schema = await _create_schema("test-3", new_source, team)
test_4_schema = await _create_schema("test-4", new_source, team)
test_5_schema = await _create_schema("test-5", new_source, team)
schemas = [
(test_1_schema.id, "test-1"),
(test_2_schema.id, "test-2"),
(test_3_schema.id, "test-3"),
(test_4_schema.id, "test-4"),
(test_5_schema.id, "test-5"),
]

with mock.patch(
"posthog.warehouse.models.table.DataWarehouseTable.get_columns"
) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS):
mock_get_columns.return_value = {"id": "string"}
await activity_environment.run(
validate_schema_activity,
ValidateSchemaInputs(
run_id=new_job.pk, team_id=team.id, schemas=["test-1", "test-2", "test-3", "test-4", "test-5"]
),
ValidateSchemaInputs(run_id=new_job.pk, team_id=team.id, schemas=schemas),
)

assert mock_get_columns.call_count == 10
Expand Down Expand Up @@ -802,7 +856,8 @@ async def setup_job_1():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)()

schemas = ["posthog_test"]
posthog_test_schema = await _create_schema("posthog_test", new_source, team)
schemas = [(posthog_test_schema.id, "posthog_test")]
inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand Down
4 changes: 4 additions & 0 deletions posthog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,3 +1329,7 @@ def label_for_team_id_to_track(team_id: int) -> str:
pass

return "unknown"


def camel_to_snake_case(name: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
27 changes: 14 additions & 13 deletions posthog/warehouse/data_load/validate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
from posthog.warehouse.models import (
get_latest_run_if_exists,
get_or_create_datawarehouse_credential,
get_table_by_url_pattern_and_source,
DataWarehouseTable,
DataWarehouseCredential,
aget_schema_if_exists,
get_external_data_job,
asave_datawarehousetable,
acreate_datawarehousetable,
asave_external_data_schema,
get_table_by_schema_id,
aget_schema_by_id,
)
from posthog.warehouse.models.external_data_job import ExternalDataJob
from posthog.temporal.common.logger import bind_temporal_worker_logger
from clickhouse_driver.errors import ServerException
from asgiref.sync import sync_to_async
from typing import Dict
from typing import Dict, Tuple
from posthog.utils import camel_to_snake_case


async def validate_schema(
Expand All @@ -42,7 +43,7 @@ async def validate_schema(
}


async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[str]) -> None:
async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: list[Tuple[str, str]]) -> None:
"""
Validates the schemas of data that has been synced by external data job.
Expand All @@ -65,9 +66,12 @@ async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: l
access_secret=settings.AIRBYTE_BUCKET_SECRET,
)

for _schema_name in schemas:
for _schema in schemas:
_schema_id = _schema[0]
_schema_name = _schema[1]

table_name = f"{job.pipeline.prefix or ''}{job.pipeline.source_type}_{_schema_name}".lower()
new_url_pattern = job.url_pattern_by_schema(_schema_name)
new_url_pattern = job.url_pattern_by_schema(camel_to_snake_case(_schema_name))

# Check
try:
Expand All @@ -92,11 +96,10 @@ async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: l
# create or update
table_created = None
if last_successful_job:
old_url_pattern = last_successful_job.url_pattern_by_schema(_schema_name)
try:
table_created = await get_table_by_url_pattern_and_source(
team_id=job.team_id, source_id=job.pipeline.id, url_pattern=old_url_pattern
)
table_created = await get_table_by_schema_id(_schema_id, team_id)
if not table_created:
raise DataWarehouseTable.DoesNotExist
except Exception:
table_created = None
else:
Expand All @@ -111,9 +114,7 @@ async def validate_schema_and_update_table(run_id: str, team_id: int, schemas: l
await asave_datawarehousetable(table_created)

# schema could have been deleted by this point
schema_model = await aget_schema_if_exists(
schema_name=_schema_name, team_id=job.team_id, source_id=job.pipeline.id
)
schema_model = await aget_schema_by_id(schema_id=_schema_id, team_id=job.team_id)

if schema_model:
schema_model.table = table_created
Expand Down
7 changes: 6 additions & 1 deletion posthog/warehouse/models/external_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def aget_schema_if_exists(schema_name: str, team_id: int, source_id: uuid.UUID)
return get_schema_if_exists(schema_name=schema_name, team_id=team_id, source_id=source_id)


@database_sync_to_async
def aget_schema_by_id(schema_id: str, team_id: int) -> ExternalDataSchema | None:
return ExternalDataSchema.objects.get(id=schema_id, team_id=team_id)


def get_active_schemas_for_source_id(source_id: uuid.UUID, team_id: int):
schemas = ExternalDataSchema.objects.filter(team_id=team_id, source_id=source_id, should_sync=True).values().all()
return [val["name"] for val in schemas]
return [(val["id"], val["name"]) for val in schemas]


def get_all_schemas_for_source_id(source_id: uuid.UUID, team_id: int):
Expand Down
6 changes: 6 additions & 0 deletions posthog/warehouse/models/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
sane_repr,
)
from posthog.warehouse.models.util import remove_named_tuples
from posthog.warehouse.models.external_data_schema import ExternalDataSchema
from django.db.models import Q
from .credential import DataWarehouseCredential
from uuid import UUID
Expand Down Expand Up @@ -154,6 +155,11 @@ def get_table_by_url_pattern_and_source(url_pattern: str, source_id: UUID, team_
)


@database_sync_to_async
def get_table_by_schema_id(schema_id: str, team_id: int):
return ExternalDataSchema.objects.get(id=schema_id, team_id=team_id).table


@database_sync_to_async
def acreate_datawarehousetable(**kwargs):
return DataWarehouseTable.objects.create(**kwargs)
Expand Down

0 comments on commit 989c065

Please sign in to comment.