Skip to content

Commit

Permalink
Add statement timeout (#539) (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
opensearch-trigger-bot[bot] authored Aug 9, 2024
1 parent d6e71fa commit 7b43ff2
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

/**
* Provides a mutable map to store and retrieve contextual data using key-value pairs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.Locale

import org.json4s.{Formats, NoTypeHints}
import org.json4s.JsonAST.JString
Expand All @@ -14,6 +16,7 @@ object StatementStates {
val RUNNING = "running"
val SUCCESS = "success"
val FAILED = "failed"
val TIMEOUT = "timeout"
val WAITING = "waiting"
}

Expand Down Expand Up @@ -50,10 +53,15 @@ class FlintStatement(
def running(): Unit = state = StatementStates.RUNNING
def complete(): Unit = state = StatementStates.SUCCESS
def fail(): Unit = state = StatementStates.FAILED
def isRunning: Boolean = state == StatementStates.RUNNING
def isComplete: Boolean = state == StatementStates.SUCCESS
def isFailed: Boolean = state == StatementStates.FAILED
def isWaiting: Boolean = state == StatementStates.WAITING
def timeout(): Unit = state = StatementStates.TIMEOUT

def isRunning: Boolean = state.equalsIgnoreCase(StatementStates.RUNNING)

def isComplete: Boolean = state.equalsIgnoreCase(StatementStates.SUCCESS)

def isFailed: Boolean = state.equalsIgnoreCase(StatementStates.FAILED)

def isWaiting: Boolean = state.equalsIgnoreCase(StatementStates.WAITING)

// Does not include context, which could contain sensitive information.
override def toString: String =
Expand All @@ -66,7 +74,7 @@ object FlintStatement {

def deserialize(statement: String): FlintStatement = {
val meta = parse(statement)
val state = (meta \ "state").extract[String]
val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT)
val query = (meta \ "query").extract[String]
val statementId = (meta \ "statementId").extract[String]
val queryId = (meta \ "queryId").extract[String]
Expand All @@ -82,6 +90,8 @@ object FlintStatement {
def serialize(flintStatement: FlintStatement): String = {
// we only need to modify state and error
Serialization.write(
Map("state" -> flintStatement.state, "error" -> flintStatement.error.getOrElse("")))
Map(
"state" -> flintStatement.state.toLowerCase(Locale.ROOT),
"error" -> flintStatement.error.getOrElse("")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{Map => JavaMap}
import java.util.{Locale, Map => JavaMap}

import scala.collection.JavaConverters._

Expand All @@ -16,9 +16,8 @@ import org.json4s.native.Serialization

object SessionStates {
val RUNNING = "running"
val COMPLETE = "complete"
val FAILED = "failed"
val WAITING = "waiting"
val DEAD = "dead"
val FAIL = "fail"
}

/**
Expand Down Expand Up @@ -56,10 +55,15 @@ class InteractiveSession(
extends ContextualDataStore {
context = sessionContext // Initialize the context from the constructor

def isRunning: Boolean = state == SessionStates.RUNNING
def isComplete: Boolean = state == SessionStates.COMPLETE
def isFailed: Boolean = state == SessionStates.FAILED
def isWaiting: Boolean = state == SessionStates.WAITING
def running(): Unit = state = SessionStates.RUNNING
def complete(): Unit = state = SessionStates.DEAD
def fail(): Unit = state = SessionStates.FAIL

def isRunning: Boolean = state.equalsIgnoreCase(SessionStates.RUNNING)

def isComplete: Boolean = state.equalsIgnoreCase(SessionStates.DEAD)

def isFail: Boolean = state.equalsIgnoreCase(SessionStates.FAIL)

override def toString: String = {
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
Expand All @@ -77,7 +81,7 @@ object InteractiveSession {
def deserialize(job: String): InteractiveSession = {
val meta = parse(job)
val applicationId = (meta \ "applicationId").extract[String]
val state = (meta \ "state").extract[String]
val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT)
val jobId = (meta \ "jobId").extract[String]
val sessionId = (meta \ "sessionId").extract[String]
val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long]
Expand Down Expand Up @@ -116,7 +120,7 @@ object InteractiveSession {
val scalaSource = source.asScala

val applicationId = scalaSource("applicationId").asInstanceOf[String]
val state = scalaSource("state").asInstanceOf[String]
val state = scalaSource("state").asInstanceOf[String].toLowerCase(Locale.ROOT)
val jobId = scalaSource("jobId").asInstanceOf[String]
val sessionId = scalaSource("sessionId").asInstanceOf[String]
val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long]
Expand Down Expand Up @@ -178,7 +182,7 @@ object InteractiveSession {
"sessionId" -> job.sessionId,
"error" -> job.error.getOrElse(""),
"applicationId" -> job.applicationId,
"state" -> job.state,
"state" -> job.state.toLowerCase(Locale.ROOT),
// update last update time
"lastUpdateTime" -> currentTime,
// Convert a Seq[String] into a comma-separated string, such as "id1,id2".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data
package org.opensearch.flint.common.model

import java.util.{HashMap => JavaHashMap}

Expand All @@ -21,7 +21,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers {
instance.applicationId shouldBe "app-123"
instance.jobId shouldBe "job-456"
instance.sessionId shouldBe "session-789"
instance.state shouldBe "RUNNING"
instance.state shouldBe "running"
instance.lastUpdateTime shouldBe 1620000000000L
instance.jobStartTime shouldBe 1620000001000L
instance.excludedJobIds should contain allOf ("job-101", "job-202")
Expand All @@ -44,7 +44,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers {
json should include(""""applicationId":"app-123"""")
json should not include (""""jobId":"job-456"""")
json should include(""""sessionId":"session-789"""")
json should include(""""state":"RUNNING"""")
json should include(""""state":"running"""")
json should include(s""""lastUpdateTime":$currentTime""")
json should include(
""""excludeJobIds":"job-101,job-202""""
Expand Down Expand Up @@ -149,7 +149,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers {
instance.applicationId shouldBe "app-123"
instance.jobId shouldBe "job-456"
instance.sessionId shouldBe "session-789"
instance.state shouldBe "RUNNING"
instance.state shouldBe "running"
instance.lastUpdateTime shouldBe 1620000000000L
instance.jobStartTime shouldBe 0L // Default or expected value for missing jobStartTime
instance.excludedJobIds should contain allOf ("job-101", "job-202")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable}

import org.opensearch.OpenSearchStatusException
import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}
import org.opensearch.flint.core.{FlintClient, FlintOptions}
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater}
import org.opensearch.flint.data.{FlintStatement, InteractiveSession}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
import org.opensearch.action.get.{GetRequest, GetResponse}
import org.opensearch.client.RequestOptions
import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.common.model.InteractiveSession
import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater}
import org.opensearch.flint.data.InteractiveSession
import org.scalatest.matchers.should.Matchers

class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
val conf = new SparkConf().setAppName(getClass.getSimpleName)

Expand Down Expand Up @@ -203,24 +221,6 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

Expand Down Expand Up @@ -253,35 +253,17 @@ trait FlintJobExecutor {
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
}

def getFailedData(
def constructErrorDF(
spark: SparkSession,
dataSource: String,
status: String,
error: String,
queryId: String,
query: String,
queryText: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))
startTime: Long): DataFrame = {

val endTime = timeProvider.currentEpochMillis()
val updateTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand All @@ -291,14 +273,14 @@ trait FlintJobExecutor {
envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"),
envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"),
dataSource,
"FAILED",
status.toUpperCase(Locale.ROOT),
error,
queryId,
query,
queryText,
sessionId,
spark.conf.get(FlintSparkConf.JOB_TYPE.key),
endTime,
endTime - startTime))
updateTime,
updateTime - startTime))

// Create the DataFrame for data
spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*)
Expand Down
Loading

0 comments on commit 7b43ff2

Please sign in to comment.