Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add statement timeout #539

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

/**
* Provides a mutable map to store and retrieve contextual data using key-value pairs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import org.json4s.{Formats, NoTypeHints}
import org.json4s.JsonAST.JString
Expand All @@ -14,6 +14,7 @@ object StatementStates {
val RUNNING = "running"
val SUCCESS = "success"
val FAILED = "failed"
val TIMEOUT = "timeout"
val WAITING = "waiting"
}

Expand Down Expand Up @@ -50,6 +51,8 @@ class FlintStatement(
def running(): Unit = state = StatementStates.RUNNING
def complete(): Unit = state = StatementStates.SUCCESS
def fail(): Unit = state = StatementStates.FAILED
def timeout(): Unit = state = StatementStates.TIMEOUT

def isRunning: Boolean = state == StatementStates.RUNNING
def isComplete: Boolean = state == StatementStates.SUCCESS
def isFailed: Boolean = state == StatementStates.FAILED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{Map => JavaMap}

Expand All @@ -16,9 +16,8 @@ import org.json4s.native.Serialization

object SessionStates {
val RUNNING = "running"
val COMPLETE = "complete"
val FAILED = "failed"
val WAITING = "waiting"
val DEAD = "dead"
val FAIL = "fail"
}

/**
Expand Down Expand Up @@ -56,10 +55,13 @@ class InteractiveSession(
extends ContextualDataStore {
context = sessionContext // Initialize the context from the constructor

def running(): Unit = state = SessionStates.RUNNING
def complete(): Unit = state = SessionStates.DEAD
def fail(): Unit = state = SessionStates.FAIL

def isRunning: Boolean = state == SessionStates.RUNNING
def isComplete: Boolean = state == SessionStates.COMPLETE
def isFailed: Boolean = state == SessionStates.FAILED
def isWaiting: Boolean = state == SessionStates.WAITING
def isComplete: Boolean = state == SessionStates.DEAD
def isFail: Boolean = state == SessionStates.FAIL
noCharger marked this conversation as resolved.
Show resolved Hide resolved

override def toString: String = {
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{HashMap => JavaHashMap}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable}

import org.opensearch.OpenSearchStatusException
import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.core.{FlintClient, FlintOptions}
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.common.model.InteractiveSession
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater}
import org.opensearch.flint.data.InteractiveSession
import org.scalatest.matchers.should.Matchers

class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
val conf = new SparkConf().setAppName(getClass.getSimpleName)

Expand Down Expand Up @@ -203,24 +221,6 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

Expand Down Expand Up @@ -253,35 +253,17 @@ trait FlintJobExecutor {
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
}

def getFailedData(
def constructErrorDF(
spark: SparkSession,
dataSource: String,
status: String,
error: String,
queryId: String,
query: String,
queryText: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))
startTime: Long): DataFrame = {

val endTime = timeProvider.currentEpochMillis()
val updateTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand All @@ -291,14 +273,14 @@ trait FlintJobExecutor {
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
dataSource,
"FAILED",
status.toUpperCase(Locale.ROOT),
error,
queryId,
query,
queryText,
sessionId,
spark.conf.get(FlintSparkConf.JOB_TYPE.key),
endTime,
endTime - startTime))
updateTime,
updateTime - startTime))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer
import org.json4s.native.Serialization
import org.opensearch.action.get.GetResponse
import org.opensearch.common.Strings
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.common.model.InteractiveSession.formats
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.flint.data.InteractiveSession.formats
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -456,7 +456,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
.getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error))

updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId)
if (flintInstance.state.equals("fail")) {
if (flintInstance.isFail) {
recordSessionFailed(sessionTimerContext)
}
}
Expand Down Expand Up @@ -530,15 +530,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime: Long): DataFrame = {
flintStatement.fail()
flintStatement.error = Some(error)
super.getFailedData(
super.constructErrorDF(
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime,
currentTimeProvider)
startTime)
}

def processQueryException(ex: Exception, flintStatement: FlintStatement): String = {
Expand Down Expand Up @@ -654,7 +654,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
error: String,
flintStatement: FlintStatement,
sessionId: String,
startTime: Long): Option[DataFrame] = {
startTime: Long): DataFrame = {
/*
* https://tinyurl.com/2ezs5xj9
*
Expand All @@ -668,14 +668,17 @@ object FlintREPL extends Logging with FlintJobExecutor {
* actions that require the computation of results that need to be collected or stored.
*/
spark.sparkContext.cancelJobGroup(flintStatement.queryId)
Some(
handleCommandFailureAndGetFailedData(
spark,
dataSource,
error,
flintStatement,
sessionId,
startTime))
flintStatement.timeout()
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
flintStatement.error = Some(error)
super.constructErrorDF(
spark,
dataSource,
flintStatement.state,
error,
flintStatement.queryId,
flintStatement.query,
sessionId,
startTime)
}

def executeAndHandle(
Expand All @@ -702,7 +705,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
case e: TimeoutException =>
val error = s"Executing ${flintStatement.query} timed out"
CustomLogging.logError(error, e)
handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)
Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime))
case e: Exception =>
val error = processQueryException(e, flintStatement)
Some(
Expand Down Expand Up @@ -761,8 +764,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
CustomLogging.logError(error, e)
dataToWrite =
handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)
dataToWrite = Some(
handleCommandTimeout(
spark,
dataSource,
error,
flintStatement,
sessionId,
startTime))
case NonFatal(e) =>
val error = s"An unexpected error occurred: ${e.getMessage}"
CustomLogging.logError(error, e)
Expand Down Expand Up @@ -941,7 +950,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
sessionId: String,
sessionTimerContext: Timer.Context): Unit = {
val flintInstance = InteractiveSession.deserializeFromMap(source)
flintInstance.state = "dead"
flintInstance.complete()
flintSessionIndexUpdater.updateIf(
sessionId,
InteractiveSession.serializeWithoutJobId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ case class JobOperator(
dataToWrite = Some(mappingCheckResult match {
case Right(_) => data
case Left(error) =>
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)
constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)
})
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime))
case e: Exception =>
val error = processQueryException(e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
}
Expand Down
Loading
Loading