diff --git a/pyproject.toml b/pyproject.toml index 7ae3021..cedcf37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,3 +74,6 @@ addopts = "-vv --tb=auto --disable-warnings" pythonpath = [ "src" ] +markers = [ + "wip: mark tests as work in progress", +] diff --git a/src/sparkle/writer/kafka_writer.py b/src/sparkle/writer/kafka_writer.py index 0828edc..db09367 100644 --- a/src/sparkle/writer/kafka_writer.py +++ b/src/sparkle/writer/kafka_writer.py @@ -1,5 +1,7 @@ from typing import Any -from pyspark.sql import SparkSession, DataFrame + +from pyspark.sql import DataFrame, SparkSession + from sparkle.config import Config from sparkle.utils.spark import to_kafka_dataframe from sparkle.writer import Writer @@ -51,9 +53,7 @@ def __init__( self.trigger_once = trigger_once @classmethod - def with_config( - cls, config: Config, spark: SparkSession, **kwargs: Any - ) -> "KafkaStreamPublisher": + def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaStreamPublisher": """Create a KafkaStreamPublisher object with a configuration. Args: @@ -109,3 +109,55 @@ def write(self, df: DataFrame) -> None: .option("topic", self.kafka_topic) .start() ) + + +class KafkaBatchPublisher(Writer): + """FIXME: Write docstring.""" + + def __init__( + self, + kafka_options: dict[str, Any], + kafka_topic: str, + unique_identifier_column_name: str, + spark: SparkSession, + ) -> None: + """FIXME: Write docstring.""" + super().__init__(spark) + self.kafka_options = kafka_options + self.kafka_topic = kafka_topic + self.unique_identifier_column_name = unique_identifier_column_name + + @classmethod + def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaBatchPublisher": + """FIXME: Write docstring.""" + if not config.kafka_output: + raise ValueError("Kafka output configuration is missing") + + c = config.kafka_output + + return cls( + kafka_options=c.kafka_config.spark_kafka_config, + kafka_topic=c.kafka_topic, + unique_identifier_column_name=c.unique_identifier_column_name, + spark=spark, + ) + + def write(self, df: DataFrame) -> None: + """FIXME: Write docstring.""" + # Convert the DataFrame to a Kafka-friendly format + kafka_df = to_kafka_dataframe(self.unique_identifier_column_name, df) + + if "key" not in kafka_df.columns or "value" not in kafka_df.columns: + raise KeyError( + "The DataFrame must contain 'key' and 'value' columns. " + "Ensure that `to_kafka_dataframe` transformation is correctly applied." + ) + + # fmt: off + ( + kafka_df.write.format("kafka") + .options(**self.kafka_options) + .option("topic", self.kafka_topic) + .save() + ) + # fmt: on diff --git a/tests/unit/writer/test_kafka_batch_writer.py b/tests/unit/writer/test_kafka_batch_writer.py index e69de29..a9a9f47 100644 --- a/tests/unit/writer/test_kafka_batch_writer.py +++ b/tests/unit/writer/test_kafka_batch_writer.py @@ -0,0 +1,94 @@ +import logging +from typing import Any + +import pytest +from pyspark.sql.functions import monotonically_increasing_id + +from sparkle.config.kafka_config import SchemaFormat +from sparkle.reader.kafka_reader import KafkaReader +from sparkle.reader.schema_registry import SchemaRegistry +from sparkle.writer.kafka_writer import KafkaBatchPublisher + +KAFKA_BROKER_URL = "localhost:9092" +UNIQUE_ID_COLUMN = "id" + + +@pytest.fixture +def kafka_config() -> dict[str, Any]: + """Fixture that provides Kafka configuration options for testing. + + Returns: + dict[str, any]: A dictionary containing Kafka configuration options, + including Kafka bootstrap servers, security protocol, Kafka topic, + and unique identifier column name. + """ + return { + "kafka_options": { + "kafka.bootstrap.servers": KAFKA_BROKER_URL, + "kafka.security.protocol": "PLAINTEXT", + }, + "kafka_topic": "test-kafka-batch-writer-topic", + "unique_identifier_column_name": UNIQUE_ID_COLUMN, + } + + +@pytest.fixture +def mock_schema_registry(mocker): + """Fixture to create a mock schema registry client.""" + mock = mocker.Mock(spec=SchemaRegistry) + # mock.cached_schema.return_value = ( + # '{"type": "record", "name": "test", "fields":' + # '[{"name": "test", "type": "string"}]}' + # ) + return mock + + +@pytest.mark.wip +def test_kafka_batch_publisher_write(user_dataframe, kafka_config, spark_session, mock_schema_registry): + """Test the write method of KafkaBatchPublisher by publishing to Kafka.""" + # fmt: off + df = ( + user_dataframe + .orderBy(user_dataframe.columns[0]) + .withColumn(UNIQUE_ID_COLUMN, monotonically_increasing_id().cast("string")) + ) + # fmt: on + + publisher = KafkaBatchPublisher( + kafka_options=kafka_config["kafka_options"], + kafka_topic=kafka_config["kafka_topic"], + unique_identifier_column_name=kafka_config["unique_identifier_column_name"], + spark=spark_session, + ) + + # TODO: Cleanup this shit 👇 + reader = KafkaReader( + spark=spark_session, + topic=kafka_config["kafka_topic"], + schema_registry=mock_schema_registry, + format_=SchemaFormat.raw, + schema_version="latest", + kafka_spark_options={ + "kafka.bootstrap.servers": KAFKA_BROKER_URL, + "auto.offset.reset": "earliest", + "enable.auto.commit": True, + }, + ) + + query = reader.read().writeStream.format("memory").queryName("kafka_test").outputMode("append").start() + + publisher.write(df) + + # NOTE: WHY THE FUCK IT WORK LIKE THIS? Ask shahin + + query.awaitTermination(10) + + df = spark_session.sql("SELECT * FROM kafka_test") + + logging.info(df.schema) + logging.info(df.count()) + res = df.collect() + + logging.info(res) + + assert False diff --git a/tests/unit/writer/test_kafka_stream_writer.py b/tests/unit/writer/test_kafka_stream_writer.py index cb8290d..6421d0b 100644 --- a/tests/unit/writer/test_kafka_stream_writer.py +++ b/tests/unit/writer/test_kafka_stream_writer.py @@ -25,7 +25,7 @@ def kafka_config() -> dict[str, Any]: "kafka.security.protocol": "PLAINTEXT", }, "checkpoint_location": "/tmp/checkpoint", - "kafka_topic": "test-kafka-writer-topic", + "kafka_topic": "test-kafka-stream-writer-topic", "output_mode": "append", "unique_identifier_column_name": "id", "trigger_once": True,