Skip to content

Commit

Permalink
Refactor REPL mode
Browse files Browse the repository at this point in the history
refactor statement lifecycle

[WIP] Fix tests

fix some tests
  • Loading branch information
noCharger committed Aug 9, 2024
1 parent 80d8f6e commit f6b7dc1
Show file tree
Hide file tree
Showing 17 changed files with 1,033 additions and 819 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def reformatQueryResult(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryExecutionContext: StatementExecutionContext): DataFrame
def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

/**
* Trait defining the interface for managing interactive sessions.
*/
trait SessionManager {

/**
* Retrieves metadata about the session manager.
*/
def getSessionManagerMetadata: Map[String, Any]

/**
* Fetches the details of a specific session.
*/
def getSessionDetails(sessionId: String): Option[InteractiveSession]

/**
* Updates the details of a specific session.
*/
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit

/**
* Retrieves the next statement to be executed in a specific session.
*/
def getNextStatement(sessionId: String): Option[FlintStatement]

/**
* Records a heartbeat for a specific session to indicate it is still active.
*/
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,16 @@

package org.apache.spark.sql

import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.concurrent.duration.Duration

import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}

case class CommandContext(
case class StatementExecutionContext(
spark: SparkSession,
dataSource: String,
resultIndex: String,
sessionId: String,
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
statementLifecycleManager: StatementLifecycleManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

/**
* Trait defining the interface for managing the lifecycle of statements.
*/
trait StatementLifecycleManager {

/**
* Prepares the statement lifecycle.
*/
def prepareStatementLifecycle(): Either[String, Unit]

/**
* Updates a specific statement.
*/
def updateStatement(statement: FlintStatement): Unit

/**
* Terminates the statement lifecycle.
*/
def terminateStatementLifecycle(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FlintStatement(
val statementId: String,
val queryId: String,
val submitTime: Long,
var queryStartTime: Option[Long] = Some(-1L),
var error: Option[String] = None,
statementContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore {
Expand Down Expand Up @@ -76,7 +77,7 @@ object FlintStatement {
case _ => None
}

new FlintStatement(state, query, statementId, queryId, submitTime, maybeError)
new FlintStatement(state, query, statementId, queryId, submitTime, error = maybeError)
}

def serialize(flintStatement: FlintStatement): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.data

import java.util.{Map => JavaMap}
import java.util.{List => JavaList, Map => JavaMap}

import scala.collection.JavaConverters._

Expand All @@ -16,9 +16,9 @@ import org.json4s.native.Serialization

object SessionStates {
val RUNNING = "running"
val COMPLETE = "complete"
val FAILED = "failed"
val WAITING = "waiting"
val DEAD = "dead"
val FAIL = "fail"
val NOT_STARTED = "not_started"
}

/**
Expand Down Expand Up @@ -57,9 +57,9 @@ class InteractiveSession(
context = sessionContext // Initialize the context from the constructor

def isRunning: Boolean = state == SessionStates.RUNNING
def isComplete: Boolean = state == SessionStates.COMPLETE
def isFailed: Boolean = state == SessionStates.FAILED
def isWaiting: Boolean = state == SessionStates.WAITING
def isDead: Boolean = state == SessionStates.DEAD
def isFail: Boolean = state == SessionStates.FAIL
def isNotStarted: Boolean = state == SessionStates.NOT_STARTED

override def toString: String = {
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
Expand Down Expand Up @@ -129,10 +129,7 @@ object InteractiveSession {
}

// We safely handle the possibility of excludeJobIds being absent or not a list.
val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match {
case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String])
case _ => Seq.empty[String]
}
val excludeJobIds: Seq[String] = parseExcludedJobIds(scalaSource.get("excludeJobIds"))

// Handle error similarly, ensuring we get an Option[String].
val maybeError: Option[String] = scalaSource.get("error") match {
Expand Down Expand Up @@ -201,4 +198,13 @@ object InteractiveSession {
def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = {
serialize(job, currentTime, includeJobId = false)
}
private def parseExcludedJobIds(source: Option[Any]): Seq[String] = {
source match {
case Some(s: String) => Seq(s)
case Some(list: JavaList[_]) => list.asScala.toList.collect { case str: String => str }
case None => Seq.empty[String]
case _ =>
Seq.empty
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ object FlintSparkConf {
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
val CUSTOM_SESSION_MANAGER =
FlintConfig("spark.flint.job.customSessionManager")
.createOptional()
val CUSTOM_STATEMENT_MANAGER =
FlintConfig("spark.flint.job.customStatementManager")
.createOptional()
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
}

/**
Expand Down Expand Up @@ -277,6 +286,9 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
SESSION_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
CUSTOM_SESSION_MANAGER,
CUSTOM_STATEMENT_MANAGER,
CUSTOM_QUERY_RESULT_WRITER,
EXCLUDE_JOB_IDS)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object FlintJob extends Logging with FlintJobExecutor {
createSparkSession(conf),
query,
dataSource,
resultIndex,
resultIndex.get,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
new SparkConf()
.setAppName(getClass.getSimpleName)
Expand Down Expand Up @@ -175,7 +193,6 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Expand All @@ -188,29 +205,11 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()
val endTime = currentTimeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
Expand Down Expand Up @@ -245,28 +244,9 @@ trait FlintJobExecutor {
queryId: String,
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val endTime = timeProvider.currentEpochMillis()
startTime: Long): DataFrame = {

val endTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand Down Expand Up @@ -419,7 +399,6 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

Expand Down Expand Up @@ -485,16 +464,19 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
Some(resultIndex)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit f6b7dc1

Please sign in to comment.