diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala similarity index 93% rename from flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala index 109bf654a..408216fad 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/ContextualDataStore.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model /** * Provides a mutable map to store and retrieve contextual data using key-value pairs. diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala similarity index 78% rename from flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index dbe73e9a5..bc8b38d9a 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model + +import java.util.Locale import org.json4s.{Formats, NoTypeHints} import org.json4s.JsonAST.JString @@ -14,6 +16,7 @@ object StatementStates { val RUNNING = "running" val SUCCESS = "success" val FAILED = "failed" + val TIMEOUT = "timeout" val WAITING = "waiting" } @@ -50,10 +53,15 @@ class FlintStatement( def running(): Unit = state = StatementStates.RUNNING def complete(): Unit = state = StatementStates.SUCCESS def fail(): Unit = state = StatementStates.FAILED - def isRunning: Boolean = state == StatementStates.RUNNING - def isComplete: Boolean = state == StatementStates.SUCCESS - def isFailed: Boolean = state == StatementStates.FAILED - def isWaiting: Boolean = state == StatementStates.WAITING + def timeout(): Unit = state = StatementStates.TIMEOUT + + def isRunning: Boolean = state.equalsIgnoreCase(StatementStates.RUNNING) + + def isComplete: Boolean = state.equalsIgnoreCase(StatementStates.SUCCESS) + + def isFailed: Boolean = state.equalsIgnoreCase(StatementStates.FAILED) + + def isWaiting: Boolean = state.equalsIgnoreCase(StatementStates.WAITING) // Does not include context, which could contain sensitive information. override def toString: String = @@ -66,7 +74,7 @@ object FlintStatement { def deserialize(statement: String): FlintStatement = { val meta = parse(statement) - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT) val query = (meta \ "query").extract[String] val statementId = (meta \ "statementId").extract[String] val queryId = (meta \ "queryId").extract[String] @@ -82,6 +90,8 @@ object FlintStatement { def serialize(flintStatement: FlintStatement): String = { // we only need to modify state and error Serialization.write( - Map("state" -> flintStatement.state, "error" -> flintStatement.error.getOrElse(""))) + Map( + "state" -> flintStatement.state.toLowerCase(Locale.ROOT), + "error" -> flintStatement.error.getOrElse(""))) } } diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala similarity index 90% rename from flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala rename to flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index c5eaee4f1..9acdeab5f 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -3,9 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model -import java.util.{Map => JavaMap} +import java.util.{Locale, Map => JavaMap} import scala.collection.JavaConverters._ @@ -16,9 +16,8 @@ import org.json4s.native.Serialization object SessionStates { val RUNNING = "running" - val COMPLETE = "complete" - val FAILED = "failed" - val WAITING = "waiting" + val DEAD = "dead" + val FAIL = "fail" } /** @@ -56,10 +55,15 @@ class InteractiveSession( extends ContextualDataStore { context = sessionContext // Initialize the context from the constructor - def isRunning: Boolean = state == SessionStates.RUNNING - def isComplete: Boolean = state == SessionStates.COMPLETE - def isFailed: Boolean = state == SessionStates.FAILED - def isWaiting: Boolean = state == SessionStates.WAITING + def running(): Unit = state = SessionStates.RUNNING + def complete(): Unit = state = SessionStates.DEAD + def fail(): Unit = state = SessionStates.FAIL + + def isRunning: Boolean = state.equalsIgnoreCase(SessionStates.RUNNING) + + def isComplete: Boolean = state.equalsIgnoreCase(SessionStates.DEAD) + + def isFail: Boolean = state.equalsIgnoreCase(SessionStates.FAIL) override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") @@ -77,7 +81,7 @@ object InteractiveSession { def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toLowerCase(Locale.ROOT) val jobId = (meta \ "jobId").extract[String] val sessionId = (meta \ "sessionId").extract[String] val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long] @@ -116,7 +120,7 @@ object InteractiveSession { val scalaSource = source.asScala val applicationId = scalaSource("applicationId").asInstanceOf[String] - val state = scalaSource("state").asInstanceOf[String] + val state = scalaSource("state").asInstanceOf[String].toLowerCase(Locale.ROOT) val jobId = scalaSource("jobId").asInstanceOf[String] val sessionId = scalaSource("sessionId").asInstanceOf[String] val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long] @@ -178,7 +182,7 @@ object InteractiveSession { "sessionId" -> job.sessionId, "error" -> job.error.getOrElse(""), "applicationId" -> job.applicationId, - "state" -> job.state, + "state" -> job.state.toLowerCase(Locale.ROOT), // update last update time "lastUpdateTime" -> currentTime, // Convert a Seq[String] into a comma-separated string, such as "id1,id2". diff --git a/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala similarity index 97% rename from flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala rename to flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala index f69fe70b4..5f6b1fdc1 100644 --- a/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala +++ b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.data +package org.opensearch.flint.common.model import java.util.{HashMap => JavaHashMap} @@ -21,7 +21,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" instance.sessionId shouldBe "session-789" - instance.state shouldBe "RUNNING" + instance.state shouldBe "running" instance.lastUpdateTime shouldBe 1620000000000L instance.jobStartTime shouldBe 1620000001000L instance.excludedJobIds should contain allOf ("job-101", "job-202") @@ -44,7 +44,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") json should include(""""sessionId":"session-789"""") - json should include(""""state":"RUNNING"""") + json should include(""""state":"running"""") json should include(s""""lastUpdateTime":$currentTime""") json should include( """"excludeJobIds":"job-101,job-202"""" @@ -149,7 +149,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" instance.sessionId shouldBe "session-789" - instance.state shouldBe "RUNNING" + instance.state shouldBe "running" instance.lastUpdateTime shouldBe 1620000000000L instance.jobStartTime shouldBe 0L // Default or expected value for missing jobStartTime instance.excludedJobIds should contain allOf ("job-101", "job-202") diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 5c101ac2d..b75ff0ce9 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable} import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index e6198496b..b235ecdd5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -10,8 +10,8 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.common.model.InteractiveSession import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} -import org.opensearch.flint.data.InteractiveSession import org.scalatest.matchers.should.Matchers class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index f9cccf27a..37801a9e8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -96,6 +96,24 @@ trait FlintJobExecutor { } }""".stripMargin + // Define the data schema + val schema = StructType( + Seq( + StructField("result", ArrayType(StringType, containsNull = true), nullable = true), + StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), + StructField("jobRunId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true), + StructField("dataSourceName", StringType, nullable = true), + StructField("status", StringType, nullable = true), + StructField("error", StringType, nullable = true), + StructField("queryId", StringType, nullable = true), + StructField("queryText", StringType, nullable = true), + StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), + // number is not nullable + StructField("updateTime", LongType, nullable = false), + StructField("queryRunTime", LongType, nullable = true))) + def createSparkConf(): SparkConf = { val conf = new SparkConf().setAppName(getClass.getSimpleName) @@ -203,24 +221,6 @@ trait FlintJobExecutor { StructField("column_name", StringType, nullable = false), StructField("data_type", StringType, nullable = false)))) - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) - val resultToSave = result.toJSON.collect.toList .map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")) @@ -253,35 +253,17 @@ trait FlintJobExecutor { spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) } - def getFailedData( + def constructErrorDF( spark: SparkSession, dataSource: String, + status: String, error: String, queryId: String, - query: String, + queryText: String, sessionId: String, - startTime: Long, - timeProvider: TimeProvider): DataFrame = { - - // Define the data schema - val schema = StructType( - Seq( - StructField("result", ArrayType(StringType, containsNull = true), nullable = true), - StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("jobRunId", StringType, nullable = true), - StructField("applicationId", StringType, nullable = true), - StructField("dataSourceName", StringType, nullable = true), - StructField("status", StringType, nullable = true), - StructField("error", StringType, nullable = true), - StructField("queryId", StringType, nullable = true), - StructField("queryText", StringType, nullable = true), - StructField("sessionId", StringType, nullable = true), - StructField("jobType", StringType, nullable = true), - // number is not nullable - StructField("updateTime", LongType, nullable = false), - StructField("queryRunTime", LongType, nullable = true))) + startTime: Long): DataFrame = { - val endTime = timeProvider.currentEpochMillis() + val updateTime = currentTimeProvider.currentEpochMillis() // Create the data rows val rows = Seq( @@ -291,14 +273,14 @@ trait FlintJobExecutor { envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, - "FAILED", + status.toUpperCase(Locale.ROOT), error, queryId, - query, + queryText, sessionId, spark.conf.get(FlintSparkConf.JOB_TYPE.key), - endTime, - endTime - startTime)) + updateTime, + updateTime - startTime)) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 782dd04c2..e6b8b11ce 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings +import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession} +import org.opensearch.flint.common.model.InteractiveSession.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.opensearch.flint.data.{FlintStatement, InteractiveSession} -import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf @@ -456,7 +456,7 @@ object FlintREPL extends Logging with FlintJobExecutor { .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) - if (flintInstance.state.equals("fail")) { + if (flintInstance.isFail) { recordSessionFailed(sessionTimerContext) } } @@ -530,15 +530,15 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime: Long): DataFrame = { flintStatement.fail() flintStatement.error = Some(error) - super.getFailedData( + super.constructErrorDF( spark, dataSource, + flintStatement.state, error, flintStatement.queryId, flintStatement.query, sessionId, - startTime, - currentTimeProvider) + startTime) } def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { @@ -654,7 +654,7 @@ object FlintREPL extends Logging with FlintJobExecutor { error: String, flintStatement: FlintStatement, sessionId: String, - startTime: Long): Option[DataFrame] = { + startTime: Long): DataFrame = { /* * https://tinyurl.com/2ezs5xj9 * @@ -668,14 +668,17 @@ object FlintREPL extends Logging with FlintJobExecutor { * actions that require the computation of results that need to be collected or stored. */ spark.sparkContext.cancelJobGroup(flintStatement.queryId) - Some( - handleCommandFailureAndGetFailedData( - spark, - dataSource, - error, - flintStatement, - sessionId, - startTime)) + flintStatement.timeout() + flintStatement.error = Some(error) + super.constructErrorDF( + spark, + dataSource, + flintStatement.state, + error, + flintStatement.queryId, + flintStatement.query, + sessionId, + startTime) } def executeAndHandle( @@ -702,7 +705,7 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + Some(handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime)) case e: Exception => val error = processQueryException(e, flintStatement) Some( @@ -761,8 +764,14 @@ object FlintREPL extends Logging with FlintJobExecutor { case e: TimeoutException => val error = s"Getting the mapping of index $resultIndex timed out" CustomLogging.logError(error, e) - dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) + dataToWrite = Some( + handleCommandTimeout( + spark, + dataSource, + error, + flintStatement, + sessionId, + startTime)) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) @@ -941,7 +950,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId: String, sessionTimerContext: Timer.Context): Unit = { val flintInstance = InteractiveSession.deserializeFromMap(source) - flintInstance.state = "dead" + flintInstance.complete() flintSessionIndexUpdater.updateIf( sessionId, InteractiveSession.serializeWithoutJobId( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index f315dc836..c079b3e96 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -57,7 +57,7 @@ case class JobOperator( dataToWrite = Some(mappingCheckResult match { case Right(_) => data case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime) }) exceptionThrown = false } catch { @@ -65,11 +65,11 @@ case class JobOperator( val error = s"Getting the mapping of index $resultIndex timed out" logError(error, e) dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + constructErrorDF(spark, dataSource, "TIMEOUT", error, "", query, "", startTime)) case e: Exception => val error = processQueryException(e) dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + constructErrorDF(spark, dataSource, "FAILED", error, "", query, "", startTime)) } finally { cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index ef5db02dc..9c193fc9a 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito.{atLeastOnce, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse +import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} -import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatestplus.mockito.MockitoSugar @@ -230,7 +230,7 @@ class FlintREPLTest verify(flintSessionIndexUpdater).updateIf(*, *, *, *) } - test("Test getFailedData method") { + test("Test super.constructErrorDF should construct dataframe properly") { // Define expected dataframe val dataSourceName = "myGlueS3" val expectedSchema = StructType( @@ -288,7 +288,7 @@ class FlintREPLTest "20", currentTime - queryRunTime) assertEqualDataframe(expected, result) - assert("failed" == flintStatement.state) + assert(flintStatement.isFailed) assert(error == flintStatement.error.get) } finally { spark.close() @@ -492,7 +492,7 @@ class FlintREPLTest assert(result == expectedError) } - test("handleGeneralException should handle MetaException with AccessDeniedException properly") { + test("processQueryException should handle MetaException with AccessDeniedException properly") { val mockFlintCommand = mock[FlintStatement] // Simulate the root cause being MetaException @@ -620,7 +620,6 @@ class FlintREPLTest test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] - val mockFlintStatement = mock[FlintStatement] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) @@ -633,9 +632,8 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] - - when(mockFlintStatement.query).thenReturn("SELECT 1") - when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) + val flintStatement = + new FlintStatement("running", "select 1", "30", "10", Instant.now().toEpochMilli(), None) // When the `sql` method is called, execute the custom Answer that introduces a delay when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { override def answer(invocation: InvocationOnMock): DataFrame = { @@ -656,7 +654,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - mockFlintStatement, + flintStatement, dataSource, sessionId, executionContext, @@ -667,6 +665,8 @@ class FlintREPLTest verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) + assert("timeout" == flintStatement.state) + assert(s"Executing ${flintStatement.query} timed out" == flintStatement.error.get) result should not be None } finally threadPool.shutdown() }