Skip to content

Commit

Permalink
read excluded jobs and customize timeout setting
Browse files Browse the repository at this point in the history
Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Nov 11, 2023
1 parent d828f39 commit b7ddd81
Show file tree
Hide file tree
Showing 14 changed files with 1,039 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.flint.app

import java.util.{Map => JavaMap}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.json4s.{Formats, NoTypeHints}
import org.json4s.JsonAST.JString
import org.json4s.JsonAST.{JArray, JString}
import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization

Expand All @@ -16,10 +21,11 @@ class FlintInstance(
val jobId: String,
// sessionId is the session type doc id
val sessionId: String,
val state: String,
var state: String,
val lastUpdateTime: Long,
// We need jobStartTime to check if HMAC token is expired or not
val jobStartTime: Long,
val excludedJobIds: Seq[String] = Seq.empty[String],
val error: Option[String] = None) {}

object FlintInstance {
Expand All @@ -34,6 +40,16 @@ object FlintInstance {
val sessionId = (meta \ "sessionId").extract[String]
val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long]
val jobStartTime = (meta \ "jobStartTime").extract[Long]
// To handle the possibility of excludeJobIds not being present,
// we use extractOpt which gives us an Option[Seq[String]].
// If it is not present, it will return None, which we can then
// convert to an empty Seq[String] using getOrElse.
// Replace extractOpt with jsonOption and map
val excludeJobIds: Seq[String] = meta \ "excludeJobIds" match {
case JArray(lst) => lst.map(_.extract[String])
case _ => Seq.empty[String]
}

val maybeError: Option[String] = (meta \ "error") match {
case JString(str) => Some(str)
case _ => None
Expand All @@ -46,6 +62,42 @@ object FlintInstance {
state,
lastUpdateTime,
jobStartTime,
excludeJobIds,
maybeError)
}

def deserializeFromMap(source: JavaMap[String, AnyRef]): FlintInstance = {
// Since we are dealing with JavaMap, we convert it to a Scala mutable Map for ease of use.
val scalaSource = source.asScala

val applicationId = scalaSource("applicationId").asInstanceOf[String]
val state = scalaSource("state").asInstanceOf[String]
val jobId = scalaSource("jobId").asInstanceOf[String]
val sessionId = scalaSource("sessionId").asInstanceOf[String]
val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long]
val jobStartTime = scalaSource("jobStartTime").asInstanceOf[Long]

// We safely handle the possibility of excludeJobIds being absent or not a list.
val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match {
case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String])
case _ => Seq.empty[String]
}

// Handle error similarly, ensuring we get an Option[String].
val maybeError: Option[String] = scalaSource.get("error") match {
case Some(str: String) => Some(str)
case _ => None
}

// Construct a new FlintInstance with the extracted values.
new FlintInstance(
applicationId,
jobId,
sessionId,
state,
lastUpdateTime,
jobStartTime,
excludeJobIds,
maybeError)
}

Expand All @@ -60,6 +112,7 @@ object FlintInstance {
"state" -> job.state,
// update last update time
"lastUpdateTime" -> currentTime,
"excludeJobIds" -> job.excludedJobIds,
"jobStartTime" -> job.jobStartTime))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

package org.opensearch.flint.app

import java.util.{HashMap => JavaHashMap}

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite

class FlintInstanceTest extends SparkFunSuite with Matchers {

test("deserialize should correctly parse a FlintInstance from JSON") {
test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") {
val json =
"""{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}"""
"""{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}"""
val instance = FlintInstance.deserialize(json)

instance.applicationId shouldBe "app-123"
Expand All @@ -22,29 +24,41 @@ class FlintInstanceTest extends SparkFunSuite with Matchers {
instance.state shouldBe "RUNNING"
instance.lastUpdateTime shouldBe 1620000000000L
instance.jobStartTime shouldBe 1620000001000L
instance.excludedJobIds should contain allOf ("job-101", "job-202")
instance.error shouldBe None
}

test("serialize should correctly produce JSON from a FlintInstance") {
test("serialize should correctly produce JSON from a FlintInstance with excludedJobIds") {
val excludedJobIds = Seq("job-101", "job-202")
val instance = new FlintInstance(
"app-123",
"job-456",
"session-789",
"RUNNING",
1620000000000L,
1620000001000L)
1620000001000L,
excludedJobIds)
val currentTime = System.currentTimeMillis()
val json = FlintInstance.serialize(instance, currentTime)

json should include(""""applicationId":"app-123"""")
json should not include(""""jobId":"job-456"""")
json should not include (""""jobId":"job-456"""")
json should include(""""sessionId":"session-789"""")
json should include(""""state":"RUNNING"""")
json should include(s""""lastUpdateTime":$currentTime""")
json should include(""""excludeJobIds":["job-101","job-202"]""")
json should include(""""jobStartTime":1620000001000""")
json should include(""""error":""""")
}

test("deserialize should correctly handle an empty excludedJobIds field in JSON") {
val jsonWithoutExcludedJobIds =
"""{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}"""
val instance = FlintInstance.deserialize(jsonWithoutExcludedJobIds)

instance.excludedJobIds shouldBe empty
}

test("deserialize should correctly handle error field in JSON") {
val jsonWithError =
"""{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"FAILED","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"error":"Some error occurred"}"""
Expand All @@ -61,10 +75,45 @@ class FlintInstanceTest extends SparkFunSuite with Matchers {
"FAILED",
1620000000000L,
1620000001000L,
Seq.empty[String],
Some("Some error occurred"))
val currentTime = System.currentTimeMillis()
val json = FlintInstance.serialize(instance, currentTime)

json should include(""""error":"Some error occurred"""")
}

test("deserializeFromMap should handle normal case") {
val sourceMap = new JavaHashMap[String, AnyRef]()
sourceMap.put("applicationId", "app1")
sourceMap.put("jobId", "job1")
sourceMap.put("sessionId", "session1")
sourceMap.put("state", "running")
sourceMap.put("lastUpdateTime", java.lang.Long.valueOf(1234567890L))
sourceMap.put("jobStartTime", java.lang.Long.valueOf(9876543210L))
sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3"))
sourceMap.put("error", "An error occurred")

val result = FlintInstance.deserializeFromMap(sourceMap)

assert(result.applicationId == "app1")
assert(result.jobId == "job1")
assert(result.sessionId == "session1")
assert(result.state == "running")
assert(result.lastUpdateTime == 1234567890L)
assert(result.jobStartTime == 9876543210L)
assert(result.excludedJobIds == Seq("job2", "job3"))
assert(result.error.contains("An error occurred"))
}

test("deserializeFromMap should handle incorrect field types") {
val sourceMap = new JavaHashMap[String, AnyRef]()
sourceMap.put("applicationId", Integer.valueOf(123))
sourceMap.put("lastUpdateTime", "1234567890")

assertThrows[ClassCastException] {
FlintInstance.deserializeFromMap(sourceMap)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
package org.apache.spark.sql

import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.concurrent.duration.Duration

import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}

case class CommandContext(
flintReader: FlintReader,
spark: SparkSession,
dataSource: String,
resultIndex: String,
sessionId: String,
futureMappingCheck: Future[Either[String, Unit]],
executionContext: ExecutionContextExecutor,
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
jobId: String)
jobId: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

package org.apache.spark.sql

import scala.concurrent.{ExecutionContextExecutor, Future}

import org.opensearch.flint.core.storage.FlintReader

case class CommandState(
recordedLastActivityTime: Long,
recordedVerificationResult: VerificationResult)
recordedVerificationResult: VerificationResult,
flintReader: FlintReader,
futureMappingCheck: Future[Either[String, Unit]],
executionContext: ExecutionContextExecutor)
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ object FlintJob extends Logging with FlintJobExecutor {

var dataToWrite: Option[DataFrame] = None
val startTime = System.currentTimeMillis()
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
try {
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
Expand All @@ -81,7 +81,7 @@ object FlintJob extends Logging with FlintJobExecutor {
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
dataToWrite.foreach(df => writeData(df, resultIndex))
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
// Stop SparkSession if it is not streaming job
if (wait.equalsIgnoreCase("streaming")) {
spark.streams.awaitAnyTermination()
Expand Down
Loading

0 comments on commit b7ddd81

Please sign in to comment.