diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index d34a3705d..72b96a091 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -21,4 +21,49 @@ public class MetricConstants { * Similar to OS_READ_METRIC_PREFIX, this constant is used for categorizing and identifying metrics that pertain to write operations. */ public static final String OS_WRITE_OP_METRIC_PREFIX = "opensearch.write"; + + /** + * Metric name for counting the errors encountered with Amazon S3 operations. + */ + public static final String S3_ERR_CNT_METRIC = "s3.error.count"; + + /** + * Metric name for counting the number of sessions currently running. + */ + public static final String REPL_RUNNING_METRIC = "session.running.count"; + + /** + * Metric name for counting the number of sessions that have failed. + */ + public static final String REPL_FAILED_METRIC = "session.failed.count"; + + /** + * Metric name for counting the number of sessions that have successfully completed. + */ + public static final String REPL_SUCCESS_METRIC = "session.success.count"; + + /** + * Metric name for tracking the processing time of sessions. + */ + public static final String REPL_PROCESSING_TIME_METRIC = "session.processingTime"; + + /** + * Metric name for counting the number of statements currently running. + */ + public static final String STATEMENT_RUNNING_METRIC = "statement.running.count"; + + /** + * Metric name for counting the number of statements that have failed. + */ + public static final String STATEMENT_FAILED_METRIC = "statement.failed.count"; + + /** + * Metric name for counting the number of statements that have successfully completed. + */ + public static final String STATEMENT_SUCCESS_METRIC = "statement.success.count"; + + /** + * Metric name for tracking the processing time of statements. + */ + public static final String STATEMENT_PROCESSING_TIME_METRIC = "STATEMENT.processingTime"; } \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 0edce3e36..13227a039 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -6,6 +6,7 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.Source; @@ -38,6 +39,47 @@ public static void incrementCounter(String metricName) { } } + /** + * Decrements the value of the specified metric counter by one, if the counter exists and its current count is greater than zero. + * + * @param metricName The name of the metric counter to be decremented. + */ + public static void decrementCounter(String metricName) { + Counter counter = getOrCreateCounter(metricName); + if (counter != null && counter.getCount() > 0) { + counter.dec(); + } + } + + /** + * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. + * This context can be used to measure the duration of a particular operation or event. + * + * @param metricName The name of the metric timer to retrieve the context for. + * @return A {@link Timer.Context} instance for timing operations, or {@code null} if the timer could not be created or retrieved. + */ + public static Timer.Context getTimerContext(String metricName) { + Timer timer = getOrCreateTimer(metricName); + if (timer != null) { + return timer.time(); + } + return null; + } + + /** + * Stops the timer associated with the given {@link Timer.Context}, effectively recording the elapsed time since the timer was started + * and returning the duration. If the context is {@code null}, this method does nothing and returns {@code null}. + * + * @param context The {@link Timer.Context} to stop. May be {@code null}, in which case this method has no effect and returns {@code null}. + * @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}. + */ + public static Long stopTimer(Timer.Context context) { + if (context != null) { + return context.stop(); + } + return null; + } + // Retrieves or creates a new counter for the given metric name private static Counter getOrCreateCounter(String metricName) { SparkEnv sparkEnv = SparkEnv.get(); @@ -54,6 +96,22 @@ private static Counter getOrCreateCounter(String metricName) { return counter; } + // Retrieves or creates a new Timer for the given metric name + private static Timer getOrCreateTimer(String metricName) { + SparkEnv sparkEnv = SparkEnv.get(); + if (sparkEnv == null) { + LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); + return null; + } + + FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); + Timer timer = flintMetricSource.metricRegistry().getTimers().get(metricName); + if (timer == null) { + timer = flintMetricSource.metricRegistry().timer(metricName); + } + return timer; + } + // Gets or initializes the FlintMetricSource private static FlintMetricSource getOrInitFlintMetricSource(SparkEnv sparkEnv) { Seq metricSourceSeq = sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()); diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index 8e646f446..09dee5eef 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -1,5 +1,7 @@ package org.opensearch.flint.core.metrics; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.junit.Test; @@ -7,6 +9,9 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; @@ -18,7 +23,34 @@ public class MetricsUtilTest { @Test - public void incOpenSearchMetric() { + public void testIncrementDecrementCounter() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Test the methods + String testMetric = "testPrefix.2xx.count"; + MetricsUtil.incrementCounter(testMetric); + MetricsUtil.incrementCounter(testMetric); + MetricsUtil.decrementCounter(testMetric); + + // Verify interactions + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(4)).metricRegistry(); + Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric); + Assertions.assertNotNull(counter); + Assertions.assertEquals(counter.getCount(), 1); + } + } + + @Test + public void testStartStopTimer() { try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { // Mock SparkEnv SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); @@ -29,14 +61,21 @@ public void incOpenSearchMetric() { when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) .thenReturn(flintMetricSource); - // Test the method - MetricsUtil.incrementCounter("testPrefix.2xx.count"); + // Test the methods + String testMetric = "testPrefix.processingTime"; + Timer.Context context = MetricsUtil.getTimerContext(testMetric); + TimeUnit.MILLISECONDS.sleep(500); + MetricsUtil.stopTimer(context); // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); verify(flintMetricSource, times(2)).metricRegistry(); - Assertions.assertNotNull( - flintMetricSource.metricRegistry().getCounters().get("testPrefix.2xx.count")); + Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric); + Assertions.assertNotNull(timer); + Assertions.assertEquals(timer.getCount(), 1L); + assertEquals(1.9, timer.getMeanRate(), 0.1); + } catch (InterruptedException e) { + throw new RuntimeException(e); } } } \ No newline at end of file diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 4a3c03d9b..4aeb0db17 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -13,6 +13,8 @@ import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue} import org.apache.spark.{SparkConf, SparkException} @@ -401,6 +403,7 @@ trait FlintJobExecutor { case r: ParseException => handleQueryException(r, "Syntax error", spark, dataSource, query, queryId, sessionId) case r: AmazonS3Exception => + incrementCounter(MetricConstants.S3_ERR_CNT_METRIC) handleQueryException( r, "Fail to read data from S3. Cause", diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 2a63653e3..093ce1932 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -13,12 +13,15 @@ import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal +import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder @@ -117,7 +120,14 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) + + addShutdownHook( + flintSessionIndexUpdater, + osClient, + sessionIndex.get, + sessionId.get, + sessionTimerContext) // 1 thread for updating heart beat val threadPool = @@ -165,6 +175,7 @@ object FlintREPL extends Logging with FlintJobExecutor { exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } + recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => handleSessionError( @@ -175,13 +186,14 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, flintSessionIndexUpdater, osClient, - sessionIndex.get) + sessionIndex.get, + sessionTimerContext) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running threadPoolFactory.shutdownThreadPool(threadPool) } - + stopTimer(sessionTimerContext) spark.stop() // Check for non-daemon threads that may prevent the driver from shutting down. @@ -374,6 +386,7 @@ object FlintREPL extends Logging with FlintJobExecutor { FlintInstance.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) + incrementCounter(MetricConstants.REPL_RUNNING_METRIC) logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") } @@ -386,7 +399,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime: Long, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, - sessionIndex: String): Unit = { + sessionIndex: String, + flintSessionContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" logError(error, e) @@ -394,6 +408,9 @@ object FlintREPL extends Logging with FlintJobExecutor { .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) + if (flintInstance.state.equals("fail")) { + recordSessionFailed(flintSessionContext) + } } private def getExistingFlintInstance( @@ -520,6 +537,8 @@ object FlintREPL extends Logging with FlintJobExecutor { } else if (!flintReader.hasNext) { canProceed = false } else { + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( @@ -540,7 +559,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand, resultIndex, flintSessionIndexUpdater, - osClient) + osClient, + statementTimerContext) // last query finish time is last activity time lastActivityTime = currentTimeProvider.currentEpochMillis() } @@ -567,7 +587,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand: FlintCommand, resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient): Unit = { + osClient: OSClient, + statementContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) if (flintCommand.isRunning() || flintCommand.isWaiting()) { @@ -575,6 +596,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.complete() } updateSessionIndex(flintCommand, flintSessionIndexUpdater) + recordStatementStateChange(flintCommand, statementContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -583,6 +605,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logError(error, e) flintCommand.fail() updateSessionIndex(flintCommand, flintSessionIndexUpdater) + recordStatementStateChange(flintCommand, statementContext) } } @@ -778,6 +801,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.running() logDebug(s"command running: $flintCommand") updateSessionIndex(flintCommand, flintSessionIndexUpdater) + incrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) flintCommand } @@ -831,6 +855,7 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient, sessionIndex: String, sessionId: String, + flintSessionContext: Timer.Context, shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { shutdownHookManager.addShutdownHook(() => { @@ -859,7 +884,8 @@ object FlintREPL extends Logging with FlintJobExecutor { source, getResponse, flintSessionIndexUpdater, - sessionId) + sessionId, + flintSessionContext) } }) } @@ -868,7 +894,8 @@ object FlintREPL extends Logging with FlintJobExecutor { source: java.util.Map[String, AnyRef], getResponse: GetResponse, flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { + sessionId: String, + flintSessionContext: Timer.Context): Unit = { val flintInstance = FlintInstance.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( @@ -878,6 +905,7 @@ object FlintREPL extends Logging with FlintJobExecutor { currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) + recordSessionSuccess(flintSessionContext) } /** @@ -1039,4 +1067,28 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + + private def recordSessionSuccess(sessionContext: Timer.Context): Unit = { + stopTimer(sessionContext) + decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + incrementCounter(MetricConstants.REPL_SUCCESS_METRIC) + } + + private def recordSessionFailed(sessionContext: Timer.Context): Unit = { + stopTimer(sessionContext) + decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + incrementCounter(MetricConstants.REPL_FAILED_METRIC) + } + + private def recordStatementStateChange( + flintCommand: FlintCommand, + statementContext: Timer.Context): Unit = { + stopTimer(statementContext) + decrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) + if (flintCommand.isComplete()) { + incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) + } else if (flintCommand.isFailed()) { + incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) + } + } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 3e9d408e6..abae546b6 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -15,6 +15,7 @@ import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag +import com.codahale.metrics.Timer import org.mockito.ArgumentMatchers.{eq => eqTo, _} import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ @@ -82,6 +83,7 @@ class FlintREPLTest val getResponse = mock[GetResponse] val sessionIndex = "testIndex" val sessionId = "testSessionId" + val flintSessionContext = mock[Timer.Context] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) @@ -110,6 +112,7 @@ class FlintREPLTest osClient, sessionIndex, sessionId, + flintSessionContext, mockShutdownHookManager) verify(flintSessionIndexUpdater).updateIf(*, *, *, *)