From f258018a200330cae4e92aec2cf5adb4b8a72064 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 8 Aug 2024 14:54:18 -0700 Subject: [PATCH] Make states always in upper case --- .../flint/common/model/FlintStatement.scala | 22 ++++++++++--------- .../common/model/InteractiveSession.scala | 18 +++++++-------- .../common/model/InteractiveSessionTest.scala | 6 ++--- .../apache/spark/sql/FlintREPLITSuite.scala | 16 +++++++------- .../org/apache/spark/sql/FlintREPLTest.scala | 2 +- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index 7715d790e..237e820dd 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -5,17 +5,19 @@ package org.opensearch.flint.common.model +import java.util.Locale + import org.json4s.{Formats, NoTypeHints} import org.json4s.JsonAST.JString import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization object StatementStates { - val RUNNING = "running" - val SUCCESS = "success" - val FAILED = "failed" - val TIMEOUT = "timeout" - val WAITING = "waiting" + val RUNNING = "RUNNING" + val SUCCESS = "SUCCESS" + val FAILED = "FAILED" + val TIMEOUT = "TIMEOUT" + val WAITING = "WAITING" } /** @@ -53,10 +55,10 @@ class FlintStatement( def fail(): Unit = state = StatementStates.FAILED def timeout(): Unit = state = StatementStates.TIMEOUT - def isRunning: Boolean = state == StatementStates.RUNNING - def isComplete: Boolean = state == StatementStates.SUCCESS - def isFailed: Boolean = state == StatementStates.FAILED - def isWaiting: Boolean = state == StatementStates.WAITING + def isRunning: Boolean = state.equalsIgnoreCase(StatementStates.RUNNING) + def isComplete: Boolean = state.equalsIgnoreCase(StatementStates.SUCCESS) + def isFailed: Boolean = state.equalsIgnoreCase(StatementStates.FAILED) + def isWaiting: Boolean = state.equalsIgnoreCase(StatementStates.WAITING) // Does not include context, which could contain sensitive information. override def toString: String = @@ -69,7 +71,7 @@ object FlintStatement { def deserialize(statement: String): FlintStatement = { val meta = parse(statement) - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toUpperCase(Locale.ROOT) val query = (meta \ "query").extract[String] val statementId = (meta \ "statementId").extract[String] val queryId = (meta \ "queryId").extract[String] diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala index f9392c7a9..4ce7dbad8 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/InteractiveSession.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.common.model -import java.util.{Map => JavaMap} +import java.util.{Locale, Map => JavaMap} import scala.collection.JavaConverters._ @@ -15,9 +15,9 @@ import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization object SessionStates { - val RUNNING = "running" - val DEAD = "dead" - val FAIL = "fail" + val RUNNING = "RUNNING" + val DEAD = "DEAD" + val FAIL = "FAIL" } /** @@ -59,9 +59,9 @@ class InteractiveSession( def complete(): Unit = state = SessionStates.DEAD def fail(): Unit = state = SessionStates.FAIL - def isRunning: Boolean = state == SessionStates.RUNNING - def isComplete: Boolean = state == SessionStates.DEAD - def isFail: Boolean = state == SessionStates.FAIL + def isRunning: Boolean = state.equalsIgnoreCase(SessionStates.RUNNING) + def isComplete: Boolean = state.equalsIgnoreCase(SessionStates.DEAD) + def isFail: Boolean = state.equalsIgnoreCase(SessionStates.FAIL) override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") @@ -79,7 +79,7 @@ object InteractiveSession { def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] - val state = (meta \ "state").extract[String] + val state = (meta \ "state").extract[String].toUpperCase(Locale.ROOT) val jobId = (meta \ "jobId").extract[String] val sessionId = (meta \ "sessionId").extract[String] val lastUpdateTime = (meta \ "lastUpdateTime").extract[Long] @@ -118,7 +118,7 @@ object InteractiveSession { val scalaSource = source.asScala val applicationId = scalaSource("applicationId").asInstanceOf[String] - val state = scalaSource("state").asInstanceOf[String] + val state = scalaSource("state").asInstanceOf[String].toUpperCase(Locale.ROOT) val jobId = scalaSource("jobId").asInstanceOf[String] val sessionId = scalaSource("sessionId").asInstanceOf[String] val lastUpdateTime = scalaSource("lastUpdateTime").asInstanceOf[Long] diff --git a/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala index 82df98910..bdba0fc3e 100644 --- a/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala +++ b/flint-commons/src/test/scala/org/opensearch/flint/common/model/InteractiveSessionTest.scala @@ -15,7 +15,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") { val json = - """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" + """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"running","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" @@ -101,7 +101,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { assert(result.applicationId == "app1") assert(result.jobId == "job1") assert(result.sessionId == "session1") - assert(result.state == "running") + assert(result.state == "RUNNING") assert(result.lastUpdateTime == 1234567890L) assert(result.jobStartTime == 9876543210L) assert(result.excludedJobIds == Seq("job2", "job3")) @@ -134,7 +134,7 @@ class InteractiveSessionTest extends SparkFunSuite with Matchers { assert(result.applicationId == "app1") assert(result.jobId == "job1") assert(result.sessionId == "session1") - assert(result.state == "running") + assert(result.state == "RUNNING") assert(result.lastUpdateTime == 1234567890L) assert(result.jobStartTime == 0L) // Default value for missing jobStartTime assert(result.excludedJobIds == Seq("job2", "job3")) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index b75ff0ce9..c59b4b069 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -304,7 +304,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( !awaitConditionForStatementOrTimeout( statement => { - statement.state == "success" + statement.isComplete }, selectStatementId), s"Fail to verify for $selectStatementId.") @@ -344,7 +344,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( !awaitConditionForStatementOrTimeout( statement => { - statement.state == "success" + statement.isComplete }, descStatementId), s"Fail to verify for $descStatementId.") @@ -381,7 +381,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( !awaitConditionForStatementOrTimeout( statement => { - statement.state == "success" + statement.isComplete }, showTableStatementId), s"Fail to verify for $showTableStatementId.") @@ -401,7 +401,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( !awaitConditionForStatementOrTimeout( statement => { - statement.state == "failed" + statement.isFailed }, wrongSelectStatementId), s"Fail to verify for $wrongSelectStatementId.") @@ -410,7 +410,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( awaitConditionForStatementOrTimeout( statement => { - statement.state != "waiting" + !statement.isWaiting }, lateSelectStatementId), s"Fail to verify for $lateSelectStatementId.") @@ -471,7 +471,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { case _ => fail(s"Statement error is: ${statement.error}") } - statement.state == "failed" + statement.isFailed }, createTableStatementId), s"Fail to verify for $createTableStatementId.") @@ -558,7 +558,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( !awaitConditionForStatementOrTimeout( statement => { - statement.state == "success" + statement.isComplete }, selectStatementId), s"Fail to verify for $selectStatementId.") @@ -566,7 +566,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { assert( awaitConditionForStatementOrTimeout( statement => { - statement.state != "waiting" + !statement.isWaiting }, lateSelectStatementId), s"Fail to verify for $lateSelectStatementId.") diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 9c193fc9a..a0ffd1aa5 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -665,7 +665,7 @@ class FlintREPLTest verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) - assert("timeout" == flintStatement.state) + assert("TIMEOUT" == flintStatement.state) assert(s"Executing ${flintStatement.query} timed out" == flintStatement.error.get) result should not be None } finally threadPool.shutdown()