Skip to content

Commit

Permalink
Merge branch 'main' into terminate-streaming-job-if-data-deleted
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed Aug 7, 2024
2 parents d2e3ca4 + 5be9be6 commit 68d4b72
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -130,19 +131,20 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {

def createSession(jobId: String, excludeJobId: String): Unit = {
val docs = Seq(s"""{
| "state": "running",
| "lastUpdateTime": 1698796582978,
| "applicationId": "00fd777k3k3ls20p",
| "error": "",
| "sessionId": ${sessionId},
| "jobId": \"${jobId}\",
| "type": "session",
| "excludeJobIds": [\"${excludeJobId}\"]
|}""".stripMargin)
| "state": "running",
| "lastUpdateTime": 1698796582978,
| "applicationId": "00fd777k3k3ls20p",
| "error": "",
| "sessionId": ${sessionId},
| "jobId": \"${jobId}\",
| "type": "session",
| "excludeJobIds": [\"${excludeJobId}\"]
|}""".stripMargin)
index(requestIndex, oneNodeSetting, requestIndexMapping, docs)
}

def startREPL(): Future[Unit] = {
def startREPL(queryLoopExecutionFrequency: Long = DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)
: Future[Unit] = {
val prefix = "flint-repl-test"
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
Expand All @@ -164,6 +166,10 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort))
System.setProperty(REFRESH_POLICY.key, "true")

System.setProperty(
"spark.flint.job.queryLoopExecutionFrequency",
queryLoopExecutionFrequency.toString)

FlintREPL.envinromentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
Expand Down Expand Up @@ -266,7 +272,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin
// submitted from last year. We won't pick it up
val lateSelectStatementId =
submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L)
submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L)

// clean up
val dropStatement =
Expand Down Expand Up @@ -485,6 +491,99 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
}
}

test("query loop should exit with inactivity timeout due to large query loop freq") {
try {
createSession(jobRunId, "")
threadLocalFuture.set(startREPL(5000L))
val createStatement =
s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\\t'
| )
|""".stripMargin
submitQuery(s"${makeJsonCompliant(createStatement)}", "119")

val insertStatement =
s"""
| INSERT INTO $testTable
| VALUES ('Hello', 30)
| """.stripMargin
submitQuery(s"${makeJsonCompliant(insertStatement)}", "120")

val selectQueryId = "121"
val selectQueryStartTime = System.currentTimeMillis()
val selectQuery = s"SELECT name, age FROM $testTable".stripMargin
val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId)

val lateSelectQueryId = "122"
val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin
// old query
val lateSelectStatementId =
submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L)

// clean up
val dropStatement =
s"""DROP TABLE $testTable""".stripMargin
submitQuery(s"${makeJsonCompliant(dropStatement)}", "999")

val selectQueryValidation: REPLResult => Boolean = result => {
assert(
result.results.size == 1,
s"expected result size is 1, but got ${result.results.size}")
val expectedResult = "{'name':'Hello','age':30}"
assert(
result.results(0).equals(expectedResult),
s"expected result is $expectedResult, but got ${result.results(0)}")
assert(
result.schemas.size == 2,
s"expected schema size is 2, but got ${result.schemas.size}")
val expectedZerothSchema = "{'column_name':'name','data_type':'string'}"
assert(
result.schemas(0).equals(expectedZerothSchema),
s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}")
val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}"
assert(
result.schemas(1).equals(expectedFirstSchema),
s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}")
commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime)
successValidation(result)
true
}
pollForResultAndAssert(selectQueryValidation, selectQueryId)
assert(
!awaitConditionForStatementOrTimeout(
statement => {
statement.state == "success"
},
selectStatementId),
s"Fail to verify for $selectStatementId.")

assert(
awaitConditionForStatementOrTimeout(
statement => {
statement.state != "waiting"
},
lateSelectStatementId),
s"Fail to verify for $lateSelectStatementId.")
} catch {
case e: Exception =>
logError("Unexpected exception", e)
assert(false, "Unexpected exception")
} finally {
waitREPLStop(threadLocalFuture.get())
threadLocalFuture.remove()

// shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT.
}
}

/**
* JSON does not support raw newlines (\n) in string values. All newlines must be escaped or
* removed when inside a JSON string. The same goes for tab characters, which should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ case class CommandContext(
jobId: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
queryWaitTimeMillis: Long,
queryLoopExecutionFrequency: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,20 @@ import org.opensearch.search.sort.SortOrder
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.FlintREPLConfConstants._
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

object FlintREPLConfConstants {
val HEARTBEAT_INTERVAL_MILLIS = 60000L
val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES)
val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES)
val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000
val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L
val INITIAL_DELAY_MILLIS = 3000L
val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L
}

/**
* Spark SQL Application entrypoint
*
Expand All @@ -48,13 +59,6 @@ import org.apache.spark.util.ThreadUtils
*/
object FlintREPL extends Logging with FlintJobExecutor {

private val HEARTBEAT_INTERVAL_MILLIS = 60000L
private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES)
private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES)
private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000
val INITIAL_DELAY_MILLIS = 3000L
val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L

@volatile var earlyExitFlag: Boolean = false

def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = {
Expand Down Expand Up @@ -134,7 +138,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
SECONDS)
val queryWaitTimeoutMillis: Long =
conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS)

val queryLoopExecutionFrequency: Long =
conf.getLong(
"spark.flint.job.queryLoopExecutionFrequency",
DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)
val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC)

Expand Down Expand Up @@ -199,7 +206,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis)
queryWaitTimeoutMillis,
queryLoopExecutionFrequency)
exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) {
queryLoop(commandContext)
}
Expand Down Expand Up @@ -342,7 +350,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

def queryLoop(commandContext: CommandContext): Unit = {
// 1 thread for updating heart beat
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

Expand Down Expand Up @@ -392,7 +400,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
flintReader.close()
}

Thread.sleep(100)
Thread.sleep(commandContext.queryLoopExecutionFrequency)
}
} finally {
if (threadPool != null) {
Expand Down Expand Up @@ -555,8 +563,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
while (canProceed) {
val currentTime = currentTimeProvider.currentEpochMillis()

// Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed
if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) {
// Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed
if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) {
canPickNextStatementResult =
canPickNextStatement(sessionId, jobId, osClient, sessionIndex)
lastCanPickCheckTime = currentTime
Expand Down
Loading

0 comments on commit 68d4b72

Please sign in to comment.