Skip to content

Commit

Permalink
Merge branch 'main' into pr/issues/524
Browse files Browse the repository at this point in the history
  • Loading branch information
YANG-DB authored Aug 29, 2024
2 parents 9ef0b3c + aa509ac commit c2a6394
Show file tree
Hide file tree
Showing 21 changed files with 1,328 additions and 748 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

/**
* Trait for writing the result of a query execution to an external data storage.
*/
trait QueryResultWriter {

/**
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

/**
* Trait defining the interface for managing interactive sessions.
*/
trait SessionManager {

/**
* Retrieves metadata about the session manager.
*/
def getSessionContext: Map[String, Any]

/**
* Fetches the details of a specific session.
*/
def getSessionDetails(sessionId: String): Option[InteractiveSession]

/**
* Updates the details of a specific session.
*/
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit

/**
* Records a heartbeat for a specific session to indicate it is still active.
*/
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

/**
* Trait defining the interface for managing FlintStatement execution. For example, in FlintREPL,
* multiple FlintStatements are running in a micro-batch within same session.
*
* This interface can also apply to other spark entry point like FlintJob.
*/
trait StatementExecutionManager {

/**
* Prepares execution of each individual statement
*/
def prepareStatementExecution(): Either[String, Unit]

/**
* Executes a specific statement and returns the spark dataframe
*/
def executeStatement(statement: FlintStatement): DataFrame

/**
* Retrieves the next statement to be executed.
*/
def getNextStatement(): Option[FlintStatement]

/**
* Updates a specific statement.
*/
def updateStatement(statement: FlintStatement): Unit

/**
* Terminates the statement lifecycle.
*/
def terminateStatementsExecution(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class FlintStatement(

// Does not include context, which could contain sensitive information.
override def toString: String =
s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)"
s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)"
}

object FlintStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import org.json4s.JsonAST.{JArray, JString}
import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization

import org.apache.spark.internal.Logging

object SessionStates {
val RUNNING = "running"
val DEAD = "dead"
Expand Down Expand Up @@ -52,7 +54,8 @@ class InteractiveSession(
val excludedJobIds: Seq[String] = Seq.empty[String],
val error: Option[String] = None,
sessionContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore {
extends ContextualDataStore
with Logging {
context = sessionContext // Initialize the context from the constructor

def running(): Unit = state = SessionStates.RUNNING
Expand Down Expand Up @@ -96,6 +99,7 @@ object InteractiveSession {
// Replace extractOpt with jsonOption and map
val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match {
case JArray(lst) => lst.map(_.extract[String])
case JString(s) => Seq(s)
case _ => Seq.empty[String]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
* Abstract OpenSearch Reader.
*/
public abstract class OpenSearchReader implements FlintReader {

@VisibleForTesting
/** Search request source builder. */
public final SearchRequest searchRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ object FlintSparkConf {
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
val CUSTOM_SESSION_MANAGER =
FlintConfig("spark.flint.job.customSessionManager")
.createOptional()
val CUSTOM_STATEMENT_MANAGER =
FlintConfig("spark.flint.job.customStatementManager")
.createOptional()
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
* JobOperator instance to accommodate specific runtime requirements.
*/
val job =
JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount)
job.envinromentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
JobOperator(
appId,
jobRunId,
spark,
query,
dataSourceName,
resultIndex,
true,
streamingRunningCount)
job.terminateJVM = false
job.start()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
"spark.flint.job.queryLoopExecutionFrequency",
queryLoopExecutionFrequency.toString)

FlintREPL.envinromentProvider = new MockEnvironment(
FlintREPL.environmentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
FlintREPL.terminateJVM = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ trait OpenSearchSuite extends BeforeAndAfterAll {

val response =
openSearchClient.bulk(request, RequestOptions.DEFAULT)

assume(
!response.hasFailures,
s"bulk index docs to $index failed: ${response.buildFailureMessage()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ import scala.concurrent.duration.Duration
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}

case class CommandContext(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
resultIndex: String,
sessionId: String,
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
jobId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import org.opensearch.flint.core.storage.FlintReader
case class CommandState(
recordedLastActivityTime: Long,
recordedVerificationResult: VerificationResult,
flintReader: FlintReader,
futureMappingCheck: Future[Either[String, Unit]],
futurePrepareQueryExecution: Future[Either[String, Unit]],
executionContext: ExecutionContextExecutor,
recordedLastCanPickCheckTime: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
*/
object FlintJob extends Logging with FlintJobExecutor {
def main(args: Array[String]): Unit = {
val (queryOption, resultIndex) = parseArgs(args)
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
Expand All @@ -41,6 +41,9 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand All @@ -52,13 +55,19 @@ object FlintJob extends Logging with FlintJobExecutor {
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
dataSource,
resultIndex,
resultIndexOption.get,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait FlintJobExecutor {

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var envinromentProvider: EnvironmentProvider = new RealEnvironment()
var environmentProvider: EnvironmentProvider = new RealEnvironment()
var enableHiveSupport: Boolean = true
// termiante JVM in the presence non-deamon thread before exiting
var terminateJVM = true
Expand Down Expand Up @@ -190,6 +190,7 @@ trait FlintJobExecutor {
}
}

// scalastyle:off
/**
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
Expand All @@ -201,6 +202,8 @@ trait FlintJobExecutor {
* dataframe with result, schema and emr step id
*/
def getFormattedData(
applicationId: String,
jobId: String,
result: DataFrame,
spark: SparkSession,
dataSource: String,
Expand Down Expand Up @@ -231,14 +234,13 @@ trait FlintJobExecutor {
// after consumed the query result. Streaming query shuffle data is cleaned after each
// microBatch execution.
cleaner.cleanUp(spark)

// Create the data rows
val rows = Seq(
(
resultToSave,
resultSchemaToSave,
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
jobId,
applicationId,
dataSource,
"SUCCESS",
"",
Expand All @@ -254,6 +256,8 @@ trait FlintJobExecutor {
}

def constructErrorDF(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
status: String,
Expand All @@ -270,8 +274,8 @@ trait FlintJobExecutor {
(
null,
null,
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
jobId,
applicationId,
dataSource,
status.toUpperCase(Locale.ROOT),
error,
Expand Down Expand Up @@ -396,6 +400,8 @@ trait FlintJobExecutor {
}

def executeQuery(
applicationId: String,
jobId: String,
spark: SparkSession,
query: String,
dataSource: String,
Expand All @@ -409,6 +415,8 @@ trait FlintJobExecutor {
val result: DataFrame = spark.sql(query)
// Get Data
getFormattedData(
applicationId,
jobId,
result,
spark,
dataSource,
Expand Down Expand Up @@ -493,16 +501,21 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
/**
* Before OS 2.13, there are two arguments from entry point: query and result index Starting
* from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also
* optional for non-OpenSearch result persist
*/
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex))
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
(Some(query), Some(resultIndex))
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit c2a6394

Please sign in to comment.