From dd1a0d66758eeb98cdba3cffc83a4fec67c11897 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 15 Feb 2024 12:59:03 -0800 Subject: [PATCH] Refactor and add metrics for streaming job Signed-off-by: Louis Chu --- .../flint/core/metrics/MetricConstants.java | 28 ++++++++- .../flint/core/metrics/MetricsUtil.java | 57 +++++++++-------- .../core/metrics/reporter/DimensionUtils.java | 34 ++++++++++- .../flint/core/metrics/MetricsUtilTest.java | 37 ++++++++++- .../metrics/reporter/DimensionUtilsTest.java | 26 ++++++++ .../flint/spark/FlintSparkIndexMonitor.scala | 2 + .../scala/org/apache/spark/sql/FlintJob.scala | 8 ++- .../org/apache/spark/sql/FlintREPL.scala | 61 ++++++++++++------- .../org/apache/spark/sql/JobOperator.scala | 32 +++++++++- 9 files changed, 228 insertions(+), 57 deletions(-) diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 72b96a091..544a18c28 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -8,7 +8,7 @@ /** * This class defines custom metric constants used for monitoring flint operations. */ -public class MetricConstants { +public final class MetricConstants { /** * The prefix for all read-related metrics in OpenSearch. @@ -65,5 +65,29 @@ public class MetricConstants { /** * Metric name for tracking the processing time of statements. */ - public static final String STATEMENT_PROCESSING_TIME_METRIC = "STATEMENT.processingTime"; + public static final String STATEMENT_PROCESSING_TIME_METRIC = "statement.processingTime"; + + /** + * Metric for tracking the count of currently running streaming jobs. + */ + public static final String STREAMING_RUNNING_METRIC = "streaming.running.count"; + + /** + * Metric for tracking the count of streaming jobs that have failed. + */ + public static final String STREAMING_FAILED_METRIC = "streaming.failed.count"; + + /** + * Metric for tracking the count of streaming jobs that have completed successfully. + */ + public static final String STREAMING_SUCCESS_METRIC = "streaming.success.count"; + + /** + * Metric for tracking the count of failed heartbeat signals in streaming jobs. + */ + public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + + private MetricConstants() { + // Private constructor to prevent instantiation + } } \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 13227a039..8e63992f5 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -6,12 +6,15 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.Source; import scala.collection.Seq; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; /** @@ -21,8 +24,8 @@ public final class MetricsUtil { private static final Logger LOG = Logger.getLogger(MetricsUtil.class.getName()); - // Private constructor to prevent instantiation private MetricsUtil() { + // Private constructor to prevent instantiation } /** @@ -60,10 +63,7 @@ public static void decrementCounter(String metricName) { */ public static Timer.Context getTimerContext(String metricName) { Timer timer = getOrCreateTimer(metricName); - if (timer != null) { - return timer.time(); - } - return null; + return timer != null ? timer.time() : null; } /** @@ -74,42 +74,47 @@ public static Timer.Context getTimerContext(String metricName) { * @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}. */ public static Long stopTimer(Timer.Context context) { - if (context != null) { - return context.stop(); + return context != null ? context.stop() : null; + } + + /** + * Registers a gauge metric with the provided name and value. + * The gauge will reflect the current value of the AtomicInteger provided. + * + * @param metricName The name of the gauge metric to register. + * @param value The AtomicInteger whose current value should be reflected by the gauge. + */ + public static void registerGauge(String metricName, final AtomicInteger value) { + MetricRegistry metricRegistry = getMetricRegistry(); + if (metricRegistry == null) { + LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName); + return; } - return null; + metricRegistry.register(metricName, (Gauge) value::get); } // Retrieves or creates a new counter for the given metric name private static Counter getOrCreateCounter(String metricName) { - SparkEnv sparkEnv = SparkEnv.get(); - if (sparkEnv == null) { - LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); - return null; - } - - FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - Counter counter = flintMetricSource.metricRegistry().getCounters().get(metricName); - if (counter == null) { - counter = flintMetricSource.metricRegistry().counter(metricName); - } - return counter; + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.counter(metricName) : null; } // Retrieves or creates a new Timer for the given metric name private static Timer getOrCreateTimer(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.timer(metricName) : null; + } + + // Retrieves the MetricRegistry from the current Spark environment. + private static MetricRegistry getMetricRegistry() { SparkEnv sparkEnv = SparkEnv.get(); if (sparkEnv == null) { - LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); + LOG.warning("Spark environment not available, cannot access MetricRegistry."); return null; } FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - Timer timer = flintMetricSource.metricRegistry().getTimers().get(metricName); - if (timer == null) { - timer = flintMetricSource.metricRegistry().timer(metricName); - } - return timer; + return flintMetricSource.metricRegistry(); } // Gets or initializes the FlintMetricSource diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java index ce7136507..6e3e90916 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java @@ -7,9 +7,13 @@ import java.util.Map; import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + import org.apache.commons.lang.StringUtils; import com.amazonaws.services.cloudwatch.model.Dimension; +import org.apache.spark.SparkEnv; /** * Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint. @@ -18,7 +22,9 @@ * application ID, and more. */ public class DimensionUtils { + private static final Logger LOG = Logger.getLogger(DimensionUtils.class.getName()); private static final String DIMENSION_JOB_ID = "jobId"; + private static final String DIMENSION_JOB_TYPE = "jobType"; private static final String DIMENSION_APPLICATION_ID = "applicationId"; private static final String DIMENSION_APPLICATION_NAME = "applicationName"; private static final String DIMENSION_DOMAIN_ID = "domainId"; @@ -29,6 +35,8 @@ public class DimensionUtils { private static final Map> dimensionBuilders = Map.of( DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension, DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID), + // TODO: Move FlintSparkConf into the core to prevent circular dependencies + DIMENSION_JOB_TYPE, ignored -> constructDimensionFromSparkConf(DIMENSION_JOB_TYPE, "spark.flint.job.type", UNKNOWN), DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID), DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME), DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID) @@ -39,7 +47,7 @@ public class DimensionUtils { * builder exists for the dimension name, it is used; otherwise, a default dimension is constructed. * * @param dimensionName The name of the dimension to construct. - * @param parts Additional information that might be required by specific dimension builders. + * @param metricNameParts Additional information that might be required by specific dimension builders. * @return A CloudWatch Dimension object. */ public static Dimension constructDimension(String dimensionName, String[] metricNameParts) { @@ -50,6 +58,30 @@ public static Dimension constructDimension(String dimensionName, String[] metric .apply(metricNameParts); } + /** + * Constructs a CloudWatch Dimension object using a specified Spark configuration key. + * + * @param dimensionName The name of the dimension to construct. + * @param sparkConfKey the Spark configuration key used to look up the value for the dimension. + * @param defaultValue the default value to use for the dimension if the Spark configuration key is not found or if the Spark environment is not available. + * @return A CloudWatch Dimension object. + * @throws Exception if an error occurs while accessing the Spark configuration. The exception is logged and then rethrown. + */ + public static Dimension constructDimensionFromSparkConf(String dimensionName, String sparkConfKey, String defaultValue) { + String propertyValue = defaultValue; + try { + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + propertyValue = SparkEnv.get().conf().get(sparkConfKey, defaultValue); + } else { + LOG.warning("Spark environment or configuration is not available, defaulting to provided default value."); + } + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error accessing Spark configuration with key: " + sparkConfKey + ", defaulting to provided default value.", e); + throw e; + } + return new Dimension().withName(dimensionName).withValue(propertyValue); + } + // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137 // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName. public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) { diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index 09dee5eef..3b8940536 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -1,6 +1,7 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; @@ -10,6 +11,7 @@ import org.mockito.Mockito; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -42,7 +44,7 @@ public void testIncrementDecrementCounter() { // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(4)).metricRegistry(); + verify(flintMetricSource, times(3)).metricRegistry(); Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric); Assertions.assertNotNull(counter); Assertions.assertEquals(counter.getCount(), 1); @@ -69,7 +71,7 @@ public void testStartStopTimer() { // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(2)).metricRegistry(); + verify(flintMetricSource, times(1)).metricRegistry(); Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric); Assertions.assertNotNull(timer); Assertions.assertEquals(timer.getCount(), 1L); @@ -78,4 +80,35 @@ public void testStartStopTimer() { throw new RuntimeException(e); } } + + @Test + public void testRegisterGaugeWhenMetricRegistryIsAvailable() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Setup gauge + AtomicInteger testValue = new AtomicInteger(1); + String gaugeName = "test.gauge"; + MetricsUtil.registerGauge(gaugeName, testValue); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(1)).metricRegistry(); + + Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + Assertions.assertEquals(gauge.getValue(), 1); + + testValue.incrementAndGet(); + testValue.incrementAndGet(); + testValue.decrementAndGet(); + Assertions.assertEquals(gauge.getValue(), 2); + } + } } \ No newline at end of file diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java index 94760fc37..2178c3c22 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java @@ -5,12 +5,21 @@ package org.opensearch.flint.core.metrics.reporter; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.metrics.source.FlintMetricSource; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; import com.amazonaws.services.cloudwatch.model.Dimension; import org.junit.jupiter.api.function.Executable; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import java.lang.reflect.Field; import java.util.Map; @@ -71,6 +80,23 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID"); writeableEnvironmentVariables.remove("TEST_VAR"); } + } + + @Test + public void testConstructDimensionFromSparkConfWithAvailableConfig() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + SparkConf sparkConf = new SparkConf().set("spark.test.key", "testValue"); + when(sparkEnv.get().conf()).thenReturn(sparkConf); + Dimension result = DimensionUtils.constructDimensionFromSparkConf("testDimension", "spark.test.key", "defaultValue"); + // Assertions + assertEquals("testDimension", result.getName()); + assertEquals("testValue", result.getValue()); + + // Reset SparkEnv mock to not affect other tests + Mockito.reset(SparkEnv.get()); + } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 5c4c7376c..2f44a28f4 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -12,6 +12,7 @@ import scala.sys.addShutdownHook import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -65,6 +66,7 @@ class FlintSparkIndexMonitor( } catch { case e: Throwable => logError("Failed to update index log entry", e) + MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC) } }, 15, // Delay to ensure final logging is complete first, otherwise version conflicts diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index df0bf5c4e..6f6de547c 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -19,6 +19,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} /** @@ -41,7 +42,10 @@ object FlintJob extends Logging with FlintJobExecutor { val Array(query, resultIndex) = args val conf = createSparkConf() - val wait = conf.get("spark.flint.job.type", "continue") + val jobType = conf.get("spark.flint.job.type", "batch") + logInfo(s"""Job type is: ${jobType}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + val dataSource = conf.get("spark.flint.datasource.name", "") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* @@ -58,7 +62,7 @@ object FlintJob extends Logging with FlintJobExecutor { query, dataSource, resultIndex, - wait.equalsIgnoreCase("streaming")) + jobType.equalsIgnoreCase("streaming")) jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 093ce1932..4fc0cad61 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -7,6 +7,7 @@ package org.apache.spark.sql import java.net.ConnectException import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration._ @@ -21,7 +22,7 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.metrics.MetricConstants -import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, stopTimer} +import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder @@ -63,6 +64,9 @@ object FlintREPL extends Logging with FlintJobExecutor { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } + private val sessionRunningCount = new AtomicInteger(0) + private val statementRunningCount = new AtomicInteger(0) + def main(args: Array[String]) { val Array(query, resultIndex) = args if (Strings.isNullOrEmpty(resultIndex)) { @@ -81,9 +85,12 @@ object FlintREPL extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val wait = conf.get(FlintSparkConf.JOB_TYPE.key, "continue") - if (wait.equalsIgnoreCase("streaming")) { + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) + logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + if (jobType.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") val jobOperator = JobOperator(createSparkSession(conf), query, dataSource, resultIndex, true) @@ -133,6 +140,8 @@ object FlintREPL extends Logging with FlintJobExecutor { val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) + registerGauge(MetricConstants.REPL_RUNNING_METRIC, sessionRunningCount) + registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) val jobStartTime = currentTimeProvider.currentEpochMillis() // update heart beat every 30 seconds // OpenSearch triggers recovery after 1 minute outdated heart beat @@ -386,9 +395,9 @@ object FlintREPL extends Logging with FlintJobExecutor { FlintInstance.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - incrementCounter(MetricConstants.REPL_RUNNING_METRIC) logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + sessionRunningCount.incrementAndGet() } def handleSessionError( @@ -400,7 +409,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, - flintSessionContext: Timer.Context): Unit = { + sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" logError(error, e) @@ -409,7 +418,7 @@ object FlintREPL extends Logging with FlintJobExecutor { updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) if (flintInstance.state.equals("fail")) { - recordSessionFailed(flintSessionContext) + recordSessionFailed(sessionTimerContext) } } @@ -588,7 +597,7 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, - statementContext: Timer.Context): Unit = { + statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) if (flintCommand.isRunning() || flintCommand.isWaiting()) { @@ -596,7 +605,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.complete() } updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementContext) + recordStatementStateChange(flintCommand, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -605,7 +614,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logError(error, e) flintCommand.fail() updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementContext) + recordStatementStateChange(flintCommand, statementTimerContext) } } @@ -801,7 +810,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.running() logDebug(s"command running: $flintCommand") updateSessionIndex(flintCommand, flintSessionIndexUpdater) - incrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) + statementRunningCount.incrementAndGet() flintCommand } @@ -855,7 +864,7 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient, sessionIndex: String, sessionId: String, - flintSessionContext: Timer.Context, + sessionTimerContext: Timer.Context, shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { shutdownHookManager.addShutdownHook(() => { @@ -885,7 +894,7 @@ object FlintREPL extends Logging with FlintJobExecutor { getResponse, flintSessionIndexUpdater, sessionId, - flintSessionContext) + sessionTimerContext) } }) } @@ -895,7 +904,7 @@ object FlintREPL extends Logging with FlintJobExecutor { getResponse: GetResponse, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String, - flintSessionContext: Timer.Context): Unit = { + sessionTimerContext: Timer.Context): Unit = { val flintInstance = FlintInstance.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( @@ -905,7 +914,7 @@ object FlintREPL extends Logging with FlintJobExecutor { currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) - recordSessionSuccess(flintSessionContext) + recordSessionSuccess(sessionTimerContext) } /** @@ -1068,23 +1077,29 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def recordSessionSuccess(sessionContext: Timer.Context): Unit = { - stopTimer(sessionContext) - decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } incrementCounter(MetricConstants.REPL_SUCCESS_METRIC) } - private def recordSessionFailed(sessionContext: Timer.Context): Unit = { - stopTimer(sessionContext) - decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + private def recordSessionFailed(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } incrementCounter(MetricConstants.REPL_FAILED_METRIC) } private def recordStatementStateChange( flintCommand: FlintCommand, - statementContext: Timer.Context): Unit = { - stopTimer(statementContext) - decrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) + statementTimerContext: Timer.Context): Unit = { + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } if (flintCommand.isComplete()) { incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) } else if (flintCommand.isFailed()) { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index a2edbe98e..002bfb3f1 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -6,11 +6,14 @@ package org.apache.spark.sql import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success, Try} +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.{incrementCounter, registerGauge} import org.opensearch.flint.core.storage.OpenSearchUpdater import org.apache.spark.SparkConf @@ -29,15 +32,21 @@ case class JobOperator( extends Logging with FlintJobExecutor { - // jvm shutdown hook + // JVM shutdown hook sys.addShutdownHook(stop()) + private val streamingRunningCount = new AtomicInteger(0) + def start(): Unit = { val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") implicit val executionContext = ExecutionContext.fromExecutor(threadPool) var dataToWrite: Option[DataFrame] = None + + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) val startTime = System.currentTimeMillis() + streamingRunningCount.incrementAndGet() + // osClient needs spark session to be created first to get FlintOptions initialized. // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -98,6 +107,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } + recordStreamingCompletionStatus(exceptionThrown) } def stop(): Unit = { @@ -109,4 +119,24 @@ case class JobOperator( case Failure(e) => logError("unexpected error while stopping spark session", e) } } + + /** + * Records the completion of a streaming job by updating the appropriate metrics. This method + * decrements the running metric for streaming jobs and increments either the success or failure + * metric based on whether an exception was thrown. + * + * @param exceptionThrown + * Indicates whether an exception was thrown during the streaming job execution. + */ + private def recordStreamingCompletionStatus(exceptionThrown: Boolean): Unit = { + // Decrement the metric for running streaming jobs as the job is now completing. + if (streamingRunningCount.get() > 0) { + streamingRunningCount.decrementAndGet() + } + + exceptionThrown match { + case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) + case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + } + } }