Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 21, 2024
1 parent dab2343 commit f140511
Show file tree
Hide file tree
Showing 18 changed files with 506 additions and 543 deletions.
8 changes: 4 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ lazy val flintCommons = (project in file("flint-commons"))
),
libraryDependencies ++= deps(sparkVersion),
publish / skip := true,
assembly / test := (Test / test).value,
assembly / test := {},
assembly / assemblyOption ~= {
_.withIncludeScala(false)
},
Expand Down Expand Up @@ -149,7 +149,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration"))
val oldStrategy = (assembly / assemblyMergeStrategy).value
oldStrategy(x)
},
assembly / test := (Test / test).value)
assembly / test := {})

lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
.dependsOn(flintCore, flintCommons)
Expand Down Expand Up @@ -193,7 +193,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
val cp = (assembly / fullClasspath).value
cp filter { file => file.data.getName.contains("LogsConnectorSpark")}
},
assembly / test := (Test / test).value)
assembly / test := {})

// Test assembly package with integration test.
lazy val integtest = (project in file("integ-test"))
Expand Down Expand Up @@ -269,7 +269,7 @@ lazy val sparkSqlApplication = (project in file("spark-sql-application"))
val oldStrategy = (assembly / assemblyMergeStrategy).value
oldStrategy(x)
},
assembly / test := (Test / test).value
assembly / test := {}
)

lazy val sparkSqlApplicationCosmetic = project
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ package org.apache.spark.sql
import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
def reformatQueryResult(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryExecutionContext: StatementExecutionContext): DataFrame
def persistQueryResult(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,40 @@ import org.opensearch.flint.data.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

/**
* Trait defining the interface for managing interactive sessions.
*/
trait SessionManager {

/**
* Retrieves metadata about the session manager.
*/
def getSessionManagerMetadata: Map[String, Any]

/**
* Fetches the details of a specific session.
*/
def getSessionDetails(sessionId: String): Option[InteractiveSession]

/**
* Updates the details of a specific session.
*/
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean

/**
* Retrieves the next statement to be executed in a specific session.
*/
def getNextStatement(sessionId: String): Option[FlintStatement]

/**
* Records a heartbeat for a specific session to indicate it is still active.
*/
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ package org.apache.spark.sql

import scala.concurrent.duration.Duration

case class QueryExecutionContext(
case class StatementExecutionContext(
spark: SparkSession,
jobId: String,
sessionId: String,
sessionManager: SessionManager,
statementLifecycleManager: StatementLifecycleManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
resultIndex: String,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

/**
* Trait defining the interface for managing the lifecycle of statements.
*/
trait StatementLifecycleManager {

/**
* Prepares the statement lifecycle.
*/
def prepareStatementLifecycle(): Either[String, Unit]

/**
* Updates a specific statement.
*/
def updateStatement(statement: FlintStatement): Unit

/**
* Terminates the statement lifecycle.
*/
def terminateStatementLifecycle(): Unit
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FlintStatement(
val statementId: String,
val queryId: String,
val submitTime: Long,
var queryStartTime: Option[Long] = Some(-1L),
var error: Option[String] = None,
statementContext: Map[String, Any] = Map.empty[String, Any])
extends ContextualDataStore {
Expand Down Expand Up @@ -76,7 +77,7 @@ object FlintStatement {
case _ => None
}

new FlintStatement(state, query, statementId, queryId, submitTime, maybeError)
new FlintStatement(state, query, statementId, queryId, submitTime, error = maybeError)
}

def serialize(flintStatement: FlintStatement): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,13 @@ public class FlintOptions implements Serializable {

public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider";

public static final String CUSTOM_SESSION_MANAGER = "customSessionManager";

public static final String CUSTOM_STATEMENT_MANAGER = "customStatementManager";

public static final String CUSTOM_QUERY_RESULT_WRITER = "customQueryResultWriter";

/**
* By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

public static final String FLINT_SESSION_ID = "spark.flint.job.sessionId";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
*/
Expand Down Expand Up @@ -145,18 +137,6 @@ public String getMetadataAccessAwsCredentialsProvider() {
return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getCustomSessionManager() {
return options.getOrDefault(CUSTOM_SESSION_MANAGER, "");
}

public String getCustomStatementManager() {
return options.getOrDefault(CUSTOM_STATEMENT_MANAGER, "");
}

public String getCustomQueryResultWriter() {
return options.getOrDefault(CUSTOM_QUERY_RESULT_WRITER, "");
}

public String getUsername() {
return options.getOrDefault(USERNAME, "flint");
}
Expand All @@ -177,10 +157,6 @@ public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}

public String getSessionId() {
return options.getOrDefault(FLINT_SESSION_ID, null);
}

public int getBatchBytes() {
// we did not expect this value could be large than 10mb = 10 * 1024 * 1024
return (int) org.apache.spark.network.util.JavaUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object FlintJob extends Logging with FlintJobExecutor {
createSparkSession(conf),
query,
dataSource,
resultIndex,
resultIndex.get,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ trait FlintJobExecutor {
}
}""".stripMargin

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

def createSparkConf(): SparkConf = {
new SparkConf()
.setAppName(getClass.getSimpleName)
Expand Down Expand Up @@ -175,7 +193,6 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Expand All @@ -188,29 +205,11 @@ trait FlintJobExecutor {
StructField("column_name", StringType, nullable = false),
StructField("data_type", StringType, nullable = false))))

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val resultToSave = result.toJSON.collect.toList
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()
val endTime = currentTimeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
Expand Down Expand Up @@ -245,28 +244,9 @@ trait FlintJobExecutor {
queryId: String,
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {

// Define the data schema
val schema = StructType(
Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
StructField("schema", ArrayType(StringType, containsNull = true), nullable = true),
StructField("jobRunId", StringType, nullable = true),
StructField("applicationId", StringType, nullable = true),
StructField("dataSourceName", StringType, nullable = true),
StructField("status", StringType, nullable = true),
StructField("error", StringType, nullable = true),
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))

val endTime = timeProvider.currentEpochMillis()
startTime: Long): DataFrame = {

val endTime = currentTimeProvider.currentEpochMillis()

// Create the data rows
val rows = Seq(
Expand Down Expand Up @@ -419,7 +399,6 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

Expand Down Expand Up @@ -485,16 +464,19 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
Some(resultIndex)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit f140511

Please sign in to comment.