Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Store state and error using QueryMetadataService #608

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ public class FlintOptions implements Serializable {
// New constants
public static final String CUSTOM_SESSION_MANAGER = "customSessionManager";
public static final String CUSTOM_COMMAND_LIFECYCLE_MANAGER = "customCommandLifecycleManager";

public static final String CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS = "customFlintMetadataLogServiceClass";

public static final String CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS = "customFlintIndexMetadataServiceClass";

// TODO: This is POC specific
public static final String CUSTOM_QUERY_METADATA_SERVICE = "spark.flint.job.customQueryMetadataServiceClass";

public FlintOptions(Map<String, String> options) {
this.options = options;
this.retryOptions = new FlintRetryOptions(options);
Expand Down Expand Up @@ -188,6 +190,11 @@ public String getCustomCommandLifecycleManager() {
return options.getOrDefault(CUSTOM_COMMAND_LIFECYCLE_MANAGER, "org.apache.spark.sql.CommandLifecycleManagerImpl");
}

// TODO: This is POC specific
public String getCustomQueryMetadataService() {
return options.getOrDefault(CUSTOM_QUERY_METADATA_SERVICE, "org.apache.spark.sql.NoOpQueryMetadataService");
}

public String getRequestMetadata() {
return options.getOrDefault("spark.flint.job.requestIndex", null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.data

object QueryState {
val WAITING = "waiting"
val RUNNING = "running"
val SUCCESS = "success"
val FAILED = "failed"
}
Comment on lines +8 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That cannot be referred from outside. I think we want to migrate to this (or with some other name) constants. (I wanted enum, but looked like Scala2 does not have simple way to implement enum)

Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.sessionId")
.doc("Flint session id")
.createOptional()
val QUERY_ID =
FlintConfig(s"spark.flint.job.queryId")
.doc("Flint query id")
.createOptional()
val REQUEST_INDEX =
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
Expand Down Expand Up @@ -238,6 +242,11 @@ object FlintSparkConf {
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()

// TODO: This is POC specific
val CUSTOM_QUERY_METADATA_SERVICE =
FlintConfig(FlintOptions.CUSTOM_QUERY_METADATA_SERVICE)
.createWithDefault("org.apache.spark.sql.NoOpQueryMetadataService")
}

/**
Expand Down Expand Up @@ -307,6 +316,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
REPL_INACTIVITY_TIMEOUT_MILLIS,
CUSTOM_SESSION_MANAGER,
CUSTOM_COMMAND_LIFECYCLE_MANAGER,
CUSTOM_QUERY_METADATA_SERVICE,
BATCH_BYTES)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.toMap
Expand All @@ -317,6 +327,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
CUSTOM_FLINT_METADATA_LOG_SERVICE_CLASS,
CUSTOM_FLINT_INDEX_METADATA_SERVICE_CLASS,
SESSION_ID,
QUERY_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
CUSTOM_QUERY_RESULT_WRITER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_SESSION_MANAGER, MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT, REQUEST_INDEX}
import org.apache.spark.sql.flint.config.FlintSparkConf.{CUSTOM_QUERY_METADATA_SERVICE, CUSTOM_SESSION_MANAGER, MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT, REQUEST_INDEX}

class FlintSparkConfSuite extends FlintSuite {
test("test spark conf") {
Expand Down Expand Up @@ -116,6 +116,19 @@ class FlintSparkConfSuite extends FlintSuite {
}
}

test("test getCustomQueryMetadataService") {
withSparkConf(CUSTOM_QUERY_METADATA_SERVICE.key) {
// default value
val defaultFlintOptions = FlintSparkConf().flintOptions()
assert(
defaultFlintOptions.getCustomQueryMetadataService == "org.apache.spark.sql.NoOpQueryMetadataService")

setFlintSparkConf(CUSTOM_QUERY_METADATA_SERVICE, "custom.query.metadata.ClassName")
val flintOptions = FlintSparkConf().flintOptions()
assert(flintOptions.getCustomQueryMetadataService == "custom.query.metadata.ClassName")
}
}

test("test getRequestMetadata") {
withSparkConf(REQUEST_INDEX.key) {
// default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}

val queryId = conf.get("spark.flint.job.queryId")
if (queryId.isEmpty) {
logWarning("Query ID was not specified.")
}

// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
Expand All @@ -59,6 +65,7 @@ object FlintJob extends Logging with FlintJobExecutor {
query,
dataSource,
resultIndex,
queryId,
jobType.equalsIgnoreCase("streaming"),
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,19 @@ object FlintREPL extends Logging with FlintJobExecutor {
logInfo(s"""streaming query ${query}""")
configDYNMaxExecutors(conf, jobType)
val streamingRunningCount = new AtomicInteger(0)

val queryId = conf.get("spark.flint.job.queryId")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure on FlintREPL having queryId here as well. Isn't queryId for interactive session per query? Gotten from flintCommand

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

queryId is already in FlintStatement / FlintCommand, considering reuse it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For streaming query, I think we cannot read from FlintStatement. Unless reading from env variable, we don't have any key to retrieve query info.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go ahead with this approach for POC.

if (queryId.isEmpty) {
logWarning("Query ID was not specified.")
}

val jobOperator =
JobOperator(
createSparkSession(conf),
query,
dataSource,
resultIndex,
queryId,
true,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.data.QueryState

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.sql.util.{CustomClassLoader, ShuffleCleaner}
import org.apache.spark.util.ThreadUtils

case class JobOperator(
spark: SparkSession,
query: String,
dataSource: String,
resultIndex: String,
queryId: String,
streaming: Boolean,
streamingRunningCount: AtomicInteger)
extends Logging
Expand All @@ -42,10 +44,14 @@ case class JobOperator(
val startTime = System.currentTimeMillis()
streamingRunningCount.incrementAndGet()

val flintSparkConf = FlintSparkConf()
val queryMetadataService = CustomClassLoader(flintSparkConf).getQueryMetadataService()
queryMetadataService.updateQueryState(queryId, QueryState.RUNNING, null)
// osClient needs spark session to be created first to get FlintOptions initialized.
// Otherwise, we will have connection exception from EMR-S to OS.
val osClient = new OSClient(FlintSparkConf().flintOptions())
val osClient = new OSClient(flintSparkConf.flintOptions())
var exceptionThrown = true
var error: String = null
try {
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
Expand All @@ -61,30 +67,31 @@ case class JobOperator(
exceptionThrown = false
} catch {
case e: TimeoutException =>
val error = s"Getting the mapping of index $resultIndex timed out"
error = s"Getting the mapping of index $resultIndex timed out"
logError(error, e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
case e: Exception =>
val error = processQueryException(e)
error = processQueryException(e)
dataToWrite = Some(
getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider))
} finally {
cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient)
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception =>
exceptionThrown = true
error = s"Failed to write to result index. originalError='${error}'"
logError(error, e)
}
val queryState = if (exceptionThrown) QueryState.FAILED else QueryState.SUCCESS
queryMetadataService.updateQueryState(queryId, queryState, error);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@ykmr1224 ykmr1224 Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CommandLifecycleManager seems to require reading FlintCommand using other methods, but as we don't have session, we cannot use it. And CommandLifecycleManager is designed for command lifecycle in session, which does not seem to fit into this use case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go ahead with this approach for POC.


cleanUpResources(exceptionThrown, threadPool)
}
}

def cleanUpResources(
exceptionThrown: Boolean,
threadPool: ThreadPoolExecutor,
dataToWrite: Option[DataFrame],
resultIndex: String,
osClient: OSClient): Unit = {
try {
dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient))
} catch {
case e: Exception => logError("fail to write to result index", e)
}
def cleanUpResources(exceptionThrown: Boolean, threadPool: ThreadPoolExecutor): Unit = {

try {
// Wait for streaming job complete if no error and there is streaming job running
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf

/**
* Temporary default implementation for QueryMetadataService. This should be replaced with an
* implementation which write status to OpenSearch index
*/
class NoOpQueryMetadataService(flintSparkConf: FlintSparkConf)
extends QueryMetadataService
with Logging {

override def updateQueryState(queryId: String, state: String, error: String): Unit =
logInfo(s"updateQueryState: queryId=${queryId}, state=`${state}`, error=`${error}`")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

/** Interface for updating query state and error. */
trait QueryMetadataService {
def updateQueryState(queryId: String, state: String, error: String): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.sql.QueryMetadataService
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.Utils

case class CustomClassLoader(flintSparkConf: FlintSparkConf) {

def getQueryMetadataService(): QueryMetadataService = {
instantiateClass[QueryMetadataService](
flintSparkConf.flintOptions().getCustomQueryMetadataService)
}

private def instantiateClass[T](className: String): T = {
try {
val providerClass = Utils.classForName(className)
val ctor = providerClass.getDeclaredConstructor(classOf[FlintSparkConf])
ctor.setAccessible(true)
ctor.newInstance(flintSparkConf).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}
Loading