diff --git a/devenv.nix b/devenv.nix index b2a3c24..830bbd4 100644 --- a/devenv.nix +++ b/devenv.nix @@ -15,6 +15,7 @@ let black pylint ]; + compose-path = "./tests/docker-compose.yml"; in { name = "sparkle"; @@ -50,6 +51,9 @@ in scripts.down.exec = "devenv processes down"; scripts.down.description = "Stop processes."; + scripts.cleanup.exec = "docker compose -f ${compose-path} rm -vf"; + scripts.cleanup.description = "Remove unused docker containers and volumes."; + scripts.show.exec = '' GREEN="\033[0;32m"; YELLOW="\033[33m"; @@ -103,7 +107,7 @@ in processes = { kafka-test.exec = '' - docker compose -f tests/docker-compose.yml up --build + docker compose -f ${compose-path} up --build ''; }; diff --git a/poetry.lock b/poetry.lock index 2399606..e9ed615 100644 --- a/poetry.lock +++ b/poetry.lock @@ -143,6 +143,20 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "chispa" +version = "0.10.1" +description = "Pyspark test helper library" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "chispa-0.10.1-py3-none-any.whl", hash = "sha256:f040d6eaaa9f6165a31ff675c44cad0778bb260d833f35f023610357b5f9b5ab"}, + {file = "chispa-0.10.1.tar.gz", hash = "sha256:7ccdbfcc187c3d630efcccc853aa7a7797d3e02a4ee16278c9aeb66fe24c88ca"}, +] + +[package.dependencies] +prettytable = ">=3.10.2,<4.0.0" + [[package]] name = "colorama" version = "0.4.6" @@ -613,6 +627,23 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "prettytable" +version = "3.11.0" +description = "A simple Python library for easily displaying tabular data in a visually appealing ASCII table format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "prettytable-3.11.0-py3-none-any.whl", hash = "sha256:aa17083feb6c71da11a68b2c213b04675c4af4ce9c541762632ca3f2cb3546dd"}, + {file = "prettytable-3.11.0.tar.gz", hash = "sha256:7e23ca1e68bbfd06ba8de98bf553bf3493264c96d5e8a615c0471025deeba722"}, +] + +[package.dependencies] +wcwidth = "*" + +[package.extras] +tests = ["pytest", "pytest-cov", "pytest-lazy-fixtures"] + [[package]] name = "prompt-toolkit" version = "3.0.36" @@ -1032,4 +1063,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = ">=3.10.14, <4.0" -content-hash = "d7cec6d06cd50cd6e3908959695f4785d2bfe9cbb0a8f0b93f179d75cba0af57" +content-hash = "404af9e2eb8f022cbb637b551e3526ec6ac2a09ef25b369abf03cf0cbd3305ba" diff --git a/pyproject.toml b/pyproject.toml index d22a8a0..51b2831 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ pytest-mock = "^3.14.0" confluent-kafka = "^2.5.3" fastavro = "^1.9.7" types-confluent-kafka = "^1.2.2" +chispa = "^0.10.1" [tool.commitizen] version = "0.5.1" @@ -71,7 +72,20 @@ source = [ skip_empty = true [tool.pytest.ini_options] -addopts = "-vv --tb=auto --disable-warnings" +addopts = "-v --tb=short -ra --no-header --show-capture=log" +# -v: add sufficient verbosity without being overwhelming +# --tb=short: show failing line and related context without printing all function code +# -ra: small recap at the end of pytest outputs excluding passed tests +# --no-header: skip pytest header +# --show-capture=log: reduce output clutter by capturing only logging calls +log_level = "info" pythonpath = [ "src" ] +markers = [ + "wip: mark tests as work in progress", +] + +[[tool.mypy.overrides]] +module = "chispa.*" +ignore_missing_imports = true diff --git a/src/sparkle/writer/kafka_writer.py b/src/sparkle/writer/kafka_writer.py index 0828edc..9670bb0 100644 --- a/src/sparkle/writer/kafka_writer.py +++ b/src/sparkle/writer/kafka_writer.py @@ -1,5 +1,7 @@ from typing import Any -from pyspark.sql import SparkSession, DataFrame + +from pyspark.sql import DataFrame, SparkSession + from sparkle.config import Config from sparkle.utils.spark import to_kafka_dataframe from sparkle.writer import Writer @@ -51,9 +53,7 @@ def __init__( self.trigger_once = trigger_once @classmethod - def with_config( - cls, config: Config, spark: SparkSession, **kwargs: Any - ) -> "KafkaStreamPublisher": + def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaStreamPublisher": """Create a KafkaStreamPublisher object with a configuration. Args: @@ -109,3 +109,91 @@ def write(self, df: DataFrame) -> None: .option("topic", self.kafka_topic) .start() ) + + +class KafkaBatchPublisher(Writer): + """KafkaBatchublisher class for writing DataFrames in batch to Kafka. + + Inherits from the Writer abstract base class and implements the write + method for writing data to Kafka topics. + + Args: + kafka_options (dict[str, Any]): Kafka connection options. + kafka_topic (str): Kafka topic to which data will be written. + unique_identifier_column_name (str): Column name used as the Kafka key. + spark (SparkSession): Spark session instance to use. + """ + + def __init__( + self, + kafka_options: dict[str, Any], + kafka_topic: str, + unique_identifier_column_name: str, + spark: SparkSession, + ) -> None: + """Initialize the KafkaBatchPublisher object. + + Args: + kafka_options (dict[str, Any]): Kafka options for the connection. + kafka_topic (str): The target Kafka topic for writing data. + unique_identifier_column_name (str): Column name to be used as Kafka key. + spark (SparkSession): The Spark session to be used for writing data. + """ + super().__init__(spark) + self.kafka_options = kafka_options + self.kafka_topic = kafka_topic + self.unique_identifier_column_name = unique_identifier_column_name + + @classmethod + def with_config(cls, config: Config, spark: SparkSession, **kwargs: Any) -> "KafkaBatchPublisher": + """Create a KafkaBatchPublisher object with a configuration. + + Args: + config (Config): Configuration object containing settings for the writer. + spark (SparkSession): The Spark session to be used for writing data. + **kwargs (Any): Additional keyword arguments. + + Returns: + KafkaBatchPublisher: An instance configured with the provided settings. + """ + if not config.kafka_output: + raise ValueError("Kafka output configuration is missing") + + c = config.kafka_output + + return cls( + kafka_options=c.kafka_config.spark_kafka_config, + kafka_topic=c.kafka_topic, + unique_identifier_column_name=c.unique_identifier_column_name, + spark=spark, + ) + + def write(self, df: DataFrame) -> None: + """Write DataFrame to Kafka by converting it to JSON using the configured primary key. + + This method transforms the DataFrame using the unique identifier column name + and writes it to the configured Kafka topic. + + Args: + df (DataFrame): The DataFrame to be written. + + Raises: + KeyError: If the DataFrame does not have the required 'key' and 'value' columns. + """ + # Convert the DataFrame to a Kafka-friendly format + kafka_df = to_kafka_dataframe(self.unique_identifier_column_name, df) + + if set(kafka_df.columns) != {"key", "value"}: + raise KeyError( + "The DataFrame must contain 'key' and 'value' columns. " + "Ensure that `to_kafka_dataframe` transformation is correctly applied." + ) + + # fmt: off + ( + kafka_df.write.format("kafka") + .options(**self.kafka_options) + .option("topic", self.kafka_topic) + .save() + ) + # fmt: on diff --git a/tests/conftest.py b/tests/conftest.py index d80db8a..ae8dea4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,14 @@ -import pytest -from typing import Any +import io import json +import logging import os -from pyspark.sql import SparkSession +import shutil +from contextlib import redirect_stdout +from typing import Any + +import pytest from pyspark.conf import SparkConf +from pyspark.sql import DataFrame, SparkSession @pytest.fixture(scope="session") @@ -45,11 +50,7 @@ def spark_session() -> SparkSession: for key, value in LOCAL_CONFIG.items(): spark_conf.set(key, str(value)) - spark_session = ( - SparkSession.builder.master("local[*]") - .appName("LocalTestSparkleApp") - .config(conf=spark_conf) - ) + spark_session = SparkSession.builder.master("local[*]").appName("LocalTestSparkleApp").config(conf=spark_conf) if ivy_settings_path: spark_session.config("spark.jars.ivySettings", ivy_settings_path) @@ -57,6 +58,43 @@ def spark_session() -> SparkSession: return spark_session.getOrCreate() +@pytest.fixture(scope="session") +def checkpoint_directory(): + """Fixture to validate and remove the checkpoint directory after tests. + + To avoid test failures due to non-unique directories, the user should add a + subdirectory to this path when using this fixture. + + Example: + >>> dir = checkpoint_directory + subdir + """ + checkpoint_dir = "/tmp/checkpoint/" + + yield checkpoint_dir + + # Remove the checkpoint directory if it exists + if os.path.exists(checkpoint_dir): + shutil.rmtree(checkpoint_dir) + logging.info(f"Checkpoint directory {checkpoint_dir} has been removed.") + else: + logging.warning(f"Checkpoint directory {checkpoint_dir} was not found.") + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_logging_handlers(): + """Fixture to cleanup logging handlers after tests. + + Prevents logging errors at the end of the report. + Taken from [here](https://github.com/pytest-dev/pytest/issues/5502#issuecomment-1803676152) + """ + try: + yield + finally: + for handler in logging.root.handlers[:]: + if isinstance(handler, logging.StreamHandler): + logging.root.removeHandler(handler) + + @pytest.fixture def user_dataframe(spark_session: SparkSession): """Fixture for creating a DataFrame with user data. @@ -71,21 +109,11 @@ def user_dataframe(spark_session: SparkSession): pyspark.sql.DataFrame: A Spark DataFrame with sample user data. """ data = [ - { - "name": "John", - "surname": "Doe", - "phone": "12345", - "email": "john@test.com", - }, - { - "name": "Jane", - "surname": "Doe", - "phone": "12345", - "email": "jane.doe@test.com", - }, + ["John", "Doe", "12345", "john@test.com"], + ["Jane", "Doe", "12345", "jane.doe@test.com"], ] - - return spark_session.createDataFrame(data) + schema = ["name", "surname", "phone", "email"] + return spark_session.createDataFrame(data, schema=schema) @pytest.fixture @@ -127,3 +155,17 @@ def json_to_string(dictionary: dict[str, Any]) -> str: ensure_ascii=True, separators=(",", ":"), ).replace("\n", "") + + +def log_spark_dataframe(df: DataFrame, *, truncate: bool = False, name: str = "") -> None: + """Logs the contents of a Spark DataFrame in tabular format. + + Useful when Pytest is configured to capture only logs, so `df.show()` won't work. + + Example: + >>> log_spark_dataframe(df, name="My DataFrame") + """ + buffer = io.StringIO() + with redirect_stdout(buffer): + df.show(truncate=truncate) + logging.info(f"\n{name}\n{buffer.getvalue()}") diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index d8f8e97..4629613 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -44,3 +44,14 @@ services: SCHEMA_REGISTRY_HOST_NAME: schema-registry SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS: 'broker:29092' SCHEMA_REGISTRY_LISTENERS: http://0.0.0.0:8081 + + kafka-ui: + container_name: kafka-ui + image: provectuslabs/kafka-ui:latest + ports: + - 8080:8080 + environment: + DYNAMIC_CONFIG_ENABLED: true + KAFKA_CLUSTERS_0_NAME: local + KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: broker:29092 + KAFKA_CLUSTERS_0_SCHEMAREGISTRY: http://schema-registry:8081 diff --git a/tests/unit/reader/test_kafka_reader.py b/tests/unit/reader/test_kafka_reader.py index 050a01b..8dd92c6 100644 --- a/tests/unit/reader/test_kafka_reader.py +++ b/tests/unit/reader/test_kafka_reader.py @@ -1,19 +1,22 @@ +import logging +from collections.abc import Generator from time import sleep from typing import Any -from collections.abc import Generator + import pytest -from pyspark.sql import SparkSession, DataFrame from confluent_kafka import Producer from confluent_kafka.admin import AdminClient, NewTopic -from confluent_kafka.schema_registry import SchemaRegistryClient, Schema +from confluent_kafka.schema_registry import Schema, SchemaRegistryClient from confluent_kafka.schema_registry.avro import AvroSerializer from confluent_kafka.serialization import ( - StringSerializer, - SerializationContext, MessageField, + SerializationContext, + StringSerializer, ) -from sparkle.reader.kafka_reader import KafkaReader, SchemaRegistry +from pyspark.sql import DataFrame, SparkSession + from sparkle.config.kafka_config import SchemaFormat +from sparkle.reader.kafka_reader import KafkaReader, SchemaRegistry KAFKA_BROKER_URL = "localhost:9092" SCHEMA_REGISTRY_URL = "http://localhost:8081" @@ -31,14 +34,13 @@ def kafka_setup() -> Generator[str, None, None]: """ admin_client = AdminClient({"bootstrap.servers": KAFKA_BROKER_URL}) - admin_client.create_topics( - [NewTopic(TEST_TOPIC, num_partitions=1, replication_factor=1)] - ) + admin_client.create_topics([NewTopic(TEST_TOPIC, num_partitions=1, replication_factor=1)]) yield TEST_TOPIC # Cleanup admin_client.delete_topics([TEST_TOPIC]) + logging.info("Deleted Kafka topic %s", TEST_TOPIC) @pytest.fixture @@ -135,9 +137,7 @@ def produce_avro_message( string_serializer = StringSerializer("utf_8") producer.produce( topic=topic, - key=string_serializer( - value["name"], SerializationContext(topic, MessageField.KEY) - ), + key=string_serializer(value["name"], SerializationContext(topic, MessageField.KEY)), value=avro_serializer(value, SerializationContext(topic, MessageField.VALUE)), ) producer.flush() diff --git a/tests/unit/utils/test_spark.py b/tests/unit/utils/test_spark.py index 24f8951..0c95d88 100644 --- a/tests/unit/utils/test_spark.py +++ b/tests/unit/utils/test_spark.py @@ -1,14 +1,14 @@ import pytest -from sparkle.writer.iceberg_writer import IcebergWriter -from sparkle.utils.spark import table_exists -from sparkle.utils.spark import to_kafka_dataframe -from tests.conftest import json_to_string -from pyspark.sql import DataFrame, SparkSession, Row +from chispa.dataframe_comparer import assert_df_equality +from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql import functions as F +from pyspark.sql.avro.functions import to_avro from pyspark.sql.functions import col, lit, struct + from sparkle.reader.schema_registry import SchemaRegistry -from sparkle.utils.spark import parse_by_avro -from pyspark.sql.avro.functions import to_avro -from pyspark.sql import functions as F +from sparkle.utils.spark import parse_by_avro, table_exists, to_kafka_dataframe +from sparkle.writer.iceberg_writer import IcebergWriter +from tests.conftest import json_to_string @pytest.mark.parametrize( @@ -58,10 +58,10 @@ def test_generate_kafka_acceptable_dataframe(user_dataframe: DataFrame, spark_se "key": "john@test.com", "value": json_to_string( { - "email": "john@test.com", "name": "John", - "phone": "12345", "surname": "Doe", + "phone": "12345", + "email": "john@test.com", }, ), }, @@ -69,33 +69,31 @@ def test_generate_kafka_acceptable_dataframe(user_dataframe: DataFrame, spark_se "key": "jane.doe@test.com", "value": json_to_string( { - "email": "jane.doe@test.com", "name": "Jane", - "phone": "12345", "surname": "Doe", + "phone": "12345", + "email": "jane.doe@test.com", }, ), }, ] - expected_df = spark_session.createDataFrame( - expected_result, schema=["key", "value"] - ) + expected_df = spark_session.createDataFrame(expected_result, schema=["key", "value"]) - df = to_kafka_dataframe("email", user_dataframe) + actual_df = to_kafka_dataframe("email", user_dataframe) - assert df.count() == expected_df.count() - assert expected_df.join(df, ["key"]).count() == expected_df.count() - assert expected_df.join(df, ["value"]).count() == expected_df.count() + assert_df_equality(expected_df, actual_df) @pytest.fixture def mock_schema_registry(mocker): """Fixture to create a mock schema registry client.""" mock = mocker.Mock(spec=SchemaRegistry) + # fmt: off mock.cached_schema.return_value = ( '{"type": "record", "name": "test", "fields":' '[{"name": "test", "type": "string"}]}' ) + # fmt: on return mock @@ -122,7 +120,8 @@ def test_parse_by_avro(spark_session: SparkSession, mock_schema_registry): # Add magic byte and schema ID to simulate real Kafka Avro messages kafka_data = kafka_data.withColumn( - "value", F.concat(F.lit(b"\x00\x00\x00\x00\x01"), col("value")) + "value", + F.concat(F.lit(b"\x00\x00\x00\x00\x01"), col("value")), ) # Create the transformer function using the parse_by_avro function diff --git a/tests/unit/writer/test_kafka_batch_writer.py b/tests/unit/writer/test_kafka_batch_writer.py new file mode 100644 index 0000000..bd7e4ad --- /dev/null +++ b/tests/unit/writer/test_kafka_batch_writer.py @@ -0,0 +1,100 @@ +import logging +from typing import Any + +import pytest +from chispa.dataframe_comparer import assert_df_equality +from confluent_kafka.admin import AdminClient, NewTopic + +from sparkle.config.kafka_config import SchemaFormat +from sparkle.reader.kafka_reader import KafkaReader +from sparkle.reader.schema_registry import SchemaRegistry +from sparkle.writer.kafka_writer import KafkaBatchPublisher + +TOPIC = "test-kafka-batch-writer-topic" +BROKER_URL = "localhost:9092" + + +@pytest.fixture +def kafka_config(user_dataframe, checkpoint_directory) -> dict[str, Any]: + """Fixture that provides Kafka configuration options for testing. + + Returns: + dict[str, any]: A dictionary containing Kafka configuration options, + including Kafka bootstrap servers, security protocol, Kafka topic, + and unique identifier column name. + """ + return { + "kafka_options": { + "kafka.bootstrap.servers": BROKER_URL, + "kafka.security.protocol": "PLAINTEXT", + }, + "checkpoint_location": checkpoint_directory + TOPIC, + "kafka_topic": TOPIC, + "unique_identifier_column_name": user_dataframe.columns[0], + } + + +@pytest.fixture +def kafka_setup(): + """Create a Kafka topic and deletes it after the test.""" + kafka_client = AdminClient({"bootstrap.servers": BROKER_URL}) + kafka_client.create_topics([NewTopic(TOPIC, num_partitions=1, replication_factor=1)]) + yield + kafka_client.delete_topics([TOPIC]) + logging.info("Deleted Kafka topic %s", TOPIC) + + +def test_kafka_batch_publisher_write( + user_dataframe, + kafka_config, + spark_session, + mocker, + kafka_setup, +): + """Test the write method of KafkaBatchPublisher by publishing to Kafka.""" + publisher = KafkaBatchPublisher( + kafka_options=kafka_config["kafka_options"], + kafka_topic=kafka_config["kafka_topic"], + unique_identifier_column_name=kafka_config["unique_identifier_column_name"], + spark=spark_session, + ) + reader = KafkaReader( + spark=spark_session, + topic=kafka_config["kafka_topic"], + schema_registry=mocker.Mock(spec=SchemaRegistry), + format_=SchemaFormat.raw, + schema_version="latest", + kafka_spark_options={ + "kafka.bootstrap.servers": kafka_config["kafka_options"]["kafka.bootstrap.servers"], + "startingOffsets": "earliest", + "enable.auto.commit": True, + }, + ) + + publisher.write(user_dataframe) + query = ( + reader.read() + .writeStream.format("memory") + .queryName("batch_data") + .outputMode("append") + .option("checkpointLocation", kafka_config["checkpoint_location"]) + .trigger(once=True) + .start() + ) + query.awaitTermination(timeout=10) + + actual_df = spark_session.sql(""" + SELECT + parsed_json.* + FROM ( + SELECT + from_json( + cast(value as string), + 'name STRING, surname STRING, phone STRING, email STRING' + ) as parsed_json + FROM batch_data + ) + """) + + expected_df = user_dataframe + assert_df_equality(expected_df, actual_df, ignore_row_order=True) diff --git a/tests/unit/writer/test_kafka_writer.py b/tests/unit/writer/test_kafka_stream_writer.py similarity index 77% rename from tests/unit/writer/test_kafka_writer.py rename to tests/unit/writer/test_kafka_stream_writer.py index cb8290d..a9cfdac 100644 --- a/tests/unit/writer/test_kafka_writer.py +++ b/tests/unit/writer/test_kafka_stream_writer.py @@ -1,5 +1,4 @@ import os -import shutil import time from typing import Any @@ -9,9 +8,11 @@ from sparkle.writer.kafka_writer import KafkaStreamPublisher +TOPIC = "test-kafka-stream-writer-topic" + @pytest.fixture -def kafka_config() -> dict[str, Any]: +def kafka_config(checkpoint_directory) -> dict[str, Any]: """Fixture that provides Kafka configuration options for testing. Returns: @@ -24,8 +25,8 @@ def kafka_config() -> dict[str, Any]: "kafka.bootstrap.servers": "localhost:9092", "kafka.security.protocol": "PLAINTEXT", }, - "checkpoint_location": "/tmp/checkpoint", - "kafka_topic": "test-kafka-writer-topic", + "checkpoint_location": checkpoint_directory + TOPIC, + "kafka_topic": TOPIC, "output_mode": "append", "unique_identifier_column_name": "id", "trigger_once": True, @@ -51,32 +52,10 @@ def rate_stream_dataframe(spark_session) -> DataFrame: return rate_df -@pytest.fixture -def cleanup_checkpoint_directory(kafka_config): - """Fixture that validates and removes the checkpoint directory after tests. - - Args: - kafka_config (dict[str, any]): The Kafka configuration dictionary. - - Yields: - None: This fixture ensures that the checkpoint directory specified in the - Kafka configuration is removed after test execution if it exists. - """ - checkpoint_dir = kafka_config["checkpoint_location"] - - yield - - # Remove the checkpoint directory if it exists - if os.path.exists(checkpoint_dir): - shutil.rmtree(checkpoint_dir) - print(f"Checkpoint directory {checkpoint_dir} has been removed.") - - def test_kafka_stream_publisher_write( spark_session: SparkSession, rate_stream_dataframe, kafka_config: dict[str, Any], - cleanup_checkpoint_directory, ): """Test the write method of KafkaStreamPublisher by publishing to Kafka. @@ -88,7 +67,6 @@ def test_kafka_stream_publisher_write( spark_session (SparkSession): The Spark session used for the test. rate_stream_dataframe (DataFrame): The streaming DataFrame to be published. kafka_config (dict[str, any]): Kafka configuration options. - cleanup_checkpoint_directory: Fixture to clean up the checkpoint directory. Raises: AssertionError: If the commit file does not exist after the stream terminates.