Skip to content

Commit

Permalink
Refactor REPL mode
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Jun 13, 2024
1 parent 80d8f6e commit dab2343
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 275 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait QueryResultWriter {
def write(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.{FlintStatement, InteractiveSession}

import org.apache.spark.sql.SessionUpdateMode.SessionUpdateMode

trait SessionManager {
def getSessionManagerMetadata: Map[String, Any]
def getSessionDetails(sessionId: String): Option[InteractiveSession]
def updateSessionDetails(
sessionDetails: InteractiveSession,
updateMode: SessionUpdateMode): Unit
def hasPendingStatement(sessionId: String): Boolean
def recordHeartbeat(sessionId: String): Unit
}

object SessionUpdateMode extends Enumeration {
type SessionUpdateMode = Value
val Update, Upsert, UpdateIf = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.opensearch.flint.data.FlintStatement

trait StatementManager {
def prepareCommandLifecycle(): Either[String, Unit]
def initCommandLifecycle(sessionId: String): FlintStatement
def closeCommandLifecycle(): Unit
def updateCommandDetails(commandDetails: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.data

import java.util.{Map => JavaMap}
import java.util.{List => JavaList, Map => JavaMap}

import scala.collection.JavaConverters._

Expand All @@ -16,9 +16,9 @@ import org.json4s.native.Serialization

object SessionStates {
val RUNNING = "running"
val COMPLETE = "complete"
val FAILED = "failed"
val WAITING = "waiting"
val DEAD = "dead"
val FAIL = "fail"
val NOT_STARTED = "not_started"
}

/**
Expand Down Expand Up @@ -57,9 +57,9 @@ class InteractiveSession(
context = sessionContext // Initialize the context from the constructor

def isRunning: Boolean = state == SessionStates.RUNNING
def isComplete: Boolean = state == SessionStates.COMPLETE
def isFailed: Boolean = state == SessionStates.FAILED
def isWaiting: Boolean = state == SessionStates.WAITING
def isDead: Boolean = state == SessionStates.DEAD
def isFail: Boolean = state == SessionStates.FAIL
def isNotStarted: Boolean = state == SessionStates.NOT_STARTED

override def toString: String = {
val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]")
Expand Down Expand Up @@ -129,10 +129,7 @@ object InteractiveSession {
}

// We safely handle the possibility of excludeJobIds being absent or not a list.
val excludeJobIds: Seq[String] = scalaSource.get("excludeJobIds") match {
case Some(lst: java.util.List[_]) => lst.asScala.toList.map(_.asInstanceOf[String])
case _ => Seq.empty[String]
}
val excludeJobIds: Seq[String] = parseExcludedJobIds(scalaSource.get("excludeJobIds"))

// Handle error similarly, ensuring we get an Option[String].
val maybeError: Option[String] = scalaSource.get("error") match {
Expand Down Expand Up @@ -201,4 +198,13 @@ object InteractiveSession {
def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = {
serialize(job, currentTime, includeJobId = false)
}
private def parseExcludedJobIds(source: Option[Any]): Seq[String] = {
source match {
case Some(s: String) => Seq(s)
case Some(list: JavaList[_]) => list.asScala.toList.collect { case str: String => str }
case None => Seq.empty[String]
case _ =>
Seq.empty
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,21 @@ public class FlintOptions implements Serializable {

public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider";

public static final String CUSTOM_SESSION_MANAGER = "customSessionManager";

public static final String CUSTOM_STATEMENT_MANAGER = "customStatementManager";

public static final String CUSTOM_QUERY_RESULT_WRITER = "customQueryResultWriter";

/**
* By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

public static final String FLINT_SESSION_ID = "spark.flint.job.sessionId";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
*/
Expand Down Expand Up @@ -137,6 +145,18 @@ public String getMetadataAccessAwsCredentialsProvider() {
return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getCustomSessionManager() {
return options.getOrDefault(CUSTOM_SESSION_MANAGER, "");
}

public String getCustomStatementManager() {
return options.getOrDefault(CUSTOM_STATEMENT_MANAGER, "");
}

public String getCustomQueryResultWriter() {
return options.getOrDefault(CUSTOM_QUERY_RESULT_WRITER, "");
}

public String getUsername() {
return options.getOrDefault(USERNAME, "flint");
}
Expand All @@ -157,6 +177,10 @@ public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}

public String getSessionId() {
return options.getOrDefault(FLINT_SESSION_ID, null);
}

public int getBatchBytes() {
// we did not expect this value could be large than 10mb = 10 * 1024 * 1024
return (int) org.apache.spark.network.util.JavaUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ object FlintSparkConf {
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
val CUSTOM_SESSION_MANAGER =
FlintConfig("spark.flint.job.customSessionManager")
.createOptional()
val CUSTOM_STATEMENT_MANAGER =
FlintConfig("spark.flint.job.customStatementManager")
.createOptional()
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
}

/**
Expand Down Expand Up @@ -277,6 +286,9 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
SESSION_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
CUSTOM_SESSION_MANAGER,
CUSTOM_STATEMENT_MANAGER,
CUSTOM_QUERY_RESULT_WRITER,
EXCLUDE_JOB_IDS)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.flatMap {
Expand Down
Loading

0 comments on commit dab2343

Please sign in to comment.