From 315352cc7fd4255b169bc167a936810ae70f2f69 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 16 Nov 2023 15:23:07 -0800 Subject: [PATCH 1/6] Improve Flint index monitor error handling (#158) * Don't throw exception and change to daemon thread pool Signed-off-by: Chen Dai * Add more IT Signed-off-by: Chen Dai * Add more IT Signed-off-by: Chen Dai * Add more IT Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- .../org/apache/spark/sql/flint/package.scala | 20 ++- .../flint/spark/FlintSparkIndexMonitor.scala | 9 +- .../spark/FlintSparkIndexMonitorITSuite.scala | 166 ++++++++++++++++++ 3 files changed, 190 insertions(+), 5 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index eba99b809..cf2cd2b6e 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -5,15 +5,33 @@ package org.apache.spark.sql +import java.util.concurrent.ScheduledExecutorService + import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.util.ShutdownHookManager +import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} /** * Flint utility methods that rely on access to private code in Spark SQL package. */ package object flint { + /** + * Create daemon thread pool with the given thread group name and size. + * + * @param threadNamePrefix + * thread group name + * @param numThreads + * thread pool size + * @return + * thread pool executor + */ + def newDaemonThreadPoolScheduledExecutor( + threadNamePrefix: String, + numThreads: Int): ScheduledExecutorService = { + ThreadUtils.newDaemonThreadPoolScheduledExecutor(threadNamePrefix, numThreads) + } + /** * Add shutdown hook to SparkContext with default priority. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 28e46cb29..5c4c7376c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.util.concurrent.{Executors, ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} import scala.collection.concurrent.{Map, TrieMap} import scala.sys.addShutdownHook @@ -15,6 +15,7 @@ import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor /** * Flint Spark index state monitor. @@ -62,9 +63,8 @@ class FlintSparkIndexMonitor( logInfo("Index monitor task is cancelled") } } catch { - case e: Exception => + case e: Throwable => logError("Failed to update index log entry", e) - throw new IllegalStateException("Failed to update index log entry") } }, 15, // Delay to ensure final logging is complete first, otherwise version conflicts @@ -100,7 +100,8 @@ object FlintSparkIndexMonitor extends Logging { * Thread-safe ExecutorService globally shared by all FlintSpark instance and will be shutdown * in Spark application upon exit. Non-final variable for test convenience. */ - var executor: ScheduledExecutorService = Executors.newScheduledThreadPool(1) + var executor: ScheduledExecutorService = + newDaemonThreadPoolScheduledExecutor("flint-index-heartbeat", 1) /** * Tracker that stores task future handle which is required to cancel the task in future. diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala new file mode 100644 index 000000000..4af147939 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.util.Base64 +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{doAnswer, spy} +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest +import org.opensearch.client.RequestOptions +import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.spark.FlintSpark.RefreshMode._ +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor + +class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matchers { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_index_monitor_test" + private val testFlintIndex = getSkippingIndexName(testTable) + private val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + + override def beforeAll(): Unit = { + super.beforeAll() + createPartitionedTable(testTable) + + // Replace mock executor with real one and change its delay + val realExecutor = newDaemonThreadPoolScheduledExecutor("flint-index-heartbeat", 1) + FlintSparkIndexMonitor.executor = spy(realExecutor) + doAnswer(invocation => { + // Delay 5 seconds to wait for refresh index done + realExecutor.scheduleWithFixedDelay(invocation.getArgument(0), 5, 1, TimeUnit.SECONDS) + }).when(FlintSparkIndexMonitor.executor) + .scheduleWithFixedDelay(any[Runnable], any[Long], any[Long], any[TimeUnit]) + } + + override def beforeEach(): Unit = { + super.beforeEach() + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("name") + .options(FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + .create() + flint.refreshIndex(testFlintIndex, INCREMENTAL) + + // Wait for refresh complete and another 5 seconds to make sure monitor thread start + val jobId = spark.streams.active.find(_.name == testFlintIndex).get.id.toString + awaitStreamingComplete(jobId) + Thread.sleep(5000L) + } + + override def afterEach(): Unit = { + // Cancel task to avoid conflict with delete operation since it runs frequently + FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true)) + FlintSparkIndexMonitor.indexMonitorTracker.clear() + + try { + flint.deleteIndex(testFlintIndex) + } catch { + // Index maybe end up with failed state in some test + case _: IllegalStateException => + openSearchClient + .indices() + .delete(new DeleteIndexRequest(testFlintIndex), RequestOptions.DEFAULT) + } finally { + super.afterEach() + } + } + + test("job start time should not change and last update time should keep updated") { + var (prevJobStartTime, prevLastUpdateTime) = getLatestTimestamp + 3 times { (jobStartTime, lastUpdateTime) => + jobStartTime shouldBe prevJobStartTime + lastUpdateTime should be > prevLastUpdateTime + prevLastUpdateTime = lastUpdateTime + } + } + + test("job start time should not change until recover index") { + val (prevJobStartTime, _) = getLatestTimestamp + + // Stop streaming job and wait for monitor task stopped + spark.streams.active.find(_.name == testFlintIndex).get.stop() + waitForMonitorTaskRun() + + // Restart streaming job and monitor task + flint.recoverIndex(testFlintIndex) + waitForMonitorTaskRun() + + val (jobStartTime, _) = getLatestTimestamp + jobStartTime should be > prevJobStartTime + } + + test("monitor task should terminate if streaming job inactive") { + val task = FlintSparkIndexMonitor.indexMonitorTracker(testFlintIndex) + + // Stop streaming job and wait for monitor task stopped + spark.streams.active.find(_.name == testFlintIndex).get.stop() + waitForMonitorTaskRun() + + // Index state transit to failed and task is cancelled + latestLogEntry(testLatestId) should contain("state" -> "failed") + task.isCancelled shouldBe true + } + + test("monitor task should not terminate if any exception") { + // Block write on metadata log index + setWriteBlockOnMetadataLogIndex(true) + waitForMonitorTaskRun() + + // Monitor task should stop working after blocking writes + var (_, prevLastUpdateTime) = getLatestTimestamp + 1 times { (_, lastUpdateTime) => + lastUpdateTime shouldBe prevLastUpdateTime + } + + // Unblock write and wait for monitor task attempt to update again + setWriteBlockOnMetadataLogIndex(false) + waitForMonitorTaskRun() + + // Monitor task continue working after unblocking write + 3 times { (_, lastUpdateTime) => + lastUpdateTime should be > prevLastUpdateTime + prevLastUpdateTime = lastUpdateTime + } + } + + private def getLatestTimestamp: (Long, Long) = { + val latest = latestLogEntry(testLatestId) + (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) + } + + private implicit class intWithTimes(n: Int) { + def times(f: (Long, Long) => Unit): Unit = { + 1 to n foreach { _ => + { + waitForMonitorTaskRun() + + val (jobStartTime, lastUpdateTime) = getLatestTimestamp + f(jobStartTime, lastUpdateTime) + } + } + } + } + + private def waitForMonitorTaskRun(): Unit = { + // Interval longer than monitor schedule to make sure it has finished another run + Thread.sleep(3000L) + } + + private def setWriteBlockOnMetadataLogIndex(isBlock: Boolean): Unit = { + val request = new UpdateSettingsRequest(testMetaLogIndex) + .settings(Map("blocks.write" -> isBlock).asJava) // Blocking write operations + openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT) + } +} From 8dc9584588f9a8ff812eff17e21c00df795b21bd Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Fri, 17 Nov 2023 14:50:17 -0800 Subject: [PATCH 2/6] Enhance Logging for Job Cleanup and Reduce REPL Inactivity Timeout (#160) * Enhance Logging for Streaming Job Cleanup and Reduce REPL Inactivity Timeout - Added detailed logging to improve visibility during streaming job cleanup. - Decreased REPL job inactivity timeout from 30 to 10 minutes.. Tested manually to ensure new logs are correctly displayed during streaming job cleanup. Signed-off-by: Kaituo Li * address Chen's comments Signed-off-by: Kaituo Li * add more logs and reuse JobOperator in FlintREPL Signed-off-by: Kaituo Li --------- Signed-off-by: Kaituo Li --- .../scala/org/apache/spark/sql/FlintJob.scala | 64 ++-------- .../apache/spark/sql/FlintJobExecutor.scala | 8 +- .../org/apache/spark/sql/FlintREPL.scala | 50 ++++---- .../org/apache/spark/sql/JobOperator.scala | 113 ++++++++++++++++++ .../org/apache/spark/sql/FlintREPLTest.scala | 2 +- 5 files changed, 163 insertions(+), 74 deletions(-) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index cb1f5c1ca..750e228ef 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,9 +8,6 @@ package org.apache.spark.sql import java.util.Locale -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} - import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata import org.opensearch.common.settings.Settings @@ -22,9 +19,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} -import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -48,51 +43,18 @@ object FlintJob extends Logging with FlintJobExecutor { val conf = createSparkConf() val wait = conf.get("spark.flint.job.type", "continue") val dataSource = conf.get("spark.flint.datasource.name", "") - val spark = createSparkSession(conf) - - val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var dataToWrite: Option[DataFrame] = None - val startTime = System.currentTimeMillis() - // osClient needs spark session to be created first to get FlintOptions initialized. - // Otherwise, we will have connection exception from EMR-S to OS. - val osClient = new OSClient(FlintSparkConf().flintOptions()) - var exceptionThrown = true - try { - val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) - } - val data = executeQuery(spark, query, dataSource, "", "") - - val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(mappingCheckResult match { - case Right(_) => data - case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) - }) - exceptionThrown = false - } catch { - case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" - logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - case e: Exception => - val error = processQueryException(e, spark, dataSource, query, "", "") - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - } finally { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // Stop SparkSession if streaming job succeeds - if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) { - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() - } else { - spark.stop() - } - - threadPool.shutdown() - } + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) + + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 6e7dbb926..903bcaa09 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,6 +6,10 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception import org.opensearch.flint.core.FlintClient @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo} +import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { this: Logging => diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index edf7d62e6..f21a01d53 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 + private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 @@ -63,30 +63,38 @@ object FlintREPL extends Logging with FlintJobExecutor { val conf: SparkConf = createSparkConf() val dataSource = conf.get("spark.flint.datasource.name", "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ conf.set("spark.sql.defaultCatalog", dataSource) val wait = conf.get("spark.flint.job.type", "continue") - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) - - if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") - } - if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") - } - - val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") - val result = executeQuery(spark, query, dataSource, "", "") - writeDataFrameToOpensearch(result, resultIndex, osClient) - spark.streams.awaitAnyTermination() + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, true) + jobOperator.start() } else { + // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. + val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) + val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + + if (sessionIndex.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + } + if (sessionId.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + } + + val spark = createSparkSession(conf) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) @@ -99,7 +107,7 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) // 1 thread for updating heart beat val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) val jobStartTime = currentTimeProvider.currentEpochMillis() @@ -767,7 +775,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - def createShutdownHook( + def addShutdownHook( flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala new file mode 100644 index 000000000..c60d250ea --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} + +import org.opensearch.flint.core.storage.OpenSearchUpdater + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession +import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown} +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +case class JobOperator( + sparkConf: SparkConf, + query: String, + dataSource: String, + resultIndex: String, + streaming: Boolean) + extends Logging + with FlintJobExecutor { + private val spark = createSparkSession(sparkConf) + + // jvm shutdown hook + sys.addShutdownHook(stop()) + + def start(): Unit = { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. + val osClient = new OSClient(FlintSparkConf().flintOptions()) + var exceptionThrown = true + try { + val futureMappingCheck = Future { + checkAndCreateIndex(osClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource, "", "") + + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + }) + exceptionThrown = false + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + logError(error, e) + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + case e: Exception => + val error = processQueryException(e, spark, dataSource, query, "", "") + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + } finally { + cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) + } + } + + def cleanUpResources( + exceptionThrown: Boolean, + threadPool: ThreadPoolExecutor, + dataToWrite: Option[DataFrame], + resultIndex: String, + osClient: OSClient): Unit = { + try { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + } catch { + case e: Exception => logError("fail to write to result index", e) + } + + try { + // Stop SparkSession if streaming job succeeds + if (!exceptionThrown && streaming) { + // wait if any child thread to finish before the main thread terminates + spark.streams.awaitAnyTermination() + } + } catch { + case e: Exception => logError("streaming job failed", e) + } + + try { + threadPool.shutdown() + logInfo("shut down thread threadpool") + } catch { + case e: Exception => logError("Fail to close threadpool", e) + } + } + + def stop(): Unit = { + Try { + spark.stop() + logInfo("stopped spark session") + } match { + case Success(_) => + case Failure(e) => logError("unexpected error while stopping spark session", e) + } + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 704045e8a..8335f2a72 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -117,7 +117,7 @@ class FlintREPLTest } // Here, we're injecting our mockShutdownHookManager into the method - FlintREPL.createShutdownHook( + FlintREPL.addShutdownHook( flintSessionIndexUpdater, osClient, sessionIndex, From 6f16801e101d108ecb594bf0196bde074c19228b Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Sat, 18 Nov 2023 18:10:20 -0800 Subject: [PATCH 3/6] Enhance Session Job Handling and Heartbeat Mechanism (#168) * Enhance Session Job Handling and Heartbeat Mechanism This commit introduces two key improvements: 1. An initial delay is added to start the heartbeat mechanism, reducing the risk of concurrent update conflicts. 2. The 'jobId' in the session store is updated at the beginning to properly manage and evict old jobs for the same session. Testing Performed: 1. Manual sanity testing was conducted. 2. It was ensured that old jobs eventually exit after a new job for the same session starts. 3. The functionality of the heartbeat mechanism was verified to still work effectively. Signed-off-by: Kaituo Li * Override only when --conf excludeJobIds is not empty and excludeJobIds does not contain currentJobId. Signed-off-by: Kaituo Li * reformat code Signed-off-by: Kaituo Li --------- Signed-off-by: Kaituo Li --- .../opensearch/flint/app/FlintInstance.scala | 63 +++++++++++++------ .../flint/app/FlintInstanceTest.scala | 4 +- .../org/apache/spark/sql/FlintREPL.scala | 50 ++++++++++----- .../org/apache/spark/sql/FlintREPLTest.scala | 3 +- 4 files changed, 82 insertions(+), 38 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index f9e8dd693..5af70b793 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -109,23 +109,50 @@ object FlintInstance { maybeError) } - def serialize(job: FlintInstance, currentTime: Long): String = { - // jobId is only readable by spark, thus we don't override jobId - Serialization.write( - Map( - "type" -> "session", - "sessionId" -> job.sessionId, - "error" -> job.error.getOrElse(""), - "applicationId" -> job.applicationId, - "state" -> job.state, - // update last update time - "lastUpdateTime" -> currentTime, - // Convert a Seq[String] into a comma-separated string, such as "id1,id2". - // This approach is chosen over serializing to an array format (e.g., ["id1", "id2"]) - // because it simplifies client-side processing. With a comma-separated string, - // clients can easily ignore this field if it's not in use, avoiding the need - // for array parsing logic. This makes the serialized data more straightforward to handle. - "excludeJobIds" -> job.excludedJobIds.mkString(","), - "jobStartTime" -> job.jobStartTime)) + /** + * After the initial setup, the 'jobId' is only readable by Spark, and it should not be + * overridden. We use 'jobId' to ensure that only one job can run per session. In the case of a + * new job for the same session, it will override the 'jobId' in the session document. The old + * job will periodically check the 'jobId.' If the read 'jobId' does not match the current + * 'jobId,' the old job will exit early. Therefore, it is crucial that old jobs do not overwrite + * the session store's 'jobId' field after the initial setup. + * + * @param job + * Flint session object + * @param currentTime + * current timestamp in milliseconds + * @param includeJobId + * flag indicating whether to include the "jobId" field in the serialization + * @return + * serialized Flint session + */ + def serialize(job: FlintInstance, currentTime: Long, includeJobId: Boolean = true): String = { + val baseMap = Map( + "type" -> "session", + "sessionId" -> job.sessionId, + "error" -> job.error.getOrElse(""), + "applicationId" -> job.applicationId, + "state" -> job.state, + // update last update time + "lastUpdateTime" -> currentTime, + // Convert a Seq[String] into a comma-separated string, such as "id1,id2". + // This approach is chosen over serializing to an array format (e.g., ["id1", "id2"]) + // because it simplifies client-side processing. With a comma-separated string, + // clients can easily ignore this field if it's not in use, avoiding the need + // for array parsing logic. This makes the serialized data more straightforward to handle. + "excludeJobIds" -> job.excludedJobIds.mkString(","), + "jobStartTime" -> job.jobStartTime) + + val resultMap = if (includeJobId) { + baseMap + ("jobId" -> job.jobId) + } else { + baseMap + } + + Serialization.write(resultMap) + } + + def serializeWithoutJobId(job: FlintInstance, currentTime: Long): String = { + serialize(job, currentTime, includeJobId = false) } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala index 12c2ae5bc..8ece6ba8a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala @@ -39,7 +39,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { 1620000001000L, excludedJobIds) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serialize(instance, currentTime) + val json = FlintInstance.serializeWithoutJobId(instance, currentTime) json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") @@ -80,7 +80,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { Seq.empty[String], Some("Some error occurred")) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serialize(instance, currentTime) + val json = FlintInstance.serializeWithoutJobId(instance, currentTime) json should include(""""error":"Some error occurred"""") } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index f21a01d53..28ce90d62 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -48,6 +48,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val INITIAL_DELAY_MILLIS = 3000L def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) @@ -121,7 +122,8 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId.get, threadPool, osClient, - sessionIndex.get) + sessionIndex.get, + INITIAL_DELAY_MILLIS) if (setupFlintJobWithExclusionCheck( conf, @@ -322,18 +324,25 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionIndex: String, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val flintJob = - new FlintInstance( - applicationId, - jobId, - sessionId, - "running", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - excludeJobIds) - flintSessionIndexUpdater.upsert( + val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) + val currentTime = currentTimeProvider.currentEpochMillis() + val flintJob = new FlintInstance( + applicationId, + jobId, sessionId, - FlintInstance.serialize(flintJob, currentTimeProvider.currentEpochMillis())) + "running", + currentTime, + jobStartTime, + excludeJobIds) + + val serializedFlintInstance = if (includeJobId) { + FlintInstance.serialize(flintJob, currentTime, true) + } else { + FlintInstance.serializeWithoutJobId(flintJob, currentTime) + } + + flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) + logDebug( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") } @@ -391,7 +400,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val currentTime = currentTimeProvider.currentEpochMillis() flintSessionIndexUpdater.upsert( sessionId, - FlintInstance.serialize(flintInstance, currentTime)) + FlintInstance.serializeWithoutJobId(flintInstance, currentTime)) } /** @@ -816,7 +825,9 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), + FlintInstance.serializeWithoutJobId( + flintInstance, + currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -833,6 +844,8 @@ object FlintREPL extends Logging with FlintJobExecutor { * the thread pool. * @param osClient * the OpenSearch client. + * @param initialDelayMillis + * the intial delay to start heartbeat */ def createHeartBeatUpdater( currentInterval: Long, @@ -840,7 +853,8 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId: String, threadPool: ScheduledExecutorService, osClient: OSClient, - sessionIndex: String): Unit = { + sessionIndex: String, + initialDelayMillis: Long): Unit = { threadPool.scheduleAtFixedRate( new Runnable { @@ -853,7 +867,9 @@ object FlintREPL extends Logging with FlintJobExecutor { flintInstance.state = "running" flintSessionUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), + FlintInstance.serializeWithoutJobId( + flintInstance, + currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -867,7 +883,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } }, - 0L, + initialDelayMillis, currentInterval, java.util.concurrent.TimeUnit.MILLISECONDS) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 8335f2a72..7b9fcc140 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -81,7 +81,8 @@ class FlintREPLTest "session1", threadPool, osClient, - "sessionIndex") + "sessionIndex", + 0) // Verifications verify(osClient, atLeastOnce()).getDoc("sessionIndex", "session1") From e201f09e0a2a3a0403e0fc3ffc3c5c95d1e1924c Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 21 Nov 2023 13:35:11 -0800 Subject: [PATCH 4/6] Request index not exist handling (#169) * handle index not found and not avaiable exception Signed-off-by: Peng Huo * add refresh=wait_until when create metalog Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../opensearch/flint/core/FlintClient.java | 12 +++ .../metadata/log/FlintMetadataLogEntry.scala | 77 +++++++++++++ .../core/storage/FlintOpenSearchClient.java | 35 ++++-- .../storage/FlintOpenSearchMetadataLog.java | 30 ++++-- .../flint/core/storage/OpenSearchUpdater.java | 19 +++- .../opensearch/flint/spark/FlintSpark.scala | 2 +- .../flint/OpenSearchTransactionSuite.scala | 30 +++++- .../core/FlintOpenSearchClientSuite.scala | 8 +- .../flint/core/FlintTransactionITSuite.scala | 67 +++++++++++- .../flint/core/OpenSearchUpdaterSuite.scala | 102 ++++++++++++++++++ .../spark/FlintSparkIndexJobITSuite.scala | 14 ++- .../spark/FlintSparkTransactionITSuite.scala | 12 ++- .../org/apache/spark/sql/FlintREPL.scala | 8 +- 13 files changed, 378 insertions(+), 38 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index c1f5d78c1..6cdf5187d 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -28,6 +28,18 @@ public interface FlintClient { */ OptimisticTransaction startTransaction(String indexName, String dataSourceName); + /** + * + * Start a new optimistic transaction. + * + * @param indexName index name + * @param dataSourceName TODO: read from elsewhere in future + * @param forceInit forceInit create empty translog if not exist. + * @return transaction handle + */ + OptimisticTransaction startTransaction(String indexName, String dataSourceName, + boolean forceInit); + /** * Create a Flint index with the metadata given. * diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala index fea9974c6..eb93c7fde 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.core.metadata.log import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState +import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} /** * Flint metadata log entry. This is temporary and will merge field in FlintMetadata here and move @@ -92,4 +93,80 @@ object FlintMetadataLogEntry { .getOrElse(IndexState.UNKNOWN) } } + + val QUERY_EXECUTION_REQUEST_MAPPING: String = + """{ + | "dynamic": false, + | "properties": { + | "version": { + | "type": "keyword" + | }, + | "type": { + | "type": "keyword" + | }, + | "state": { + | "type": "keyword" + | }, + | "statementId": { + | "type": "keyword" + | }, + | "applicationId": { + | "type": "keyword" + | }, + | "sessionId": { + | "type": "keyword" + | }, + | "sessionType": { + | "type": "keyword" + | }, + | "error": { + | "type": "text" + | }, + | "lang": { + | "type": "keyword" + | }, + | "query": { + | "type": "text" + | }, + | "dataSourceName": { + | "type": "keyword" + | }, + | "submitTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "jobId": { + | "type": "keyword" + | }, + | "lastUpdateTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "queryId": { + | "type": "keyword" + | }, + | "excludeJobIds": { + | "type": "keyword" + | } + | } + |}""".stripMargin + + val QUERY_EXECUTION_REQUEST_SETTINGS: String = + """{ + | "index": { + | "number_of_shards": "1", + | "auto_expand_replicas": "0-2", + | "number_of_replicas": "0" + | } + |}""".stripMargin + + def failLogEntry(dataSourceName: String, error: String): FlintMetadataLogEntry = + FlintMetadataLogEntry( + "", + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + 0L, + IndexState.FAILED, + dataSourceName, + error) } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 8652f8092..92a749d86 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -44,14 +44,15 @@ import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; import org.opensearch.flint.core.metadata.log.OptimisticTransaction; -import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NoOptimisticTransaction; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import scala.Option; +import scala.Some; /** * Flint client implementation for OpenSearch storage. @@ -80,34 +81,48 @@ public FlintOpenSearchClient(FlintOptions options) { } @Override - public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { + public OptimisticTransaction startTransaction(String indexName, String dataSourceName, + boolean forceInit) { LOG.info("Starting transaction on index " + indexName + " and data source " + dataSourceName); String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; - try (RestHighLevelClient client = createClient()) { if (client.indices().exists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { LOG.info("Found metadata log index " + metaLogIndexName); - return new DefaultOptimisticTransaction<>(dataSourceName, - new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName)); } else { - LOG.info("Metadata log index not found " + metaLogIndexName); - return new NoOptimisticTransaction<>(); + if (forceInit) { + createIndex(metaLogIndexName, FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_MAPPING(), + Some.apply(FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_SETTINGS())); + } else { + String errorMsg = "Metadata log index not found " + metaLogIndexName; + LOG.warning(errorMsg); + throw new IllegalStateException(errorMsg); + } } + return new DefaultOptimisticTransaction<>(dataSourceName, + new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName)); } catch (IOException e) { throw new IllegalStateException("Failed to check if index metadata log index exists " + metaLogIndexName, e); } } + @Override + public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { + return startTransaction(indexName, dataSourceName, false); + } + @Override public void createIndex(String indexName, FlintMetadata metadata) { LOG.info("Creating Flint index " + indexName + " with metadata " + metadata); + createIndex(indexName, metadata.getContent(), metadata.indexSettings()); + } + + protected void createIndex(String indexName, String mapping, Option settings) { + LOG.info("Creating Flint index " + indexName); String osIndexName = toLowercase(indexName); try (RestHighLevelClient client = createClient()) { CreateIndexRequest request = new CreateIndexRequest(osIndexName); - request.mapping(metadata.getContent(), XContentType.JSON); - - Option settings = metadata.indexSettings(); + request.mapping(mapping, XContentType.JSON); if (settings.isDefined()) { request.settings(settings.get(), XContentType.JSON); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 07029d608..f51e8a628 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -5,12 +5,6 @@ package org.opensearch.flint.core.storage; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy; - -import java.io.IOException; -import java.util.Base64; -import java.util.Optional; -import java.util.logging.Logger; import org.opensearch.OpenSearchException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.get.GetRequest; @@ -19,11 +13,20 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.metadata.log.FlintMetadataLog; import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; +import java.io.IOException; +import java.util.Base64; +import java.util.Optional; +import java.util.logging.Logger; + +import static java.util.logging.Level.SEVERE; +import static org.opensearch.action.support.WriteRequest.RefreshPolicy; + /** * Flint metadata log in OpenSearch store. For now use single doc instead of maintaining history * of metadata log. @@ -57,6 +60,11 @@ public FlintOpenSearchMetadataLog(FlintClient flintClient, String flintIndexName public FlintMetadataLogEntry add(FlintMetadataLogEntry logEntry) { // TODO: use single doc for now. this will be always append in future. FlintMetadataLogEntry latest; + if (!exists()) { + String errorMsg = "Flint Metadata Log index not found " + metaLogIndexName; + LOG.log(SEVERE, errorMsg); + throw new IllegalStateException(errorMsg); + } if (logEntry.id().isEmpty()) { latest = createLogEntry(logEntry); } else { @@ -108,6 +116,7 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { new IndexRequest() .index(metaLogIndexName) .id(logEntryWithId.id()) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) .source(logEntryWithId.toJson(), XContentType.JSON), RequestOptions.DEFAULT)); } @@ -148,6 +157,15 @@ private FlintMetadataLogEntry writeLogEntry( } } + private boolean exists() { + LOG.info("Checking if Flint index exists " + metaLogIndexName); + try (RestHighLevelClient client = flintClient.createClient()) { + return client.indices().exists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); + } catch (IOException e) { + throw new IllegalStateException("Failed to check if Flint index exists " + metaLogIndexName, e); + } + } + @FunctionalInterface public interface CheckedFunction { R apply(T t) throws IOException; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index 58963ab74..4a6424512 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -4,14 +4,17 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; import org.opensearch.flint.core.FlintClient; -import org.opensearch.flint.core.FlintClientBuilder; -import org.opensearch.flint.core.FlintOptions; import java.io.IOException; +import java.util.logging.Level; +import java.util.logging.Logger; public class OpenSearchUpdater { + private static final Logger LOG = Logger.getLogger(OpenSearchUpdater.class.getName()); + private final String indexName; private final FlintClient flintClient; @@ -28,6 +31,7 @@ public void upsert(String id, String doc) { // also, failure to close the client causes the job to be stuck in the running state as the client resource // is not released. try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -44,6 +48,7 @@ public void upsert(String id, String doc) { public void update(String id, String doc) { try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -59,6 +64,7 @@ public void update(String id, String doc) { public void updateIf(String id, String doc, long seqNo, long primaryTerm) { try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -73,4 +79,13 @@ public void updateIf(String id, String doc, long seqNo, long primaryTerm) { id), e); } } + + private void assertIndexExist(RestHighLevelClient client, String indexName) throws IOException { + LOG.info("Checking if index exists " + indexName); + if (!client.indices().exists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { + String errorMsg = "Index not found " + indexName; + LOG.log(Level.SEVERE, errorMsg); + throw new IllegalStateException(errorMsg); + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index e9331113a..47ade0f87 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -104,7 +104,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { val metadata = index.metadata() try { flintClient - .startTransaction(indexName, dataSourceName) + .startTransaction(indexName, dataSourceName, true) .initialLog(latest => latest.state == EMPTY || latest.state == DELETED) .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) diff --git a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala index 1e7077799..ba9acffd1 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala @@ -14,9 +14,10 @@ import org.opensearch.action.get.GetRequest import org.opensearch.action.index.IndexRequest import org.opensearch.action.update.UpdateRequest import org.opensearch.client.RequestOptions -import org.opensearch.client.indices.CreateIndexRequest +import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest} import org.opensearch.common.xcontent.XContentType import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.{QUERY_EXECUTION_REQUEST_MAPPING, QUERY_EXECUTION_REQUEST_SETTINGS} import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState import org.opensearch.flint.core.storage.FlintOpenSearchClient._ import org.opensearch.flint.spark.FlintSparkSuite @@ -39,13 +40,15 @@ trait OpenSearchTransactionSuite extends FlintSparkSuite { super.beforeEach() openSearchClient .indices() - .create(new CreateIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + .create( + new CreateIndexRequest(testMetaLogIndex) + .mapping(QUERY_EXECUTION_REQUEST_MAPPING, XContentType.JSON) + .settings(QUERY_EXECUTION_REQUEST_SETTINGS, XContentType.JSON), + RequestOptions.DEFAULT) } override def afterEach(): Unit = { - openSearchClient - .indices() - .delete(new DeleteIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + deleteIndex(testMetaLogIndex) super.afterEach() } @@ -71,4 +74,21 @@ trait OpenSearchTransactionSuite extends FlintSparkSuite { .doc(latest.copy(state = newState).toJson, XContentType.JSON), RequestOptions.DEFAULT) } + + def deleteIndex(indexName: String): Unit = { + if (openSearchClient + .indices() + .exists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { + openSearchClient + .indices() + .delete(new DeleteIndexRequest(indexName), RequestOptions.DEFAULT) + } + } + + def indexMapping(): String = { + val response = + openSearchClient.indices.get(new GetIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + + response.getMappings.get(testMetaLogIndex).source().toString + } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index 9a762d9d6..7da67051d 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -16,7 +16,6 @@ import org.opensearch.client.opensearch.OpenSearchClient import org.opensearch.client.transport.rest_client.RestClientTransport import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NoOptimisticTransaction import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchScrollReader} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -31,9 +30,10 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M behavior of "Flint OpenSearch client" - it should "start no optimistic transaction if metadata log index doesn't exists" in { - val transaction = flintClient.startTransaction("test", "non-exist-index") - transaction shouldBe a[NoOptimisticTransaction[AnyRef]] + it should "throw IllegalStateException if metadata log index doesn't exists" in { + the[IllegalStateException] thrownBy { + flintClient.startTransaction("test", "non-exist-index") + } } it should "create index successfully" in { diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala index a8b5a1fa2..fa072898b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala @@ -9,6 +9,8 @@ import java.util.Base64 import scala.collection.JavaConverters.mapAsJavaMapConverter +import org.json4s.{Formats, NoTypeHints} +import org.json4s.native.{JsonMethods, Serialization} import org.opensearch.flint.OpenSearchTransactionSuite import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ @@ -214,7 +216,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latestLogEntry(testLatestId) should contain("state" -> "active") } - test("should not necessarily rollback if transaction operation failed but no transient action") { + test( + "should not necessarily rollback if transaction operation failed but no transient action") { // Use create index scenario in this test case the[IllegalStateException] thrownBy { flintClient @@ -227,4 +230,66 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { // Should rollback to initial empty log latestLogEntry(testLatestId) should contain("state" -> "empty") } + + test("forceInit translog, even index is deleted before startTransaction") { + deleteIndex(testMetaLogIndex) + flintClient + .startTransaction(testFlintIndex, testDataSourceName, true) + .initialLog(latest => { + latest.id shouldBe testLatestId + latest.state shouldBe EMPTY + latest.createTime shouldBe 0L + latest.dataSource shouldBe testDataSourceName + latest.error shouldBe "" + true + }) + .finalLog(latest => latest) + .commit(_ => {}) + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + (JsonMethods.parse(indexMapping()) \ "properties" \ "sessionId" \ "type") + .extract[String] should equal("keyword") + } + + test("should fail if index is deleted before initial operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => { + deleteIndex(testMetaLogIndex) + true + }) + .transientLog(latest => latest.copy(state = CREATING)) + .finalLog(latest => latest.copy(state = ACTIVE)) + .commit(_ => {}) + } + } + + test("should fail if index is deleted before transient operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => true) + .transientLog(latest => { + deleteIndex(testMetaLogIndex) + latest.copy(state = CREATING) + }) + .finalLog(latest => latest.copy(state = ACTIVE)) + .commit(_ => {}) + } + } + + test("should fail if index is deleted before final operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => true) + .transientLog(latest => { latest.copy(state = CREATING) }) + .finalLog(latest => { + deleteIndex(testMetaLogIndex) + latest.copy(state = ACTIVE) + }) + .commit(_ => {}) + } + } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala new file mode 100644 index 000000000..3b317a0fe --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.client.RequestOptions +import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.app.FlintInstance +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} +import org.scalatest.matchers.should.Matchers + +class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { + val sessionId = "sessionId" + val timestamp = 1700090926955L + val flintJob = + new FlintInstance( + "applicationId", + "jobId", + sessionId, + "running", + timestamp, + timestamp, + Seq("")) + var flintClient: FlintClient = _ + var updater: OpenSearchUpdater = _ + + override def beforeAll(): Unit = { + super.beforeAll() + flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); + updater = new OpenSearchUpdater( + testMetaLogIndex, + new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + } + + test("upsert flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe timestamp + } + + test("index is deleted when upsert flintJob should throw IllegalStateException") { + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + } + } + + test("update flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + + val newTimestamp = 1700090926956L + updater.update(sessionId, FlintInstance.serialize(flintJob, newTimestamp)) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp + } + + test("index is deleted when update flintJob should throw IllegalStateException") { + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.update(sessionId, FlintInstance.serialize(flintJob, timestamp)) + } + } + + test("updateIf flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + val (resp, latest) = getFlintInstance(sessionId) + + val newTimestamp = 1700090926956L + updater.updateIf( + sessionId, + FlintInstance.serialize(latest, newTimestamp), + resp.getSeqNo, + resp.getPrimaryTerm) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp + } + + test("index is deleted when updateIf flintJob should throw IllegalStateException") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + val (resp, latest) = getFlintInstance(sessionId) + + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.updateIf( + sessionId, + FlintInstance.serialize(latest, timestamp), + resp.getSeqNo, + resp.getPrimaryTerm) + } + } + + def getFlintInstance(docId: String): (GetResponse, FlintInstance) = { + val response = + openSearchClient.get(new GetRequest(testMetaLogIndex, docId), RequestOptions.DEFAULT) + (response, FlintInstance.deserializeFromMap(response.getSourceAsMap)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala index 365aab83d..8df2bc472 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala @@ -29,8 +29,18 @@ class FlintSparkIndexJobITSuite extends OpenSearchTransactionSuite with Matchers } override def afterEach(): Unit = { - super.afterEach() // must clean up metadata log first and then delete - flint.deleteIndex(testIndex) + + /** + * Todo, if state is not valid, will throw IllegalStateException. Should check flint + * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) if + * failed, delete index itself. + */ + try { + flint.deleteIndex(testIndex) + } catch { + case _: IllegalStateException => deleteIndex(testIndex) + } + super.afterEach() } test("recover should exit if index doesn't exist") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index 294449a48..56227533a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -33,8 +33,18 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match } override def afterEach(): Unit = { + + /** + * Todo, if state is not valid, will throw IllegalStateException. Should check flint + * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) if + * failed, delete index itself. + */ + try { + flint.deleteIndex(testFlintIndex) + } catch { + case _: IllegalStateException => deleteIndex(testFlintIndex) + } super.afterEach() - flint.deleteIndex(testFlintIndex) } test("create index") { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 28ce90d62..0f6c21786 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -6,14 +6,10 @@ package org.apache.spark.sql import java.net.ConnectException -import java.time.Instant -import java.util.Map import java.util.concurrent.ScheduledExecutorService -import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} -import scala.concurrent.duration._ -import scala.concurrent.duration.{Duration, MINUTES} +import scala.concurrent.duration.{Duration, MINUTES, _} import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -46,7 +42,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) + private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L From 2f295ee6bd18eb42bb2465f4ea8cb1a61e7fb1bc Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Mon, 20 Nov 2023 18:11:17 -0800 Subject: [PATCH 5/6] Metrics Addition Signed-off-by: Vamsi Manohar --- build.sbt | 10 +- .../spark/metrics/sink/CloudWatchSink.java | 261 ++++++ .../DimensionedCloudWatchReporter.java | 819 ++++++++++++++++++ .../InvalidMetricsPropertyException.java | 20 + .../metrics/sink/CloudWatchSinkTests.java | 84 ++ .../DimensionedCloudWatchReporterTest.java | 544 ++++++++++++ project/plugins.sbt | 2 + 7 files changed, 1739 insertions(+), 1 deletion(-) create mode 100644 flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java create mode 100644 flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java create mode 100644 flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java create mode 100644 flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java create mode 100644 flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java diff --git a/build.sbt b/build.sbt index 938f19a64..ccb735ae0 100644 --- a/build.sbt +++ b/build.sbt @@ -63,7 +63,15 @@ lazy val flintCore = (project in file("flint-core")) "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", - "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test"), + "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", + "org.mockito" % "mockito-core" % "2.23.0" % "test", + "org.mockito" % "mockito-junit-jupiter" % "3.12.4" % "test", + "org.junit.jupiter" % "junit-jupiter-api" % "5.9.0" % "test", + "org.junit.jupiter" % "junit-jupiter-engine" % "5.9.0" % "test", + "com.google.truth" % "truth" % "1.1.5" % "test", + "net.aichler" % "jupiter-interface" % "0.11.1" % Test + ), + libraryDependencies ++= deps(sparkVersion), publish / skip := true) lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) diff --git a/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java new file mode 100644 index 000000000..293a05d4a --- /dev/null +++ b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.metrics.sink; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.regions.AwsRegionProvider; +import com.amazonaws.regions.DefaultAwsRegionProviderChain; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsync; +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.ScheduledReporter; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.apache.spark.SecurityManager; +import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; +import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; + +/** + * Implementation of the Spark metrics {@link Sink} interface + * for reporting internal Spark metrics into CloudWatch. Spark's metric system uses DropWizard's + * metric library internally, so this class simply wraps the {@link DimensionedCloudWatchReporter} + * with the constructor and methods mandated for Spark metric Sinks. + * + * @see org.apache.spark.metrics.MetricsSystem + * @see ScheduledReporter + * @author kmccaw + */ +public class CloudWatchSink implements Sink { + + private final ScheduledReporter reporter; + + private final long pollingPeriod; + + private final boolean shouldParseInlineDimensions; + + private final boolean shouldAppendDropwizardTypeDimension; + + private final TimeUnit pollingTimeUnit; + + /** + * Constructor with the signature required by Spark, which loads the class through reflection. + * + * @see org.apache.spark.metrics.MetricsSystem + * @param properties Properties for this sink defined in Spark's "metrics.properties" configuration file. + * @param metricRegistry The DropWizard MetricRegistry used by Sparks {@link org.apache.spark.metrics.MetricsSystem} + * @param securityManager Unused argument; required by the Spark sink constructor signature. + */ + public CloudWatchSink( + final Properties properties, + final MetricRegistry metricRegistry, + final SecurityManager securityManager) { + // First extract properties defined in the Spark metrics configuration + + // Extract the required namespace property. This is used as the namespace + // for all metrics reported to CloudWatch + final Optional namespaceProperty = getProperty(properties, PropertyKeys.NAMESPACE); + if (!namespaceProperty.isPresent()) { + final String message = "CloudWatch Spark metrics sink requires '" + + PropertyKeys.NAMESPACE + "' property."; + throw new InvalidMetricsPropertyException(message); + } + + // Extract the optional AWS credentials. If either of the access or secret keys are + // missing in the properties, fall back to using the credentials of the EC2 instance. + final Optional accessKeyProperty = getProperty(properties, PropertyKeys.AWS_ACCESS_KEY_ID); + final Optional secretKeyProperty = getProperty(properties, PropertyKeys.AWS_SECRET_KEY); + final AWSCredentialsProvider awsCredentialsProvider; + if (accessKeyProperty.isPresent() && secretKeyProperty.isPresent()) { + final AWSCredentials awsCredentials = new BasicAWSCredentials( + accessKeyProperty.get(), + secretKeyProperty.get()); + awsCredentialsProvider = new AWSStaticCredentialsProvider(awsCredentials); + } else { + // If the AWS credentials aren't specified in the properties, fall back to using the + // DefaultAWSCredentialsProviderChain, which looks for credentials in the order + // (1) Environment Variables + // (2) Java System Properties + // (3) Credentials file at ~/.aws/credentials + // (4) AWS_CONTAINER_CREDENTIALS_RELATIVE_URI + // (5) EC2 Instance profile credentials + awsCredentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); + } + + // Extract the AWS region CloudWatch metrics should be reported to. + final Optional regionProperty = getProperty(properties, PropertyKeys.AWS_REGION); + final Regions awsRegion; + if (regionProperty.isPresent()) { + try { + awsRegion = Regions.fromName(regionProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + regionProperty.get(), + PropertyKeys.AWS_REGION); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + final AwsRegionProvider regionProvider = new DefaultAwsRegionProviderChain(); + awsRegion = Regions.fromName(regionProvider.getRegion()); + } + + // Extract the polling period, the interval at which metrics are reported. + final Optional pollingPeriodProperty = getProperty(properties, PropertyKeys.POLLING_PERIOD); + if (pollingPeriodProperty.isPresent()) { + try { + final long parsedPollingPeriod = Long.parseLong(pollingPeriodProperty.get()); + // Confirm that the value of this property is a positive number + if (parsedPollingPeriod <= 0) { + final String message = String.format( + "The value (%s) of the \"%s\" CloudWatchSink metrics property is non-positive.", + pollingPeriodProperty.get(), + PropertyKeys.POLLING_PERIOD); + throw new InvalidMetricsPropertyException(message); + } + pollingPeriod = parsedPollingPeriod; + } catch (NumberFormatException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + pollingPeriodProperty.get(), + PropertyKeys.POLLING_PERIOD); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + pollingPeriod = PropertyDefaults.POLLING_PERIOD; + } + + final Optional pollingTimeUnitProperty = getProperty(properties, PropertyKeys.POLLING_TIME_UNIT); + if (pollingTimeUnitProperty.isPresent()) { + try { + pollingTimeUnit = TimeUnit.valueOf(pollingTimeUnitProperty.get().toUpperCase()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + pollingTimeUnitProperty.get(), + PropertyKeys.POLLING_TIME_UNIT); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + pollingTimeUnit = PropertyDefaults.POLLING_PERIOD_TIME_UNIT; + } + + // Extract the inline dimension parsing setting. + final Optional shouldParseInlineDimensionsProperty = getProperty( + properties, + PropertyKeys.SHOULD_PARSE_INLINE_DIMENSIONS); + if (shouldParseInlineDimensionsProperty.isPresent()) { + try { + shouldParseInlineDimensions = Boolean.parseBoolean(shouldParseInlineDimensionsProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + shouldParseInlineDimensionsProperty.get(), + PropertyKeys.SHOULD_PARSE_INLINE_DIMENSIONS); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + shouldParseInlineDimensions = PropertyDefaults.SHOULD_PARSE_INLINE_DIMENSIONS; + } + + // Extract the setting to append dropwizard metrics types as a dimension + final Optional shouldAppendDropwizardTypeDimensionProperty = getProperty( + properties, + PropertyKeys.SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION); + if (shouldAppendDropwizardTypeDimensionProperty.isPresent()) { + try { + shouldAppendDropwizardTypeDimension = Boolean.parseBoolean(shouldAppendDropwizardTypeDimensionProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + shouldAppendDropwizardTypeDimensionProperty.get(), + PropertyKeys.SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + shouldAppendDropwizardTypeDimension = PropertyDefaults.SHOULD_PARSE_INLINE_DIMENSIONS; + } + + final AmazonCloudWatchAsync cloudWatchClient = AmazonCloudWatchAsyncClient.asyncBuilder() + .withCredentials(awsCredentialsProvider) + .withRegion(awsRegion) + .build(); + + this.reporter = DimensionedCloudWatchReporter.forRegistry(metricRegistry, cloudWatchClient, namespaceProperty.get()) + .convertRatesTo(TimeUnit.SECONDS) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .filter(MetricFilter.ALL) + .withPercentiles( + DimensionedCloudWatchReporter.Percentile.P50, + DimensionedCloudWatchReporter.Percentile.P75, + DimensionedCloudWatchReporter.Percentile.P99) + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withMeanRate() + .withArithmeticMean() + .withStdDev() + .withStatisticSet() + .withGlobalDimensions() + .withShouldParseDimensionsFromName(shouldParseInlineDimensions) + .withShouldAppendDropwizardTypeDimension(shouldAppendDropwizardTypeDimension) + .build(); + } + + @Override + public void start() { + reporter.start(pollingPeriod, pollingTimeUnit); + } + + @Override + public void stop() { + reporter.stop(); + } + + @Override + public void report() { + reporter.report(); + } + + /** + * Returns the value for specified property key as an Optional. + * @param properties + * @param key + * @return + */ + private static Optional getProperty(Properties properties, final String key) { + return Optional.ofNullable(properties.getProperty(key)); + } + + /** + * The keys used in the metrics properties configuration file. + */ + private static class PropertyKeys { + static final String NAMESPACE = "namespace"; + static final String AWS_ACCESS_KEY_ID = "awsAccessKeyId"; + static final String AWS_SECRET_KEY = "awsSecretKey"; + static final String AWS_REGION = "awsRegion"; + static final String POLLING_PERIOD = "pollingPeriod"; + static final String POLLING_TIME_UNIT = "pollingTimeUnit"; + static final String SHOULD_PARSE_INLINE_DIMENSIONS = "shouldParseInlineDimensions"; + static final String SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION = "shouldAppendDropwizardTypeDimension"; + } + + /** + * The default values for optional properties in the metrics properties configuration file. + */ + private static class PropertyDefaults { + static final long POLLING_PERIOD = 1; + static final TimeUnit POLLING_PERIOD_TIME_UNIT = TimeUnit.MINUTES; + static final boolean SHOULD_PARSE_INLINE_DIMENSIONS = false; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java new file mode 100644 index 000000000..450fe0d0d --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -0,0 +1,819 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsync; +import com.amazonaws.services.cloudwatch.model.Dimension; +import com.amazonaws.services.cloudwatch.model.InvalidParameterValueException; +import com.amazonaws.services.cloudwatch.model.MetricDatum; +import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest; +import com.amazonaws.services.cloudwatch.model.PutMetricDataResult; +import com.amazonaws.services.cloudwatch.model.StandardUnit; +import com.amazonaws.services.cloudwatch.model.StatisticSet; +import com.amazonaws.util.StringUtils; +import com.codahale.metrics.Clock; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Counting; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.Metered; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.ScheduledReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import com.codahale.metrics.jvm.BufferPoolMetricSet; +import com.codahale.metrics.jvm.ClassLoadingGaugeSet; +import com.codahale.metrics.jvm.FileDescriptorRatioGauge; +import com.codahale.metrics.jvm.GarbageCollectorMetricSet; +import com.codahale.metrics.jvm.MemoryUsageGaugeSet; +import com.codahale.metrics.jvm.ThreadStatesGaugeSet; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Date; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reports metrics to Amazon's CloudWatch periodically. + *

+ * Use {@link Builder} to construct instances of this class. The {@link Builder} + * allows to configure what aggregated metrics will be reported as a single {@link MetricDatum} to CloudWatch. + *

+ * There are a bunch of {@code with*} methods that provide a sufficient fine-grained control over what metrics + * should be reported. + * + * Forked from https://github.com/azagniotov/codahale-aggregated-metrics-cloudwatch-reporter. + */ +public class DimensionedCloudWatchReporter extends ScheduledReporter { + + private static final Logger LOGGER = LoggerFactory.getLogger(DimensionedCloudWatchReporter.class); + + // Visible for testing + public static final String DIMENSION_NAME_TYPE = "Type"; + + // Visible for testing + public static final String DIMENSION_GAUGE = "gauge"; + + // Visible for testing + public static final String DIMENSION_COUNT = "count"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_SUMMARY = "snapshot-summary"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_MEAN = "snapshot-mean"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_STD_DEV = "snapshot-std-dev"; + + /** + * Amazon CloudWatch rejects values that are either too small or too large. + * Values must be in the range of 8.515920e-109 to 1.174271e+108 (Base 10) or 2e-360 to 2e360 (Base 2). + *

+ * In addition, special values (e.g., NaN, +Infinity, -Infinity) are not supported. + */ + private static final double SMALLEST_SENDABLE_VALUE = 8.515920e-109; + private static final double LARGEST_SENDABLE_VALUE = 1.174271e+108; + + /** + * Each CloudWatch API request may contain at maximum 20 datums + */ + private static final int MAXIMUM_DATUMS_PER_REQUEST = 20; + + /** + * We only submit the difference in counters since the last submission. This way we don't have to reset + * the counters within this application. + */ + private final Map lastPolledCounts; + + private final Builder builder; + private final String namespace; + private final AmazonCloudWatchAsync cloudWatchAsyncClient; + private final StandardUnit rateUnit; + private final StandardUnit durationUnit; + private final boolean shouldParseDimensionsFromName; + private final boolean shouldAppendDropwizardTypeDimension; + + private DimensionedCloudWatchReporter(final Builder builder) { + super(builder.metricRegistry, "coda-hale-metrics-cloud-watch-reporter", builder.metricFilter, builder.rateUnit, builder.durationUnit); + this.builder = builder; + this.namespace = builder.namespace; + this.cloudWatchAsyncClient = builder.cloudWatchAsyncClient; + this.lastPolledCounts = new ConcurrentHashMap<>(); + this.rateUnit = builder.cwRateUnit; + this.durationUnit = builder.cwDurationUnit; + this.shouldParseDimensionsFromName = builder.withShouldParseDimensionsFromName; + this.shouldAppendDropwizardTypeDimension = builder.withShouldAppendDropwizardTypeDimension; + } + + @Override + public void report(final SortedMap gauges, + final SortedMap counters, + final SortedMap histograms, + final SortedMap meters, + final SortedMap timers) { + + if (builder.withDryRun) { + LOGGER.warn("** Reporter is running in 'DRY RUN' mode **"); + } + + try { + final List metricData = new ArrayList<>( + gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size()); + + for (final Map.Entry gaugeEntry : gauges.entrySet()) { + processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + } + + for (final Map.Entry counterEntry : counters.entrySet()) { + processCounter(counterEntry.getKey(), counterEntry.getValue(), metricData); + } + + for (final Map.Entry histogramEntry : histograms.entrySet()) { + processCounter(histogramEntry.getKey(), histogramEntry.getValue(), metricData); + processHistogram(histogramEntry.getKey(), histogramEntry.getValue(), metricData); + } + + for (final Map.Entry meterEntry : meters.entrySet()) { + processCounter(meterEntry.getKey(), meterEntry.getValue(), metricData); + processMeter(meterEntry.getKey(), meterEntry.getValue(), metricData); + } + + for (final Map.Entry timerEntry : timers.entrySet()) { + processCounter(timerEntry.getKey(), timerEntry.getValue(), metricData); + processMeter(timerEntry.getKey(), timerEntry.getValue(), metricData); + processTimer(timerEntry.getKey(), timerEntry.getValue(), metricData); + } + + final Collection> metricDataPartitions = partition(metricData, MAXIMUM_DATUMS_PER_REQUEST); + final List> cloudWatchFutures = new ArrayList<>(metricData.size()); + + for (final List partition : metricDataPartitions) { + final PutMetricDataRequest putMetricDataRequest = new PutMetricDataRequest() + .withNamespace(namespace) + .withMetricData(partition); + + if (builder.withDryRun) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dry run - constructed PutMetricDataRequest: {}", putMetricDataRequest); + } + } else { + cloudWatchFutures.add(cloudWatchAsyncClient.putMetricDataAsync(putMetricDataRequest)); + } + } + + for (final Future cloudWatchFuture : cloudWatchFutures) { + try { + cloudWatchFuture.get(); + } catch (final Exception e) { + LOGGER.error("Error reporting metrics to CloudWatch. The data in this CloudWatch API request " + + "may have been discarded, did not make it to CloudWatch.", e); + } + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Sent {} metric datums to CloudWatch. Namespace: {}, metric data {}", metricData.size(), namespace, metricData); + } + + } catch (final RuntimeException e) { + LOGGER.error("Error marshalling CloudWatch metrics.", e); + } + } + + @Override + public void stop() { + try { + super.stop(); + } catch (final Exception e) { + LOGGER.error("Error when stopping the reporter.", e); + } finally { + if (!builder.withDryRun) { + try { + cloudWatchAsyncClient.shutdown(); + } catch (final Exception e) { + LOGGER.error("Error shutting down AmazonCloudWatchAsync", cloudWatchAsyncClient, e); + } + } + } + } + + private void processGauge(final String metricName, final Gauge gauge, final List metricData) { + if (gauge.getValue() instanceof Number) { + final Number number = (Number) gauge.getValue(); + stageMetricDatum(true, metricName, number.doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData); + } + } + + private void processCounter(final String metricName, final Counting counter, final List metricData) { + long currentCount = counter.getCount(); + Long lastCount = lastPolledCounts.get(counter); + lastPolledCounts.put(counter, currentCount); + + if (lastCount == null) { + lastCount = 0L; + } + + // Only submit metrics that have changed - let's save some money! + final long delta = currentCount - lastCount; + stageMetricDatum(true, metricName, delta, StandardUnit.Count, DIMENSION_COUNT, metricData); + } + + /** + * The rates of {@link Metered} are reported after being converted using the rate factor, which is deduced from + * the set rate unit + * + * @see Timer#getSnapshot + * @see #getRateUnit + * @see #convertRate(double) + */ + private void processMeter(final String metricName, final Metered meter, final List metricData) { + final String formattedRate = String.format("-rate [per-%s]", getRateUnit()); + stageMetricDatum(builder.withOneMinuteMeanRate, metricName, convertRate(meter.getOneMinuteRate()), rateUnit, "1-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withFiveMinuteMeanRate, metricName, convertRate(meter.getFiveMinuteRate()), rateUnit, "5-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withFifteenMinuteMeanRate, metricName, convertRate(meter.getFifteenMinuteRate()), rateUnit, "15-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withMeanRate, metricName, convertRate(meter.getMeanRate()), rateUnit, "mean" + formattedRate, metricData); + } + + /** + * The {@link Snapshot} values of {@link Timer} are reported as {@link StatisticSet} after conversion. The + * conversion is done using the duration factor, which is deduced from the set duration unit. + *

+ * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + * + * @see Timer#getSnapshot + * @see #getDurationUnit + * @see #convertDuration(double) + */ + private void processTimer(final String metricName, final Timer timer, final List metricData) { + final Snapshot snapshot = timer.getSnapshot(); + + if (builder.withZeroValuesSubmission || snapshot.size() > 0) { + for (final Percentile percentile : builder.percentiles) { + final double convertedDuration = convertDuration(snapshot.getValue(percentile.getQuantile())); + stageMetricDatum(true, metricName, convertedDuration, durationUnit, percentile.getDesc(), metricData); + } + } + + // prevent empty snapshot from causing InvalidParameterValueException + if (snapshot.size() > 0) { + final String formattedDuration = String.format(" [in-%s]", getDurationUnit()); + stageMetricDatum(builder.withArithmeticMean, metricName, convertDuration(snapshot.getMean()), durationUnit, DIMENSION_SNAPSHOT_MEAN + formattedDuration, metricData); + stageMetricDatum(builder.withStdDev, metricName, convertDuration(snapshot.getStdDev()), durationUnit, DIMENSION_SNAPSHOT_STD_DEV + formattedDuration, metricData); + stageMetricDatumWithConvertedSnapshot(builder.withStatisticSet, metricName, snapshot, durationUnit, metricData); + } + } + + /** + * The {@link Snapshot} values of {@link Histogram} are reported as {@link StatisticSet} raw. In other words, the + * conversion using the duration factor does NOT apply. + *

+ * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + * + * @see Histogram#getSnapshot + */ + private void processHistogram(final String metricName, final Histogram histogram, final List metricData) { + final Snapshot snapshot = histogram.getSnapshot(); + + if (builder.withZeroValuesSubmission || snapshot.size() > 0) { + for (final Percentile percentile : builder.percentiles) { + final double value = snapshot.getValue(percentile.getQuantile()); + stageMetricDatum(true, metricName, value, StandardUnit.None, percentile.getDesc(), metricData); + } + } + + // prevent empty snapshot from causing InvalidParameterValueException + if (snapshot.size() > 0) { + stageMetricDatum(builder.withArithmeticMean, metricName, snapshot.getMean(), StandardUnit.None, DIMENSION_SNAPSHOT_MEAN, metricData); + stageMetricDatum(builder.withStdDev, metricName, snapshot.getStdDev(), StandardUnit.None, DIMENSION_SNAPSHOT_STD_DEV, metricData); + stageMetricDatumWithRawSnapshot(builder.withStatisticSet, metricName, snapshot, StandardUnit.None, metricData); + } + } + + /** + * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + */ + private void stageMetricDatum(final boolean metricConfigured, + final String metricName, + final double metricValue, + final StandardUnit standardUnit, + final String dimensionValue, + final List metricData) { + // Only submit metrics that show some data, so let's save some money + if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + final String name; + if (shouldParseDimensionsFromName) { + final String[] nameParts = metricName.split(" "); + final StringBuilder nameBuilder = new StringBuilder(nameParts[0]); + int i = 1; + for (; i < nameParts.length; ++i) { + final String[] dimensionParts = nameParts[i].split("="); + if (dimensionParts.length == 2 + && !StringUtils.isNullOrEmpty(dimensionParts[0]) + && !StringUtils.isNullOrEmpty(dimensionParts[1])) { + final Dimension dimension = new Dimension(); + dimension.withName(dimensionParts[0]); + dimension.withValue(dimensionParts[1]); + dimensions.add(dimension); + } else { + nameBuilder.append(" "); + nameBuilder.append(nameParts[i]); + } + } + name = nameBuilder.toString(); + } else { + name = metricName; + } + + if (shouldAppendDropwizardTypeDimension) { + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue)); + } + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withValue(cleanMetricValue(metricValue)) + .withMetricName(name) + .withDimensions(dimensions) + .withUnit(standardUnit)); + } + } + + private void stageMetricDatumWithConvertedSnapshot(final boolean metricConfigured, + final String metricName, + final Snapshot snapshot, + final StandardUnit standardUnit, + final List metricData) { + if (metricConfigured) { + double scaledSum = convertDuration(LongStream.of(snapshot.getValues()).sum()); + final StatisticSet statisticSet = new StatisticSet() + .withSum(scaledSum) + .withSampleCount((double) snapshot.size()) + .withMinimum(convertDuration(snapshot.getMin())) + .withMaximum(convertDuration(snapshot.getMax())); + + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_SUMMARY)); + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withMetricName(metricName) + .withDimensions(dimensions) + .withStatisticValues(statisticSet) + .withUnit(standardUnit)); + } + } + + private void stageMetricDatumWithRawSnapshot(final boolean metricConfigured, + final String metricName, + final Snapshot snapshot, + final StandardUnit standardUnit, + final List metricData) { + if (metricConfigured) { + double total = LongStream.of(snapshot.getValues()).sum(); + final StatisticSet statisticSet = new StatisticSet() + .withSum(total) + .withSampleCount((double) snapshot.size()) + .withMinimum((double) snapshot.getMin()) + .withMaximum((double) snapshot.getMax()); + + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_SUMMARY)); + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withMetricName(metricName) + .withDimensions(dimensions) + .withStatisticValues(statisticSet) + .withUnit(standardUnit)); + } + } + + private double cleanMetricValue(final double metricValue) { + double absoluteValue = Math.abs(metricValue); + if (absoluteValue < SMALLEST_SENDABLE_VALUE) { + // Allow 0 through untouched, everything else gets rounded to SMALLEST_SENDABLE_VALUE + if (absoluteValue > 0) { + if (metricValue < 0) { + return -SMALLEST_SENDABLE_VALUE; + } else { + return SMALLEST_SENDABLE_VALUE; + } + } + } else if (absoluteValue > LARGEST_SENDABLE_VALUE) { + if (metricValue < 0) { + return -LARGEST_SENDABLE_VALUE; + } else { + return LARGEST_SENDABLE_VALUE; + } + } + return metricValue; + } + + private static Collection> partition(final Collection wholeCollection, final int partitionSize) { + final int[] itemCounter = new int[]{0}; + + return wholeCollection.stream() + .collect(Collectors.groupingBy(item -> itemCounter[0]++ / partitionSize)) + .values(); + } + + /** + * Creates a new {@link Builder} that sends values from the given {@link MetricRegistry} to the given namespace + * using the given CloudWatch client. + * + * @param metricRegistry {@link MetricRegistry} instance + * @param client {@link AmazonCloudWatchAsync} instance + * @param namespace the namespace. Must be non-null and not empty. + * @return {@link Builder} instance + */ + public static Builder forRegistry( + final MetricRegistry metricRegistry, + final AmazonCloudWatchAsync client, + final String namespace) { + return new Builder(metricRegistry, client, namespace); + } + + public enum Percentile { + P50(0.50, "50%"), + P75(0.75, "75%"), + P95(0.95, "95%"), + P98(0.98, "98%"), + P99(0.99, "99%"), + P995(0.995, "99.5%"), + P999(0.999, "99.9%"); + + private final double quantile; + private final String desc; + + Percentile(final double quantile, final String desc) { + this.quantile = quantile; + this.desc = desc; + } + + public double getQuantile() { + return quantile; + } + + public String getDesc() { + return desc; + } + } + + public static class Builder { + + private final String namespace; + private final AmazonCloudWatchAsync cloudWatchAsyncClient; + private final MetricRegistry metricRegistry; + + private Percentile[] percentiles; + private boolean withOneMinuteMeanRate; + private boolean withFiveMinuteMeanRate; + private boolean withFifteenMinuteMeanRate; + private boolean withMeanRate; + private boolean withArithmeticMean; + private boolean withStdDev; + private boolean withDryRun; + private boolean withZeroValuesSubmission; + private boolean withStatisticSet; + private boolean withJvmMetrics; + private boolean withShouldParseDimensionsFromName; + private boolean withShouldAppendDropwizardTypeDimension=true; + private MetricFilter metricFilter; + private TimeUnit rateUnit; + private TimeUnit durationUnit; + private StandardUnit cwRateUnit; + private StandardUnit cwDurationUnit; + private Set globalDimensions; + private final Clock clock; + + private Builder( + final MetricRegistry metricRegistry, + final AmazonCloudWatchAsync cloudWatchAsyncClient, + final String namespace) { + this.metricRegistry = metricRegistry; + this.cloudWatchAsyncClient = cloudWatchAsyncClient; + this.namespace = namespace; + this.percentiles = new Percentile[]{Percentile.P75, Percentile.P95, Percentile.P999}; + this.metricFilter = MetricFilter.ALL; + this.rateUnit = TimeUnit.SECONDS; + this.durationUnit = TimeUnit.MILLISECONDS; + this.globalDimensions = new LinkedHashSet<>(); + this.cwRateUnit = toStandardUnit(rateUnit); + this.cwDurationUnit = toStandardUnit(durationUnit); + this.clock = Clock.defaultClock(); + } + + /** + * Convert rates to the given time unit. + * + * @param rateUnit a unit of time + * @return {@code this} + */ + public Builder convertRatesTo(final TimeUnit rateUnit) { + this.rateUnit = rateUnit; + return this; + } + + /** + * Convert durations to the given time unit. + * + * @param durationUnit a unit of time + * @return {@code this} + */ + public Builder convertDurationsTo(final TimeUnit durationUnit) { + this.durationUnit = durationUnit; + return this; + } + + /** + * Only report metrics which match the given filter. + * + * @param metricFilter a {@link MetricFilter} + * @return {@code this} + */ + public Builder filter(final MetricFilter metricFilter) { + this.metricFilter = metricFilter; + return this; + } + + /** + * If the one minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getOneMinuteRate() + * @see Timer#getOneMinuteRate() + */ + public Builder withOneMinuteMeanRate() { + withOneMinuteMeanRate = true; + return this; + } + + /** + * If the five minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getFiveMinuteRate() + * @see Timer#getFiveMinuteRate() + */ + public Builder withFiveMinuteMeanRate() { + withFiveMinuteMeanRate = true; + return this; + } + + /** + * If the fifteen minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getFifteenMinuteRate() + * @see Timer#getFifteenMinuteRate() + */ + public Builder withFifteenMinuteMeanRate() { + withFifteenMinuteMeanRate = true; + return this; + } + + /** + * If the mean rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getMeanRate() + * @see Timer#getMeanRate() + */ + public Builder withMeanRate() { + withMeanRate = true; + return this; + } + + /** + * If the arithmetic mean of {@link Snapshot} values in {@link Histogram} and {@link Timer} should be sent. + * {@code false} by default. + *

+ * The {@link Timer#getSnapshot()} values are converted before reporting based on the duration unit set + * The {@link Histogram#getSnapshot()} values are reported as is + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + * @see Snapshot#getMean() + */ + public Builder withArithmeticMean() { + withArithmeticMean = true; + return this; + } + + /** + * If the standard deviation of {@link Snapshot} values in {@link Histogram} and {@link Timer} should be sent. + * {@code false} by default. + *

+ * The {@link Timer#getSnapshot()} values are converted before reporting based on the duration unit set + * The {@link Histogram#getSnapshot()} values are reported as is + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + * @see Snapshot#getStdDev() + */ + public Builder withStdDev() { + withStdDev = true; + return this; + } + + /** + * If lifetime {@link Snapshot} summary of {@link Histogram} and {@link Timer} should be translated + * to {@link StatisticSet} in the most direct way possible and reported. {@code false} by default. + *

+ * The {@link Snapshot} duration values are converted before reporting based on the duration unit set + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + */ + public Builder withStatisticSet() { + withStatisticSet = true; + return this; + } + + /** + * If JVM statistic should be reported. Supported metrics include: + *

+ * - Run count and elapsed times for all supported garbage collectors + * - Memory usage for all memory pools, including off-heap memory + * - Breakdown of thread states, including deadlocks + * - File descriptor usage + * - Buffer pool sizes and utilization (Java 7 only) + *

+ * {@code false} by default. + * + * @return {@code this} + */ + public Builder withJvmMetrics() { + withJvmMetrics = true; + return this; + } + + /** + * If CloudWatch dimensions should be parsed off the the metric name: + * + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withShouldParseDimensionsFromName(final boolean value) { + withShouldParseDimensionsFromName = value; + return this; + } + + /** + * If the Dropwizard metric type should be reported as a CloudWatch dimension. + * + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withShouldAppendDropwizardTypeDimension(final boolean value) { + withShouldAppendDropwizardTypeDimension = value; + return this; + } + + /** + * Does not actually POST to CloudWatch, logs the {@link PutMetricDataRequest putMetricDataRequest} instead. + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withDryRun() { + withDryRun = true; + return this; + } + + /** + * POSTs to CloudWatch all values. Otherwise, the reporter does not POST values which are zero in order to save + * costs. Also, some users have been experiencing {@link InvalidParameterValueException} when submitting zero + * values. Please refer to: + * https://github.com/azagniotov/codahale-aggregated-metrics-cloudwatch-reporter/issues/4 + *

+ * {@code false} by default. + * + * @return {@code this} + */ + public Builder withZeroValuesSubmission() { + withZeroValuesSubmission = true; + return this; + } + + /** + * The {@link Histogram} and {@link Timer} percentiles to send. If 0.5 is included, it'll be + * reported as median.This defaults to 0.75, 0.95 and 0.999. + *

+ * The {@link Timer#getSnapshot()} percentile values are converted before reporting based on the duration unit + * The {@link Histogram#getSnapshot()} percentile values are reported as is + * + * @param percentiles the percentiles to send. Replaces the default percentiles. + * @return {@code this} + */ + public Builder withPercentiles(final Percentile... percentiles) { + if (percentiles.length > 0) { + this.percentiles = percentiles; + } + return this; + } + + /** + * Global {@link Set} of {@link Dimension} to send with each {@link MetricDatum}. A dimension is a name/value + * pair that helps you to uniquely identify a metric. Every metric has specific characteristics that describe + * it, and you can think of dimensions as categories for those characteristics. + *

+ * Whenever you add a unique name/value pair to one of your metrics, you are creating a new metric. + * Defaults to {@code empty} {@link Set}. + * + * @param dimensions arguments in a form of {@code name=value}. The number of arguments is variable and may be + * zero. The maximum number of arguments is limited by the maximum dimension of a Java array + * as defined by the Java Virtual Machine Specification. Each {@code name=value} string + * will be converted to an instance of {@link Dimension} + * @return {@code this} + */ + public Builder withGlobalDimensions(final String... dimensions) { + for (final String pair : dimensions) { + final List splitted = Stream.of(pair.split("=")).map(String::trim).collect(Collectors.toList()); + this.globalDimensions.add(new Dimension().withName(splitted.get(0)).withValue(splitted.get(1))); + } + return this; + } + + public DimensionedCloudWatchReporter build() { + + if (withJvmMetrics) { + metricRegistry.register("jvm.uptime", (Gauge) () -> ManagementFactory.getRuntimeMXBean().getUptime()); + metricRegistry.register("jvm.current_time", (Gauge) clock::getTime); + metricRegistry.register("jvm.classes", new ClassLoadingGaugeSet()); + metricRegistry.register("jvm.fd_usage", new FileDescriptorRatioGauge()); + metricRegistry.register("jvm.buffers", new BufferPoolMetricSet(ManagementFactory.getPlatformMBeanServer())); + metricRegistry.register("jvm.gc", new GarbageCollectorMetricSet()); + metricRegistry.register("jvm.memory", new MemoryUsageGaugeSet()); + metricRegistry.register("jvm.thread-states", new ThreadStatesGaugeSet()); + } + + cwRateUnit = toStandardUnit(rateUnit); + cwDurationUnit = toStandardUnit(durationUnit); + + return new DimensionedCloudWatchReporter(this); + } + + private StandardUnit toStandardUnit(final TimeUnit timeUnit) { + switch (timeUnit) { + case SECONDS: + return StandardUnit.Seconds; + case MILLISECONDS: + return StandardUnit.Milliseconds; + case MICROSECONDS: + return StandardUnit.Microseconds; + default: + throw new IllegalArgumentException("Unsupported TimeUnit: " + timeUnit); + } + } + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java new file mode 100644 index 000000000..56755d545 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + + +import java.io.Serializable; + +public class InvalidMetricsPropertyException extends RuntimeException implements Serializable { + + public InvalidMetricsPropertyException(final String message) { + super(message); + } + + public InvalidMetricsPropertyException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java new file mode 100644 index 000000000..6f87276a8 --- /dev/null +++ b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package apache.spark.metrics.sink; + +import org.apache.spark.SecurityManager; +import com.codahale.metrics.MetricRegistry; +import org.apache.spark.metrics.sink.CloudWatchSink; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import java.util.Properties; +import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; + +class CloudWatchSinkTests { + private final MetricRegistry metricRegistry = Mockito.mock(MetricRegistry.class); + private final SecurityManager securityManager = Mockito.mock(SecurityManager.class); + + @Test + void should_throwException_when_namespacePropertyIsNotSet() { + final Properties properties = getDefaultValidProperties(); + properties.remove("namespace"); + final Executable executable = () -> { + final CloudWatchSink + cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_awsPropertyIsInvalid() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("awsRegion", "someInvalidRegion"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingPeriodPropertyIsNotANumber() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingPeriod", "notANumber"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingPeriodPropertyIsNegative() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingPeriod", "-5"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingTimeUnitPropertyIsInvalid() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingTimeUnit", "notATimeUnitValue"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + private Properties getDefaultValidProperties() { + final Properties properties = new Properties(); + properties.setProperty("namespace", "namespaceValue"); + properties.setProperty("awsAccessKeyId", "awsAccessKeyIdValue"); + properties.setProperty("awsSecretKey", "awsSecretKeyValue"); + properties.setProperty("awsRegion", "us-east-1"); + properties.setProperty("pollingPeriod", "1"); + properties.setProperty("pollingTimeUnit", "MINUTES"); + return properties; + } +} diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java new file mode 100644 index 000000000..991fd78b4 --- /dev/null +++ b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java @@ -0,0 +1,544 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package opensearch.flint.core.metrics.reporter; + +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient; +import com.amazonaws.services.cloudwatch.model.Dimension; +import com.amazonaws.services.cloudwatch.model.MetricDatum; +import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest; +import com.amazonaws.services.cloudwatch.model.PutMetricDataResult; +import com.codahale.metrics.EWMA; +import com.codahale.metrics.ExponentialMovingAverages; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SlidingWindowReservoir; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; + +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Count; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Microseconds; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Milliseconds; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.None; +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_COUNT; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_GAUGE; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_NAME_TYPE; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_MEAN; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_STD_DEV; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_SUMMARY; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class DimensionedCloudWatchReporterTest { + + private static final String NAMESPACE = "namespace"; + private static final String ARBITRARY_COUNTER_NAME = "TheCounter"; + private static final String ARBITRARY_METER_NAME = "TheMeter"; + private static final String ARBITRARY_HISTOGRAM_NAME = "TheHistogram"; + private static final String ARBITRARY_TIMER_NAME = "TheTimer"; + private static final String ARBITRARY_GAUGE_NAME = "TheGauge"; + + @Mock + private AmazonCloudWatchAsyncClient mockAmazonCloudWatchAsyncClient; + + @Mock + private Future mockPutMetricDataResultFuture; + + @Captor + private ArgumentCaptor metricDataRequestCaptor; + + private MetricRegistry metricRegistry; + private DimensionedCloudWatchReporter.Builder reporterBuilder; + + @BeforeAll + public static void beforeClass() throws Exception { + reduceExponentialMovingAveragesDefaultTickInterval(); + } + + @BeforeEach + public void setUp() throws Exception { + metricRegistry = new MetricRegistry(); + reporterBuilder = DimensionedCloudWatchReporter.forRegistry(metricRegistry, mockAmazonCloudWatchAsyncClient, NAMESPACE); + when(mockAmazonCloudWatchAsyncClient.putMetricDataAsync(metricDataRequestCaptor.capture())).thenReturn(mockPutMetricDataResultFuture); + } + + @Test + public void shouldNotInvokeCloudWatchClientInDryRunMode() { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withDryRun().build().report(); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportWithoutGlobalDimensionsWhenGlobalDimensionsNotConfigured() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); // When 'withGlobalDimensions' was not called + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).hasSize(1); + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); + } + + @Test + public void reportedCounterShouldContainExpectedDimension() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); + } + + @Test + public void reportedCounterShouldContainDimensionEmbeddedInName() throws Exception { + final String DIMENSION_NAME = "some_dimension"; + final String DIMENSION_VALUE = "some_value"; + + metricRegistry.counter(ARBITRARY_COUNTER_NAME + " " + DIMENSION_NAME + "=" + DIMENSION_VALUE).inc(); + reporterBuilder.withShouldParseDimensionsFromName(true).build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME).withValue(DIMENSION_VALUE)); + } + + @Test + public void reportedGaugeShouldContainExpectedDimension() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 1L); + reporterBuilder.build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_GAUGE)); + } + + @Test + public void shouldNotReportGaugeWhenMetricValueNotOfTypeNumber() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> "bad value type"); + reporterBuilder.build().report(); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void neverReportMetersCountersGaugesWithZeroValues() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 0L); + metricRegistry.meter(ARBITRARY_METER_NAME).mark(0); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(0); + + buildReportWithSleep(reporterBuilder + .withArithmeticMean() + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withMeanRate()); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void reportMetersCountersGaugesWithZeroValuesOnlyWhenConfigured() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 0L); + metricRegistry.meter(ARBITRARY_METER_NAME).mark(0); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(0); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(-1L, TimeUnit.NANOSECONDS); + + buildReportWithSleep(reporterBuilder + .withArithmeticMean() + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withZeroValuesSubmission() + .withMeanRate()); + + verify(mockAmazonCloudWatchAsyncClient, times(1)).putMetricDataAsync(metricDataRequestCaptor.capture()); + + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + for (final MetricDatum metricDatum : metricData) { + assertThat(metricDatum.getValue()).isEqualTo(0.0); + } + } + + @Test + public void reportedMeterShouldContainExpectedOneMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withOneMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("1-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedFiveMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withFiveMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("5-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedFifteenMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withFifteenMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("15-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + reporterBuilder.withMeanRate().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("mean-rate [per-second]")); + } + + @Test + public void reportedHistogramShouldContainExpectedArithmeticMeanDimension() throws Exception { + metricRegistry.histogram(ARBITRARY_HISTOGRAM_NAME).update(1); + reporterBuilder.withArithmeticMean().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_MEAN)); + } + + @Test + public void reportedHistogramShouldContainExpectedStdDevDimension() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + reporterBuilder.withStdDev().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_STD_DEV)); + } + + @Test + public void reportedTimerShouldContainExpectedArithmeticMeanDimension() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.MILLISECONDS); + reporterBuilder.withArithmeticMean().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("snapshot-mean [in-milliseconds]")); + } + + @Test + public void reportedTimerShouldContainExpectedStdDevDimension() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1, TimeUnit.MILLISECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.MILLISECONDS); + reporterBuilder.withStdDev().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("snapshot-std-dev [in-milliseconds]")); + } + + @Test + public void shouldReportExpectedSingleGlobalDimension() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); + } + + @Test + public void shouldReportExpectedMultipleGlobalDimensions() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2", "Instance=stage").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); + assertThat(dimensions).contains(new Dimension().withName("Instance").withValue("stage")); + } + + @Test + public void shouldNotReportDuplicateGlobalDimensions() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2", "Region=us-west-2").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).containsNoDuplicates(); + } + + @Test + public void shouldReportExpectedCounterValue() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); + + final MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + + assertThat(metricDatum.getValue()).isWithin(1.0); + assertThat(metricDatum.getUnit()).isEqualTo(Count.toString()); + } + + @Test + public void shouldNotReportUnchangedCounterValue() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + final DimensionedCloudWatchReporter dimensionedCloudWatchReporter = reporterBuilder.build(); + + dimensionedCloudWatchReporter.report(); + MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(1); + metricDataRequestCaptor.getAllValues().clear(); + + dimensionedCloudWatchReporter.report(); + + verify(mockAmazonCloudWatchAsyncClient, times(1)).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportCounterValueDelta() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + final DimensionedCloudWatchReporter dimensionedCloudWatchReporter = reporterBuilder.build(); + + dimensionedCloudWatchReporter.report(); + MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(2); + metricDataRequestCaptor.getAllValues().clear(); + + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + + dimensionedCloudWatchReporter.report(); + metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(6); + + verify(mockAmazonCloudWatchAsyncClient, times(2)).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportArithmeticMeanAfterConversionByDefaultDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1_000_000, TimeUnit.NANOSECONDS); + reporterBuilder.withArithmeticMean().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest("snapshot-mean [in-milliseconds]"); + + assertThat(metricData.getValue().intValue()).isEqualTo(1); + assertThat(metricData.getUnit()).isEqualTo(Milliseconds.toString()); + } + + @Test + public void shouldReportStdDevAfterConversionByDefaultDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(2_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(30_000_000, TimeUnit.NANOSECONDS); + reporterBuilder.withStdDev().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest("snapshot-std-dev [in-milliseconds]"); + + assertThat(metricData.getValue().intValue()).isEqualTo(12); + assertThat(metricData.getUnit()).isEqualTo(Milliseconds.toString()); + } + + @Test + public void shouldReportSnapshotValuesAfterConversionByCustomDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(2, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(30, TimeUnit.SECONDS); + reporterBuilder.withStatisticSet().convertDurationsTo(TimeUnit.MICROSECONDS).build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(36_000_000); + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30_000_000); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1_000_000); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(metricData.getUnit()).isEqualTo(Microseconds.toString()); + } + + @Test + public void shouldReportArithmeticMeanWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + reporterBuilder.withArithmeticMean().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_MEAN); + + assertThat(metricData.getValue().intValue()).isEqualTo(1); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportStdDevWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(3); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(30); + reporterBuilder.withStdDev().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_STD_DEV); + + assertThat(metricData.getValue().intValue()).isEqualTo(12); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportSnapshotValuesWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(3); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(30); + reporterBuilder.withStatisticSet().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(36); + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportHistogramSubsequentSnapshotValues_SumMaxMinValues() throws Exception { + DimensionedCloudWatchReporter reporter = reporterBuilder.withStatisticSet().build(); + + final Histogram slidingWindowHistogram = new Histogram(new SlidingWindowReservoir(4)); + metricRegistry.register("SlidingWindowHistogram", slidingWindowHistogram); + + slidingWindowHistogram.update(1); + slidingWindowHistogram.update(2); + slidingWindowHistogram.update(30); + reporter.report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(3); + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(33); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + + slidingWindowHistogram.update(4); + slidingWindowHistogram.update(100); + slidingWindowHistogram.update(5); + slidingWindowHistogram.update(6); + reporter.report(); + + final MetricDatum secondMetricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(secondMetricData.getStatisticValues().getMaximum().intValue()).isEqualTo(100); + assertThat(secondMetricData.getStatisticValues().getMinimum().intValue()).isEqualTo(4); + assertThat(secondMetricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(secondMetricData.getStatisticValues().getSum().intValue()).isEqualTo(115); + assertThat(secondMetricData.getUnit()).isEqualTo(None.toString()); + + } + + private MetricDatum metricDatumByDimensionFromCapturedRequest(final String dimensionValue) { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + + final Optional metricDatumOptional = + metricData + .stream() + .filter(metricDatum -> metricDatum.getDimensions() + .contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue))) + .findFirst(); + + if (metricDatumOptional.isPresent()) { + return metricDatumOptional.get(); + } + + throw new IllegalStateException("Could not find MetricDatum for Dimension value: " + dimensionValue); + } + + private MetricDatum firstMetricDatumFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + return putMetricDataRequest.getMetricData().get(0); + } + + private List firstMetricDatumDimensionsFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final MetricDatum metricDatum = putMetricDataRequest.getMetricData().get(0); + return metricDatum.getDimensions(); + } + + private List allDimensionsFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + final List all = new LinkedList<>(); + for (final MetricDatum metricDatum : metricData) { + all.addAll(metricDatum.getDimensions()); + } + return all; + } + + private void buildReportWithSleep(final DimensionedCloudWatchReporter.Builder dimensionedCloudWatchReporterBuilder) throws InterruptedException { + final DimensionedCloudWatchReporter cloudWatchReporter = dimensionedCloudWatchReporterBuilder.build(); + Thread.sleep(10); + cloudWatchReporter.report(); + } + + /** + * This is a very ugly way to fool the {@link EWMA} by reducing the default tick interval + * in {@link ExponentialMovingAverages} from {@code 5} seconds to {@code 1} millisecond in order to ensure that + * exponentially-weighted moving average rates are populated. This helps to verify that all + * the expected {@link Dimension}s are present in {@link MetricDatum}. + * + * @throws NoSuchFieldException + * @throws IllegalAccessException + * @see ExponentialMovingAverages#tickIfNecessary() + * @see MetricDatum#getDimensions() + */ + private static void reduceExponentialMovingAveragesDefaultTickInterval() throws NoSuchFieldException, IllegalAccessException { + setFinalStaticField(ExponentialMovingAverages.class, "TICK_INTERVAL", TimeUnit.MILLISECONDS.toNanos(1)); + } + + private static void setFinalStaticField(final Class clazz, final String fieldName, long value) throws NoSuchFieldException, IllegalAccessException { + final Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + final Field modifiers = field.getClass().getDeclaredField("modifiers"); + modifiers.setAccessible(true); + modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); + field.set(null, value); + } + +} diff --git a/project/plugins.sbt b/project/plugins.sbt index 0fe5dd1ab..38550667b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,3 +8,5 @@ addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") addSbtPlugin("com.lightbend.sbt" % "sbt-java-formatter" % "0.8.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.0") addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1") +addSbtPlugin("net.aichler" % "sbt-jupiter-interface" % "0.11.1") \ No newline at end of file From be8202480be851f2aba6d13a2edb12a379b0cd58 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 22 Nov 2023 12:44:39 -0800 Subject: [PATCH 6/6] Migrate resutIndex and requestIndex state inconsistency issue (#174) Signed-off-by: Peng Huo --- .../main/scala/org/apache/spark/sql/FlintREPL.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 0f6c21786..674c0a75f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -273,7 +273,8 @@ object FlintREPL extends Logging with FlintJobExecutor { var canPickUpNextStatement = true while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logDebug(s"""read from ${commandContext.sessionIndex}""") + logInfo( + s"""read from ${commandContext.sessionIndex}, sessionId: $commandContext.sessionId""") val flintReader: FlintReader = createQueryReader( commandContext.osClient, @@ -339,7 +340,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - logDebug( + logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") } @@ -524,6 +525,9 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + // todo. it is migration plan to handle https://github + // .com/opensearch-project/sql/issues/2436. Remove sleep after issue fixed in plugin. + Thread.sleep(2000) if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() @@ -691,7 +695,7 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logDebug(s"command complete: $flintCommand") + logInfo(s"command complete: $flintCommand") (dataToWrite, verificationResult) }