Skip to content

Commit

Permalink
Add flint opensearch metrics
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vmmusings committed Dec 27, 2023
1 parent 9061eb9 commit a69be3f
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.metrics.source

import com.codahale.metrics.MetricRegistry

class FlintMetricSource(val sourceName: String) extends Source {
override val metricRegistry: MetricRegistry = new MetricRegistry
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.core;

import org.opensearch.flint.core.metrics.FlintOpensearchClientMetricsWrapper;
import org.opensearch.flint.core.storage.FlintOpenSearchClient;

/**
Expand All @@ -13,6 +14,6 @@
public class FlintClientBuilder {

public static FlintClient build(FlintOptions options) {
return new FlintOpenSearchClient(options);
return new FlintOpensearchClientMetricsWrapper(new FlintOpenSearchClient(options));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.metrics;

import com.amazonaws.services.opensearch.model.AccessDeniedException;
import com.codahale.metrics.Counter;
import java.util.List;
import java.util.function.Supplier;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.flint.core.FlintClient;
import org.opensearch.flint.core.metadata.FlintMetadata;
import org.opensearch.flint.core.metadata.log.OptimisticTransaction;
import org.opensearch.flint.core.metrics.reporter.DimensionedName;
import org.opensearch.flint.core.storage.FlintOpenSearchClient;
import org.opensearch.flint.core.storage.FlintReader;
import org.opensearch.flint.core.storage.FlintWriter;

/**
* This class wraps FlintOpensearchClient and emit spark metrics to FlintMetricSource.
*/
public class FlintOpensearchClientMetricsWrapper implements FlintClient {

private final FlintOpenSearchClient delegate;

public FlintOpensearchClientMetricsWrapper(FlintOpenSearchClient delegate) {
this.delegate = delegate;
}

@Override
public <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourceName) {
return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName));
}

@Override
public <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourceName,
boolean forceInit) {
return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName, forceInit));
}

@Override
public void createIndex(String indexName, FlintMetadata metadata) {
try {
delegate.createIndex(indexName, metadata);
} catch (AccessDeniedException exception){
handleAccessDeniedException();
throw exception;
} catch (Throwable t) {
handleThrowable();
throw t;
}
}

@Override
public boolean exists(String indexName) {
return handleExceptions(() -> delegate.exists(indexName));
}

@Override
public List<FlintMetadata> getAllIndexMetadata(String indexNamePattern) {
return handleExceptions(() -> delegate.getAllIndexMetadata(indexNamePattern));
}

@Override
public FlintMetadata getIndexMetadata(String indexName) {
return handleExceptions(() -> delegate.getIndexMetadata(indexName));
}

@Override
public void deleteIndex(String indexName) {
try {
delegate.deleteIndex(indexName);
} catch (AccessDeniedException exception){
handleAccessDeniedException();
throw exception;
} catch (Throwable t) {
handleThrowable();
throw t;
}
}

@Override
public FlintReader createReader(String indexName, String query) {
return handleExceptions(() -> delegate.createReader(indexName, query));
}

@Override
public FlintWriter createWriter(String indexName) {
return handleExceptions(() -> delegate.createWriter(indexName));
}

@Override
public RestHighLevelClient createClient() {
return handleExceptions(delegate::createClient);
}

private <T> T handleExceptions(Supplier<T> function) {
try {
return function.get();
} catch (AccessDeniedException exception) {
handleAccessDeniedException();
throw exception;
} catch (Throwable t) {
handleThrowable();
throw new RuntimeException(t);
}
}

private void handleThrowable(){
String clusterName = System.getenv("FLINT_AUTH_DOMAIN_IDENTIFIER");
if (clusterName == null) {
clusterName = "unknown";
}
DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessError")
.withDimension("domain_ident", clusterName)
.build();
publishMetric(metricName);
}

private void handleAccessDeniedException() {
String clusterName = System.getenv("FLINT_AUTH_DOMAIN_IDENTIFIER");
if (clusterName == null) {
clusterName = "unknown";
}
DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessDeniedError")
.withDimension("domain_ident", clusterName)
.build();
publishMetric(metricName);
}

private void publishMetric(DimensionedName metricName) {
FlintMetricSource flintMetricSource =
(FlintMetricSource) SparkEnv.get().metricsSystem().getSourcesByName("FlintMetricSource");
if (flintMetricSource != null) {
Counter flintOpenSearchAccessError =
flintMetricSource.metricRegistry().getCounters().get(metricName.encode());
if (flintOpenSearchAccessError == null) {
flintMetricSource.metricRegistry().counter(metricName.encode());
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.apache.commons.lang3.StringUtils
import org.opensearch.flint.core.FlintClient
import org.opensearch.flint.core.metadata.FlintMetadata
import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue}
Expand Down Expand Up @@ -84,6 +85,7 @@ trait FlintJobExecutor {
.set(
"spark.sql.extensions",
"org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions")
.set("spark.metrics.namespace", StringUtils.EMPTY)
}

def createSparkSession(conf: SparkConf): SparkSession = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import org.opensearch.action.get.GetResponse
import org.opensearch.common.Strings
import org.opensearch.flint.app.{FlintCommand, FlintInstance}
import org.opensearch.flint.app.FlintInstance.formats
import org.opensearch.flint.core.metrics.reporter.DimensionedName
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.FlintMetricSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -91,10 +93,30 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

val spark = createSparkSession(conf)
val flintMetricSource = new FlintMetricSource("FlintMetricSource")
SparkEnv.get.metricsSystem.registerSource(flintMetricSource)

val dimensionedName1 = DimensionedName
.withName("FlintOpensearchErrorCount")
.withDimension("domain_ident", "88888:hello")
.build()
flintMetricSource.metricRegistry.counter(dimensionedName1.encode())
(1 to 10).foreach(_ =>
flintMetricSource.metricRegistry.getCounters().get(dimensionedName1.encode()).inc())
val dimensionedName2 = DimensionedName
.withName("OpensearchErrorCount")
.withDimension("domain_ident", "88888:hello")
.build();
flintMetricSource.metricRegistry.counter(dimensionedName2.encode())
(1 to 10).foreach(_ =>
flintMetricSource.metricRegistry.getCounters().get(dimensionedName2.encode()).inc())
logInfo(
"Reached after metrics")
val osClient = new OSClient(FlintSparkConf().flintOptions())
val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown")
val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")

val doesVamsiExist = osClient.doesIndexExist("vamsi")
logInfo(s"""Vamsi Exists : $doesVamsiExist""")
// Read the values from the Spark configuration or fall back to the default values
val inactivityLimitMillis: Long =
conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS)
Expand All @@ -108,17 +130,14 @@ object FlintREPL extends Logging with FlintJobExecutor {

val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get)
addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get)

// 1 thread for updating heart beat
val threadPool =
threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)

val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1)
val jobStartTime = currentTimeProvider.currentEpochMillis()
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
var heartBeatFuture: ScheduledFuture[_] = null

try {
heartBeatFuture = createHeartBeatUpdater(
// update heart beat every 30 seconds
// OpenSearch triggers recovery after 1 minute outdated heart beat
createHeartBeatUpdater(
HEARTBEAT_INTERVAL_MILLIS,
flintSessionIndexUpdater,
sessionId.get,
Expand Down Expand Up @@ -167,23 +186,9 @@ object FlintREPL extends Logging with FlintJobExecutor {
osClient,
sessionIndex.get)
} finally {
if (threadPool != null) {
heartBeatFuture.cancel(true) // Pass `true` to interrupt if running
threadPoolFactory.shutdownThreadPool(threadPool)
}

spark.stop()

// Check for non-daemon threads that may prevent the driver from shutting down.
// Non-daemon threads other than the main thread indicate that the driver is still processing tasks,
// which may be due to unresolved bugs in dependencies or threads not being properly shut down.
if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) {
logInfo("A non-daemon thread in the driver is seen.")
// Exit the JVM to prevent resource leaks and potential emr-s job hung.
// A zero status code is used for a graceful shutdown without indicating an error.
// If exiting with non-zero status, emr-s job will fail.
// This is a part of the fault tolerance mechanism to handle such scenarios gracefully.
System.exit(0)
if (threadPool != null) {
threadPool.shutdown()
}
}
}
Expand Down Expand Up @@ -283,10 +288,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for updating heart beat
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

var futureMappingCheck: Future[Either[String, Unit]] = null
try {
futureMappingCheck = Future {
val futureMappingCheck = Future {
checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex)
}

Expand Down Expand Up @@ -334,7 +337,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
} finally {
if (threadPool != null) {
threadPoolFactory.shutdownThreadPool(threadPool)
threadPool.shutdown()
}
}
}
Expand Down Expand Up @@ -511,6 +514,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
} else if (!flintReader.hasNext) {
canProceed = false
} else {
lastActivityTime = currentTimeProvider.currentEpochMillis()
val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater)

val (dataToWrite, returnedVerificationResult) = processStatementOnVerification(
Expand All @@ -532,8 +536,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
resultIndex,
flintSessionIndexUpdater,
osClient)
// last query finish time is last activity time
lastActivityTime = currentTimeProvider.currentEpochMillis()
}
}

Expand Down Expand Up @@ -890,7 +892,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
threadPool: ScheduledExecutorService,
osClient: OSClient,
sessionIndex: String,
initialDelayMillis: Long): ScheduledFuture[_] = {
initialDelayMillis: Long): Unit = {

threadPool.scheduleAtFixedRate(
new Runnable {
Expand All @@ -909,10 +911,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
"lastUpdateTime" -> currentTimeProvider.currentEpochMillis(),
"state" -> "running")))
} catch {
case ie: InterruptedException =>
// Preserve the interrupt status
Thread.currentThread().interrupt()
logError("HeartBeatUpdater task was interrupted", ie)
// maybe due to invalid sequence number or primary term
case e: Exception =>
logWarning(
Expand Down

0 comments on commit a69be3f

Please sign in to comment.