Skip to content

Commit

Permalink
refactor(batch-exports): Insert rows instead of using COPY
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Oct 30, 2023
1 parent 5f91a9d commit d97a84a
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 70 deletions.
112 changes: 75 additions & 37 deletions posthog/temporal/workflows/postgres_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import collections.abc
import contextlib
import datetime as dt
import json
from dataclasses import dataclass

import psycopg2
import psycopg2.extensions
from django.conf import settings
from psycopg2 import sql
from temporalio import activity, exceptions, workflow
Expand All @@ -26,7 +28,7 @@


@contextlib.contextmanager
def postgres_connection(inputs):
def postgres_connection(inputs) -> collections.abc.Iterator[psycopg2.extensions.connection]:
"""Manage a Postgres connection."""
connection = psycopg2.connect(
user=inputs.user,
Expand All @@ -52,8 +54,22 @@ def postgres_connection(inputs):
connection.close()


def copy_tsv_to_postgres(tsv_file, postgres_connection, schema: str, table_name: str, schema_columns):
"""Execute a COPY FROM query with given connection to copy contents of tsv_file."""
def copy_tsv_to_postgres(
tsv_file,
postgres_connection: psycopg2.extensions.connection,
schema: str,
table_name: str,
schema_columns: list[str],
):
"""Execute a COPY FROM query with given connection to copy contents of tsv_file.
Arguments:
tsv_file: A file-like object to interpret as TSV to copy its contents.
postgres_connection: A connection to Postgres as setup by psycopg2.
schema: An existing schema where to create the table.
table_name: The name of the table to create.
schema_columns: A list of column names.
"""
tsv_file.seek(0)

with postgres_connection.cursor() as cursor:
Expand All @@ -67,6 +83,44 @@ def copy_tsv_to_postgres(tsv_file, postgres_connection, schema: str, table_name:
)


Field = tuple[str, str]
Fields = collections.abc.Iterable[Field]


def create_table_in_postgres(
postgres_connection: psycopg2.extensions.connection, schema: str | None, table_name: str, fields: Fields
) -> None:
"""Create a table in a Postgres database if it doesn't exist already.
Arguments:
postgres_connection: A connection to Postgres as setup by psycopg2.
schema: An existing schema where to create the table.
table_name: The name of the table to create.
fields: An iterable of (name, type) tuples representing the fields of the table.
"""
if schema:
table_identifier = sql.Identifier(schema, table_name)
else:
table_identifier = sql.Identifier(table_name)

with postgres_connection.cursor() as cursor:
cursor.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {table} (
{fields}
)
"""
).format(
table=table_identifier,
fields=sql.SQL(",").join(
sql.SQL("{field} {type}").format(field=sql.Identifier(field), type=sql.SQL(field_type))
for field, field_type in fields
),
)
)


@dataclass
class PostgresInsertInputs:
"""Inputs for Postgres insert activity."""
Expand All @@ -84,19 +138,6 @@ class PostgresInsertInputs:
port: int = 5432
exclude_events: list[str] | None = None
include_events: list[str] | None = None
fields: list[tuple[str, str]] = [
("uuid", "VARCHAR(200)"),
("event", "VARCHAR(200)"),
("properties", "JSONB"),
("elements", "JSONB"),
("set", "JSONB"),
("set_once", "JSONB"),
("distinct_id", "VARCHAR(200)"),
("team_id", "INTEGER"),
("ip", "VARCHAR(200)"),
("site_url", "VARCHAR(200)"),
("timestamp", "TIMESTAMP WITH TIME ZONE"),
]


@activity.defn
Expand Down Expand Up @@ -141,27 +182,24 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs):
include_events=inputs.include_events,
)
with postgres_connection(inputs) as connection:
with connection.cursor() as cursor:
if inputs.schema:
table_identifier = sql.Identifier(inputs.schema, inputs.table_name)
else:
table_identifier = sql.Identifier(inputs.table_name)

result = cursor.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {table} (
{fields}
)
"""
).format(
table=table_identifier,
fields=sql.SQL(",").join(
sql.SQL("{field} {type}").format(field=sql.Identifier(field), type=sql.SQL(field_type))
for field, field_type in inputs.fields
),
)
)
create_table_in_postgres(
connection,
schema=inputs.schema,
table_name=inputs.table_name,
fields=[
("uuid", "VARCHAR(200)"),
("event", "VARCHAR(200)"),
("properties", "JSONB"),
("elements", "JSONB"),
("set", "JSONB"),
("set_once", "JSONB"),
("distinct_id", "VARCHAR(200)"),
("team_id", "INTEGER"),
("ip", "VARCHAR(200)"),
("site_url", "VARCHAR(200)"),
("timestamp", "TIMESTAMP WITH TIME ZONE"),
],
)

schema_columns = [
"uuid",
Expand Down
171 changes: 138 additions & 33 deletions posthog/temporal/workflows/redshift_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import datetime as dt
import json
import typing
from dataclasses import dataclass

from temporalio import workflow
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2 import sql
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

from posthog.batch_exports.service import RedshiftBatchExportInputs
Expand All @@ -14,34 +19,148 @@
execute_batch_export_insert_activity,
get_batch_exports_logger,
get_data_interval,
get_results_iterator,
get_rows_count,
)
from posthog.temporal.workflows.clickhouse import get_client
from posthog.temporal.workflows.postgres_batch_export import (
PostgresInsertInputs,
insert_into_postgres_activity,
create_table_in_postgres,
postgres_connection,
)


def insert_record_to_redshift(
record: dict[str, typing.Any],
redshift_connection: psycopg2.extensions.connection,
schema: str,
table: str,
):
"""Execute an INSERT query with given Redshift connection.
The recommended way to insert multiple values into Redshift is using a COPY statement (see:
https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html). However, Redshift cannot COPY from local
files like Postgres, but only from files in S3 or executing commands in SSH hosts. Setting that up would
be quite complex and require more configuration from the user compared to the old Redshift export plugin.
For this reasons, we are going with basic INSERT statements for now, and we can migrate to COPY from S3
later if the need arises.
Arguments:
record: A dictionary representing the record to insert. Each key should correspond to a column
in the destination table.
redshift_connection: A connection to Redshift setup by psycopg2.
schema: The schema that contains the table where to insert the record.
table: The name of the table where to insert the record.
"""
columns = record.keys()

with redshift_connection.cursor() as cursor:
query = sql.SQL("INSERT INTO {table} {fields} VALUES {placeholder}").format(
table=sql.Identifier(schema, table),
fields=sql.SQL(", ").join(map(sql.Identifier, columns)),
placeholder=sql.Placeholder(),
)
template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns)))

psycopg2.extras.execute_values(cursor, query, record, template)


@dataclass
class RedshiftInsertInputs(PostgresInsertInputs):
"""Inputs for Redshift insert activity.
Inherit from PostgresInsertInputs as they are the same, but
update fields to account for JSONB not being supported in Redshift.
Inherit from PostgresInsertInputs as they are the same, but allow
for setting property_data_type which is unique to Redshift.
"""

fields: list[tuple[str, str]] = [
("uuid", "VARCHAR(200)"),
("event", "VARCHAR(200)"),
("properties", "VARCHAR(65535)"),
("elements", "VARCHAR(65535)"),
("set", "VARCHAR(65535)"),
("set_once", "VARCHAR(65535)"),
("distinct_id", "VARCHAR(200)"),
("team_id", "INTEGER"),
("ip", "VARCHAR(200)"),
("site_url", "VARCHAR(200)"),
("timestamp", "TIMESTAMP WITH TIME ZONE"),
]
properties_data_type: str = "varchar"


@activity.defn
async def insert_into_redshift_activity(inputs: RedshiftInsertInputs):
"""Activity streams data from ClickHouse to Redshift."""
logger = get_batch_exports_logger(inputs=inputs)
logger.info(
"Running Postgres export batch %s - %s",
inputs.data_interval_start,
inputs.data_interval_end,
)

async with get_client() as client:
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

count = await get_rows_count(
client=client,
team_id=inputs.team_id,
interval_start=inputs.data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
)

if count == 0:
logger.info(
"Nothing to export in batch %s - %s",
inputs.data_interval_start,
inputs.data_interval_end,
)
return

logger.info("BatchExporting %s rows to Postgres", count)

results_iterator = get_results_iterator(
client=client,
team_id=inputs.team_id,
interval_start=inputs.data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
)
properties_type = "VARCHAR(65535)" if inputs.properties_data_type == "varchar" else "SUPER"

with postgres_connection(inputs) as connection:
create_table_in_postgres(
connection,
schema=inputs.schema,
table_name=inputs.table_name,
fields=[
("uuid", "VARCHAR(200)"),
("event", "VARCHAR(200)"),
("properties", properties_type),
("elements", properties_type),
("set", properties_type),
("set_once", properties_type),
("distinct_id", "VARCHAR(200)"),
("team_id", "INTEGER"),
("ip", "VARCHAR(200)"),
("site_url", "VARCHAR(200)"),
("timestamp", "TIMESTAMP WITH TIME ZONE"),
],
)

schema_columns = [
"uuid",
"event",
"properties",
"elements",
"set",
"set_once",
"distinct_id",
"team_id",
"ip",
"site_url",
"timestamp",
]
json_columns = ("properties", "elements", "set", "set_once")

with postgres_connection(inputs) as connection:
for result in results_iterator:
record = {
key: json.dumps(result[key]) if key in json_columns and result[key] is not None else result[key]
for key in schema_columns
}
insert_record_to_redshift(record, connection, inputs.schema, inputs.table_name)


@workflow.defn(name="redshift-export")
Expand Down Expand Up @@ -92,8 +211,6 @@ async def run(self, inputs: RedshiftBatchExportInputs):

update_inputs = UpdateBatchExportRunStatusInputs(id=run_id, status="Completed")

properties_type = "VARCHAR(65535)" if inputs.properties_data_type == "varchar" else "SUPER"

insert_inputs = RedshiftInsertInputs(
team_id=inputs.team_id,
user=inputs.user,
Expand All @@ -108,21 +225,9 @@ async def run(self, inputs: RedshiftBatchExportInputs):
data_interval_end=data_interval_end.isoformat(),
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=[
("uuid", "VARCHAR(200)"),
("event", "VARCHAR(200)"),
("properties", properties_type),
("elements", properties_type),
("set", properties_type),
("set_once", properties_type),
("distinct_id", "VARCHAR(200)"),
("team_id", "INTEGER"),
("ip", "VARCHAR(200)"),
("site_url", "VARCHAR(200)"),
("timestamp", "TIMESTAMP WITH TIME ZONE"),
],
properties_data_type=inputs.properties_data_type,
)

await execute_batch_export_insert_activity(
insert_into_postgres_activity, insert_inputs, non_retryable_error_types=[], update_inputs=update_inputs
insert_into_redshift_activity, insert_inputs, non_retryable_error_types=[], update_inputs=update_inputs
)

0 comments on commit d97a84a

Please sign in to comment.