diff --git a/src/sparkle/application/__init__.py b/src/sparkle/application/__init__.py index 6939030..2b3745b 100644 --- a/src/sparkle/application/__init__.py +++ b/src/sparkle/application/__init__.py @@ -11,28 +11,37 @@ class Sparkle(abc.ABC): - """Base class for Spark applications.""" + """Base class for Spark applications. + + This class provides a foundation for Spark-based data processing applications. + It handles setting up the Spark session, configuring readers and writers, + and defining the main process logic through abstract methods that must be implemented + by subclasses. + """ def __init__( self, config: Config, + readers: dict[str, type[Reader]], writers: list[Writer], spark_extensions: list[str] | None = None, spark_packages: list[str] | None = None, extra_spark_config: dict[str, str] | None = None, ): - """Sparkle's application initializer. + """Initializes the Sparkle application with the given configuration. Args: - config (Config): Configuration object containing application-specific settings. - writers (list[Writer]): list of Writer objects used for writing the processed data. - spark_extensions (list[str], optional): list of Spark session extensions to use. - spark_packages (list[str], optional): list of Spark packages to include. - extra_spark_config (dict[str, str], optional): Additional Spark configurations - to merge with default configurations. + config (Config): The configuration object containing application-specific settings. + readers (dict[str, type[Reader]]): A dictionary of readers for input data, keyed by source name. + writers (list[Writer]): A list of Writer objects used to output processed data. + spark_extensions (list[str], optional): A list of Spark session extensions to apply. + spark_packages (list[str], optional): A list of Spark packages to include in the session. + extra_spark_config (dict[str, str], optional): Additional Spark configurations to + merge with the default settings. """ self.config = config self.writers = writers + self.readers = readers self.execution_env = config.execution_environment self.spark_config = config.get_spark_config( self.execution_env, extra_spark_config @@ -43,16 +52,16 @@ def __init__( self.spark_session = self.get_spark_session(self.execution_env) def get_spark_session(self, env: ExecutionEnvironment) -> SparkSession: - """Create and return a Spark session based on the environment. + """Creates and returns a Spark session based on the environment. Args: - env (ExecutionEnvironment): The environment type (either LOCAL or AWS). + env (ExecutionEnvironment): The environment in which the Spark session is created (LOCAL or AWS). Returns: - SparkSession: The Spark session configured with the appropriate settings. + SparkSession: A Spark session configured for the specified environment. Raises: - ValueError: If the environment is unsupported. + ValueError: If an unsupported environment is provided. """ if env == ExecutionEnvironment.LOCAL: return self._get_local_session() @@ -62,10 +71,10 @@ def get_spark_session(self, env: ExecutionEnvironment) -> SparkSession: raise ValueError(f"Unsupported environment: {env}") def _get_local_session(self) -> SparkSession: - """Create a Spark session for the local environment. + """Creates a Spark session for local execution. Returns: - SparkSession: Configured Spark session for local environment. + SparkSession: A configured Spark session for the local environment. """ spark_conf = SparkConf() for key, value in self.spark_config.items(): @@ -87,15 +96,19 @@ def _get_local_session(self) -> SparkSession: return spark_session_builder.getOrCreate() def _get_aws_session(self) -> SparkSession: - """Create a Spark session for the AWS environment. + """Creates a Spark session for AWS Glue execution. Returns: - SparkSession: Configured Spark session for AWS environment. + SparkSession: A configured Spark session for the AWS Glue environment. + + Raises: + ImportError: If the AWS Glue libraries are not available. """ try: from awsglue.context import GlueContext # type: ignore[import] except ImportError: logger.error("Could not import GlueContext. Is this running on AWS Glue?") + raise spark_conf = SparkConf() for key, value in self.spark_config.items(): @@ -106,31 +119,31 @@ def _get_aws_session(self) -> SparkSession: @property def input(self) -> dict[str, Reader]: - """Dictionary of input DataReaders used in the application. + """Returns the input readers configured for the application. Returns: - dict[str, DataReader]: dictionary of input DataReaders used in the application, keyed by source name. + dict[str, Reader]: A dictionary mapping input sources to Reader instances. Raises: - ValueError: If no inputs are configured. + ValueError: If no readers are configured for the application. """ - if len(self.config.inputs) == 0: - raise ValueError("No inputs configured.") + if len(self.readers) == 0: + raise ValueError("No readers configured.") return { key: value.with_config(self.config, self.spark_session) - for key, value in self.config.inputs.items() + for key, value in self.readers.items() } @abc.abstractmethod def process(self) -> DataFrame: - """Application's entrypoint responsible for the main business logic. + """Defines the application's data processing logic. - This method should be overridden in subclasses to process the - input DataFrames and return a resulting DataFrame. + This method should be implemented by subclasses to define how the input data + is processed and transformed into the desired output. Returns: - DataFrame: The resulting DataFrame after processing. + DataFrame: The resulting DataFrame after the processing logic is applied. Raises: NotImplementedError: If the subclass does not implement this method. @@ -138,14 +151,13 @@ def process(self) -> DataFrame: raise NotImplementedError("process method must be implemented by subclasses") def write(self, df: DataFrame) -> None: - """Write output DataFrame to the application's writer(s). + """Writes the output DataFrame to the application's configured writers. - The DataFrame is first persisted in memory to optimize writing - operations and then unpersisted after all writers have - completed their tasks. + The DataFrame is persisted in memory to optimize writing operations, + and once all writers have completed their tasks, the DataFrame is unpersisted. Args: - df (DataFrame): The DataFrame to be written to the destinations. + df (DataFrame): The DataFrame to be written to the output destinations. """ df.persist(StorageLevel.MEMORY_ONLY) diff --git a/src/sparkle/config/__init__.py b/src/sparkle/config/__init__.py index 540bf52..384eb08 100644 --- a/src/sparkle/config/__init__.py +++ b/src/sparkle/config/__init__.py @@ -1,8 +1,6 @@ -from __future__ import annotations from dataclasses import dataclass from enum import Enum import os -from sparkle.reader import Reader from .kafka_config import KafkaReaderConfig, KafkaWriterConfig from .iceberg_config import IcebergConfig from .database_config import TableConfig @@ -25,7 +23,6 @@ class Config: version (str): The version of the application. database_bucket (str): The S3 bucket where the database is stored. checkpoints_bucket (str): The S3 bucket where the Spark checkpoints are stored. - inputs (dict[str, Type[Reader]]): A dictionary mapping input sources to Reader types. execution_environment (ExecutionEnvironment): The environment where the app is executed. filesystem_scheme (str): The file system scheme, default is 's3a://'. spark_trigger (str): The Spark trigger configuration in JSON format. @@ -40,7 +37,6 @@ class Config: version: str database_bucket: str checkpoints_bucket: str - inputs: dict[str, type[Reader]] execution_environment: ExecutionEnvironment = ExecutionEnvironment.LOCAL filesystem_scheme: str = "s3a://" spark_trigger: str = '{"once": True}' @@ -86,10 +82,6 @@ def get_local_spark_config( dict[str, str]: dictionary of Spark configurations for the local environment. """ default_config = { - "spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", - "spark.jars.packages": "org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.3.1," - "org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0," - "org.apache.spark:spark-avro_2.12:3.3.0", "spark.sql.session.timeZone": "UTC", "spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog", "spark.sql.catalog.local.type": "hadoop", @@ -112,11 +104,10 @@ def get_aws_spark_config( dict[str, str]: dictionary of Spark configurations for the AWS environment. """ default_config = { - "spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", "spark.sql.catalog.glue_catalog": "org.apache.iceberg.spark.SparkCatalog", "spark.sql.catalog.glue_catalog.catalog-impl": "org.apache.iceberg.aws.glue.GlueCatalog", "spark.sql.catalog.glue_catalog.io-impl": "org.apache.iceberg.aws.s3.S3FileIO", - "spark.sql.catalog.glue_catalog.warehouse": "./tmp/warehouse", + # "spark.sql.catalog.glue_catalog.warehouse": "", TODO: Validate if needed } if extra_config: default_config.update(extra_config) diff --git a/tests/conftest.py b/tests/conftest.py index 95299e8..d80db8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,7 +36,7 @@ def spark_session() -> SparkSession: "spark.sql.catalog.spark_catalog.type": "hive", "spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog", "spark.sql.catalog.local.type": "hadoop", - "spark.sql.catalog.local.warehouse": "./tmp/warehouse", + "spark.sql.catalog.local.warehouse": "./tmp/test_warehouse", "spark.sql.defaultCatalog": "local", } @@ -47,7 +47,7 @@ def spark_session() -> SparkSession: spark_session = ( SparkSession.builder.master("local[*]") - .appName("LocalSparkleApp") + .appName("LocalTestSparkleApp") .config(conf=spark_conf) ) diff --git a/tests/unit/writer/test_iceberg_writer.py b/tests/unit/writer/test_iceberg_writer.py index 1b7f61b..06f42b1 100644 --- a/tests/unit/writer/test_iceberg_writer.py +++ b/tests/unit/writer/test_iceberg_writer.py @@ -9,7 +9,7 @@ TEST_DB = "default" TEST_TABLE = "test_table" -WAREHOUSE = "./tmp/warehouse" +WAREHOUSE = "./tmp/test_warehouse" CATALOG = "glue_catalog"