Skip to content

Commit

Permalink
REPL refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 15, 2024
1 parent c733d3c commit 1b277ad
Show file tree
Hide file tree
Showing 13 changed files with 813 additions and 524 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

trait QueryResultWriter {
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

/**
* Trait defining the interface for managing interactive sessions.
*/
trait SessionManager {

/**
* Retrieves metadata about the session manager.
*/
def getSessionContext: Map[String, Any]

/**
* Fetches the details of a specific session.
*/
def getSessionDetails(sessionId: String): Option[InteractiveSession]

/**
* Updates the details of a specific session.
*/
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit

/**
* Retrieves the next statement to be executed in a specific session.
*/
def getNextStatement(sessionId: String): Option[FlintStatement]

/**
* Records a heartbeat for a specific session to indicate it is still active.
*/
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

/**
* Trait defining the interface for managing the lifecycle of executing a FlintStatement.
*/
trait StatementLifecycleManager {

/**
* Prepares the statement lifecycle.
*/
def prepareStatementLifecycle(): Either[String, Unit]

// def executeStatement(statement: FlintStatement): DataFrame

/**
* Updates a specific statement.
*/
def updateStatement(statement: FlintStatement): Unit

/**
* Terminates the statement lifecycle.
*/
def terminateStatementLifecycle(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ object FlintSparkConf {
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
val CUSTOM_SESSION_MANAGER =
FlintConfig("spark.flint.job.customSessionManager")
.createOptional()
val CUSTOM_STATEMENT_MANAGER =
FlintConfig("spark.flint.job.customStatementManager")
.createOptional()
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
case class CommandContext(
spark: SparkSession,
dataSource: String,
resultIndex: String,
sessionId: String,
flintSessionIndexUpdater: OpenSearchUpdater,
osClient: OSClient,
sessionIndex: String,
sessionManager: SessionManager,
jobId: String,
statementLifecycleManager: StatementLifecycleManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import org.opensearch.flint.core.storage.FlintReader
case class CommandState(
recordedLastActivityTime: Long,
recordedVerificationResult: VerificationResult,
flintReader: FlintReader,
futureMappingCheck: Future[Either[String, Unit]],
futurePrepareQueryExecution: Future[Either[String, Unit]],
executionContext: ExecutionContextExecutor,
recordedLastCanPickCheckTime: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
*/
object FlintJob extends Logging with FlintJobExecutor {
def main(args: Array[String]): Unit = {
val (queryOption, resultIndex) = parseArgs(args)
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
Expand All @@ -41,6 +41,9 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand All @@ -58,7 +61,7 @@ object FlintJob extends Logging with FlintJobExecutor {
createSparkSession(conf),
query,
dataSource,
resultIndex,
resultIndexOption.get,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,16 +493,21 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
/**
* Before OS 2.13, there are two arguments from entry point: query and result index Starting
* from OS 2.13, query is optional for FlintREPL And since Flint 0.5, result index is also
* optional for non-OpenSearch result persist
*/
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex))
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
(Some(query), Some(resultIndex))
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit 1b277ad

Please sign in to comment.