diff --git a/build.sbt b/build.sbt index f12a19647..38d1a04da 100644 --- a/build.sbt +++ b/build.sbt @@ -48,7 +48,7 @@ lazy val commonSettings = Seq( // running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) + .aggregate(flintData, flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -84,6 +84,37 @@ lazy val flintCore = (project in file("flint-core")) libraryDependencies ++= deps(sparkVersion), publish / skip := true) +lazy val flintData = (project in file("flint-data")) + .settings( + commonSettings, + name := "flint-data", + scalaVersion := scala212, + libraryDependencies ++= Seq( + "org.scalactic" %% "scalactic" % "3.2.15" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", + ), + libraryDependencies ++= deps(sparkVersion), + publish / skip := true, + assembly / test := (Test / test).value, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps@_*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs@_, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + ) + .enablePlugins(AssemblyPlugin) + + lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( @@ -121,7 +152,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) assembly / test := (Test / test).value) lazy val flintSparkIntegration = (project in file("flint-spark-integration")) - .dependsOn(flintCore) + .dependsOn(flintCore, flintData) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, @@ -166,7 +197,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") + .dependsOn(flintData % "test->test", flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", diff --git a/flint-data/src/main/scala/org/opensearch/flint/data/ContextualData.scala b/flint-data/src/main/scala/org/opensearch/flint/data/ContextualData.scala new file mode 100644 index 000000000..ee0dbbc7a --- /dev/null +++ b/flint-data/src/main/scala/org/opensearch/flint/data/ContextualData.scala @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.data + +/** + * Provides a mutable map to store and retrieve contextual data using key-value pairs. + */ +trait ContextualData { + + /** Holds the contextual data as key-value pairs. */ + var context: Map[String, Any] = Map.empty + + /** + * Adds a key-value pair to the context map. + * + * @param key + * The key under which the value is stored. + * @param value + * The data value to store. + */ + def addContext(key: String, value: Any): Unit = { + context += (key -> value) + } + + /** + * Retrieves the value associated with a key from the context map. + * + * @param key + * The key whose value needs to be retrieved. + * @return + * An option containing the value if it exists, None otherwise. + */ + def getContext(key: String): Option[Any] = { + context.get(key) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala b/flint-data/src/main/scala/org/opensearch/flint/data/FlintCommand.scala similarity index 51% rename from flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala rename to flint-data/src/main/scala/org/opensearch/flint/data/FlintCommand.scala index 7624c2c54..49602f0d5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala +++ b/flint-data/src/main/scala/org/opensearch/flint/data/FlintCommand.scala @@ -3,13 +3,38 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import org.json4s.{Formats, NoTypeHints} import org.json4s.JsonAST.JString import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization +object CommandStates { + val Running = "running" + val Success = "success" + val Failed = "failed" + val Waiting = "waiting" +} + +/** + * Represents a command processed in the Flint job. + * + * @param state + * The current state of the command. + * @param query + * SQL-like query string that the command will execute. + * @param statementId + * Unique identifier for the type of statement. + * @param queryId + * Unique identifier for the query. + * @param submitTime + * Timestamp when the command was submitted. + * @param error + * Optional error message if the command fails. + * @param commandContext + * Additional context for the command as key-value pairs. + */ class FlintCommand( var state: String, val query: String, @@ -17,38 +42,22 @@ class FlintCommand( val statementId: String, val queryId: String, val submitTime: Long, - var error: Option[String] = None) { - def running(): Unit = { - state = "running" - } - - def complete(): Unit = { - state = "success" - } - - def fail(): Unit = { - state = "failed" - } + var error: Option[String] = None, + commandContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualData { + context = commandContext - def isRunning(): Boolean = { - state == "running" - } - - def isComplete(): Boolean = { - state == "success" - } - - def isFailed(): Boolean = { - state == "failed" - } - - def isWaiting(): Boolean = { - state == "waiting" - } + def running(): Unit = state = CommandStates.Running + def complete(): Unit = state = CommandStates.Success + def fail(): Unit = state = CommandStates.Failed + def isRunning: Boolean = state == CommandStates.Running + def isComplete: Boolean = state == CommandStates.Success + def isFailed: Boolean = state == CommandStates.Failed + def isWaiting: Boolean = state == CommandStates.Waiting - override def toString: String = { + // Does not include context, which could contain sensitive information. + override def toString: String = s"FlintCommand(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" - } } object FlintCommand { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-data/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala similarity index 76% rename from flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala rename to flint-data/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala index 9911a3b6c..238baff99 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-data/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala @@ -3,42 +3,78 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ -import scala.collection.mutable import org.json4s.{Formats, JNothing, JNull, NoTypeHints} import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization -// lastUpdateTime is added to FlintInstance to track the last update time of the instance. Its unit is millisecond. -class FlintInstance( +object SessionStates { + val Running = "running" + val Complete = "complete" + val Failed = "failed" + val Waiting = "waiting" +} + +/** + * Represents an interactive session for job and state management. + * + * @param applicationId + * Unique identifier for the EMR-S application. + * @param jobId + * Identifier for the specific EMR-S job. + * @param sessionId + * Unique session identifier. + * @param state + * Current state of the session. + * @param lastUpdateTime + * Timestamp of the last update. + * @param jobStartTime + * Start time of the job. + * @param excludedJobIds + * List of job IDs that are excluded. + * @param error + * Optional error message. + * @param sessionContext + * Additional context for the session. + */ +class InteractiveSession( val applicationId: String, val jobId: String, - // sessionId is the session type doc id val sessionId: String, var state: String, val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) { + val error: Option[String] = None, + sessionContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualData { + context = sessionContext // Initialize the context from the constructor + + def isRunning: Boolean = state == SessionStates.Running + def isComplete: Boolean = state == SessionStates.Complete + def isFailed: Boolean = state == SessionStates.Failed + def isWaiting: Boolean = state == SessionStates.Waiting + override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") val errorStr = error.getOrElse("None") + // Does not include context, which could contain sensitive information. s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" } } -object FlintInstance { +object InteractiveSession { implicit val formats: Formats = Serialization.formats(NoTypeHints) - def deserialize(job: String): FlintInstance = { + def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] val state = (meta \ "state").extract[String] @@ -64,7 +100,7 @@ object FlintInstance { case _ => None } - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -75,7 +111,7 @@ object FlintInstance { maybeError) } - def deserializeFromMap(source: JavaMap[String, AnyRef]): FlintInstance = { + def deserializeFromMap(source: JavaMap[String, AnyRef]): InteractiveSession = { // Since we are dealing with JavaMap, we convert it to a Scala mutable Map for ease of use. val scalaSource = source.asScala @@ -105,7 +141,7 @@ object FlintInstance { } // Construct a new FlintInstance with the extracted values. - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -133,7 +169,10 @@ object FlintInstance { * @return * serialized Flint session */ - def serialize(job: FlintInstance, currentTime: Long, includeJobId: Boolean = true): String = { + def serialize( + job: InteractiveSession, + currentTime: Long, + includeJobId: Boolean = true): String = { val baseMap = Map( "type" -> "session", "sessionId" -> job.sessionId, @@ -159,7 +198,7 @@ object FlintInstance { Serialization.write(resultMap) } - def serializeWithoutJobId(job: FlintInstance, currentTime: Long): String = { + def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = { serialize(job, currentTime, includeJobId = false) } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-data/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala similarity index 78% rename from flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala rename to flint-data/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala index 8ece6ba8a..d72628455 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala +++ b/flint-data/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{HashMap => JavaHashMap} @@ -11,12 +11,12 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite -class FlintInstanceTest extends SparkFunSuite with Matchers { +class InteractiveSessionTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -30,7 +30,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("serialize should correctly produce JSON from a FlintInstance with excludedJobIds") { val excludedJobIds = Seq("job-101", "job-202") - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -39,7 +39,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { 1620000001000L, excludedJobIds) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") @@ -56,7 +56,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle an empty excludedJobIds field in JSON") { val jsonWithoutExcludedJobIds = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}""" - val instance = FlintInstance.deserialize(jsonWithoutExcludedJobIds) + val instance = InteractiveSession.deserialize(jsonWithoutExcludedJobIds) instance.excludedJobIds shouldBe empty } @@ -64,13 +64,13 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle error field in JSON") { val jsonWithError = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"FAILED","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"error":"Some error occurred"}""" - val instance = FlintInstance.deserialize(jsonWithError) + val instance = InteractiveSession.deserialize(jsonWithError) instance.error shouldBe Some("Some error occurred") } test("serialize should include error when present in FlintInstance") { - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -80,7 +80,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { Seq.empty[String], Some("Some error occurred")) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""error":"Some error occurred"""") } @@ -96,7 +96,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -114,7 +114,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("lastUpdateTime", "1234567890") assertThrows[ClassCastException] { - FlintInstance.deserializeFromMap(sourceMap) + InteractiveSession.deserializeFromMap(sourceMap) } } @@ -129,7 +129,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -144,7 +144,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance without jobStartTime from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -155,4 +155,23 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { instance.excludedJobIds should contain allOf ("job-101", "job-202") instance.error shouldBe None } + + test("sessionContext should add/get key-value pairs to/from the context") { + val session = + new InteractiveSession("app-123", "job-456", "session-789", "RUNNING", 1620000000000L) + session.context shouldBe empty + + session.addContext("key1", "value1") + session.addContext("key2", 42) + + session.context should contain("key1" -> "value1") + session.context should contain("key2" -> 42) + + session.getContext("key1") shouldBe Some("value1") + session.getContext("key2") shouldBe Some(42) + session.getContext("key3") shouldBe None // Test for a key that does not exist + + session.addContext("key1", "updatedValue") + session.getContext("key1") shouldBe Some("updatedValue") + } } diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala index f2d4be911..a4669138b 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable} import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite -import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintCommand, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite @@ -559,15 +559,15 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } def awaitConditionForSessionOrTimeout( - expected: FlintInstance => Boolean, + expected: InteractiveSession => Boolean, sessionId: String): Boolean = { - awaitConditionOrTimeout[FlintInstance]( + awaitConditionOrTimeout[InteractiveSession]( osClient, expected, sessionId, 10000, requestIndex, - FlintInstance.deserialize, + InteractiveSession.deserialize, "session") } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index 3b317a0fe..fa7f75b81 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -10,15 +10,15 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite -import org.opensearch.flint.app.FlintInstance import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} +import org.opensearch.flint.data.InteractiveSession import org.scalatest.matchers.should.Matchers class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { val sessionId = "sessionId" val timestamp = 1700090926955L val flintJob = - new FlintInstance( + new InteractiveSession( "applicationId", "jobId", sessionId, @@ -38,7 +38,7 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { } test("upsert flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe timestamp } @@ -46,15 +46,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("update flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val newTimestamp = 1700090926956L - updater.update(sessionId, FlintInstance.serialize(flintJob, newTimestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, newTimestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } @@ -62,25 +62,25 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.update(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("updateIf flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) val newTimestamp = 1700090926956L updater.updateIf( sessionId, - FlintInstance.serialize(latest, newTimestamp), + InteractiveSession.serialize(latest, newTimestamp), resp.getSeqNo, resp.getPrimaryTerm) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } test("index is deleted when updateIf flintJob should throw IllegalStateException") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) deleteIndex(testMetaLogIndex) @@ -88,15 +88,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { the[IllegalStateException] thrownBy { updater.updateIf( sessionId, - FlintInstance.serialize(latest, timestamp), + InteractiveSession.serialize(latest, timestamp), resp.getSeqNo, resp.getPrimaryTerm) } } - def getFlintInstance(docId: String): (GetResponse, FlintInstance) = { + def getFlintInstance(docId: String): (GetResponse, InteractiveSession) = { val response = openSearchClient.get(new GetRequest(testMetaLogIndex, docId), RequestOptions.DEFAULT) - (response, FlintInstance.deserializeFromMap(response.getSourceAsMap)) + (response, InteractiveSession.deserializeFromMap(response.getSourceAsMap)) } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 36432f016..81543cdab 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings -import org.opensearch.flint.app.{FlintCommand, FlintInstance} -import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintCommand, InteractiveSession} +import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf @@ -411,7 +411,7 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new FlintInstance( + val flintJob = new InteractiveSession( applicationId, jobId, sessionId, @@ -421,9 +421,9 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds) val serializedFlintInstance = if (includeJobId) { - FlintInstance.serialize(flintJob, currentTime, true) + InteractiveSession.serialize(flintJob, currentTime, true) } else { - FlintInstance.serializeWithoutJobId(flintJob, currentTime) + InteractiveSession.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) logInfo( @@ -456,11 +456,11 @@ object FlintREPL extends Logging with FlintJobExecutor { private def getExistingFlintInstance( osClient: OSClient, sessionIndex: String, - sessionId: String): Option[FlintInstance] = Try( + sessionId: String): Option[InteractiveSession] = Try( osClient.getDoc(sessionIndex, sessionId)) match { case Success(getResponse) if getResponse.isExists() => Option(getResponse.getSourceAsMap) - .map(FlintInstance.deserializeFromMap) + .map(InteractiveSession.deserializeFromMap) case Failure(exception) => CustomLogging.logError( s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", @@ -474,7 +474,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId: String, sessionId: String, jobStartTime: Long, - errorMessage: String): FlintInstance = new FlintInstance( + errorMessage: String): InteractiveSession = new InteractiveSession( applicationId, jobId, sessionId, @@ -484,13 +484,13 @@ object FlintREPL extends Logging with FlintJobExecutor { error = Some(errorMessage)) private def updateFlintInstance( - flintInstance: FlintInstance, + flintInstance: InteractiveSession, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String): Unit = { val currentTime = currentTimeProvider.currentEpochMillis() flintSessionIndexUpdater.upsert( sessionId, - FlintInstance.serializeWithoutJobId(flintInstance, currentTime)) + InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) } /** @@ -622,7 +622,7 @@ object FlintREPL extends Logging with FlintJobExecutor { statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - if (flintCommand.isRunning() || flintCommand.isWaiting()) { + if (flintCommand.isRunning || flintCommand.isWaiting) { // we have set failed state in exception handling flintCommand.complete() } @@ -932,11 +932,11 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String, sessionTimerContext: Timer.Context): Unit = { - val flintInstance = FlintInstance.deserializeFromMap(source) + val flintInstance = InteractiveSession.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serializeWithoutJobId( + InteractiveSession.serializeWithoutJobId( flintInstance, currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, @@ -1137,9 +1137,9 @@ object FlintREPL extends Logging with FlintJobExecutor { if (statementRunningCount.get() > 0) { statementRunningCount.decrementAndGet() } - if (flintCommand.isComplete()) { + if (flintCommand.isComplete) { incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) - } else if (flintCommand.isFailed()) { + } else if (flintCommand.isFailed) { incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 1a6aea4f4..3c4d0ecc2 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintCommand import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar