From 7a4c206714215187757428d3df9d01a28559bd0d Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Sat, 14 Sep 2024 15:54:28 -0700 Subject: [PATCH] Add langType to FlintStatement model (#664) Signed-off-by: Louis Chu --- .../flint/common/model/FlintStatement.scala | 8 ++++-- .../flint/common/model/LangType.scala | 25 +++++++++++++++++++ .../apache/spark/sql/FlintREPLITSuite.scala | 1 + .../org/apache/spark/sql/JobOperator.scala | 9 ++++++- .../org/apache/spark/sql/FlintREPLTest.scala | 15 +++++++++-- 5 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 flint-commons/src/main/scala/org/opensearch/flint/common/model/LangType.scala diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala index 00876d46e..da07f435c 100644 --- a/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/FlintStatement.scala @@ -31,6 +31,8 @@ object StatementStates { * Unique identifier for the type of statement. * @param queryId * Unique identifier for the query. + * @param langType + * The language type of the query (e.g., "sql" or "ppl"). * @param submitTime * Timestamp when the statement was submitted. * @param error @@ -44,6 +46,7 @@ class FlintStatement( // statementId is the statement type doc id val statementId: String, val queryId: String, + val langType: String, val submitTime: Long, var error: Option[String] = None, statementContext: Map[String, Any] = Map.empty[String, Any]) @@ -65,7 +68,7 @@ class FlintStatement( // Does not include context, which could contain sensitive information. override def toString: String = - s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" + s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, langType=$langType, submitTime=$submitTime, error=$error)" } object FlintStatement { @@ -78,13 +81,14 @@ object FlintStatement { val query = (meta \ "query").extract[String] val statementId = (meta \ "statementId").extract[String] val queryId = (meta \ "queryId").extract[String] + val langType = (meta \ "lang").extract[String].toLowerCase(Locale.ROOT) val submitTime = (meta \ "submitTime").extract[Long] val maybeError: Option[String] = (meta \ "error") match { case JString(str) => Some(str) case _ => None } - new FlintStatement(state, query, statementId, queryId, submitTime, maybeError) + new FlintStatement(state, query, statementId, queryId, langType, submitTime, maybeError) } def serialize(flintStatement: FlintStatement): String = { diff --git a/flint-commons/src/main/scala/org/opensearch/flint/common/model/LangType.scala b/flint-commons/src/main/scala/org/opensearch/flint/common/model/LangType.scala new file mode 100644 index 000000000..746219f4e --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/common/model/LangType.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.common.scheduler.model + +object LangType { + val SQL = "sql" + val PPL = "ppl" + + private val values = Seq(SQL, PPL) + + /** + * Get LangType from text. + * + * @param text + * input text + * @return + * Option[String] if found, None otherwise + */ + def fromString(text: String): Option[String] = { + values.find(_.equalsIgnoreCase(text)) + } +} diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 24f3c9a89..1ddfa540b 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -210,6 +210,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { | "state": "waiting", | "submitTime": $submitTime, | "type": "statement", + | "lang": "sql", | "statementId": "${statementId}", | "queryId": "${queryId}", | "dataSourceName": "${dataSourceName}" diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index cb4af86da..deee6eb1d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -13,6 +13,7 @@ import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.common.scheduler.model.LangType import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -71,7 +72,13 @@ case class JobOperator( instantiateStatementExecutionManager(commandContext, resultIndex, osClient) val statement = - new FlintStatement("running", query, "", queryId, currentTimeProvider.currentEpochMillis()) + new FlintStatement( + "running", + query, + "", + queryId, + LangType.SQL, + currentTimeProvider.currentEpochMillis()) var exceptionThrown = true var error: String = null diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index a3f990a59..355bd9ede 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -23,6 +23,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} +import org.opensearch.flint.common.scheduler.model.LangType import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -269,7 +270,8 @@ class FlintREPLTest val expected = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) - val flintStatement = new FlintStatement("failed", "select 1", "30", "10", currentTime, None) + val flintStatement = + new FlintStatement("failed", "select 1", "30", "10", LangType.SQL, currentTime, None) try { FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) @@ -711,7 +713,14 @@ class FlintREPLTest val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] val flintStatement = - new FlintStatement("running", "select 1", "30", "10", Instant.now().toEpochMilli(), None) + new FlintStatement( + "running", + "select 1", + "30", + "10", + LangType.SQL, + Instant.now().toEpochMilli(), + None) // When the `sql` method is called, execute the custom Answer that introduces a delay when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { override def answer(invocation: InvocationOnMock): DataFrame = { @@ -782,6 +791,7 @@ class FlintREPLTest "select * from default.http_logs limit1 1", "10", "20", + LangType.SQL, Instant.now().toEpochMilli, None) val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) @@ -1325,6 +1335,7 @@ class FlintREPLTest "query": "SELECT * FROM table", "statementId": "stmt123", "queryId": "query456", + "lang": "sql", "submitTime": 1234567890, "error": "Some error" }