diff --git a/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala b/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala new file mode 100644 index 000000000..e22a61a51 --- /dev/null +++ b/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry + +class FlintMetricSource(val sourceName: String) extends Source { + override val metricRegistry: MetricRegistry = new MetricRegistry +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java index a0372a86f..8ca254d9b 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core; +import org.opensearch.flint.core.metrics.FlintOpensearchClientMetricsWrapper; import org.opensearch.flint.core.storage.FlintOpenSearchClient; /** @@ -13,6 +14,6 @@ public class FlintClientBuilder { public static FlintClient build(FlintOptions options) { - return new FlintOpenSearchClient(options); + return new FlintOpensearchClientMetricsWrapper(new FlintOpenSearchClient(options)); } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java new file mode 100644 index 000000000..de05a70f4 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import com.amazonaws.services.opensearch.model.AccessDeniedException; +import com.codahale.metrics.Counter; +import java.util.List; +import java.util.function.Supplier; +import org.apache.spark.SparkEnv; +import org.apache.spark.metrics.source.FlintMetricSource; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.flint.core.FlintClient; +import org.opensearch.flint.core.metadata.FlintMetadata; +import org.opensearch.flint.core.metadata.log.OptimisticTransaction; +import org.opensearch.flint.core.metrics.reporter.DimensionedName; +import org.opensearch.flint.core.storage.FlintOpenSearchClient; +import org.opensearch.flint.core.storage.FlintReader; +import org.opensearch.flint.core.storage.FlintWriter; + +/** + * This class wraps FlintOpensearchClient and emit spark metrics to FlintMetricSource. + */ +public class FlintOpensearchClientMetricsWrapper implements FlintClient { + + private final FlintOpenSearchClient delegate; + + public FlintOpensearchClientMetricsWrapper(FlintOpenSearchClient delegate) { + this.delegate = delegate; + } + + @Override + public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { + return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName)); + } + + @Override + public OptimisticTransaction startTransaction(String indexName, String dataSourceName, + boolean forceInit) { + return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName, forceInit)); + } + + @Override + public void createIndex(String indexName, FlintMetadata metadata) { + try { + delegate.createIndex(indexName, metadata); + } catch (AccessDeniedException exception){ + handleAccessDeniedException(); + throw exception; + } catch (Throwable t) { + handleThrowable(); + throw t; + } + } + + @Override + public boolean exists(String indexName) { + return handleExceptions(() -> delegate.exists(indexName)); + } + + @Override + public List getAllIndexMetadata(String indexNamePattern) { + return handleExceptions(() -> delegate.getAllIndexMetadata(indexNamePattern)); + } + + @Override + public FlintMetadata getIndexMetadata(String indexName) { + return handleExceptions(() -> delegate.getIndexMetadata(indexName)); + } + + @Override + public void deleteIndex(String indexName) { + try { + delegate.deleteIndex(indexName); + } catch (AccessDeniedException exception){ + handleAccessDeniedException(); + throw exception; + } catch (Throwable t) { + handleThrowable(); + throw t; + } + } + + @Override + public FlintReader createReader(String indexName, String query) { + return handleExceptions(() -> delegate.createReader(indexName, query)); + } + + @Override + public FlintWriter createWriter(String indexName) { + return handleExceptions(() -> delegate.createWriter(indexName)); + } + + @Override + public RestHighLevelClient createClient() { + return handleExceptions(delegate::createClient); + } + + private T handleExceptions(Supplier function) { + try { + return function.get(); + } catch (AccessDeniedException exception) { + handleAccessDeniedException(); + throw exception; + } catch (Throwable t) { + handleThrowable(); + throw new RuntimeException(t); + } + } + + private void handleThrowable(){ + String clusterName = System.getenv("FLINT_AUTH_DOMAIN_IDENTIFIER"); + if (clusterName == null) { + clusterName = "unknown"; + } + DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessError") + .withDimension("domain_ident", clusterName) + .build(); + publishMetric(metricName); + } + + private void handleAccessDeniedException() { + String clusterName = System.getenv("FLINT_AUTH_DOMAIN_IDENTIFIER"); + if (clusterName == null) { + clusterName = "unknown"; + } + DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessDeniedError") + .withDimension("domain_ident", clusterName) + .build(); + publishMetric(metricName); + } + + private void publishMetric(DimensionedName metricName) { + FlintMetricSource flintMetricSource = + (FlintMetricSource) SparkEnv.get().metricsSystem().getSourcesByName("FlintMetricSource"); + if (flintMetricSource != null) { + Counter flintOpenSearchAccessError = + flintMetricSource.metricRegistry().getCounters().get(metricName.encode()); + if (flintOpenSearchAccessError == null) { + flintMetricSource.metricRegistry().counter(metricName.encode()); + } + } + } + +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index a44e70401..5fcb794f2 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception +import org.apache.commons.lang3.StringUtils import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.FlintMetadata import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue} @@ -84,6 +85,7 @@ trait FlintJobExecutor { .set( "spark.sql.extensions", "org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions") + .set("spark.metrics.namespace", StringUtils.EMPTY) } def createSparkSession(conf: SparkConf): SparkSession = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 99085185c..52f9273ba 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,10 +18,12 @@ import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats +import org.opensearch.flint.core.metrics.reporter.DimensionedName import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.FlintMetricSource import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -91,10 +93,30 @@ object FlintREPL extends Logging with FlintJobExecutor { } val spark = createSparkSession(conf) + val flintMetricSource = new FlintMetricSource("FlintMetricSource") + SparkEnv.get.metricsSystem.registerSource(flintMetricSource) + + val dimensionedName1 = DimensionedName + .withName("FlintOpensearchErrorCount") + .withDimension("domain_ident", "88888:hello") + .build() + flintMetricSource.metricRegistry.counter(dimensionedName1.encode()) + (1 to 10).foreach(_ => + flintMetricSource.metricRegistry.getCounters().get(dimensionedName1.encode()).inc()) + val dimensionedName2 = DimensionedName + .withName("OpensearchErrorCount") + .withDimension("domain_ident", "88888:hello") + .build(); + flintMetricSource.metricRegistry.counter(dimensionedName2.encode()) + (1 to 10).foreach(_ => + flintMetricSource.metricRegistry.getCounters().get(dimensionedName2.encode()).inc()) + logInfo( + "Reached after metrics") val osClient = new OSClient(FlintSparkConf().flintOptions()) val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") - + val doesVamsiExist = osClient.doesIndexExist("vamsi") + logInfo(s"""Vamsi Exists : $doesVamsiExist""") // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) @@ -108,17 +130,14 @@ object FlintREPL extends Logging with FlintJobExecutor { val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) - // 1 thread for updating heart beat - val threadPool = - threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) - + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) val jobStartTime = currentTimeProvider.currentEpochMillis() - // update heart beat every 30 seconds - // OpenSearch triggers recovery after 1 minute outdated heart beat - var heartBeatFuture: ScheduledFuture[_] = null + try { - heartBeatFuture = createHeartBeatUpdater( + // update heart beat every 30 seconds + // OpenSearch triggers recovery after 1 minute outdated heart beat + createHeartBeatUpdater( HEARTBEAT_INTERVAL_MILLIS, flintSessionIndexUpdater, sessionId.get, @@ -167,23 +186,9 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient, sessionIndex.get) } finally { - if (threadPool != null) { - heartBeatFuture.cancel(true) // Pass `true` to interrupt if running - threadPoolFactory.shutdownThreadPool(threadPool) - } - spark.stop() - - // Check for non-daemon threads that may prevent the driver from shutting down. - // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, - // which may be due to unresolved bugs in dependencies or threads not being properly shut down. - if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { - logInfo("A non-daemon thread in the driver is seen.") - // Exit the JVM to prevent resource leaks and potential emr-s job hung. - // A zero status code is used for a graceful shutdown without indicating an error. - // If exiting with non-zero status, emr-s job will fail. - // This is a part of the fault tolerance mechanism to handle such scenarios gracefully. - System.exit(0) + if (threadPool != null) { + threadPool.shutdown() } } } @@ -283,10 +288,8 @@ object FlintREPL extends Logging with FlintJobExecutor { // 1 thread for updating heart beat val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var futureMappingCheck: Future[Either[String, Unit]] = null try { - futureMappingCheck = Future { + val futureMappingCheck = Future { checkAndCreateIndex(commandContext.osClient, commandContext.resultIndex) } @@ -334,7 +337,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } finally { if (threadPool != null) { - threadPoolFactory.shutdownThreadPool(threadPool) + threadPool.shutdown() } } } @@ -511,6 +514,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } else if (!flintReader.hasNext) { canProceed = false } else { + lastActivityTime = currentTimeProvider.currentEpochMillis() val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( @@ -532,8 +536,6 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndex, flintSessionIndexUpdater, osClient) - // last query finish time is last activity time - lastActivityTime = currentTimeProvider.currentEpochMillis() } } @@ -890,7 +892,7 @@ object FlintREPL extends Logging with FlintJobExecutor { threadPool: ScheduledExecutorService, osClient: OSClient, sessionIndex: String, - initialDelayMillis: Long): ScheduledFuture[_] = { + initialDelayMillis: Long): Unit = { threadPool.scheduleAtFixedRate( new Runnable { @@ -909,10 +911,6 @@ object FlintREPL extends Logging with FlintJobExecutor { "lastUpdateTime" -> currentTimeProvider.currentEpochMillis(), "state" -> "running"))) } catch { - case ie: InterruptedException => - // Preserve the interrupt status - Thread.currentThread().interrupt() - logError("HeartBeatUpdater task was interrupted", ie) // maybe due to invalid sequence number or primary term case e: Exception => logWarning(