Skip to content

Commit

Permalink
fix(data-warehouse): Split out the create model activity (#25683)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 authored Oct 18, 2024
1 parent 14a8b5c commit a867559
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 87 deletions.
23 changes: 11 additions & 12 deletions mypy-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,13 @@ posthog/warehouse/api/external_data_schema.py:0: note: def [_T] get(self, Type,
posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore]
posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore]
posthog/warehouse/api/table.py:0: error: Unused "type: ignore" comment [unused-ignore]
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: Argument 1 has incompatible type "str"; expected "Type" [arg-type]
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "dict[str, list[tuple[str, str]]]") [assignment]
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: No overload variant of "get" of "dict" matches argument types "str", "tuple[()]" [call-overload]
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: note: Possible overload variants:
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: note: def get(self, Type, /) -> Sequence[str] | None
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: note: def get(self, Type, Sequence[str], /) -> Sequence[str]
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: note: def [_T] get(self, Type, _T, /) -> Sequence[str] | _T
posthog/temporal/data_imports/workflow_activities/create_job_model.py:0: error: Argument 1 has incompatible type "dict[str, list[tuple[str, str]]]"; expected "list[Any]" [arg-type]
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument 1 has incompatible type "str"; expected "Type" [arg-type]
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: No overload variant of "get" of "dict" matches argument types "str", "tuple[()]" [call-overload]
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: Possible overload variants:
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def get(self, Type, /) -> Sequence[str] | None
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def get(self, Type, Sequence[str], /) -> Sequence[str]
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def [_T] get(self, Type, _T, /) -> Sequence[str] | _T
posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument "source_id" has incompatible type "str"; expected "UUID" [arg-type]
posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a return type annotation [no-untyped-def]
posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a type annotation [no-untyped-def]
posthog/tasks/exports/test/test_csv_exporter.py:0: error: Function is missing a type annotation for one or more arguments [no-untyped-def]
Expand Down Expand Up @@ -758,6 +757,10 @@ posthog/admin/inlines/plugin_attachment_inline.py:0: note: Subclass:
posthog/admin/inlines/plugin_attachment_inline.py:0: note: def has_delete_permission(self, request: Any, obj: Any) -> Any
posthog/admin/admins/plugin_admin.py:0: error: Item "None" of "Organization | None" has no attribute "pk" [union-attr]
posthog/admin/admins/plugin_admin.py:0: error: Item "None" of "Organization | None" has no attribute "name" [union-attr]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseTrendExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseFunnelExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseSecondaryExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Item "None" of "User | None" has no attribute "email" [union-attr]
posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore]
posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore]
posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore]
Expand Down Expand Up @@ -787,10 +790,6 @@ posthog/api/plugin.py:0: error: Item "None" of "IO[Any] | None" has no attribute
posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "organization" [union-attr]
posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "id" [union-attr]
posthog/admin/admins/plugin_config_admin.py:0: error: Item "None" of "Team | None" has no attribute "name" [union-attr]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseTrendExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseFunnelExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Argument 4 to "ClickhouseSecondaryExperimentResult" has incompatible type "datetime | None"; expected "datetime" [arg-type]
ee/clickhouse/views/experiments.py:0: error: Item "None" of "User | None" has no attribute "email" [union-attr]
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_calls" (hint: "_execute_calls: list[<type>] = ...") [var-annotated]
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_async_calls" (hint: "_execute_async_calls: list[<type>] = ...") [var-annotated]
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_cursors" (hint: "_cursors: list[<type>] = ...") [var-annotated]
Expand Down
2 changes: 2 additions & 0 deletions posthog/temporal/data_imports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
update_external_data_job_model,
check_schedule_activity,
check_billing_limits_activity,
sync_new_schemas_activity,
)

WORKFLOWS = [ExternalDataJobWorkflow]
Expand All @@ -17,4 +18,5 @@
create_source_templates,
check_schedule_activity,
check_billing_limits_activity,
sync_new_schemas_activity,
]
21 changes: 18 additions & 3 deletions posthog/temporal/data_imports/external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
CheckBillingLimitsActivityInputs,
check_billing_limits_activity,
)
from posthog.temporal.data_imports.workflow_activities.sync_new_schemas import (
SyncNewSchemasActivityInputs,
sync_new_schemas_activity,
)
from posthog.temporal.utils import ExternalDataWorkflowInputs
from posthog.temporal.data_imports.workflow_activities.create_job_model import (
CreateExternalDataJobModelActivityInputs,
Expand Down Expand Up @@ -152,16 +156,15 @@ async def run(self, inputs: ExternalDataWorkflowInputs):
source_id=inputs.external_data_source_id,
)

# TODO: split out the creation of the external data job model from schema getting to seperate out exception handling
job_id, incremental = await workflow.execute_activity(
create_external_data_job_model_activity,
create_external_data_job_inputs,
start_to_close_timeout=dt.timedelta(minutes=5),
start_to_close_timeout=dt.timedelta(minutes=1),
retry_policy=RetryPolicy(
initial_interval=dt.timedelta(seconds=10),
maximum_interval=dt.timedelta(seconds=60),
maximum_attempts=3,
non_retryable_error_types=["NotNullViolation", "IntegrityError", "BaseSSHTunnelForwarderError"],
non_retryable_error_types=["NotNullViolation", "IntegrityError"],
),
)

Expand Down Expand Up @@ -191,6 +194,18 @@ async def run(self, inputs: ExternalDataWorkflowInputs):
)

try:
await workflow.execute_activity(
sync_new_schemas_activity,
SyncNewSchemasActivityInputs(source_id=str(inputs.external_data_source_id), team_id=inputs.team_id),
start_to_close_timeout=dt.timedelta(minutes=10),
retry_policy=RetryPolicy(
initial_interval=dt.timedelta(seconds=10),
maximum_interval=dt.timedelta(seconds=60),
maximum_attempts=3,
non_retryable_error_types=["NotNullViolation", "IntegrityError", "BaseSSHTunnelForwarderError"],
),
)

job_inputs = ImportDataActivityInputs(
team_id=inputs.team_id,
run_id=job_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
from temporalio import activity

# TODO: remove dependency
from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING

from posthog.warehouse.external_data_source.jobs import (
create_external_data_job,
)
from posthog.warehouse.models import sync_old_schemas_with_new_schemas, ExternalDataSource, aget_schema_by_id
from posthog.warehouse.models import aget_schema_by_id
from posthog.warehouse.models.external_data_schema import (
ExternalDataSchema,
get_sql_schemas_for_source_type,
get_snowflake_schemas,
)
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.warehouse.models.ssh_tunnel import SSHTunnel


@dataclasses.dataclass
Expand All @@ -44,66 +40,6 @@ async def create_external_data_job_model_activity(inputs: CreateExternalDataJobM
schema.status = ExternalDataSchema.Status.RUNNING
await sync_to_async(schema.save)()

source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=inputs.source_id)

if source.source_type in [
ExternalDataSource.Type.POSTGRES,
ExternalDataSource.Type.MYSQL,
ExternalDataSource.Type.MSSQL,
]:
host = source.job_inputs.get("host")
port = source.job_inputs.get("port")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
db_schema = source.job_inputs.get("schema")

using_ssh_tunnel = str(source.job_inputs.get("ssh_tunnel_enabled", False)) == "True"
ssh_tunnel_host = source.job_inputs.get("ssh_tunnel_host")
ssh_tunnel_port = source.job_inputs.get("ssh_tunnel_port")
ssh_tunnel_auth_type = source.job_inputs.get("ssh_tunnel_auth_type")
ssh_tunnel_auth_type_username = source.job_inputs.get("ssh_tunnel_auth_type_username")
ssh_tunnel_auth_type_password = source.job_inputs.get("ssh_tunnel_auth_type_password")
ssh_tunnel_auth_type_passphrase = source.job_inputs.get("ssh_tunnel_auth_type_passphrase")
ssh_tunnel_auth_type_private_key = source.job_inputs.get("ssh_tunnel_auth_type_private_key")

ssh_tunnel = SSHTunnel(
enabled=using_ssh_tunnel,
host=ssh_tunnel_host,
port=ssh_tunnel_port,
auth_type=ssh_tunnel_auth_type,
username=ssh_tunnel_auth_type_username,
password=ssh_tunnel_auth_type_password,
passphrase=ssh_tunnel_auth_type_passphrase,
private_key=ssh_tunnel_auth_type_private_key,
)

schemas_to_sync = await sync_to_async(get_sql_schemas_for_source_type)(
source.source_type, host, port, database, user, password, db_schema, ssh_tunnel
)
elif source.source_type == ExternalDataSource.Type.SNOWFLAKE:
account_id = source.job_inputs.get("account_id")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
warehouse = source.job_inputs.get("warehouse")
sf_schema = source.job_inputs.get("schema")
role = source.job_inputs.get("role")

schemas_to_sync = await sync_to_async(get_snowflake_schemas)(
account_id, database, warehouse, user, password, sf_schema, role
)
else:
schemas_to_sync = list(PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING.get(source.source_type, ()))

# TODO: this could cause a race condition where each schema worker creates the missing schema

await sync_to_async(sync_old_schemas_with_new_schemas)(
schemas_to_sync,
source_id=inputs.source_id,
team_id=inputs.team_id,
)

logger.info(
f"Created external data job for external data source {inputs.source_id}",
)
Expand Down
104 changes: 104 additions & 0 deletions posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import dataclasses

from asgiref.sync import sync_to_async
from temporalio import activity

from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING

from posthog.warehouse.models import sync_old_schemas_with_new_schemas, ExternalDataSource
from posthog.warehouse.models.external_data_schema import (
get_sql_schemas_for_source_type,
get_snowflake_schemas,
)
from posthog.warehouse.models.ssh_tunnel import SSHTunnel


@dataclasses.dataclass
class SyncNewSchemasActivityInputs:
source_id: str
team_id: int


@activity.defn
async def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None:
logger = await bind_temporal_worker_logger(team_id=inputs.team_id)

logger.info("Syncing new -> old schemas")

source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=inputs.source_id)

schemas_to_sync: list[str] = []

if source.source_type in [
ExternalDataSource.Type.POSTGRES,
ExternalDataSource.Type.MYSQL,
ExternalDataSource.Type.MSSQL,
]:
if not source.job_inputs:
return

host = source.job_inputs.get("host")
port = source.job_inputs.get("port")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
db_schema = source.job_inputs.get("schema")

using_ssh_tunnel = str(source.job_inputs.get("ssh_tunnel_enabled", False)) == "True"
ssh_tunnel_host = source.job_inputs.get("ssh_tunnel_host")
ssh_tunnel_port = source.job_inputs.get("ssh_tunnel_port")
ssh_tunnel_auth_type = source.job_inputs.get("ssh_tunnel_auth_type")
ssh_tunnel_auth_type_username = source.job_inputs.get("ssh_tunnel_auth_type_username")
ssh_tunnel_auth_type_password = source.job_inputs.get("ssh_tunnel_auth_type_password")
ssh_tunnel_auth_type_passphrase = source.job_inputs.get("ssh_tunnel_auth_type_passphrase")
ssh_tunnel_auth_type_private_key = source.job_inputs.get("ssh_tunnel_auth_type_private_key")

ssh_tunnel = SSHTunnel(
enabled=using_ssh_tunnel,
host=ssh_tunnel_host,
port=ssh_tunnel_port,
auth_type=ssh_tunnel_auth_type,
username=ssh_tunnel_auth_type_username,
password=ssh_tunnel_auth_type_password,
passphrase=ssh_tunnel_auth_type_passphrase,
private_key=ssh_tunnel_auth_type_private_key,
)

sql_schemas = await sync_to_async(get_sql_schemas_for_source_type)(
source.source_type, host, port, database, user, password, db_schema, ssh_tunnel
)

schemas_to_sync = list(sql_schemas.keys())
elif source.source_type == ExternalDataSource.Type.SNOWFLAKE:
if not source.job_inputs:
return

account_id = source.job_inputs.get("account_id")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
warehouse = source.job_inputs.get("warehouse")
sf_schema = source.job_inputs.get("schema")
role = source.job_inputs.get("role")

sql_schemas = await sync_to_async(get_snowflake_schemas)(
account_id, database, warehouse, user, password, sf_schema, role
)

schemas_to_sync = list(sql_schemas.keys())
else:
schemas_to_sync = list(PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING.get(source.source_type, ()))

# TODO: this could cause a race condition where each schema worker creates the missing schema

schemas_created = await sync_to_async(sync_old_schemas_with_new_schemas)(
schemas_to_sync,
source_id=inputs.source_id,
team_id=inputs.team_id,
)

if len(schemas_created) > 0:
logger.info(f"Added new schemas: {', '.join(schemas_created)}")
else:
logger.info("No new schemas to create")
14 changes: 8 additions & 6 deletions posthog/temporal/tests/external_data/test_external_data_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
create_external_data_job_model_activity,
)
from posthog.temporal.data_imports.workflow_activities.import_data import ImportDataActivityInputs, import_data_activity
from posthog.temporal.data_imports.workflow_activities.sync_new_schemas import (
SyncNewSchemasActivityInputs,
sync_new_schemas_activity,
)
from posthog.warehouse.external_data_source.jobs import create_external_data_job
from posthog.warehouse.models import (
get_latest_run_if_exists,
Expand Down Expand Up @@ -196,19 +200,16 @@ async def test_create_external_job_activity_update_schemas(activity_environment,
source_type="Stripe",
)

schema = await sync_to_async(ExternalDataSchema.objects.create)(
await sync_to_async(ExternalDataSchema.objects.create)(
name=PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING[new_source.source_type][0],
team_id=team.id,
source_id=new_source.pk,
should_sync=True,
)

inputs = CreateExternalDataJobModelActivityInputs(team_id=team.id, source_id=new_source.pk, schema_id=schema.id)

run_id, _ = await activity_environment.run(create_external_data_job_model_activity, inputs)
inputs = SyncNewSchemasActivityInputs(source_id=str(new_source.pk), team_id=team.id)

runs = ExternalDataJob.objects.filter(id=run_id)
assert await sync_to_async(runs.exists)()
await activity_environment.run(sync_new_schemas_activity, inputs)

all_schemas = await sync_to_async(get_all_schemas_for_source_id)(new_source.pk, team.id)

Expand Down Expand Up @@ -698,6 +699,7 @@ async def mock_async_func(inputs):
import_data_activity,
create_source_templates,
check_billing_limits_activity,
sync_new_schemas_activity,
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
Expand Down
4 changes: 3 additions & 1 deletion posthog/warehouse/models/external_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_all_schemas_for_source_id(source_id: uuid.UUID, team_id: int):
return list(ExternalDataSchema.objects.exclude(deleted=True).filter(team_id=team_id, source_id=source_id).all())


def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, team_id: int):
def sync_old_schemas_with_new_schemas(new_schemas: list[str], source_id: uuid.UUID, team_id: int) -> list[str]:
old_schemas = get_all_schemas_for_source_id(source_id=source_id, team_id=team_id)
old_schemas_names = [schema.name for schema in old_schemas]

Expand All @@ -141,6 +141,8 @@ def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, t
for schema in schemas_to_create:
ExternalDataSchema.objects.create(name=schema, team_id=team_id, source_id=source_id, should_sync=False)

return schemas_to_create


def sync_frequency_to_sync_frequency_interval(frequency: str) -> timedelta:
if frequency == "5min":
Expand Down

0 comments on commit a867559

Please sign in to comment.