Skip to content

Commit

Permalink
Fixes for 100% pass rate in end_to_end tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 committed Dec 3, 2024
1 parent 27b47e7 commit 8e6ab69
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 73 deletions.
10 changes: 8 additions & 2 deletions posthog/hogql/database/s3_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ def return_expr(expr: str) -> str:
query_folder = "__query_v2" if pipeline_version == ExternalDataJob.PipelineVersion.V2 else "__query"

if url.endswith("/"):
escaped_url = add_param(f"{url[:len(url) - 1]}{query_folder}/*.parquet")
if pipeline_version == ExternalDataJob.PipelineVersion.V2:
escaped_url = add_param(f"{url[:-5]}{query_folder}/*.parquet")
else:
escaped_url = add_param(f"{url[:-1]}{query_folder}/*.parquet")
else:
escaped_url = add_param(f"{url}{query_folder}/*.parquet")
if pipeline_version == ExternalDataJob.PipelineVersion.V2:
escaped_url = add_param(f"{url[:-4]}{query_folder}/*.parquet")
else:
escaped_url = add_param(f"{url}{query_folder}/*.parquet")

if structure:
escaped_structure = add_param(structure, False)
Expand Down
3 changes: 2 additions & 1 deletion posthog/temporal/data_imports/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2

# TODO: remove dependency
from posthog.settings.base_variables import TEST
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.common.client import sync_connect
from posthog.temporal.data_imports.workflow_activities.check_billing_limits import (
Expand Down Expand Up @@ -183,7 +184,7 @@ def parse_inputs(inputs: list[str]) -> ExternalDataWorkflowInputs:
async def run(self, inputs: ExternalDataWorkflowInputs):
assert inputs.external_data_schema_id is not None

if settings.TEMPORAL_TASK_QUEUE != DATA_WAREHOUSE_TASK_QUEUE_V2:
if settings.TEMPORAL_TASK_QUEUE != DATA_WAREHOUSE_TASK_QUEUE_V2 and not TEST:
await workflow.execute_activity(
trigger_pipeline_v2,
inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _get_credentials(self):
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"endpoint_url": settings.OBJECT_STORAGE_ENDPOINT,
"region_name": settings.AIRBYTE_BUCKET_REGION,
"AWS_DEFAULT_REGION": settings.AIRBYTE_BUCKET_REGION,
"AWS_ALLOW_HTTP": "true",
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}
Expand Down Expand Up @@ -95,7 +96,10 @@ def write_to_deltalake(
schema_mode = "overwrite"

if delta_table is None:
delta_table = deltalake.DeltaTable.create(table_uri=self._get_delta_table_uri(), schema=data.schema)
storage_options = self._get_credentials()
delta_table = deltalake.DeltaTable.create(
table_uri=self._get_delta_table_uri(), schema=data.schema, storage_options=storage_options
)

deltalake.write_deltalake(
table_or_uri=delta_table,
Expand Down
5 changes: 5 additions & 0 deletions posthog/temporal/data_imports/pipelines/pipeline_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ def validate_schema_and_update_table_sync(
using_v2_pipeline = job.pipeline_version == ExternalDataJob.PipelineVersion.V2
pipeline_version = ExternalDataJob.PipelineVersion(job.pipeline_version)

# Temp so we dont create a bunch of orphaned Table objects
if using_v2_pipeline:
logger.debug("Using V2 pipeline - dont create table object or get columns")
return

credential = get_or_create_datawarehouse_credential(
team_id=team_id,
access_key=settings.AIRBYTE_BUCKET_KEY,
Expand Down
162 changes: 97 additions & 65 deletions posthog/temporal/tests/data_imports/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker

from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE
from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE, DATA_WAREHOUSE_TASK_QUEUE_V2
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.query import execute_hogql_query
from posthog.hogql_queries.insights.funnels.funnel import Funnel
Expand Down Expand Up @@ -99,6 +99,19 @@ async def minio_client():
yield minio_client


def pytest_generate_tests(metafunc):
if "task_queue" in metafunc.fixturenames:
metafunc.parametrize("task_queue", [DATA_WAREHOUSE_TASK_QUEUE, DATA_WAREHOUSE_TASK_QUEUE_V2], indirect=True)


@pytest.fixture(autouse=True)
def task_queue(request):
queue = getattr(request, "param", None)

with override_settings(TEMPORAL_TASK_QUEUE=queue):
yield


async def _run(
team: Team,
schema_name: str,
Expand Down Expand Up @@ -142,18 +155,23 @@ async def _run(
assert run.status == ExternalDataJob.Status.COMPLETED

await sync_to_async(schema.refresh_from_db)()
assert schema.last_synced_at == run.created_at

res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM {table_name}", team)
assert len(res.results) == 1
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
assert schema.last_synced_at == run.created_at
else:
assert schema.last_synced_at is None

for name, field in external_tables.get(table_name, {}).items():
if field.hidden:
continue
assert name in (res.columns or [])
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM {table_name}", team)
assert len(res.results) == 1

await sync_to_async(source.refresh_from_db)()
assert source.job_inputs.get("reset_pipeline", None) is None
for name, field in external_tables.get(table_name, {}).items():
if field.hidden:
continue
assert name in (res.columns or [])

await sync_to_async(source.refresh_from_db)()
assert source.job_inputs.get("reset_pipeline", None) is None

return workflow_id, inputs

Expand Down Expand Up @@ -203,11 +221,12 @@ def mock_to_object_store_rs_credentials(class_self):
),
mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials),
mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials),
mock.patch("posthog.temporal.data_imports.external_data_job.trigger_pipeline_v2"),
):
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=DATA_WAREHOUSE_TASK_QUEUE,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[ExternalDataJobWorkflow],
activities=ACTIVITIES, # type: ignore
workflow_runner=UnsandboxedWorkflowRunner(),
Expand All @@ -218,7 +237,7 @@ def mock_to_object_store_rs_credentials(class_self):
ExternalDataJobWorkflow.run,
inputs,
id=workflow_id,
task_queue=DATA_WAREHOUSE_TASK_QUEUE,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
)

Expand Down Expand Up @@ -525,12 +544,13 @@ async def test_postgres_binary_columns(team, postgres_config, postgres_connectio
mock_data_response=[],
)

res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_binary_col_test", team)
columns = res.columns
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_binary_col_test", team)
columns = res.columns

assert columns is not None
assert len(columns) == 1
assert columns[0] == "id"
assert columns is not None
assert len(columns) == 1
assert columns[0] == "id"


@pytest.mark.django_db(transaction=True)
Expand Down Expand Up @@ -558,9 +578,14 @@ def get_jobs():
latest_job = jobs[0]
folder_path = await sync_to_async(latest_job.folder_path)()

s3_objects = await minio_client.list_objects_v2(
Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query/"
)
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
s3_objects = await minio_client.list_objects_v2(
Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query/"
)
else:
s3_objects = await minio_client.list_objects_v2(
Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query_v2/"
)

assert len(s3_objects["Contents"]) != 0

Expand All @@ -587,23 +612,24 @@ async def test_funnels_lazy_joins_ordering(team, stripe_customer):
field_name="stripe_customer",
)

query = FunnelsQuery(
series=[EventsNode(), EventsNode()],
breakdownFilter=BreakdownFilter(
breakdown_type=BreakdownType.DATA_WAREHOUSE_PERSON_PROPERTY, breakdown="stripe_customer.email"
),
)
funnel_class = Funnel(context=FunnelQueryContext(query=query, team=team))

query_ast = funnel_class.get_query()
await sync_to_async(execute_hogql_query)(
query_type="FunnelsQuery",
query=query_ast,
team=team,
modifiers=create_default_modifiers_for_team(
team, HogQLQueryModifiers(personsOnEventsMode=PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_JOINED)
),
)
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
query = FunnelsQuery(
series=[EventsNode(), EventsNode()],
breakdownFilter=BreakdownFilter(
breakdown_type=BreakdownType.DATA_WAREHOUSE_PERSON_PROPERTY, breakdown="stripe_customer.email"
),
)
funnel_class = Funnel(context=FunnelQueryContext(query=query, team=team))

query_ast = funnel_class.get_query()
await sync_to_async(execute_hogql_query)(
query_type="FunnelsQuery",
query=query_ast,
team=team,
modifiers=create_default_modifiers_for_team(
team, HogQLQueryModifiers(personsOnEventsMode=PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_JOINED)
),
)


@pytest.mark.django_db(transaction=True)
Expand Down Expand Up @@ -636,12 +662,13 @@ async def test_postgres_schema_evolution(team, postgres_config, postgres_connect
sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"},
)

res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns

assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)
assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)

# Evole schema
await postgres_connection.execute(
Expand All @@ -655,18 +682,20 @@ async def test_postgres_schema_evolution(team, postgres_config, postgres_connect
# Execute the same schema again - load
await _execute_run(str(uuid.uuid4()), inputs, [])

res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns

assert columns is not None
assert len(columns) == 2
assert any(x == "id" for x in columns)
assert any(x == "new_col" for x in columns)
assert columns is not None
assert len(columns) == 2
assert any(x == "id" for x in columns)
assert any(x == "new_col" for x in columns)


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_sql_database_missing_incremental_values(team, postgres_config, postgres_connection):
await postgres_connection.execute("CREATE SCHEMA IF NOT EXISTS {schema}".format(schema=postgres_config["schema"]))
await postgres_connection.execute(
"CREATE TABLE IF NOT EXISTS {schema}.test_table (id integer)".format(schema=postgres_config["schema"])
)
Expand Down Expand Up @@ -697,15 +726,16 @@ async def test_sql_database_missing_incremental_values(team, postgres_config, po
sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"},
)

res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns

assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)
assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)

# Exclude rows that don't have the incremental cursor key set
assert len(res.results) == 1
# Exclude rows that don't have the incremental cursor key set
assert len(res.results) == 1


@pytest.mark.django_db(transaction=True)
Expand Down Expand Up @@ -739,15 +769,16 @@ async def test_sql_database_incremental_initial_value(team, postgres_config, pos
sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"},
)

res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team)
columns = res.columns

assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)
assert columns is not None
assert len(columns) == 1
assert any(x == "id" for x in columns)

# Include rows that have the same incremental value as the `initial_value`
assert len(res.results) == 1
# Include rows that have the same incremental value as the `initial_value`
assert len(res.results) == 1


@pytest.mark.django_db(transaction=True)
Expand Down Expand Up @@ -1007,7 +1038,8 @@ async def test_delta_table_deleted(team, stripe_balance_transaction):
sync_type=ExternalDataSchema.SyncType.FULL_REFRESH,
)

with mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete:
await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"])
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE:
with mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete:
await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"])

mock_delta_table_delete.assert_called_once()
mock_delta_table_delete.assert_called_once()
12 changes: 10 additions & 2 deletions posthog/warehouse/models/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@ def folder_path(self) -> str:

def url_pattern_by_schema(self, schema: str) -> str:
if TEST:
return f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}/"
if self.pipeline_version == ExternalDataJob.PipelineVersion.V1:
return (
f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}/"
)
else:
return f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}__v2/"

return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}/"
if self.pipeline_version == ExternalDataJob.PipelineVersion.V1:
return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}/"

return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}__v2/"


@database_sync_to_async
Expand Down
7 changes: 5 additions & 2 deletions posthog/warehouse/models/external_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional
from django.db import models
from django_deprecate_fields import deprecate_field
import numpy
import snowflake.connector
from django.conf import settings
from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2
Expand Down Expand Up @@ -73,13 +74,15 @@ def is_incremental(self):
def update_incremental_field_last_value(self, last_value: Any) -> None:
incremental_field_type = self.sync_type_config.get("incremental_field_type")

last_value_py = last_value.item() if isinstance(last_value, numpy.generic) else last_value

if (
incremental_field_type == IncrementalFieldType.Integer
or incremental_field_type == IncrementalFieldType.Numeric
):
last_value_json = last_value
last_value_json = last_value_py
else:
last_value_json = str(last_value)
last_value_json = str(last_value_py)

if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2:
key = "incremental_field_last_value_v2"
Expand Down
1 change: 1 addition & 0 deletions posthog/warehouse/models/external_table_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"*": {
"__dlt_id": StringDatabaseField(name="_dlt_id", hidden=True),
"__dlt_load_id": StringDatabaseField(name="_dlt_load_id", hidden=True),
"__ph_debug": StringJSONDatabaseField(name="_ph_debug", hidden=True),
},
"stripe_account": {
"id": StringDatabaseField(name="id"),
Expand Down

0 comments on commit 8e6ab69

Please sign in to comment.