Skip to content

Commit

Permalink
Add langType to FlintStatement model (#664)
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Sep 14, 2024
1 parent 0be5697 commit 7a4c206
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ object StatementStates {
* Unique identifier for the type of statement.
* @param queryId
* Unique identifier for the query.
* @param langType
* The language type of the query (e.g., "sql" or "ppl").
* @param submitTime
* Timestamp when the statement was submitted.
* @param error
Expand All @@ -44,6 +46,7 @@ class FlintStatement(
// statementId is the statement type doc id
val statementId: String,
val queryId: String,
val langType: String,
val submitTime: Long,
var error: Option[String] = None,
statementContext: Map[String, Any] = Map.empty[String, Any])
Expand All @@ -65,7 +68,7 @@ class FlintStatement(

// Does not include context, which could contain sensitive information.
override def toString: String =
s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)"
s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, langType=$langType, submitTime=$submitTime, error=$error)"
}

object FlintStatement {
Expand All @@ -78,13 +81,14 @@ object FlintStatement {
val query = (meta \ "query").extract[String]
val statementId = (meta \ "statementId").extract[String]
val queryId = (meta \ "queryId").extract[String]
val langType = (meta \ "lang").extract[String].toLowerCase(Locale.ROOT)
val submitTime = (meta \ "submitTime").extract[Long]
val maybeError: Option[String] = (meta \ "error") match {
case JString(str) => Some(str)
case _ => None
}

new FlintStatement(state, query, statementId, queryId, submitTime, maybeError)
new FlintStatement(state, query, statementId, queryId, langType, submitTime, maybeError)
}

def serialize(flintStatement: FlintStatement): String = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.common.scheduler.model

object LangType {
val SQL = "sql"
val PPL = "ppl"

private val values = Seq(SQL, PPL)

/**
* Get LangType from text.
*
* @param text
* input text
* @return
* Option[String] if found, None otherwise
*/
def fromString(text: String): Option[String] = {
values.find(_.equalsIgnoreCase(text))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
| "state": "waiting",
| "submitTime": $submitTime,
| "type": "statement",
| "lang": "sql",
| "statementId": "${statementId}",
| "queryId": "${queryId}",
| "dataSourceName": "${dataSourceName}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import scala.concurrent.duration.{Duration, MINUTES}
import scala.util.{Failure, Success, Try}

import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.common.scheduler.model.LangType
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.spark.FlintSpark
Expand Down Expand Up @@ -71,7 +72,13 @@ case class JobOperator(
instantiateStatementExecutionManager(commandContext, resultIndex, osClient)

val statement =
new FlintStatement("running", query, "", queryId, currentTimeProvider.currentEpochMillis())
new FlintStatement(
"running",
query,
"",
queryId,
LangType.SQL,
currentTimeProvider.currentEpochMillis())

var exceptionThrown = true
var error: String = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.opensearch.action.get.GetResponse
import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates}
import org.opensearch.flint.common.scheduler.model.LangType
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder
import org.scalatest.prop.TableDrivenPropertyChecks._
Expand Down Expand Up @@ -269,7 +270,8 @@ class FlintREPLTest
val expected =
spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema)

val flintStatement = new FlintStatement("failed", "select 1", "30", "10", currentTime, None)
val flintStatement =
new FlintStatement("failed", "select 1", "30", "10", LangType.SQL, currentTime, None)

try {
FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime)
Expand Down Expand Up @@ -711,7 +713,14 @@ class FlintREPLTest
val startTime = System.currentTimeMillis()
val expectedDataFrame = mock[DataFrame]
val flintStatement =
new FlintStatement("running", "select 1", "30", "10", Instant.now().toEpochMilli(), None)
new FlintStatement(
"running",
"select 1",
"30",
"10",
LangType.SQL,
Instant.now().toEpochMilli(),
None)
// When the `sql` method is called, execute the custom Answer that introduces a delay
when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] {
override def answer(invocation: InvocationOnMock): DataFrame = {
Expand Down Expand Up @@ -782,6 +791,7 @@ class FlintREPLTest
"select * from default.http_logs limit1 1",
"10",
"20",
LangType.SQL,
Instant.now().toEpochMilli,
None)
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1)
Expand Down Expand Up @@ -1325,6 +1335,7 @@ class FlintREPLTest
"query": "SELECT * FROM table",
"statementId": "stmt123",
"queryId": "query456",
"lang": "sql",
"submitTime": 1234567890,
"error": "Some error"
}
Expand Down

0 comments on commit 7a4c206

Please sign in to comment.