Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add config on query loop execution frequency #411

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID}
import org.apache.spark.sql.util.MockEnvironment
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -130,19 +131,20 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {

def createSession(jobId: String, excludeJobId: String): Unit = {
val docs = Seq(s"""{
| "state": "running",
| "lastUpdateTime": 1698796582978,
| "applicationId": "00fd777k3k3ls20p",
| "error": "",
| "sessionId": ${sessionId},
| "jobId": \"${jobId}\",
| "type": "session",
| "excludeJobIds": [\"${excludeJobId}\"]
|}""".stripMargin)
| "state": "running",
| "lastUpdateTime": 1698796582978,
| "applicationId": "00fd777k3k3ls20p",
| "error": "",
| "sessionId": ${sessionId},
| "jobId": \"${jobId}\",
| "type": "session",
| "excludeJobIds": [\"${excludeJobId}\"]
|}""".stripMargin)
index(requestIndex, oneNodeSetting, requestIndexMapping, docs)
}

def startREPL(): Future[Unit] = {
def startREPL(queryLoopExecutionFrequency: Long = DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)
: Future[Unit] = {
val prefix = "flint-repl-test"
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
Expand All @@ -164,6 +166,10 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort))
System.setProperty(REFRESH_POLICY.key, "true")

System.setProperty(
"spark.flint.job.queryLoopExecutionFrequency",
queryLoopExecutionFrequency.toString)

FlintREPL.envinromentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))
FlintREPL.enableHiveSupport = false
Expand Down Expand Up @@ -266,7 +272,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin
// submitted from last year. We won't pick it up
val lateSelectStatementId =
submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L)
submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L)

// clean up
val dropStatement =
Expand Down Expand Up @@ -485,6 +491,99 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
}
}

test("query loop should exit with inactivity timeout due to large query loop freq") {
try {
createSession(jobRunId, "")
threadLocalFuture.set(startREPL(5000L))
val createStatement =
s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\\t'
| )
|""".stripMargin
submitQuery(s"${makeJsonCompliant(createStatement)}", "119")

val insertStatement =
s"""
| INSERT INTO $testTable
| VALUES ('Hello', 30)
| """.stripMargin
submitQuery(s"${makeJsonCompliant(insertStatement)}", "120")

val selectQueryId = "121"
val selectQueryStartTime = System.currentTimeMillis()
val selectQuery = s"SELECT name, age FROM $testTable".stripMargin
val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId)

val lateSelectQueryId = "122"
val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin
// old query
val lateSelectStatementId =
submitQuery(s"${makeJsonCompliant(lateSelectQuery)}", lateSelectQueryId, 1672101970000L)

// clean up
val dropStatement =
s"""DROP TABLE $testTable""".stripMargin
submitQuery(s"${makeJsonCompliant(dropStatement)}", "999")

val selectQueryValidation: REPLResult => Boolean = result => {
assert(
result.results.size == 1,
s"expected result size is 1, but got ${result.results.size}")
val expectedResult = "{'name':'Hello','age':30}"
assert(
result.results(0).equals(expectedResult),
s"expected result is $expectedResult, but got ${result.results(0)}")
assert(
result.schemas.size == 2,
s"expected schema size is 2, but got ${result.schemas.size}")
val expectedZerothSchema = "{'column_name':'name','data_type':'string'}"
assert(
result.schemas(0).equals(expectedZerothSchema),
s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}")
val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}"
assert(
result.schemas(1).equals(expectedFirstSchema),
s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}")
commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime)
successValidation(result)
true
}
pollForResultAndAssert(selectQueryValidation, selectQueryId)
assert(
!awaitConditionForStatementOrTimeout(
statement => {
statement.state == "success"
},
selectStatementId),
s"Fail to verify for $selectStatementId.")

assert(
awaitConditionForStatementOrTimeout(
statement => {
statement.state != "waiting"
},
lateSelectStatementId),
s"Fail to verify for $lateSelectStatementId.")
} catch {
case e: Exception =>
logError("Unexpected exception", e)
assert(false, "Unexpected exception")
} finally {
waitREPLStop(threadLocalFuture.get())
threadLocalFuture.remove()

// shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT.
}
}

/**
* JSON does not support raw newlines (\n) in string values. All newlines must be escaped or
* removed when inside a JSON string. The same goes for tab characters, which should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ case class CommandContext(
jobId: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
queryWaitTimeMillis: Long,
queryLoopExecutionFrequency: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,20 @@ import org.opensearch.search.sort.SortOrder
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.FlintREPLConfConstants._
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

object FlintREPLConfConstants {
val HEARTBEAT_INTERVAL_MILLIS = 60000L
val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES)
val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES)
val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000
val DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY = 100L
val INITIAL_DELAY_MILLIS = 3000L
val EARLY_TERMINATION_CHECK_FREQUENCY = 60000L
}

/**
* Spark SQL Application entrypoint
*
Expand All @@ -48,13 +59,6 @@ import org.apache.spark.util.ThreadUtils
*/
object FlintREPL extends Logging with FlintJobExecutor {

private val HEARTBEAT_INTERVAL_MILLIS = 60000L
private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES)
private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES)
private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000
val INITIAL_DELAY_MILLIS = 3000L
val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L

@volatile var earlyExitFlag: Boolean = false

def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = {
Expand Down Expand Up @@ -134,7 +138,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
SECONDS)
val queryWaitTimeoutMillis: Long =
conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS)

val queryLoopExecutionFrequency: Long =
conf.getLong(
"spark.flint.job.queryLoopExecutionFrequency",
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)
val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC)

Expand Down Expand Up @@ -199,7 +206,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis)
queryWaitTimeoutMillis,
queryLoopExecutionFrequency)
exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) {
queryLoop(commandContext)
}
Expand Down Expand Up @@ -342,7 +350,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

def queryLoop(commandContext: CommandContext): Unit = {
// 1 thread for updating heart beat
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

Expand Down Expand Up @@ -392,7 +400,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
flintReader.close()
}

Thread.sleep(100)
Thread.sleep(commandContext.queryLoopExecutionFrequency)
}
} finally {
if (threadPool != null) {
Expand Down Expand Up @@ -555,8 +563,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
while (canProceed) {
val currentTime = currentTimeProvider.currentEpochMillis()

// Only call canPickNextStatement if EARLY_TERMIANTION_CHECK_FREQUENCY milliseconds have passed
if (currentTime - lastCanPickCheckTime > EARLY_TERMIANTION_CHECK_FREQUENCY) {
// Only call canPickNextStatement if EARLY_TERMINATION_CHECK_FREQUENCY milliseconds have passed
if (currentTime - lastCanPickCheckTime > EARLY_TERMINATION_CHECK_FREQUENCY) {
canPickNextStatementResult =
canPickNextStatement(sessionId, jobId, osClient, sessionIndex)
lastCanPickCheckTime = currentTime
Expand Down
Loading
Loading