Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 21, 2024
1 parent dab2343 commit 06b609d
Show file tree
Hide file tree
Showing 17 changed files with 502 additions and 539 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ package org.apache.spark.sql
import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
def reformatQueryResult(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryExecutionContext: StatementExecutionContext): DataFrame
def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,40 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

/**
* Trait defining the interface for managing interactive sessions.
*/
trait SessionManager {

/**
* Retrieves metadata about the session manager.
*/
def getSessionManagerMetadata: Map[String, Any]

/**
* Fetches the details of a specific session.
*/
def getSessionDetails(sessionId: String): Option[InteractiveSession]

/**
* Updates the details of a specific session.
*/
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean

/**
* Retrieves the next statement to be executed in a specific session.
*/
def getNextStatement(sessionId: String): Option[FlintStatement]

/**
* Records a heartbeat for a specific session to indicate it is still active.
*/
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ package org.apache.spark.sql

import scala.concurrent.duration.Duration

case class QueryExecutionContext(
case class StatementExecutionContext(
spark: SparkSession,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
statementLifecycleManager: StatementLifecycleManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
resultIndex: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

/**
* Trait defining the interface for managing the lifecycle of statements.
*/
trait StatementLifecycleManager {

/**
* Prepares the statement lifecycle.
*/
def prepareStatementLifecycle(): Either[String, Unit]

/**
* Updates a specific statement.
*/
def updateStatement(statement: FlintStatement): Unit

/**
* Terminates the statement lifecycle.
*/
def terminateStatementLifecycle(): Unit
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FlintStatement(
val statementId: String,
val queryId: String,
val submitTime: Long,
var queryStartTime: Option[Long] = Some(-1L),
var error: Option[String] = None,
statementContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore {
Expand Down Expand Up @@ -76,7 +77,7 @@ object FlintStatement {
case _ => None
}

new FlintStatement(state, query, statementId, queryId, submitTime, maybeError)
new FlintStatement(state, query, statementId, queryId, submitTime, error = maybeError)
}

def serialize(flintStatement: FlintStatement): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,13 @@ public class FlintOptions implements Serializable {

public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider";

public static final String CUSTOM_SESSION_MANAGER = "customSessionManager";

public static final String CUSTOM_STATEMENT_MANAGER = "customStatementManager";

public static final String CUSTOM_QUERY_RESULT_WRITER = "customQueryResultWriter";

/**
* By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

public static final String FLINT_SESSION_ID = "spark.flint.job.sessionId";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
*/
Expand Down Expand Up @@ -145,18 +137,6 @@ public String getMetadataAccessAwsCredentialsProvider() {
return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getCustomSessionManager() {
return options.getOrDefault(CUSTOM_SESSION_MANAGER, "");
}

public String getCustomStatementManager() {
return options.getOrDefault(CUSTOM_STATEMENT_MANAGER, "");
}

public String getCustomQueryResultWriter() {
return options.getOrDefault(CUSTOM_QUERY_RESULT_WRITER, "");
}

public String getUsername() {
return options.getOrDefault(USERNAME, "flint");
}
Expand All @@ -177,10 +157,6 @@ public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}

public String getSessionId() {
return options.getOrDefault(FLINT_SESSION_ID, null);
}

public int getBatchBytes() {
// we did not expect this value could be large than 10mb = 10 * 1024 * 1024
return (int) org.apache.spark.network.util.JavaUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object FlintJob extends Logging with FlintJobExecutor {
createSparkSession(conf),
query,
dataSource,
resultIndex,
resultIndex.get,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
new SparkConf()
.setAppName(getClass.getSimpleName)
Expand Down Expand Up @@ -175,7 +193,6 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Expand All @@ -188,29 +205,11 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()
val endTime = currentTimeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
Expand Down Expand Up @@ -245,28 +244,9 @@ trait FlintJobExecutor {
queryId: String,
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val endTime = timeProvider.currentEpochMillis()
startTime: Long): DataFrame = {

val endTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand Down Expand Up @@ -419,7 +399,6 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

Expand Down Expand Up @@ -485,16 +464,19 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
Some(resultIndex)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit 06b609d

Please sign in to comment.