Skip to content

Commit

Permalink
Clean up logs and add UTs
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 23, 2024
1 parent ce46922 commit 93f2e71
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

/**
* Trait for writing the result of a query execution to an external data storage.
*/
trait QueryResultWriter {

/**
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ package org.apache.spark.sql
import org.opensearch.flint.common.model.FlintStatement

/**
* Trait defining the interface for managing FlintStatements executing in a micro-batch within
* same session.
* Trait defining the interface for managing FlintStatement execution. For example, in FlintREPL,
* multiple FlintStatements are running in a micro-batch within same session.
*
* This interface can also apply to other spark entry point like FlintJob.
*/
trait StatementsExecutionManager {
trait StatementExecutionManager {

/**
* Prepares execution of each individual statement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class FlintStatement(

// Does not include context, which could contain sensitive information.
override def toString: String =
s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)"
s"FlintStatement(state=$state, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)"
}

object FlintStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
import org.opensearch.flint.core.RestHighLevelClientWrapper;
import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.http.RetryableHttpAsyncClient;
import java.util.logging.Logger;


/**
* Utility functions to create {@link IRestHighLevelClient}.
*/
public class OpenSearchClientUtils {
private static final Logger LOG = Logger.getLogger(OpenSearchClientUtils.class.getName());


/**
* Metadata log index name prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;

/**
* Abstract OpenSearch Reader.
*/
public abstract class OpenSearchReader implements FlintReader {
private static final Logger LOG = Logger.getLogger(OpenSearchReader.class.getName());

@VisibleForTesting
/** Search request source builder. */
public final SearchRequest searchRequest;
Expand All @@ -50,7 +47,6 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest
return false;
}
List<SearchHit> searchHits = Arrays.asList(response.get().getHits().getHits());
LOG.info("Result sets: " + searchHits.size());
iterator = searchHits.iterator();
}
return iterator.hasNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ import org.opensearch.common.xcontent.XContentType
import org.opensearch.testcontainers.OpenSearchContainer
import org.scalatest.{BeforeAndAfterAll, Suite}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, IGNORE_DOC_ID_COLUMN, REFRESH_POLICY}

/**
* Test required OpenSearch domain should extend OpenSearchSuite.
*/
trait OpenSearchSuite extends BeforeAndAfterAll with Logging {
trait OpenSearchSuite extends BeforeAndAfterAll {
self: Suite =>

protected lazy val container = new OpenSearchContainer()
Expand Down Expand Up @@ -146,7 +145,6 @@ trait OpenSearchSuite extends BeforeAndAfterAll with Logging {

val response =
openSearchClient.bulk(request, RequestOptions.DEFAULT)
logInfo(response.toString)
assume(
!response.hasFailures,
s"bulk index docs to $index failed: ${response.buildFailureMessage()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ case class CommandContext(
val sessionId: String,
val sessionManager: SessionManager,
val jobId: String,
var statementsExecutionManager: StatementsExecutionManager,
var statementsExecutionManager: StatementExecutionManager,
val queryResultWriter: QueryResultWriter,
val queryExecutionTimeout: Duration,
val inactivityLimitMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ object FlintREPL extends Logging with FlintJobExecutor {

// init SparkContext
val conf: SparkConf = createSparkConf()
val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown")
val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "")

if (dataSource == "unknown") {
logInfo(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set")
if (dataSource.trim.isEmpty) {
logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set or is empty")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
Expand Down Expand Up @@ -323,7 +323,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
.currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) {
logInfo(s"""Executing session with sessionId: ${sessionId}""")
val statementsExecutionManager =
instantiateStatementsExecutionManager(
instantiateStatementExecutionManager(
spark,
sessionId,
dataSource,
Expand Down Expand Up @@ -514,7 +514,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
statementsExecutionManager.getNextStatement() match {
case Some(flintStatement) =>
flintStatement.running()
logDebug(s"command running: $flintStatement")
statementsExecutionManager.updateStatement(flintStatement)
statementRunningCount.incrementAndGet()

Expand Down Expand Up @@ -606,7 +605,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
def executeAndHandle(
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementsExecutionManager,
statementsExecutionManager: StatementExecutionManager,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -618,7 +617,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
executeQueryAsync(
spark,
flintStatement,
statementsExecutionManager: StatementsExecutionManager,
statementsExecutionManager: StatementExecutionManager,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -734,7 +733,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
def executeQueryAsync(
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementsExecutionManager,
statementsExecutionManager: StatementExecutionManager,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand Down Expand Up @@ -919,12 +918,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
result.getOrElse(throw new RuntimeException("Failed after retries"))
}

private def getSessionId(conf: SparkConf): String = {
val sessionIdOption: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null))
if (sessionIdOption.isEmpty) {
logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set")
def getSessionId(conf: SparkConf): String = {
conf.getOption(FlintSparkConf.SESSION_ID.key) match {
case Some(sessionId) if sessionId.nonEmpty =>
sessionId
case _ =>
logAndThrow(s"${FlintSparkConf.SESSION_ID.key} is not set or is empty")
}
sessionIdOption.get
}

private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
Expand Down Expand Up @@ -956,13 +956,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""))
}

private def instantiateStatementsExecutionManager(
private def instantiateStatementExecutionManager(
spark: SparkSession,
sessionId: String,
dataSource: String,
context: Map[String, Any]): StatementsExecutionManager = {
context: Map[String, Any]): StatementExecutionManager = {
instantiate(
new StatementsExecutionManagerImpl(spark, sessionId, dataSource, context),
new StatementExecutionManagerImpl(spark, sessionId, dataSource, context),
spark.sparkContext.getConf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class OSClient(val flintOptions: FlintOptions) extends Logging {
case Success(response) =>
IRestHighLevelClient.recordOperationSuccess(
MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX)
logInfo(response.toString)
response
case Failure(e: Exception) =>
IRestHighLevelClient.recordOperationFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch

class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {

val resultIndex = context("resultIndex").asInstanceOf[String]
val osClient = context("osClient").asInstanceOf[OSClient]
private val resultIndex = context("resultIndex").asInstanceOf[String]
private val osClient = context("osClient").asInstanceOf[OSClient]

override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String])
with FlintJobExecutor
with Logging {

// we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found.
val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}

// we don't allow default value for sessionIndex. Throw exception if key not found.
private val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")

if (sessionIndex.isEmpty) {
logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set")
}

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}

val osClient = new OSClient(FlintSparkConf().flintOptions())
val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex)
private val osClient = new OSClient(FlintSparkConf().flintOptions())
private val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex)

override def getSessionContext: Map[String, Any] = {
Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,29 @@ import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.{createResultIndex, isSuperset, resultIndexMapping}
import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createResultIndex, isSuperset, resultIndexMapping}
import org.apache.spark.sql.FlintREPL.executeQuery

class StatementsExecutionManagerImpl(
class StatementExecutionManagerImpl(
spark: SparkSession,
sessionId: String,
dataSource: String,
context: Map[String, Any])
extends StatementsExecutionManager
extends StatementExecutionManager
with Logging {

val sessionIndex = context("sessionIndex").asInstanceOf[String]
val resultIndex = context("resultIndex").asInstanceOf[String]
val osClient = context("osClient").asInstanceOf[OSClient]
val flintSessionIndexUpdater =
private val sessionIndex = context("sessionIndex").asInstanceOf[String]
private val resultIndex = context("resultIndex").asInstanceOf[String]
private val osClient = context("osClient").asInstanceOf[OSClient]
private val flintSessionIndexUpdater =
context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater]

// Using one reader client within same session will cause concurrency issue.
// To resolve this move the reader creation and getNextStatement method to mirco-batch level
val flintReader = createOpenSearchQueryReader()
private val flintReader = createOpenSearchQueryReader()

override def prepareStatementExecution(): Either[String, Unit] = {
try {
val existingSchema = osClient.getIndexMetadata(resultIndex)
if (!isSuperset(existingSchema, resultIndexMapping)) {
Left(s"The mapping of $resultIndex is incorrect.")
} else {
Right(())
}
} catch {
case e: IllegalStateException
if e.getCause != null &&
e.getCause.getMessage.contains("index_not_found_exception") =>
createResultIndex(osClient, resultIndex, resultIndexMapping)
case e: InterruptedException =>
val error = s"Interrupted by the main thread: ${e.getMessage}"
Thread.currentThread().interrupt() // Preserve the interrupt status
logError(error, e)
Left(error)
case e: Exception =>
val error = s"Failed to verify existing mapping: ${e.getMessage}"
logError(error, e)
Left(error)
}
checkAndCreateIndex(osClient, resultIndex)
}
override def updateStatement(statement: FlintStatement): Unit = {
flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement))
Expand All @@ -65,9 +44,8 @@ class StatementsExecutionManagerImpl(
override def getNextStatement(): Option[FlintStatement] = {
if (flintReader.hasNext) {
val rawStatement = flintReader.next()
logInfo(s"raw statement: $rawStatement")
val flintStatement = FlintStatement.deserialize(rawStatement)
logInfo(s"statement: $flintStatement")
logInfo(s"Next statement to execute: $flintStatement")
Some(flintStatement)
} else {
None
Expand Down Expand Up @@ -114,7 +92,6 @@ class StatementsExecutionManagerImpl(
| ]
| }
|}""".stripMargin

val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC)
flintReader
}
Expand Down
Loading

0 comments on commit 93f2e71

Please sign in to comment.