Skip to content

Commit

Permalink
Clean up and fix UTs
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 28, 2024
1 parent 93f2e71 commit 28bacdb
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 393 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import org.json4s.JsonAST.{JArray, JString}
import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization

import org.apache.spark.internal.Logging

object SessionStates {
val RUNNING = "running"
val DEAD = "dead"
Expand Down Expand Up @@ -52,7 +54,8 @@ class InteractiveSession(
val excludedJobIds: Seq[String] = Seq.empty[String],
val error: Option[String] = None,
sessionContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore {
extends ContextualDataStore
with Logging {
context = sessionContext // Initialize the context from the constructor

def running(): Unit = state = SessionStates.RUNNING
Expand Down Expand Up @@ -96,6 +99,7 @@ object InteractiveSession {
// Replace extractOpt with jsonOption and map
val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match {
case JArray(lst) => lst.map(_.extract[String])
case JString(s) => Seq(s)
case _ => Seq.empty[String]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
* JobOperator instance to accommodate specific runtime requirements.
*/
val job =
JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount)
job.envinromentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
JobOperator(
appId,
jobRunId,
spark,
query,
dataSourceName,
resultIndex,
true,
streamingRunningCount)
job.terminateJVM = false
job.start()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
"spark.flint.job.queryLoopExecutionFrequency",
queryLoopExecutionFrequency.toString)

FlintREPL.envinromentProvider = new MockEnvironment(
FlintREPL.environmentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
FlintREPL.terminateJVM = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import scala.concurrent.duration.Duration
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}

case class CommandContext(
val spark: SparkSession,
val dataSource: String,
val sessionId: String,
val sessionManager: SessionManager,
val jobId: String,
var statementsExecutionManager: StatementExecutionManager,
val queryResultWriter: QueryResultWriter,
val queryExecutionTimeout: Duration,
val inactivityLimitMillis: Long,
val queryWaitTimeMillis: Long,
val queryLoopExecutionFrequency: Long)
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
queryLoopExecutionFrequency: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ object FlintJob extends Logging with FlintJobExecutor {
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
dataSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait FlintJobExecutor {

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var envinromentProvider: EnvironmentProvider = new RealEnvironment()
var environmentProvider: EnvironmentProvider = new RealEnvironment()
var enableHiveSupport: Boolean = true
// termiante JVM in the presence non-deamon thread before exiting
var terminateJVM = true
Expand Down Expand Up @@ -190,6 +190,7 @@ trait FlintJobExecutor {
}
}

// scalastyle:off
/**
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
Expand All @@ -201,6 +202,8 @@ trait FlintJobExecutor {
* dataframe with result, schema and emr step id
*/
def getFormattedData(
applicationId: String,
jobId: String,
result: DataFrame,
spark: SparkSession,
dataSource: String,
Expand Down Expand Up @@ -231,14 +234,13 @@ trait FlintJobExecutor {
// after consumed the query result. Streaming query shuffle data is cleaned after each
// microBatch execution.
cleaner.cleanUp(spark)

// Create the data rows
val rows = Seq(
(
resultToSave,
resultSchemaToSave,
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
jobId,
applicationId,
dataSource,
"SUCCESS",
"",
Expand All @@ -254,6 +256,8 @@ trait FlintJobExecutor {
}

def constructErrorDF(
applicationId: String,
jobId: String,
spark: SparkSession,
dataSource: String,
status: String,
Expand All @@ -270,8 +274,8 @@ trait FlintJobExecutor {
(
null,
null,
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
jobId,
applicationId,
dataSource,
status.toUpperCase(Locale.ROOT),
error,
Expand Down Expand Up @@ -396,6 +400,8 @@ trait FlintJobExecutor {
}

def executeQuery(
applicationId: String,
jobId: String,
spark: SparkSession,
query: String,
dataSource: String,
Expand All @@ -409,6 +415,8 @@ trait FlintJobExecutor {
val result: DataFrame = spark.sql(query)
// Get Data
getFormattedData(
applicationId,
jobId,
result,
spark,
dataSource,
Expand Down
Loading

0 comments on commit 28bacdb

Please sign in to comment.