Skip to content

Commit

Permalink
Replace static extensions with params (#61)
Browse files Browse the repository at this point in the history
* replace the static extention

Signed-off-by: YANGDB <[email protected]>

* replace the static spark extension with input args

Signed-off-by: YANGDB <[email protected]>

* update README.md with the additional spark extensions input argument

Signed-off-by: YANGDB <[email protected]>

* update README.md with the additional spark extensions input argument

Signed-off-by: YANGDB <[email protected]>

---------

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB authored Oct 5, 2023
1 parent d3f72e6 commit e478410
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 39 deletions.
1 change: 1 addition & 0 deletions spark-sql-application/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ FlintJob is designed for EMR Serverless Spark, executing SQL queries and storing
+ Spark 3.3.1
+ Scala 2.12.15
+ flint-spark-integration
+ ppl-spark-integration

## Usage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ object FlintJob extends Logging {
def createSparkConf(): SparkConf = {
new SparkConf()
.setAppName("FlintJob")
.set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions")
}

def createSparkSession(conf: SparkConf): SparkSession = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,81 @@ import org.apache.spark.sql.types._
/**
* Spark SQL Application entrypoint
*
* @param args(0)
* sql query
* @param args(1)
* opensearch index name
* @param args(2-6)
* opensearch connection values required for flint-integration jar.
* host, port, scheme, auth, region respectively.
* @param args (0)
* sql query
* @param args (1)
* opensearch index name
* @param args (2-6)
* opensearch connection values required for flint-integration jar.
* host, port, scheme, auth, region respectively.
* @return
* write sql query result to given opensearch index
* write sql query result to given opensearch index
*/
case class JobConfig(
query: String,
index: String,
host: String,
port: String,
scheme: String,
auth: String,
region: String
)

object SQLJob {
def main(args: Array[String]) {
// Get the SQL query and Opensearch Config from the command line arguments
val query = args(0)
val index = args(1)
val host = args(2)
val port = args(3)
val scheme = args(4)
val auth = args(5)
val region = args(6)

val conf: SparkConf = new SparkConf()
private def parseArgs(args: Array[String]): JobConfig = {
if (args.length < 7) {
throw new IllegalArgumentException("Insufficient arguments provided! - args: [extensions, query, index, host, port, scheme, auth, region]")
}

JobConfig(
query = args(0),
index = args(1),
host = args(2),
port = args(3),
scheme = args(4),
auth = args(5),
region = args(6)
)
}

def createSparkConf(config: JobConfig): SparkConf = {
new SparkConf()
.setAppName("SQLJob")
.set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions")
.set("spark.datasource.flint.host", host)
.set("spark.datasource.flint.port", port)
.set("spark.datasource.flint.scheme", scheme)
.set("spark.datasource.flint.auth", auth)
.set("spark.datasource.flint.region", region)
.set("spark.datasource.flint.host", config.host)
.set("spark.datasource.flint.port", config.port)
.set("spark.datasource.flint.scheme", config.scheme)
.set("spark.datasource.flint.auth", config.auth)
.set("spark.datasource.flint.region", config.region)
}
def main(args: Array[String]) {
val config = parseArgs(args)

val sparkConf = createSparkConf(config)


// Create a SparkSession
val spark = SparkSession.builder().config(conf).enableHiveSupport().getOrCreate()
val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()

try {
// Execute SQL query
val result: DataFrame = spark.sql(query)
val result: DataFrame = spark.sql(config.query)

// Get Data
val data = getFormattedData(result, spark)

// Write data to OpenSearch index
val aos = Map(
"host" -> host,
"port" -> port,
"scheme" -> scheme,
"auth" -> auth,
"region" -> region)
"host" -> config.host,
"port" -> config.port,
"scheme" -> config.scheme,
"auth" -> config.auth,
"region" -> config.region)

data.write
.format("flint")
.options(aos)
.mode("append")
.save(index)
.save(config.index)

} finally {
// Stop SparkSession
Expand All @@ -76,11 +98,11 @@ object SQLJob {
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
* @param result
* sql query result dataframe
* sql query result dataframe
* @param spark
* spark session
* spark session
* @return
* dataframe with result, schema and emr step id
* dataframe with result, schema and emr step id
*/
def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = {
// Create the schema dataframe
Expand All @@ -89,8 +111,8 @@ object SQLJob {
}
val resultSchema = spark.createDataFrame(spark.sparkContext.parallelize(schemaRows),
StructType(Seq(
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(Seq(
Expand Down

0 comments on commit e478410

Please sign in to comment.