From 80d8f6eb768ff1f73a11cef7e463d28f09434a49 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 12 Jun 2024 08:01:57 -0700 Subject: [PATCH] [Refactor] Introduce flint-commons for models and interfaces (#373) * [Refactor] Introduce flint-data for model and interface Signed-off-by: Louis Chu * Uppercase for constant Signed-off-by: Louis Chu * Address comments Signed-off-by: Louis Chu * Rename package Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- README.md | 3 +- build.sbt | 37 ++++- .../flint/data/ContextualDataStore.scala | 29 ++++ .../flint/data/FlintStatement.scala | 87 +++++++++++ .../flint/data/InteractiveSession.scala | 65 ++++++-- .../flint/data/InteractiveSessionTest.scala | 45 ++++-- .../opensearch/flint/app/FlintCommand.scala | 78 ---------- .../apache/spark/sql/FlintREPLITSuite.scala | 14 +- .../flint/core/OpenSearchUpdaterSuite.scala | 26 ++-- .../org/apache/spark/sql/FlintREPL.scala | 144 +++++++++--------- .../org/apache/spark/sql/FlintREPLTest.scala | 36 ++--- 11 files changed, 346 insertions(+), 218 deletions(-) create mode 100644 flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala create mode 100644 flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala rename flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala => flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala (76%) rename flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala => flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala (78%) delete mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala diff --git a/README.md b/README.md index 017b4a1c2..f9568838e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # OpenSearch Flint -OpenSearch Flint is ... It consists of two modules: +OpenSearch Flint is ... It consists of four modules: - `flint-core`: a module that contains Flint specification and client. +- `flint-commons`: a module that provides a shared library of utilities and common functionalities, designed to easily extend Flint's capabilities. - `flint-spark-integration`: a module that provides Spark integration for Flint and derived dataset based on it. - `ppl-spark-integration`: a module that provides PPL query execution on top of Spark See [PPL repository](https://github.com/opensearch-project/piped-processing-language). diff --git a/build.sbt b/build.sbt index f12a19647..9d419decb 100644 --- a/build.sbt +++ b/build.sbt @@ -48,7 +48,7 @@ lazy val commonSettings = Seq( // running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) + .aggregate(flintCommons, flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -84,6 +84,37 @@ lazy val flintCore = (project in file("flint-core")) libraryDependencies ++= deps(sparkVersion), publish / skip := true) +lazy val flintCommons = (project in file("flint-commons")) + .settings( + commonSettings, + name := "flint-commons", + scalaVersion := scala212, + libraryDependencies ++= Seq( + "org.scalactic" %% "scalactic" % "3.2.15" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", + ), + libraryDependencies ++= deps(sparkVersion), + publish / skip := true, + assembly / test := (Test / test).value, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps@_*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs@_, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + ) + .enablePlugins(AssemblyPlugin) + + lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( @@ -121,7 +152,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) assembly / test := (Test / test).value) lazy val flintSparkIntegration = (project in file("flint-spark-integration")) - .dependsOn(flintCore) + .dependsOn(flintCore, flintCommons) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, @@ -166,7 +197,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") + .dependsOn(flintCommons % "test->test", flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala new file mode 100644 index 000000000..109bf654a --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.data + +/** + * Provides a mutable map to store and retrieve contextual data using key-value pairs. + */ +trait ContextualDataStore { + + /** Holds the contextual data as key-value pairs. */ + var context: Map[String, Any] = Map.empty + + /** + * Adds a key-value pair to the context map. + */ + def setContextValue(key: String, value: Any): Unit = { + context += (key -> value) + } + + /** + * Retrieves the value associated with a key from the context map. + */ + def getContextValue(key: String): Option[Any] = { + context.get(key) + } +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala new file mode 100644 index 000000000..dbe73e9a5 --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.data + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.JString +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization + +object StatementStates { + val RUNNING = "running" + val SUCCESS = "success" + val FAILED = "failed" + val WAITING = "waiting" +} + +/** + * Represents a statement processed in the Flint job. + * + * @param state + * The current state of the statement. + * @param query + * SQL-like query string that the statement will execute. + * @param statementId + * Unique identifier for the type of statement. + * @param queryId + * Unique identifier for the query. + * @param submitTime + * Timestamp when the statement was submitted. + * @param error + * Optional error message if the statement fails. + * @param statementContext + * Additional context for the statement as key-value pairs. + */ +class FlintStatement( + var state: String, + val query: String, + // statementId is the statement type doc id + val statementId: String, + val queryId: String, + val submitTime: Long, + var error: Option[String] = None, + statementContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualDataStore { + context = statementContext + + def running(): Unit = state = StatementStates.RUNNING + def complete(): Unit = state = StatementStates.SUCCESS + def fail(): Unit = state = StatementStates.FAILED + def isRunning: Boolean = state == StatementStates.RUNNING + def isComplete: Boolean = state == StatementStates.SUCCESS + def isFailed: Boolean = state == StatementStates.FAILED + def isWaiting: Boolean = state == StatementStates.WAITING + + // Does not include context, which could contain sensitive information. + override def toString: String = + s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" +} + +object FlintStatement { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + def deserialize(statement: String): FlintStatement = { + val meta = parse(statement) + val state = (meta \ "state").extract[String] + val query = (meta \ "query").extract[String] + val statementId = (meta \ "statementId").extract[String] + val queryId = (meta \ "queryId").extract[String] + val submitTime = (meta \ "submitTime").extract[Long] + val maybeError: Option[String] = (meta \ "error") match { + case JString(str) => Some(str) + case _ => None + } + + new FlintStatement(state, query, statementId, queryId, submitTime, maybeError) + } + + def serialize(flintStatement: FlintStatement): String = { + // we only need to modify state and error + Serialization.write( + Map("state" -> flintStatement.state, "error" -> flintStatement.error.getOrElse(""))) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala similarity index 76% rename from flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala rename to flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala index 9911a3b6c..c5eaee4f1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala @@ -3,42 +3,78 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ -import scala.collection.mutable import org.json4s.{Formats, JNothing, JNull, NoTypeHints} import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization -// lastUpdateTime is added to FlintInstance to track the last update time of the instance. Its unit is millisecond. -class FlintInstance( +object SessionStates { + val RUNNING = "running" + val COMPLETE = "complete" + val FAILED = "failed" + val WAITING = "waiting" +} + +/** + * Represents an interactive session for job and state management. + * + * @param applicationId + * Unique identifier for the EMR-S application. + * @param jobId + * Identifier for the specific EMR-S job. + * @param sessionId + * Unique session identifier. + * @param state + * Current state of the session. + * @param lastUpdateTime + * Timestamp of the last update. + * @param jobStartTime + * Start time of the job. + * @param excludedJobIds + * List of job IDs that are excluded. + * @param error + * Optional error message. + * @param sessionContext + * Additional context for the session. + */ +class InteractiveSession( val applicationId: String, val jobId: String, - // sessionId is the session type doc id val sessionId: String, var state: String, val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) { + val error: Option[String] = None, + sessionContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualDataStore { + context = sessionContext // Initialize the context from the constructor + + def isRunning: Boolean = state == SessionStates.RUNNING + def isComplete: Boolean = state == SessionStates.COMPLETE + def isFailed: Boolean = state == SessionStates.FAILED + def isWaiting: Boolean = state == SessionStates.WAITING + override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") val errorStr = error.getOrElse("None") + // Does not include context, which could contain sensitive information. s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" } } -object FlintInstance { +object InteractiveSession { implicit val formats: Formats = Serialization.formats(NoTypeHints) - def deserialize(job: String): FlintInstance = { + def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] val state = (meta \ "state").extract[String] @@ -64,7 +100,7 @@ object FlintInstance { case _ => None } - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -75,7 +111,7 @@ object FlintInstance { maybeError) } - def deserializeFromMap(source: JavaMap[String, AnyRef]): FlintInstance = { + def deserializeFromMap(source: JavaMap[String, AnyRef]): InteractiveSession = { // Since we are dealing with JavaMap, we convert it to a Scala mutable Map for ease of use. val scalaSource = source.asScala @@ -105,7 +141,7 @@ object FlintInstance { } // Construct a new FlintInstance with the extracted values. - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -133,7 +169,10 @@ object FlintInstance { * @return * serialized Flint session */ - def serialize(job: FlintInstance, currentTime: Long, includeJobId: Boolean = true): String = { + def serialize( + job: InteractiveSession, + currentTime: Long, + includeJobId: Boolean = true): String = { val baseMap = Map( "type" -> "session", "sessionId" -> job.sessionId, @@ -159,7 +198,7 @@ object FlintInstance { Serialization.write(resultMap) } - def serializeWithoutJobId(job: FlintInstance, currentTime: Long): String = { + def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = { serialize(job, currentTime, includeJobId = false) } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala similarity index 78% rename from flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala rename to flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala index 8ece6ba8a..f69fe70b4 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala +++ b/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{HashMap => JavaHashMap} @@ -11,12 +11,12 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite -class FlintInstanceTest extends SparkFunSuite with Matchers { +class InteractiveSessionTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -30,7 +30,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("serialize should correctly produce JSON from a FlintInstance with excludedJobIds") { val excludedJobIds = Seq("job-101", "job-202") - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -39,7 +39,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { 1620000001000L, excludedJobIds) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") @@ -56,7 +56,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle an empty excludedJobIds field in JSON") { val jsonWithoutExcludedJobIds = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}""" - val instance = FlintInstance.deserialize(jsonWithoutExcludedJobIds) + val instance = InteractiveSession.deserialize(jsonWithoutExcludedJobIds) instance.excludedJobIds shouldBe empty } @@ -64,13 +64,13 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle error field in JSON") { val jsonWithError = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"FAILED","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"error":"Some error occurred"}""" - val instance = FlintInstance.deserialize(jsonWithError) + val instance = InteractiveSession.deserialize(jsonWithError) instance.error shouldBe Some("Some error occurred") } test("serialize should include error when present in FlintInstance") { - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -80,7 +80,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { Seq.empty[String], Some("Some error occurred")) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""error":"Some error occurred"""") } @@ -96,7 +96,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -114,7 +114,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("lastUpdateTime", "1234567890") assertThrows[ClassCastException] { - FlintInstance.deserializeFromMap(sourceMap) + InteractiveSession.deserializeFromMap(sourceMap) } } @@ -129,7 +129,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -144,7 +144,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance without jobStartTime from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -155,4 +155,23 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { instance.excludedJobIds should contain allOf ("job-101", "job-202") instance.error shouldBe None } + + test("sessionContext should add/get key-value pairs to/from the context") { + val session = + new InteractiveSession("app-123", "job-456", "session-789", "RUNNING", 1620000000000L) + session.context shouldBe empty + + session.setContextValue("key1", "value1") + session.setContextValue("key2", 42) + + session.context should contain("key1" -> "value1") + session.context should contain("key2" -> 42) + + session.getContextValue("key1") shouldBe Some("value1") + session.getContextValue("key2") shouldBe Some(42) + session.getContextValue("key3") shouldBe None // Test for a key that does not exist + + session.setContextValue("key1", "updatedValue") + session.getContextValue("key1") shouldBe Some("updatedValue") + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala deleted file mode 100644 index 7624c2c54..000000000 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.app - -import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.JString -import org.json4s.native.JsonMethods.parse -import org.json4s.native.Serialization - -class FlintCommand( - var state: String, - val query: String, - // statementId is the statement type doc id - val statementId: String, - val queryId: String, - val submitTime: Long, - var error: Option[String] = None) { - def running(): Unit = { - state = "running" - } - - def complete(): Unit = { - state = "success" - } - - def fail(): Unit = { - state = "failed" - } - - def isRunning(): Boolean = { - state == "running" - } - - def isComplete(): Boolean = { - state == "success" - } - - def isFailed(): Boolean = { - state == "failed" - } - - def isWaiting(): Boolean = { - state == "waiting" - } - - override def toString: String = { - s"FlintCommand(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" - } -} - -object FlintCommand { - - implicit val formats: Formats = Serialization.formats(NoTypeHints) - - def deserialize(command: String): FlintCommand = { - val meta = parse(command) - val state = (meta \ "state").extract[String] - val query = (meta \ "query").extract[String] - val statementId = (meta \ "statementId").extract[String] - val queryId = (meta \ "queryId").extract[String] - val submitTime = (meta \ "submitTime").extract[Long] - val maybeError: Option[String] = (meta \ "error") match { - case JString(str) => Some(str) - case _ => None - } - - new FlintCommand(state, query, statementId, queryId, submitTime, maybeError) - } - - def serialize(flintCommand: FlintCommand): String = { - // we only need to modify state and error - Serialization.write( - Map("state" -> flintCommand.state, "error" -> flintCommand.error.getOrElse(""))) - } -} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala index f2d4be911..1c0b27674 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable} import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite -import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite @@ -546,28 +546,28 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } def awaitConditionForStatementOrTimeout( - expected: FlintCommand => Boolean, + expected: FlintStatement => Boolean, statementId: String): Boolean = { - awaitConditionOrTimeout[FlintCommand]( + awaitConditionOrTimeout[FlintStatement]( osClient, expected, statementId, 10000, requestIndex, - FlintCommand.deserialize, + FlintStatement.deserialize, "statement") } def awaitConditionForSessionOrTimeout( - expected: FlintInstance => Boolean, + expected: InteractiveSession => Boolean, sessionId: String): Boolean = { - awaitConditionOrTimeout[FlintInstance]( + awaitConditionOrTimeout[InteractiveSession]( osClient, expected, sessionId, 10000, requestIndex, - FlintInstance.deserialize, + InteractiveSession.deserialize, "session") } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index 3b317a0fe..fa7f75b81 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -10,15 +10,15 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite -import org.opensearch.flint.app.FlintInstance import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} +import org.opensearch.flint.data.InteractiveSession import org.scalatest.matchers.should.Matchers class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { val sessionId = "sessionId" val timestamp = 1700090926955L val flintJob = - new FlintInstance( + new InteractiveSession( "applicationId", "jobId", sessionId, @@ -38,7 +38,7 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { } test("upsert flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe timestamp } @@ -46,15 +46,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("update flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val newTimestamp = 1700090926956L - updater.update(sessionId, FlintInstance.serialize(flintJob, newTimestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, newTimestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } @@ -62,25 +62,25 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.update(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("updateIf flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) val newTimestamp = 1700090926956L updater.updateIf( sessionId, - FlintInstance.serialize(latest, newTimestamp), + InteractiveSession.serialize(latest, newTimestamp), resp.getSeqNo, resp.getPrimaryTerm) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } test("index is deleted when updateIf flintJob should throw IllegalStateException") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) deleteIndex(testMetaLogIndex) @@ -88,15 +88,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { the[IllegalStateException] thrownBy { updater.updateIf( sessionId, - FlintInstance.serialize(latest, timestamp), + InteractiveSession.serialize(latest, timestamp), resp.getSeqNo, resp.getPrimaryTerm) } } - def getFlintInstance(docId: String): (GetResponse, FlintInstance) = { + def getFlintInstance(docId: String): (GetResponse, InteractiveSession) = { val response = openSearchClient.get(new GetRequest(testMetaLogIndex, docId), RequestOptions.DEFAULT) - (response, FlintInstance.deserializeFromMap(response.getSourceAsMap)) + (response, InteractiveSession.deserializeFromMap(response.getSourceAsMap)) } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 36432f016..8cad8844b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings -import org.opensearch.flint.app.{FlintCommand, FlintInstance} -import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf @@ -57,8 +57,8 @@ object FlintREPL extends Logging with FlintJobExecutor { @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { - updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) + def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { + updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) } private val sessionRunningCount = new AtomicInteger(0) @@ -411,7 +411,7 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new FlintInstance( + val flintJob = new InteractiveSession( applicationId, jobId, sessionId, @@ -421,9 +421,9 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds) val serializedFlintInstance = if (includeJobId) { - FlintInstance.serialize(flintJob, currentTime, true) + InteractiveSession.serialize(flintJob, currentTime, true) } else { - FlintInstance.serializeWithoutJobId(flintJob, currentTime) + InteractiveSession.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) logInfo( @@ -456,11 +456,11 @@ object FlintREPL extends Logging with FlintJobExecutor { private def getExistingFlintInstance( osClient: OSClient, sessionIndex: String, - sessionId: String): Option[FlintInstance] = Try( + sessionId: String): Option[InteractiveSession] = Try( osClient.getDoc(sessionIndex, sessionId)) match { case Success(getResponse) if getResponse.isExists() => Option(getResponse.getSourceAsMap) - .map(FlintInstance.deserializeFromMap) + .map(InteractiveSession.deserializeFromMap) case Failure(exception) => CustomLogging.logError( s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", @@ -474,7 +474,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId: String, sessionId: String, jobStartTime: Long, - errorMessage: String): FlintInstance = new FlintInstance( + errorMessage: String): InteractiveSession = new InteractiveSession( applicationId, jobId, sessionId, @@ -484,19 +484,19 @@ object FlintREPL extends Logging with FlintJobExecutor { error = Some(errorMessage)) private def updateFlintInstance( - flintInstance: FlintInstance, + flintInstance: InteractiveSession, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String): Unit = { val currentTime = currentTimeProvider.currentEpochMillis() flintSessionIndexUpdater.upsert( sessionId, - FlintInstance.serializeWithoutJobId(flintInstance, currentTime)) + InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) } /** - * handling the case where a command's execution fails, updates the flintCommand with the error - * and failure status, and then write the result to result index. Thus, an error is written to - * both result index or statement model in request index + * handling the case where a command's execution fails, updates the flintStatement with the + * error and failure status, and then write the result to result index. Thus, an error is + * written to both result index or statement model in request index * * @param spark * spark session @@ -504,7 +504,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * data source * @param error * error message - * @param flintCommand + * @param flintStatement * flint command * @param sessionId * session id @@ -517,26 +517,26 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, dataSource: String, error: String, - flintCommand: FlintCommand, + flintStatement: FlintStatement, sessionId: String, startTime: Long): DataFrame = { - flintCommand.fail() - flintCommand.error = Some(error) + flintStatement.fail() + flintStatement.error = Some(error) super.getFailedData( spark, dataSource, error, - flintCommand.queryId, - flintCommand.query, + flintStatement.queryId, + flintStatement.query, sessionId, startTime, currentTimeProvider) } - def processQueryException(ex: Exception, flintCommand: FlintCommand): String = { + def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { val error = super.processQueryException(ex) - flintCommand.fail() - flintCommand.error = Some(error) + flintStatement.fail() + flintStatement.error = Some(error) error } @@ -570,12 +570,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } else { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) + val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( recordedVerificationResult, spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -587,7 +587,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = returnedVerificationResult finalizeCommand( dataToWrite, - flintCommand, + flintStatement, resultIndex, flintSessionIndexUpdater, osClient, @@ -606,7 +606,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * * @param dataToWrite * data to write - * @param flintCommand + * @param flintStatement * flint command * @param resultIndex * result index @@ -615,28 +615,28 @@ object FlintREPL extends Logging with FlintJobExecutor { */ private def finalizeCommand( dataToWrite: Option[DataFrame], - flintCommand: FlintCommand, + flintStatement: FlintStatement, resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - if (flintCommand.isRunning() || flintCommand.isWaiting()) { + if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling - flintCommand.complete() + flintStatement.complete() } - updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementTimerContext) + updateSessionIndex(flintStatement, flintSessionIndexUpdater) + recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) case e: Exception => - val error = s"""Fail to write result of ${flintCommand}, cause: ${e.getMessage}""" + val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) - flintCommand.fail() - updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementTimerContext) + flintStatement.fail() + updateSessionIndex(flintStatement, flintSessionIndexUpdater) + recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -644,7 +644,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, dataSource: String, error: String, - flintCommand: FlintCommand, + flintStatement: FlintStatement, sessionId: String, startTime: Long): Option[DataFrame] = { /* @@ -659,20 +659,20 @@ object FlintREPL extends Logging with FlintJobExecutor { * of Spark jobs. In the context of Spark SQL, this typically happens when we perform * actions that require the computation of results that need to be collected or stored. */ - spark.sparkContext.cancelJobGroup(flintCommand.queryId) + spark.sparkContext.cancelJobGroup(flintStatement.queryId) Some( handleCommandFailureAndGetFailedData( spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } def executeAndHandle( spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -683,7 +683,7 @@ object FlintREPL extends Logging with FlintJobExecutor { Some( executeQueryAsync( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -692,17 +692,17 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis)) } catch { case e: TimeoutException => - val error = s"Executing ${flintCommand.query} timed out" + val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) + handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) case e: Exception => - val error = processQueryException(e, flintCommand) + val error = processQueryException(e, flintStatement) Some( handleCommandFailureAndGetFailedData( spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -711,14 +711,14 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processStatementOnVerification( recordedVerificationResult: VerificationResult, spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, futureMappingCheck: Future[Either[String, Unit]], resultIndex: String, queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long): (Option[DataFrame], VerificationResult) = { + queryWaitTimeMillis: Long) = { val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -730,7 +730,7 @@ object FlintREPL extends Logging with FlintJobExecutor { case Right(_) => dataToWrite = executeAndHandle( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -745,7 +745,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -754,7 +754,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"Getting the mapping of index $resultIndex timed out" CustomLogging.logError(error, e) dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) + handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) @@ -763,7 +763,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -773,13 +773,13 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, err, - flintCommand, + flintStatement, sessionId, startTime)) case VerifiedWithoutError => dataToWrite = executeAndHandle( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -788,13 +788,13 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logInfo(s"command complete: $flintCommand") + logInfo(s"command complete: $flintStatement") (dataToWrite, verificationResult) } def executeQueryAsync( spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -802,21 +802,21 @@ object FlintREPL extends Logging with FlintJobExecutor { queryExecutionTimeOut: Duration, queryWaitTimeMillis: Long): DataFrame = { if (currentTimeProvider - .currentEpochMillis() - flintCommand.submitTime > queryWaitTimeMillis) { + .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( spark, dataSource, "wait timeout", - flintCommand, + flintStatement, sessionId, startTime) } else { val futureQueryExecution = Future { executeQuery( spark, - flintCommand.query, + flintStatement.query, dataSource, - flintCommand.queryId, + flintStatement.queryId, sessionId, false) }(executionContext) @@ -827,16 +827,16 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommandInitiation( flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintCommand = { + flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { val command = flintReader.next() logDebug(s"raw command: $command") - val flintCommand = FlintCommand.deserialize(command) - logDebug(s"command: $flintCommand") - flintCommand.running() - logDebug(s"command running: $flintCommand") - updateSessionIndex(flintCommand, flintSessionIndexUpdater) + val flintStatement = FlintStatement.deserialize(command) + logDebug(s"command: $flintStatement") + flintStatement.running() + logDebug(s"command running: $flintStatement") + updateSessionIndex(flintStatement, flintSessionIndexUpdater) statementRunningCount.incrementAndGet() - flintCommand + flintStatement } private def createQueryReader( @@ -932,11 +932,11 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String, sessionTimerContext: Timer.Context): Unit = { - val flintInstance = FlintInstance.deserializeFromMap(source) + val flintInstance = InteractiveSession.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serializeWithoutJobId( + InteractiveSession.serializeWithoutJobId( flintInstance, currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, @@ -1131,15 +1131,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def recordStatementStateChange( - flintCommand: FlintCommand, + flintStatement: FlintStatement, statementTimerContext: Timer.Context): Unit = { stopTimer(statementTimerContext) if (statementRunningCount.get() > 0) { statementRunningCount.decrementAndGet() } - if (flintCommand.isComplete()) { + if (flintStatement.isComplete) { incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) - } else if (flintCommand.isFailed()) { + } else if (flintStatement.isFailed) { incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 1a6aea4f4..546cd8e97 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar @@ -245,7 +245,7 @@ class FlintREPLTest val expected = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) - val flintCommand = new FlintCommand("failed", "select 1", "30", "10", currentTime, None) + val flintStatement = new FlintStatement("failed", "select 1", "30", "10", currentTime, None) try { FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) @@ -256,12 +256,12 @@ class FlintREPLTest spark, dataSourceName, error, - flintCommand, + flintStatement, "20", currentTime - queryRunTime) assertEqualDataframe(expected, result) - assert("failed" == flintCommand.state) - assert(error == flintCommand.error.get) + assert("failed" == flintStatement.state) + assert(error == flintStatement.error.get) } finally { spark.close() FlintREPL.currentTimeProvider = new RealTimeProvider() @@ -448,18 +448,18 @@ class FlintREPLTest exception.setErrorCode("AccessDeniedException") exception.setServiceName("AWSGlue") - val mockFlintCommand = mock[FlintCommand] + val mockFlintStatement = mock[FlintStatement] val expectedError = ( """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + """"ErrorSource":"AWSGlue","StatusCode":"400"}""" ) - val result = FlintREPL.processQueryException(exception, mockFlintCommand) + val result = FlintREPL.processQueryException(exception, mockFlintStatement) result shouldEqual expectedError - verify(mockFlintCommand).fail() - verify(mockFlintCommand).error = Some(expectedError) + verify(mockFlintStatement).fail() + verify(mockFlintStatement).error = Some(expectedError) assert(result == expectedError) } @@ -574,7 +574,7 @@ class FlintREPLTest test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] - val mockFlintCommand = mock[FlintCommand] + val mockFlintStatement = mock[FlintStatement] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) @@ -588,8 +588,8 @@ class FlintREPLTest val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] - when(mockFlintCommand.query).thenReturn("SELECT 1") - when(mockFlintCommand.submitTime).thenReturn(Instant.now().toEpochMilli()) + when(mockFlintStatement.query).thenReturn("SELECT 1") + when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) // When the `sql` method is called, execute the custom Answer that introduces a delay when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { override def answer(invocation: InvocationOnMock): DataFrame = { @@ -610,7 +610,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - mockFlintCommand, + mockFlintStatement, dataSource, sessionId, executionContext, @@ -631,8 +631,8 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val flintCommand = - new FlintCommand( + val flintStatement = + new FlintStatement( "Running", "select * from default.http_logs limit1 1", "10", @@ -661,7 +661,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -671,8 +671,8 @@ class FlintREPLTest // Verify that ParseException was caught and handled result should not be None // or result.isDefined shouldBe true - flintCommand.error should not be None - flintCommand.error.get should include("Syntax error:") + flintStatement.error should not be None + flintStatement.error.get should include("Syntax error:") } finally threadPool.shutdown() }