Skip to content

Commit

Permalink
[POC] Store state and error using QueryMetadataService (#608)
Browse files Browse the repository at this point in the history
* Store state and error using QueryMetadataService

Signed-off-by: Tomoyuki Morita <[email protected]>

* Address comments

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored Aug 30, 2024
1 parent b03bdfd commit 7972895
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ public class FlintOptions implements Serializable {
// New constants
public static final String CUSTOM_SESSION_MANAGER = "customSessionManager";
public static final String CUSTOM_COMMAND_LIFECYCLE_MANAGER = "customCommandLifecycleManager";

public static final String CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS = "customFlintMetadataLogServiceClass";

public static final String CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS = "customFlintIndexMetadataServiceClass";

// TODO: This is POC specific
public static final String CUSTOM_QUERY_METADATA_SERVICE = "spark.flint.job.customQueryMetadataServiceClass";

public FlintOptions(Map<String, String> options) {
this.options = options;
this.retryOptions = new FlintRetryOptions(options);
Expand Down Expand Up @@ -188,6 +190,11 @@ public String getCustomCommandLifecycleManager() {
return options.getOrDefault(CUSTOM_COMMAND_LIFECYCLE_MANAGER, "org.apache.spark.sql.CommandLifecycleManagerImpl");
}

// TODO: This is POC specific
public String getCustomQueryMetadataService() {
return options.getOrDefault(CUSTOM_QUERY_METADATA_SERVICE, "org.apache.spark.sql.NoOpQueryMetadataService");
}

public String getRequestMetadata() {
return options.getOrDefault("spark.flint.job.requestIndex", null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data

object QueryState {
val WAITING = "waiting"
val RUNNING = "running"
val SUCCESS = "success"
val FAILED = "failed"
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.sessionId")
.doc("Flint session id")
.createOptional()
val QUERY_ID =
FlintConfig(s"spark.flint.job.queryId")
.doc("Flint query id")
.createOptional()
val REQUEST_INDEX =
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
Expand Down Expand Up @@ -238,6 +242,11 @@ object FlintSparkConf {
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()

// TODO: This is POC specific
val CUSTOM_QUERY_METADATA_SERVICE =
FlintConfig(FlintOptions.CUSTOM_QUERY_METADATA_SERVICE)
.createWithDefault("org.apache.spark.sql.NoOpQueryMetadataService")
}

/**
Expand Down Expand Up @@ -307,6 +316,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
REPL_INACTIVITY_TIMEOUT_MILLIS,
CUSTOM_SESSION_MANAGER,
CUSTOM_COMMAND_LIFECYCLE_MANAGER,
CUSTOM_QUERY_METADATA_SERVICE,
BATCH_BYTES)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.toMap
Expand All @@ -317,6 +327,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS,
CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS,
SESSION_ID,
QUERY_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
CUSTOM_QUERY_RESULT_WRITER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_SESSION_MANAGER, MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT, REQUEST_INDEX}
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_QUERY_METADATA_SERVICE, CUSTOM_SESSION_MANAGER, MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT, REQUEST_INDEX}

class FlintSparkConfSuite extends FlintSuite {
test("test spark conf") {
Expand Down Expand Up @@ -116,6 +116,19 @@ class FlintSparkConfSuite extends FlintSuite {
}
}

test("test getCustomQueryMetadataService") {
withSparkConf(CUSTOM_QUERY_METADATA_SERVICE.key) {
// default value
val defaultFlintOptions = FlintSparkConf().flintOptions()
assert(
defaultFlintOptions.getCustomQueryMetadataService == "org.apache.spark.sql.NoOpQueryMetadataService")

setFlintSparkConf(CUSTOM_QUERY_METADATA_SERVICE, "custom.query.metadata.ClassName")
val flintOptions = FlintSparkConf().flintOptions()
assert(flintOptions.getCustomQueryMetadataService == "custom.query.metadata.ClassName")
}
}

test("test getRequestMetadata") {
withSparkConf(REQUEST_INDEX.key) {
// default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}

val queryId = conf.get("spark.flint.job.queryId")
if (queryId.isEmpty) {
logWarning("Query ID was not specified.")
}

// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand All @@ -59,6 +65,7 @@ object FlintJob extends Logging with FlintJobExecutor {
query,
dataSource,
resultIndex,
queryId,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,19 @@ object FlintREPL extends Logging with FlintJobExecutor {
logInfo(s"""streaming query ${query}""")
configDYNMaxExecutors(conf, jobType)
val streamingRunningCount = new AtomicInteger(0)

val queryId = conf.get("spark.flint.job.queryId")
if (queryId.isEmpty) {
logWarning("Query ID was not specified.")
}

val jobOperator =
JobOperator(
createSparkSession(conf),
query,
dataSource,
resultIndex,
queryId,
true,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.data.QueryState

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.sql.util.{CustomClassLoader, ShuffleCleaner}
import org.apache.spark.util.ThreadUtils

case class JobOperator(
spark: SparkSession,
query: String,
dataSource: String,
resultIndex: String,
queryId: String,
streaming: Boolean,
streamingRunningCount: AtomicInteger)
extends Logging
Expand All @@ -42,10 +44,14 @@ case class JobOperator(
val startTime = System.currentTimeMillis()
streamingRunningCount.incrementAndGet()

val flintSparkConf = FlintSparkConf()
val queryMetadataService = CustomClassLoader(flintSparkConf).getQueryMetadataService()
queryMetadataService.updateQueryState(queryId, QueryState.RUNNING, null)
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
val osClient = new OSClient(flintSparkConf.flintOptions())
var exceptionThrown = true
var error: String = null
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
Expand All @@ -61,30 +67,31 @@ case class JobOperator(
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e)
error = processQueryException(e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception =>
exceptionThrown = true
error = s"Failed to write to result index. originalError='${error}'"
logError(error, e)
}
val queryState = if (exceptionThrown) QueryState.FAILED else QueryState.SUCCESS
queryMetadataService.updateQueryState(queryId, queryState, error);

cleanUpResources(exceptionThrown, threadPool)
}
}

def cleanUpResources(
exceptionThrown: Boolean,
threadPool: ThreadPoolExecutor,
dataToWrite: Option[DataFrame],
resultIndex: String,
osClient: OSClient): Unit = {
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception => logError("fail to write to result index", e)
}
def cleanUpResources(exceptionThrown: Boolean, threadPool: ThreadPoolExecutor): Unit = {

try {
// Wait for streaming job complete if no error and there is streaming job running
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf

/**
* Temporary default implementation for QueryMetadataService. This should be replaced with an
* implementation which write status to OpenSearch index
*/
class NoOpQueryMetadataService(flintSparkConf: FlintSparkConf)
extends QueryMetadataService
with Logging {

override def updateQueryState(queryId: String, state: String, error: String): Unit =
logInfo(s"updateQueryState: queryId=${queryId}, state=`${state}`, error=`${error}`")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

/** Interface for updating query state and error. */
trait QueryMetadataService {
def updateQueryState(queryId: String, state: String, error: String): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.sql.QueryMetadataService
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.Utils

case class CustomClassLoader(flintSparkConf: FlintSparkConf) {

def getQueryMetadataService(): QueryMetadataService = {
instantiateClass[QueryMetadataService](
flintSparkConf.flintOptions().getCustomQueryMetadataService)
}

private def instantiateClass[T](className: String): T = {
try {
val providerClass = Utils.classForName(className)
val ctor = providerClass.getDeclaredConstructor(classOf[FlintSparkConf])
ctor.setAccessible(true)
ctor.newInstance(flintSparkConf).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

0 comments on commit 7972895

Please sign in to comment.