Skip to content

Commit

Permalink
wip: add kafka batch writer and related tests
Browse files Browse the repository at this point in the history
also:
- add WIP marker to pytest
  • Loading branch information
Federico Zambelli committed Sep 24, 2024
1 parent 34b7669 commit 9c9d7f0
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@ addopts = "-vv --tb=auto --disable-warnings"
pythonpath = [
"src"
]
markers = [
"wip: mark tests as work in progress",
]
60 changes: 56 additions & 4 deletions src/sparkle/writer/kafka_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any
from pyspark.sql import SparkSession, DataFrame

from pyspark.sql import DataFrame, SparkSession

from sparkle.config import Config
from sparkle.utils.spark import to_kafka_dataframe
from sparkle.writer import Writer
Expand Down Expand Up @@ -51,9 +53,7 @@ def __init__(
self.trigger_once = trigger_once

@classmethod
def with_config(
cls, config: Config, spark: SparkSession, **kwargs: Any
) -> "KafkaStreamPublisher":
def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaStreamPublisher":
"""Create a KafkaStreamPublisher object with a configuration.
Args:
Expand Down Expand Up @@ -109,3 +109,55 @@ def write(self, df: DataFrame) -> None:
.option("topic", self.kafka_topic)
.start()
)


class KafkaBatchPublisher(Writer):
"""FIXME: Write docstring."""

def __init__(
self,
kafka_options: dict[str, Any],
kafka_topic: str,
unique_identifier_column_name: str,
spark: SparkSession,
) -> None:
"""FIXME: Write docstring."""
super().__init__(spark)
self.kafka_options = kafka_options
self.kafka_topic = kafka_topic
self.unique_identifier_column_name = unique_identifier_column_name

@classmethod
def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaBatchPublisher":
"""FIXME: Write docstring."""
if not config.kafka_output:
raise ValueError("Kafka output configuration is missing")

c = config.kafka_output

return cls(
kafka_options=c.kafka_config.spark_kafka_config,
kafka_topic=c.kafka_topic,
unique_identifier_column_name=c.unique_identifier_column_name,
spark=spark,
)

def write(self, df: DataFrame) -> None:
"""FIXME: Write docstring."""
# Convert the DataFrame to a Kafka-friendly format
kafka_df = to_kafka_dataframe(self.unique_identifier_column_name, df)

if "key" not in kafka_df.columns or "value" not in kafka_df.columns:
raise KeyError(
"The DataFrame must contain 'key' and 'value' columns. "
"Ensure that `to_kafka_dataframe` transformation is correctly applied."
)

# fmt: off
(
kafka_df.write.format("kafka")
.options(**self.kafka_options)
.option("topic", self.kafka_topic)
.save()
)
# fmt: on
94 changes: 94 additions & 0 deletions tests/unit/writer/test_kafka_batch_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
from typing import Any

import pytest
from pyspark.sql.functions import monotonically_increasing_id

from sparkle.config.kafka_config import SchemaFormat
from sparkle.reader.kafka_reader import KafkaReader
from sparkle.reader.schema_registry import SchemaRegistry
from sparkle.writer.kafka_writer import KafkaBatchPublisher

KAFKA_BROKER_URL = "localhost:9092"
UNIQUE_ID_COLUMN = "id"


@pytest.fixture
def kafka_config() -> dict[str, Any]:
"""Fixture that provides Kafka configuration options for testing.
Returns:
dict[str, any]: A dictionary containing Kafka configuration options,
including Kafka bootstrap servers, security protocol, Kafka topic,
and unique identifier column name.
"""
return {
"kafka_options": {
"kafka.bootstrap.servers": KAFKA_BROKER_URL,
"kafka.security.protocol": "PLAINTEXT",
},
"kafka_topic": "test-kafka-batch-writer-topic",
"unique_identifier_column_name": UNIQUE_ID_COLUMN,
}


@pytest.fixture
def mock_schema_registry(mocker):
"""Fixture to create a mock schema registry client."""
mock = mocker.Mock(spec=SchemaRegistry)
# mock.cached_schema.return_value = (
# '{"type": "record", "name": "test", "fields":'
# '[{"name": "test", "type": "string"}]}'
# )
return mock


@pytest.mark.wip
def test_kafka_batch_publisher_write(user_dataframe, kafka_config, spark_session, mock_schema_registry):
"""Test the write method of KafkaBatchPublisher by publishing to Kafka."""
# fmt: off
df = (
user_dataframe
.orderBy(user_dataframe.columns[0])
.withColumn(UNIQUE_ID_COLUMN, monotonically_increasing_id().cast("string"))
)
# fmt: on

publisher = KafkaBatchPublisher(
kafka_options=kafka_config["kafka_options"],
kafka_topic=kafka_config["kafka_topic"],
unique_identifier_column_name=kafka_config["unique_identifier_column_name"],
spark=spark_session,
)

# TODO: Cleanup this shit 👇
reader = KafkaReader(
spark=spark_session,
topic=kafka_config["kafka_topic"],
schema_registry=mock_schema_registry,
format_=SchemaFormat.raw,
schema_version="latest",
kafka_spark_options={
"kafka.bootstrap.servers": KAFKA_BROKER_URL,
"auto.offset.reset": "earliest",
"enable.auto.commit": True,
},
)

query = reader.read().writeStream.format("memory").queryName("kafka_test").outputMode("append").start()

publisher.write(df)

# NOTE: WHY THE FUCK IT WORK LIKE THIS? Ask shahin

query.awaitTermination(10)

df = spark_session.sql("SELECT * FROM kafka_test")

logging.info(df.schema)
logging.info(df.count())
res = df.collect()

logging.info(res)

assert False
2 changes: 1 addition & 1 deletion tests/unit/writer/test_kafka_stream_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def kafka_config() -> dict[str, Any]:
"kafka.security.protocol": "PLAINTEXT",
},
"checkpoint_location": "/tmp/checkpoint",
"kafka_topic": "test-kafka-writer-topic",
"kafka_topic": "test-kafka-stream-writer-topic",
"output_mode": "append",
"unique_identifier_column_name": "id",
"trigger_once": True,
Expand Down

0 comments on commit 9c9d7f0

Please sign in to comment.