Skip to content

Commit

Permalink
fix: Merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Jun 25, 2024
1 parent f749369 commit 9034a2c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 107 deletions.
16 changes: 10 additions & 6 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
set_status_to_running_task,
should_resume_from_activity_heartbeat,
)

Expand Down Expand Up @@ -239,7 +240,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

records_iterator = iter_records(
records_iterator = iter_model_records(
client=client,
team_id=inputs.team_id,
interval_start=data_interval_start,
Expand All @@ -251,7 +252,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
is_backfill=inputs.is_backfill,
)

first_record_batch, records_iterator = peek_first_and_rewind(records_iterator)
first_record_batch, records_iterator = apeek_first_and_rewind(records_iterator)
if first_record_batch is None:
return 0

Expand Down Expand Up @@ -318,7 +319,10 @@ async def flush_to_bigquery(bigquery_table, table_schema):
if first_record_batch is None:
return 0

for record_batch in records_iterator:
# Columns need to be sorted according to BigQuery schema.
record_columns = [field.name for field in schema] + ["_inserted_at"]

async for record_batch in records_iterator:
for record in record_batch.select(record_columns).to_pylist():
inserted_at = record.pop("_inserted_at")

Expand All @@ -329,8 +333,8 @@ async def flush_to_bigquery(bigquery_table, table_schema):
# TODO: Parquet is a much more efficient format to send data to BigQuery.
jsonl_file.write_records_to_jsonl([record])

rows_exported.add(jsonl_file.records_since_last_reset)
bytes_exported.add(jsonl_file.bytes_since_last_reset)
rows_exported.add(jsonl_file.records_since_last_reset)
bytes_exported.add(jsonl_file.bytes_since_last_reset)

if inputs.use_json_type is True:
json_type = "JSON"
Expand Down
98 changes: 50 additions & 48 deletions posthog/temporal/batch_exports/s3_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
ParquetBatchExportWriter,
UnsupportedFileFormatError,
)
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down Expand Up @@ -453,7 +453,8 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

record_iterator = iter_records(
record_iterator = iter_model_records(
model="events",
client=client,
team_id=inputs.team_id,
interval_start=interval_start,
Expand All @@ -465,62 +466,63 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
is_backfill=inputs.is_backfill,
)

record_iterator = iter_model_records(
model="events",
client=client,
team_id=inputs.team_id,
interval_start=interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
)

first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)
first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)

records_completed = 0
if first_record_batch is None:
return records_completed

for record_batch in record_iterator:
record_batch = cast_record_batch_json_columns(record_batch)
records_completed = 0
if first_record_batch is None:
return records_completed

await writer.write_record_batch(record_batch)
async with s3_upload as s3_upload:

async def flush_to_s3(
local_results_file,
records_since_last_flush: int,
bytes_since_last_flush: int,
last_inserted_at: dt.datetime,
last: bool,
):
logger.debug(
"Uploading %s part %s containing %s records with size %s bytes",
"last " if last else "",
s3_upload.part_number + 1,
records_since_last_flush,
bytes_since_last_flush,
)

await s3_upload.complete()
await s3_upload.upload_part(local_results_file)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

heartbeater.details = (str(last_inserted_at), s3_upload.to_state())
heartbeater.details = (str(last_inserted_at), s3_upload.to_state())

first_record_batch = cast_record_batch_json_columns(first_record_batch)
column_names = first_record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))
first_record_batch = cast_record_batch_json_columns(first_record_batch)
column_names = first_record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))

schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[field.with_nullable(True) for field in first_record_batch.select(column_names).schema]
)
schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[field.with_nullable(True) for field in first_record_batch.select(column_names).schema]
)

writer = get_batch_export_writer(
inputs,
flush_callable=flush_to_s3,
max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES,
schema=schema,
)
writer = get_batch_export_writer(
inputs,
flush_callable=flush_to_s3,
max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES,
schema=schema,
)

async with writer.open_temporary_file():
rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()
async with writer.open_temporary_file():
rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

async for record_batch in record_iterator:
record_batch = cast_record_batch_json_columns(record_batch)
async for record_batch in record_iterator:
record_batch = cast_record_batch_json_columns(record_batch)

await writer.write_record_batch(record_batch)
await writer.write_record_batch(record_batch)

records_completed = writer.records_total
await s3_upload.complete()
Expand Down
92 changes: 39 additions & 53 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running
from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down Expand Up @@ -408,9 +408,12 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor

async with (
Heartbeater() as heartbeater,
get_client(team_id=inputs.team_id) as client,
set_status_to_running_task(run_id=inputs.run_id, logger=logger),
get_client(team_id=inputs.team_id) as client,
):
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

should_resume, details = await should_resume_from_activity_heartbeat(
activity, SnowflakeHeartbeatDetails, logger
)
Expand All @@ -424,52 +427,23 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
last_inserted_at = None
file_no = 0

if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

async def flush_to_snowflake(
connection: SnowflakeConnection,
file: BatchExportTemporaryFile,
table_name: str,
file_no: int,
last: bool = False,
):
logger.info(
"Putting %sfile %s containing %s records with size %s bytes",
"last " if last else "",
file_no,
file.records_since_last_reset,
file.bytes_since_last_reset,
)

await put_file_to_snowflake_table(connection, file, table_name, file_no)
rows_exported.add(file.records_since_last_reset)
bytes_exported.add(file.bytes_since_last_reset)

if inputs.batch_export_schema is None:
fields = snowflake_default_fields()
query_parameters = None

else:
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

record_iterator = iter_model_records(
client=client,
model="events",
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
async def flush_to_snowflake(
connection: SnowflakeConnection,
file: BatchExportTemporaryFile,
table_name: str,
file_no: int,
last: bool = False,
):
logger.info(
"Putting %sfile %s containing %s records with size %s bytes",
"last " if last else "",
file_no,
file.records_since_last_reset,
file.bytes_since_last_reset,
)
first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)

await put_file_to_snowflake_table(connection, file, table_name, file_no)
rows_exported.add(file.records_since_last_reset)
Expand All @@ -483,8 +457,9 @@ async def flush_to_snowflake(
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

record_iterator = iter_records(
record_iterator = iter_model_records(
client=client,
model="events",
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
Expand All @@ -494,15 +469,26 @@ async def flush_to_snowflake(
extra_query_parameters=query_parameters,
is_backfill=inputs.is_backfill,
)
first_record_batch, record_iterator = peek_first_and_rewind(record_iterator)
first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)

if first_record_batch is None:
return 0

with BatchExportTemporaryFile() as local_results_file:
async for record_batch in record_iterator:
for record in record_batch.select(record_columns).to_pylist():
inserted_at = record.pop("_inserted_at")
known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"]
if inputs.batch_export_schema is None:
table_fields = [
("uuid", "STRING"),
("event", "STRING"),
("properties", "VARIANT"),
("elements", "VARIANT"),
("people_set", "VARIANT"),
("people_set_once", "VARIANT"),
("distinct_id", "STRING"),
("team_id", "INTEGER"),
("ip", "STRING"),
("site_url", "STRING"),
("timestamp", "TIMESTAMP"),
]

else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
Expand All @@ -520,7 +506,7 @@ async def flush_to_snowflake(
inserted_at = None

with BatchExportTemporaryFile() as local_results_file:
for record_batch in record_iterator:
async for record_batch in record_iterator:
for record in record_batch.select(record_columns).to_pylist():
inserted_at = record.pop("_inserted_at")

Expand Down Expand Up @@ -548,9 +534,9 @@ async def flush_to_snowflake(

heartbeater.details = (str(last_inserted_at), file_no)

await copy_loaded_files_to_snowflake_table(connection, inputs.table_name)
await copy_loaded_files_to_snowflake_table(connection, inputs.table_name)

return local_results_file.records_total
return local_results_file.records_total


@workflow.defn(name="snowflake-export")
Expand Down

0 comments on commit 9034a2c

Please sign in to comment.