Skip to content

Commit

Permalink
refactor statement lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 20, 2024
1 parent dab2343 commit 3cc87db
Show file tree
Hide file tree
Showing 13 changed files with 416 additions and 358 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ package org.apache.spark.sql
import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def reformat(dataFrame: DataFrame, flintStatement: FlintStatement): DataFrame
def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ trait SessionManager {
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean
def getNextStatement(sessionId: String): Option[FlintStatement]
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
val UPDATE, UPSERT, UPDATE_IF = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait StatementLifecycleManager {
def prepareStatementLifecycle(): Either[String, Unit]
def updateStatement(statement: FlintStatement): Unit
def terminateStatementLifecycle(): Unit
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Expand Down Expand Up @@ -210,7 +209,7 @@ trait FlintJobExecutor {
.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'"))

val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()
val endTime = currentTimeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
Expand Down Expand Up @@ -348,31 +347,6 @@ trait FlintJobExecutor {
compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson)
}

def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = {
try {
val existingSchema = osClient.getIndexMetadata(resultIndex)
if (!isSuperset(existingSchema, resultIndexMapping)) {
Left(s"The mapping of $resultIndex is incorrect.")
} else {
Right(())
}
} catch {
case e: IllegalStateException
if e.getCause != null &&
e.getCause.getMessage.contains("index_not_found_exception") =>
createResultIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Left(error)
}
}

def createResultIndex(
osClient: OSClient,
resultIndex: String,
Expand Down Expand Up @@ -419,7 +393,6 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

Expand Down Expand Up @@ -485,16 +458,19 @@ trait FlintJobExecutor {
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
def parseArgs(args: Array[String]): (Option[String], Option[String]) = {
args match {
case Array() =>
(None, None)
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
(None, Some(resultIndex)) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
Some(resultIndex)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
case _ =>
logAndThrow("Unsupported number of arguments. Expected no more than two arguments.")
}
}

Expand Down
Loading

0 comments on commit 3cc87db

Please sign in to comment.