Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 18, 2024
1 parent dab2343 commit 8b6fa08
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ package org.apache.spark.sql
import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def reformat(dataFrame: DataFrame, flintStatement: FlintStatement): DataFrame
def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ trait SessionManager {
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean
def getNextStatement(sessionId: String): Option[FlintStatement]
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait StatementLifecycleManager {
def initializeStatementLifecycle(): Either[String, Unit]
def updateStatement(statement: FlintStatement): Unit
def setLocalProperty(statement: FlintStatement): Unit
def terminateStatementLifecycle(): Unit
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,7 @@ trait FlintJobExecutor {
* @return
* dataframe with result, schema and emr step id
*/
def getFormattedData(
result: DataFrame,
spark: SparkSession,
dataSource: String,
queryId: String,
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
def getFormattedData(result: DataFrame, spark: SparkSession, dataSource: String, queryId: String, query: String, sessionId: String, startTime: Long, cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
Expand Down Expand Up @@ -210,7 +201,7 @@ trait FlintJobExecutor {
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()
val endTime = currentTimeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
Expand Down Expand Up @@ -348,31 +339,6 @@ trait FlintJobExecutor {
compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson)
}

def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = {
try {
val existingSchema = osClient.getIndexMetadata(resultIndex)
if (!isSuperset(existingSchema, resultIndexMapping)) {
Left(s"The mapping of $resultIndex is incorrect.")
} else {
Right(())
}
} catch {
case e: IllegalStateException
if e.getCause != null &&
e.getCause.getMessage.contains("index_not_found_exception") =>
createResultIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Left(error)
}
}

def createResultIndex(
osClient: OSClient,
resultIndex: String,
Expand Down Expand Up @@ -411,16 +377,7 @@ trait FlintJobExecutor {
spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true)
val result: DataFrame = spark.sql(query)
// Get Data
getFormattedData(
result,
spark,
dataSource,
queryId,
query,
sessionId,
startTime,
currentTimeProvider,
CleanerFactory.cleaner(streaming))
getFormattedData(result, spark, dataSource, queryId, query, sessionId, startTime, CleanerFactory.cleaner(streaming))
}

private def handleQueryException(
Expand Down Expand Up @@ -485,16 +442,16 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
(Some(query), Some(resultIndex)) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit 8b6fa08

Please sign in to comment.