Skip to content

Commit

Permalink
Add statement timeout
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 8, 2024
1 parent 773ad22 commit 484fdaa
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

/**
* Provides a mutable map to store and retrieve contextual data using key-value pairs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import org.json4s.{Formats, NoTypeHints}
import org.json4s.JsonAST.JString
Expand All @@ -14,6 +14,7 @@ object StatementStates {
val RUNNING = "running"
val SUCCESS = "success"
val FAILED = "failed"
val TIMEOUT = "timeout"
val WAITING = "waiting"
}

Expand Down Expand Up @@ -50,6 +51,8 @@ class FlintStatement(
def running(): Unit = state = StatementStates.RUNNING
def complete(): Unit = state = StatementStates.SUCCESS
def fail(): Unit = state = StatementStates.FAILED
def timeout(): Unit = state = StatementStates.TIMEOUT

def isRunning: Boolean = state == StatementStates.RUNNING
def isComplete: Boolean = state == StatementStates.SUCCESS
def isFailed: Boolean = state == StatementStates.FAILED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{Map => JavaMap}

Expand All @@ -16,9 +16,8 @@ import org.json4s.native.Serialization

object SessionStates {
val RUNNING = "running"
val COMPLETE = "complete"
val FAILED = "failed"
val WAITING = "waiting"
val DEAD = "dead"
val FAIL = "fail"
}

/**
Expand Down Expand Up @@ -56,10 +55,13 @@ class InteractiveSession(
extends ContextualDataStore {
context = sessionContext // Initialize the context from the constructor

def running(): Unit = state = SessionStates.RUNNING
def complete(): Unit = state = SessionStates.DEAD
def fail(): Unit = state = SessionStates.FAIL

def isRunning: Boolean = state == SessionStates.RUNNING
def isComplete: Boolean = state == SessionStates.COMPLETE
def isFailed: Boolean = state == SessionStates.FAILED
def isWaiting: Boolean = state == SessionStates.WAITING
def isComplete: Boolean = state == SessionStates.DEAD
def isFail: Boolean = state == SessionStates.FAIL

override def toString: String = {
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{HashMap => JavaHashMap}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable}

import org.opensearch.OpenSearchStatusException
import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.core.{FlintClient, FlintOptions}
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.common.model.InteractiveSession
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater}
import org.opensearch.flint.data.InteractiveSession
import org.scalatest.matchers.should.Matchers

class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
val conf = new SparkConf().setAppName(getClass.getSimpleName)

Expand Down Expand Up @@ -203,24 +221,6 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

Expand Down Expand Up @@ -253,35 +253,17 @@ trait FlintJobExecutor {
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
}

def getFailedData(
def constructErrorDF(
spark: SparkSession,
dataSource: String,
status: String,
error: String,
queryId: String,
query: String,
queryText: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))
startTime: Long): DataFrame = {

val endTime = timeProvider.currentEpochMillis()
val updateTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand All @@ -291,14 +273,14 @@ trait FlintJobExecutor {
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
dataSource,
"FAILED",
status.toUpperCase(Locale.ROOT),
error,
queryId,
query,
queryText,
sessionId,
spark.conf.get(FlintSparkConf.JOB_TYPE.key),
endTime,
endTime - startTime))
updateTime,
updateTime - startTime))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer
import org.json4s.native.Serialization
import org.opensearch.action.get.GetResponse
import org.opensearch.common.Strings
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.common.model.InteractiveSession.formats
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.flint.data.InteractiveSession.formats
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -456,7 +456,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
.getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error))

updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId)
if (flintInstance.state.equals("fail")) {
if (flintInstance.isFail) {
recordSessionFailed(sessionTimerContext)
}
}
Expand Down Expand Up @@ -530,15 +530,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime: Long): DataFrame = {
flintStatement.fail()
flintStatement.error = Some(error)
super.getFailedData(
super.constructErrorDF(
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime,
currentTimeProvider)
startTime)
}

def processQueryException(ex: Exception, flintStatement: FlintStatement): String = {
Expand Down Expand Up @@ -654,7 +654,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): Option[DataFrame] = {
startTime: Long): DataFrame = {
/*
* https://tinyurl.com/2ezs5xj9
*
Expand All @@ -668,14 +668,17 @@ object FlintREPL extends Logging with FlintJobExecutor {
* actions that require the computation of results that need to be collected or stored.
*/
spark.sparkContext.cancelJobGroup(flintStatement.queryId)
Some(
handleCommandFailureAndGetFailedData(
spark,
dataSource,
error,
flintStatement,
sessionId,
startTime))
flintStatement.timeout()
flintStatement.error = Some(error)
super.constructErrorDF(
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

def executeAndHandle(
Expand All @@ -702,7 +705,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
case e: TimeoutException =>
val error = s"Executing ${flintStatement.query} timed out"
CustomLogging.logError(error, e)
handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)
Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime))
case e: Exception =>
val error = processQueryException(e, flintStatement)
Some(
Expand Down Expand Up @@ -761,8 +764,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
CustomLogging.logError(error, e)
dataToWrite =
handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)
dataToWrite = Some(
handleCommandTimeout(
spark,
dataSource,
error,
flintStatement,
sessionId,
startTime))
case NonFatal(e) =>
val error = s"An unexpected error occurred: ${e.getMessage}"
CustomLogging.logError(error, e)
Expand Down Expand Up @@ -941,7 +950,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
sessionId: String,
sessionTimerContext: Timer.Context): Unit = {
val flintInstance = InteractiveSession.deserializeFromMap(source)
flintInstance.state = "dead"
flintInstance.complete()
flintSessionIndexUpdater.updateIf(
sessionId,
InteractiveSession.serializeWithoutJobId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ case class JobOperator(
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime))
case e: Exception =>
val error = processQueryException(e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.mockito.Mockito.{atLeastOnce, never, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.opensearch.action.get.GetResponse
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater}
import org.opensearch.flint.data.FlintStatement
import org.opensearch.search.sort.SortOrder
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatestplus.mockito.MockitoSugar
Expand Down Expand Up @@ -230,7 +230,7 @@ class FlintREPLTest
verify(flintSessionIndexUpdater).updateIf(*, *, *, *)
}

test("Test getFailedData method") {
test("Test super.constructErrorDF should construct dataframe properly") {
// Define expected dataframe
val dataSourceName = "myGlueS3"
val expectedSchema = StructType(
Expand Down Expand Up @@ -288,7 +288,7 @@ class FlintREPLTest
"20",
currentTime - queryRunTime)
assertEqualDataframe(expected, result)
assert("failed" == flintStatement.state)
assert(flintStatement.isFailed)
assert(error == flintStatement.error.get)
} finally {
spark.close()
Expand Down Expand Up @@ -492,7 +492,7 @@ class FlintREPLTest
assert(result == expectedError)
}

test("handleGeneralException should handle MetaException with AccessDeniedException properly") {
test("processQueryException should handle MetaException with AccessDeniedException properly") {
val mockFlintCommand = mock[FlintStatement]

// Simulate the root cause being MetaException
Expand Down

0 comments on commit 484fdaa

Please sign in to comment.