Skip to content

Commit

Permalink
refactor: accept readers in Sparkle class
Browse files Browse the repository at this point in the history
  • Loading branch information
farbodahm committed Sep 20, 2024
1 parent ffb0bea commit 5921885
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 44 deletions.
74 changes: 43 additions & 31 deletions src/sparkle/application/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,37 @@


class Sparkle(abc.ABC):
"""Base class for Spark applications."""
"""Base class for Spark applications.
This class provides a foundation for Spark-based data processing applications.
It handles setting up the Spark session, configuring readers and writers,
and defining the main process logic through abstract methods that must be implemented
by subclasses.
"""

def __init__(
self,
config: Config,
readers: dict[str, type[Reader]],
writers: list[Writer],
spark_extensions: list[str] | None = None,
spark_packages: list[str] | None = None,
extra_spark_config: dict[str, str] | None = None,
):
"""Sparkle's application initializer.
"""Initializes the Sparkle application with the given configuration.
Args:
config (Config): Configuration object containing application-specific settings.
writers (list[Writer]): list of Writer objects used for writing the processed data.
spark_extensions (list[str], optional): list of Spark session extensions to use.
spark_packages (list[str], optional): list of Spark packages to include.
extra_spark_config (dict[str, str], optional): Additional Spark configurations
to merge with default configurations.
config (Config): The configuration object containing application-specific settings.
readers (dict[str, type[Reader]]): A dictionary of readers for input data, keyed by source name.
writers (list[Writer]): A list of Writer objects used to output processed data.
spark_extensions (list[str], optional): A list of Spark session extensions to apply.
spark_packages (list[str], optional): A list of Spark packages to include in the session.
extra_spark_config (dict[str, str], optional): Additional Spark configurations to
merge with the default settings.
"""
self.config = config
self.writers = writers
self.readers = readers
self.execution_env = config.execution_environment
self.spark_config = config.get_spark_config(
self.execution_env, extra_spark_config
Expand All @@ -43,16 +52,16 @@ def __init__(
self.spark_session = self.get_spark_session(self.execution_env)

def get_spark_session(self, env: ExecutionEnvironment) -> SparkSession:
"""Create and return a Spark session based on the environment.
"""Creates and returns a Spark session based on the environment.
Args:
env (ExecutionEnvironment): The environment type (either LOCAL or AWS).
env (ExecutionEnvironment): The environment in which the Spark session is created (LOCAL or AWS).
Returns:
SparkSession: The Spark session configured with the appropriate settings.
SparkSession: A Spark session configured for the specified environment.
Raises:
ValueError: If the environment is unsupported.
ValueError: If an unsupported environment is provided.
"""
if env == ExecutionEnvironment.LOCAL:
return self._get_local_session()
Expand All @@ -62,10 +71,10 @@ def get_spark_session(self, env: ExecutionEnvironment) -> SparkSession:
raise ValueError(f"Unsupported environment: {env}")

def _get_local_session(self) -> SparkSession:
"""Create a Spark session for the local environment.
"""Creates a Spark session for local execution.
Returns:
SparkSession: Configured Spark session for local environment.
SparkSession: A configured Spark session for the local environment.
"""
spark_conf = SparkConf()
for key, value in self.spark_config.items():
Expand All @@ -87,15 +96,19 @@ def _get_local_session(self) -> SparkSession:
return spark_session_builder.getOrCreate()

def _get_aws_session(self) -> SparkSession:
"""Create a Spark session for the AWS environment.
"""Creates a Spark session for AWS Glue execution.
Returns:
SparkSession: Configured Spark session for AWS environment.
SparkSession: A configured Spark session for the AWS Glue environment.
Raises:
ImportError: If the AWS Glue libraries are not available.
"""
try:
from awsglue.context import GlueContext # type: ignore[import]
except ImportError:
logger.error("Could not import GlueContext. Is this running on AWS Glue?")
raise

spark_conf = SparkConf()
for key, value in self.spark_config.items():
Expand All @@ -106,46 +119,45 @@ def _get_aws_session(self) -> SparkSession:

@property
def input(self) -> dict[str, Reader]:
"""Dictionary of input DataReaders used in the application.
"""Returns the input readers configured for the application.
Returns:
dict[str, DataReader]: dictionary of input DataReaders used in the application, keyed by source name.
dict[str, Reader]: A dictionary mapping input sources to Reader instances.
Raises:
ValueError: If no inputs are configured.
ValueError: If no readers are configured for the application.
"""
if len(self.config.inputs) == 0:
raise ValueError("No inputs configured.")
if len(self.readers) == 0:
raise ValueError("No readers configured.")

return {
key: value.with_config(self.config, self.spark_session)
for key, value in self.config.inputs.items()
for key, value in self.readers.items()
}

@abc.abstractmethod
def process(self) -> DataFrame:
"""Application's entrypoint responsible for the main business logic.
"""Defines the application's data processing logic.
This method should be overridden in subclasses to process the
input DataFrames and return a resulting DataFrame.
This method should be implemented by subclasses to define how the input data
is processed and transformed into the desired output.
Returns:
DataFrame: The resulting DataFrame after processing.
DataFrame: The resulting DataFrame after the processing logic is applied.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError("process method must be implemented by subclasses")

def write(self, df: DataFrame) -> None:
"""Write output DataFrame to the application's writer(s).
"""Writes the output DataFrame to the application's configured writers.
The DataFrame is first persisted in memory to optimize writing
operations and then unpersisted after all writers have
completed their tasks.
The DataFrame is persisted in memory to optimize writing operations,
and once all writers have completed their tasks, the DataFrame is unpersisted.
Args:
df (DataFrame): The DataFrame to be written to the destinations.
df (DataFrame): The DataFrame to be written to the output destinations.
"""
df.persist(StorageLevel.MEMORY_ONLY)

Expand Down
11 changes: 1 addition & 10 deletions src/sparkle/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
import os
from sparkle.reader import Reader
from .kafka_config import KafkaReaderConfig, KafkaWriterConfig
from .iceberg_config import IcebergConfig
from .database_config import TableConfig
Expand All @@ -25,7 +23,6 @@ class Config:
version (str): The version of the application.
database_bucket (str): The S3 bucket where the database is stored.
checkpoints_bucket (str): The S3 bucket where the Spark checkpoints are stored.
inputs (dict[str, Type[Reader]]): A dictionary mapping input sources to Reader types.
execution_environment (ExecutionEnvironment): The environment where the app is executed.
filesystem_scheme (str): The file system scheme, default is 's3a://'.
spark_trigger (str): The Spark trigger configuration in JSON format.
Expand All @@ -40,7 +37,6 @@ class Config:
version: str
database_bucket: str
checkpoints_bucket: str
inputs: dict[str, type[Reader]]
execution_environment: ExecutionEnvironment = ExecutionEnvironment.LOCAL
filesystem_scheme: str = "s3a://"
spark_trigger: str = '{"once": True}'
Expand Down Expand Up @@ -86,10 +82,6 @@ def get_local_spark_config(
dict[str, str]: dictionary of Spark configurations for the local environment.
"""
default_config = {
"spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
"spark.jars.packages": "org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.3.1,"
"org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0,"
"org.apache.spark:spark-avro_2.12:3.3.0",
"spark.sql.session.timeZone": "UTC",
"spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.local.type": "hadoop",
Expand All @@ -112,11 +104,10 @@ def get_aws_spark_config(
dict[str, str]: dictionary of Spark configurations for the AWS environment.
"""
default_config = {
"spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
"spark.sql.catalog.glue_catalog": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.glue_catalog.catalog-impl": "org.apache.iceberg.aws.glue.GlueCatalog",
"spark.sql.catalog.glue_catalog.io-impl": "org.apache.iceberg.aws.s3.S3FileIO",
"spark.sql.catalog.glue_catalog.warehouse": "./tmp/warehouse",
# "spark.sql.catalog.glue_catalog.warehouse": "", TODO: Validate if needed
}
if extra_config:
default_config.update(extra_config)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def spark_session() -> SparkSession:
"spark.sql.catalog.spark_catalog.type": "hive",
"spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.local.type": "hadoop",
"spark.sql.catalog.local.warehouse": "./tmp/warehouse",
"spark.sql.catalog.local.warehouse": "./tmp/test_warehouse",
"spark.sql.defaultCatalog": "local",
}

Expand All @@ -47,7 +47,7 @@ def spark_session() -> SparkSession:

spark_session = (
SparkSession.builder.master("local[*]")
.appName("LocalSparkleApp")
.appName("LocalTestSparkleApp")
.config(conf=spark_conf)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/writer/test_iceberg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

TEST_DB = "default"
TEST_TABLE = "test_table"
WAREHOUSE = "./tmp/warehouse"
WAREHOUSE = "./tmp/test_warehouse"
CATALOG = "glue_catalog"


Expand Down

0 comments on commit 5921885

Please sign in to comment.