diff --git a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java index ce35c34e8..23205fe99 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core; +import org.opensearch.OpenSearchException; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteRequest; @@ -26,6 +27,7 @@ import org.opensearch.client.indices.GetIndexResponse; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.client.RequestOptions; +import org.opensearch.flint.core.metrics.MetricsUtil; import java.io.Closeable; import java.io.IOException; @@ -52,11 +54,62 @@ public interface IRestHighLevelClient extends Closeable { IndexResponse index(IndexRequest indexRequest, RequestOptions options) throws IOException; - Boolean isIndexExists(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException; + Boolean doesIndexExist(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException; SearchResponse search(SearchRequest searchRequest, RequestOptions options) throws IOException; SearchResponse scroll(SearchScrollRequest searchScrollRequest, RequestOptions options) throws IOException; DocWriteResponse update(UpdateRequest updateRequest, RequestOptions options) throws IOException; + + + /** + * Records the success of an OpenSearch operation by incrementing the corresponding metric counter. + * This method constructs the metric name by appending ".200.count" to the provided metric name prefix. + * The metric name is then used to increment the counter, indicating a successful operation. + * + * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for success + */ + static void recordOperationSuccess(String metricNamePrefix) { + String successMetricName = metricNamePrefix + ".2xx.count"; + MetricsUtil.incrementCounter(successMetricName); + } + + /** + * Records the failure of an OpenSearch operation by incrementing the corresponding metric counter. + * If the exception is an OpenSearchException with a specific status code (e.g., 403), + * it increments a metric specifically for that status code. + * Otherwise, it increments a general failure metric counter based on the status code category (e.g., 4xx, 5xx). + * + * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for failure + * @param e the exception encountered during the operation, used to determine the type of failure + */ + static void recordOperationFailure(String metricNamePrefix, Exception e) { + OpenSearchException openSearchException = extractOpenSearchException(e); + int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500; + + if (statusCode == 403) { + String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; + MetricsUtil.incrementCounter(forbiddenErrorMetricName); + } + + String failureMetricName = metricNamePrefix + "." + (statusCode / 100) + "xx.count"; + MetricsUtil.incrementCounter(failureMetricName); + } + + /** + * Extracts an OpenSearchException from the given Throwable. + * Checks if the Throwable is an instance of OpenSearchException or caused by one. + * + * @param ex the exception to be checked + * @return the extracted OpenSearchException, or null if not found + */ + private static OpenSearchException extractOpenSearchException(Throwable ex) { + if (ex instanceof OpenSearchException) { + return (OpenSearchException) ex; + } else if (ex.getCause() instanceof OpenSearchException) { + return (OpenSearchException) ex.getCause(); + } + return null; + } } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java index 3556c7e24..bf48af52d 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java @@ -20,7 +20,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.OpenSearchException; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; @@ -28,7 +27,6 @@ import org.opensearch.client.indices.CreateIndexResponse; import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.client.indices.GetIndexResponse; -import org.opensearch.flint.core.metrics.MetricsUtil; import java.io.IOException; @@ -91,7 +89,7 @@ public IndexResponse index(IndexRequest indexRequest, RequestOptions options) th } @Override - public Boolean isIndexExists(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { + public Boolean doesIndexExist(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { return execute(OS_READ_OP_METRIC_PREFIX, () -> client.indices().exists(getIndexRequest, options)); } @@ -122,64 +120,14 @@ public UpdateResponse update(UpdateRequest updateRequest, RequestOptions options private T execute(String metricNamePrefix, IOCallable operation) throws IOException { try { T result = operation.call(); - recordOperationSuccess(metricNamePrefix); + IRestHighLevelClient.recordOperationSuccess(metricNamePrefix); return result; } catch (Exception e) { - recordOperationFailure(metricNamePrefix, e); + IRestHighLevelClient.recordOperationFailure(metricNamePrefix, e); throw e; } } - /** - * Records the success of an OpenSearch operation by incrementing the corresponding metric counter. - * This method constructs the metric name by appending ".200.count" to the provided metric name prefix. - * The metric name is then used to increment the counter, indicating a successful operation. - * - * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for success - */ - private void recordOperationSuccess(String metricNamePrefix) { - String successMetricName = metricNamePrefix + ".2xx.count"; - MetricsUtil.incrementCounter(successMetricName); - } - - /** - * Records the failure of an OpenSearch operation by incrementing the corresponding metric counter. - * If the exception is an OpenSearchException with a specific status code (e.g., 403), - * it increments a metric specifically for that status code. - * Otherwise, it increments a general failure metric counter based on the status code category (e.g., 4xx, 5xx). - * - * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for failure - * @param e the exception encountered during the operation, used to determine the type of failure - */ - private void recordOperationFailure(String metricNamePrefix, Exception e) { - OpenSearchException openSearchException = extractOpenSearchException(e); - int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500; - - if (statusCode == 403) { - String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; - MetricsUtil.incrementCounter(forbiddenErrorMetricName); - } - - String failureMetricName = metricNamePrefix + "." + (statusCode / 100) + "xx.count"; - MetricsUtil.incrementCounter(failureMetricName); - } - - /** - * Extracts an OpenSearchException from the given Throwable. - * Checks if the Throwable is an instance of OpenSearchException or caused by one. - * - * @param ex the exception to be checked - * @return the extracted OpenSearchException, or null if not found - */ - private OpenSearchException extractOpenSearchException(Throwable ex) { - if (ex instanceof OpenSearchException) { - return (OpenSearchException) ex; - } else if (ex.getCause() instanceof OpenSearchException) { - return (OpenSearchException) ex.getCause(); - } - return null; - } - /** * Functional interface for operations that can throw IOException. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 72b96a091..6a081a740 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -8,7 +8,7 @@ /** * This class defines custom metric constants used for monitoring flint operations. */ -public class MetricConstants { +public final class MetricConstants { /** * The prefix for all read-related metrics in OpenSearch. @@ -47,6 +47,26 @@ public class MetricConstants { */ public static final String REPL_PROCESSING_TIME_METRIC = "session.processingTime"; + /** + * Prefix for metrics related to the request metadata read operations. + */ + public static final String REQUEST_METADATA_READ_METRIC_PREFIX = "request.metadata.read"; + + /** + * Prefix for metrics related to the request metadata write operations. + */ + public static final String REQUEST_METADATA_WRITE_METRIC_PREFIX = "request.metadata.write"; + + /** + * Metric name for counting failed heartbeat operations on request metadata. + */ + public static final String REQUEST_METADATA_HEARTBEAT_FAILED_METRIC = "request.metadata.heartbeat.failed.count"; + + /** + * Prefix for metrics related to the result metadata write operations. + */ + public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write"; + /** * Metric name for counting the number of statements currently running. */ @@ -65,5 +85,29 @@ public class MetricConstants { /** * Metric name for tracking the processing time of statements. */ - public static final String STATEMENT_PROCESSING_TIME_METRIC = "STATEMENT.processingTime"; + public static final String STATEMENT_PROCESSING_TIME_METRIC = "statement.processingTime"; + + /** + * Metric for tracking the count of currently running streaming jobs. + */ + public static final String STREAMING_RUNNING_METRIC = "streaming.running.count"; + + /** + * Metric for tracking the count of streaming jobs that have failed. + */ + public static final String STREAMING_FAILED_METRIC = "streaming.failed.count"; + + /** + * Metric for tracking the count of streaming jobs that have completed successfully. + */ + public static final String STREAMING_SUCCESS_METRIC = "streaming.success.count"; + + /** + * Metric for tracking the count of failed heartbeat signals in streaming jobs. + */ + public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + + private MetricConstants() { + // Private constructor to prevent instantiation + } } \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 13227a039..8e63992f5 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -6,12 +6,15 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.Source; import scala.collection.Seq; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; /** @@ -21,8 +24,8 @@ public final class MetricsUtil { private static final Logger LOG = Logger.getLogger(MetricsUtil.class.getName()); - // Private constructor to prevent instantiation private MetricsUtil() { + // Private constructor to prevent instantiation } /** @@ -60,10 +63,7 @@ public static void decrementCounter(String metricName) { */ public static Timer.Context getTimerContext(String metricName) { Timer timer = getOrCreateTimer(metricName); - if (timer != null) { - return timer.time(); - } - return null; + return timer != null ? timer.time() : null; } /** @@ -74,42 +74,47 @@ public static Timer.Context getTimerContext(String metricName) { * @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}. */ public static Long stopTimer(Timer.Context context) { - if (context != null) { - return context.stop(); + return context != null ? context.stop() : null; + } + + /** + * Registers a gauge metric with the provided name and value. + * The gauge will reflect the current value of the AtomicInteger provided. + * + * @param metricName The name of the gauge metric to register. + * @param value The AtomicInteger whose current value should be reflected by the gauge. + */ + public static void registerGauge(String metricName, final AtomicInteger value) { + MetricRegistry metricRegistry = getMetricRegistry(); + if (metricRegistry == null) { + LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName); + return; } - return null; + metricRegistry.register(metricName, (Gauge) value::get); } // Retrieves or creates a new counter for the given metric name private static Counter getOrCreateCounter(String metricName) { - SparkEnv sparkEnv = SparkEnv.get(); - if (sparkEnv == null) { - LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); - return null; - } - - FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - Counter counter = flintMetricSource.metricRegistry().getCounters().get(metricName); - if (counter == null) { - counter = flintMetricSource.metricRegistry().counter(metricName); - } - return counter; + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.counter(metricName) : null; } // Retrieves or creates a new Timer for the given metric name private static Timer getOrCreateTimer(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.timer(metricName) : null; + } + + // Retrieves the MetricRegistry from the current Spark environment. + private static MetricRegistry getMetricRegistry() { SparkEnv sparkEnv = SparkEnv.get(); if (sparkEnv == null) { - LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); + LOG.warning("Spark environment not available, cannot access MetricRegistry."); return null; } FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - Timer timer = flintMetricSource.metricRegistry().getTimers().get(metricName); - if (timer == null) { - timer = flintMetricSource.metricRegistry().timer(metricName); - } - return timer; + return flintMetricSource.metricRegistry(); } // Gets or initializes the FlintMetricSource diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java index ce7136507..6e3e90916 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java @@ -7,9 +7,13 @@ import java.util.Map; import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + import org.apache.commons.lang.StringUtils; import com.amazonaws.services.cloudwatch.model.Dimension; +import org.apache.spark.SparkEnv; /** * Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint. @@ -18,7 +22,9 @@ * application ID, and more. */ public class DimensionUtils { + private static final Logger LOG = Logger.getLogger(DimensionUtils.class.getName()); private static final String DIMENSION_JOB_ID = "jobId"; + private static final String DIMENSION_JOB_TYPE = "jobType"; private static final String DIMENSION_APPLICATION_ID = "applicationId"; private static final String DIMENSION_APPLICATION_NAME = "applicationName"; private static final String DIMENSION_DOMAIN_ID = "domainId"; @@ -29,6 +35,8 @@ public class DimensionUtils { private static final Map> dimensionBuilders = Map.of( DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension, DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID), + // TODO: Move FlintSparkConf into the core to prevent circular dependencies + DIMENSION_JOB_TYPE, ignored -> constructDimensionFromSparkConf(DIMENSION_JOB_TYPE, "spark.flint.job.type", UNKNOWN), DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID), DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME), DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID) @@ -39,7 +47,7 @@ public class DimensionUtils { * builder exists for the dimension name, it is used; otherwise, a default dimension is constructed. * * @param dimensionName The name of the dimension to construct. - * @param parts Additional information that might be required by specific dimension builders. + * @param metricNameParts Additional information that might be required by specific dimension builders. * @return A CloudWatch Dimension object. */ public static Dimension constructDimension(String dimensionName, String[] metricNameParts) { @@ -50,6 +58,30 @@ public static Dimension constructDimension(String dimensionName, String[] metric .apply(metricNameParts); } + /** + * Constructs a CloudWatch Dimension object using a specified Spark configuration key. + * + * @param dimensionName The name of the dimension to construct. + * @param sparkConfKey the Spark configuration key used to look up the value for the dimension. + * @param defaultValue the default value to use for the dimension if the Spark configuration key is not found or if the Spark environment is not available. + * @return A CloudWatch Dimension object. + * @throws Exception if an error occurs while accessing the Spark configuration. The exception is logged and then rethrown. + */ + public static Dimension constructDimensionFromSparkConf(String dimensionName, String sparkConfKey, String defaultValue) { + String propertyValue = defaultValue; + try { + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + propertyValue = SparkEnv.get().conf().get(sparkConfKey, defaultValue); + } else { + LOG.warning("Spark environment or configuration is not available, defaulting to provided default value."); + } + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error accessing Spark configuration with key: " + sparkConfKey + ", defaulting to provided default value.", e); + throw e; + } + return new Dimension().withName(dimensionName).withValue(propertyValue); + } + // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137 // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName. public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) { diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 29ebad206..45aedbaa6 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -99,7 +99,7 @@ public OptimisticTransaction startTransaction(String indexName, String da String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; try (IRestHighLevelClient client = createClient()) { - if (client.isIndexExists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { + if (client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { LOG.info("Found metadata log index " + metaLogIndexName); } else { if (forceInit) { @@ -149,7 +149,7 @@ public boolean exists(String indexName) { LOG.info("Checking if Flint index exists " + indexName); String osIndexName = sanitizeIndexName(indexName); try (IRestHighLevelClient client = createClient()) { - return client.isIndexExists(new GetIndexRequest(osIndexName), RequestOptions.DEFAULT); + return client.doesIndexExist(new GetIndexRequest(osIndexName), RequestOptions.DEFAULT); } catch (IOException e) { throw new IllegalStateException("Failed to check if Flint index exists " + osIndexName, e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 9c1502b29..7195ae177 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -175,7 +175,7 @@ private FlintMetadataLogEntry writeLogEntry( private boolean exists() { LOG.info("Checking if Flint index exists " + metaLogIndexName); try (IRestHighLevelClient client = flintClient.createClient()) { - return client.isIndexExists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); + return client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); } catch (IOException e) { throw new IllegalStateException("Failed to check if Flint index exists " + metaLogIndexName, e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java index 349e5c126..19ce6ce8b 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java @@ -23,6 +23,8 @@ import java.util.logging.Level; import java.util.logging.Logger; +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX; + /** * {@link OpenSearchReader} using search. https://opensearch.org/docs/latest/api-reference/search/ */ @@ -37,8 +39,15 @@ public OpenSearchQueryReader(IRestHighLevelClient client, String indexName, Sear /** * search. */ - Optional search(SearchRequest request) throws IOException { - return Optional.of(client.search(request, RequestOptions.DEFAULT)); + Optional search(SearchRequest request) { + Optional response = Optional.empty(); + try { + response = Optional.of(client.search(request, RequestOptions.DEFAULT)); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_READ_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_READ_METRIC_PREFIX, e); + } + return response; } /** diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index ffe771b15..0d84b4956 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -12,80 +12,93 @@ import java.util.logging.Level; import java.util.logging.Logger; +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX; +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_WRITE_METRIC_PREFIX; + +/** + * Provides functionality for updating and upserting documents in an OpenSearch index. + * This class utilizes FlintClient for managing connections to OpenSearch and performs + * document updates and upserts with optional optimistic concurrency control. + */ public class OpenSearchUpdater { private static final Logger LOG = Logger.getLogger(OpenSearchUpdater.class.getName()); private final String indexName; - private final FlintClient flintClient; - public OpenSearchUpdater(String indexName, FlintClient flintClient) { this.indexName = indexName; this.flintClient = flintClient; } public void upsert(String id, String doc) { - // we might need to keep the updater for a long time. Reusing the client may not work as the temporary - // credentials may expire. - // also, failure to close the client causes the job to be stuck in the running state as the client resource - // is not released. - try (IRestHighLevelClient client = flintClient.createClient()) { - assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) - .docAsUpsert(true); - client.update(updateRequest, RequestOptions.DEFAULT); - } catch (IOException e) { - throw new RuntimeException(String.format( - "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); - } + updateDocument(id, doc, true, -1, -1); } public void update(String id, String doc) { - try (IRestHighLevelClient client = flintClient.createClient()) { - assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - client.update(updateRequest, RequestOptions.DEFAULT); - } catch (IOException e) { - throw new RuntimeException(String.format( - "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); - } + updateDocument(id, doc, false, -1, -1); } public void updateIf(String id, String doc, long seqNo, long primaryTerm) { + updateDocument(id, doc, false, seqNo, primaryTerm); + } + + /** + * Internal method for updating or upserting a document with optional optimistic concurrency control. + * + * @param id The document ID. + * @param doc The document content in JSON format. + * @param upsert Flag indicating whether to upsert the document. + * @param seqNo The sequence number for optimistic concurrency control. + * @param primaryTerm The primary term for optimistic concurrency control. + */ + private void updateDocument(String id, String doc, boolean upsert, long seqNo, long primaryTerm) { + // we might need to keep the updater for a long time. Reusing the client may not work as the temporary + // credentials may expire. + // also, failure to close the client causes the job to be stuck in the running state as the client resource + // is not released. try (IRestHighLevelClient client = flintClient.createClient()) { assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm); - client.update(updateRequest, RequestOptions.DEFAULT); + UpdateRequest updateRequest = new UpdateRequest(indexName, id) + .doc(doc, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + + if (upsert) { + updateRequest.docAsUpsert(true); + } + if (seqNo >= 0 && primaryTerm >= 0) { + updateRequest.setIfSeqNo(seqNo).setIfPrimaryTerm(primaryTerm); + } + + try { + client.update(updateRequest, RequestOptions.DEFAULT); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_WRITE_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_WRITE_METRIC_PREFIX, e); + } } catch (IOException e) { throw new RuntimeException(String.format( "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); + indexName, id), e); } } private void assertIndexExist(IRestHighLevelClient client, String indexName) throws IOException { - LOG.info("Checking if index exists " + indexName); - if (!client.isIndexExists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { - String errorMsg = "Index not found " + indexName; + LOG.info("Checking if index exists: " + indexName); + boolean exists; + try { + exists = client.doesIndexExist(new GetIndexRequest(indexName), RequestOptions.DEFAULT); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_READ_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_READ_METRIC_PREFIX, e); + throw e; + } + + if (!exists) { + String errorMsg = "Index not found: " + indexName; LOG.log(Level.SEVERE, errorMsg); throw new IllegalStateException(errorMsg); } } } + diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index 09dee5eef..3b8940536 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -1,6 +1,7 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; @@ -10,6 +11,7 @@ import org.mockito.Mockito; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -42,7 +44,7 @@ public void testIncrementDecrementCounter() { // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(4)).metricRegistry(); + verify(flintMetricSource, times(3)).metricRegistry(); Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric); Assertions.assertNotNull(counter); Assertions.assertEquals(counter.getCount(), 1); @@ -69,7 +71,7 @@ public void testStartStopTimer() { // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(2)).metricRegistry(); + verify(flintMetricSource, times(1)).metricRegistry(); Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric); Assertions.assertNotNull(timer); Assertions.assertEquals(timer.getCount(), 1L); @@ -78,4 +80,35 @@ public void testStartStopTimer() { throw new RuntimeException(e); } } + + @Test + public void testRegisterGaugeWhenMetricRegistryIsAvailable() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Setup gauge + AtomicInteger testValue = new AtomicInteger(1); + String gaugeName = "test.gauge"; + MetricsUtil.registerGauge(gaugeName, testValue); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(1)).metricRegistry(); + + Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + Assertions.assertEquals(gauge.getValue(), 1); + + testValue.incrementAndGet(); + testValue.incrementAndGet(); + testValue.decrementAndGet(); + Assertions.assertEquals(gauge.getValue(), 2); + } + } } \ No newline at end of file diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java index 94760fc37..2178c3c22 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java @@ -5,12 +5,21 @@ package org.opensearch.flint.core.metrics.reporter; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.metrics.source.FlintMetricSource; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; import com.amazonaws.services.cloudwatch.model.Dimension; import org.junit.jupiter.api.function.Executable; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import java.lang.reflect.Field; import java.util.Map; @@ -71,6 +80,23 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID"); writeableEnvironmentVariables.remove("TEST_VAR"); } + } + + @Test + public void testConstructDimensionFromSparkConfWithAvailableConfig() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + SparkConf sparkConf = new SparkConf().set("spark.test.key", "testValue"); + when(sparkEnv.get().conf()).thenReturn(sparkConf); + Dimension result = DimensionUtils.constructDimensionFromSparkConf("testDimension", "spark.test.key", "defaultValue"); + // Assertions + assertEquals("testDimension", result.getName()); + assertEquals("testValue", result.getValue()); + + // Reset SparkEnv mock to not affect other tests + Mockito.reset(SparkEnv.get()); + } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 5c4c7376c..2f44a28f4 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -12,6 +12,7 @@ import scala.sys.addShutdownHook import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -65,6 +66,7 @@ class FlintSparkIndexMonitor( } catch { case e: Throwable => logError("Failed to update index log entry", e) + MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC) } }, 15, // Delay to ensure final logging is complete first, otherwise version conflicts diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index f070ef3ab..86bf567f5 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -5,6 +5,8 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.{Duration, MINUTES} @@ -67,10 +69,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { val prefix = "flint-job-test" val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val streamingRunningCount = new AtomicInteger(0) val futureResult = Future { val job = - JobOperator(spark, query, dataSourceName, resultIndex, true) + JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) job.envinromentProvider = new MockEnvironment( Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index df0bf5c4e..32747e20f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -7,6 +7,7 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata @@ -14,11 +15,14 @@ import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.XContentType import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} /** @@ -41,7 +45,10 @@ object FlintJob extends Logging with FlintJobExecutor { val Array(query, resultIndex) = args val conf = createSparkConf() - val wait = conf.get("spark.flint.job.type", "continue") + val jobType = conf.get("spark.flint.job.type", "batch") + logInfo(s"""Job type is: ${jobType}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + val dataSource = conf.get("spark.flint.datasource.name", "") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* @@ -52,13 +59,16 @@ object FlintJob extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) + val streamingRunningCount = new AtomicInteger(0) val jobOperator = JobOperator( createSparkSession(conf), query, dataSource, resultIndex, - wait.equalsIgnoreCase("streaming")) + jobType.equalsIgnoreCase("streaming"), + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 4aeb0db17..1814a8d8e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -11,7 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception -import org.opensearch.flint.core.FlintClient +import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient} import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter @@ -19,7 +19,6 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} import org.apache.spark.sql.FlintREPL.envinromentProvider import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource @@ -100,10 +99,19 @@ trait FlintJobExecutor { } private def writeData(resultData: DataFrame, resultIndex: String): Unit = { - resultData.write - .format("flint") - .mode("append") - .save(resultIndex) + try { + resultData.write + .format("flint") + .mode("append") + .save(resultIndex) + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX) + } catch { + case e: Exception => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX, + e) + } } /** @@ -123,7 +131,7 @@ trait FlintJobExecutor { if (osClient.doesIndexExist(resultIndex)) { writeData(resultData, resultIndex) } else { - createIndex(osClient, resultIndex, resultIndexMapping) + createResultIndex(osClient, resultIndex, resultIndexMapping) writeData(resultData, resultIndex) } } @@ -321,7 +329,7 @@ trait FlintJobExecutor { case e: IllegalStateException if e.getCause != null && e.getCause.getMessage.contains("index_not_found_exception") => - createIndex(osClient, resultIndex, resultIndexMapping) + createResultIndex(osClient, resultIndex, resultIndexMapping) case e: InterruptedException => val error = s"Interrupted by the main thread: ${e.getMessage}" Thread.currentThread().interrupt() // Preserve the interrupt status @@ -334,7 +342,7 @@ trait FlintJobExecutor { } } - def createIndex( + def createResultIndex( osClient: OSClient, resultIndex: String, mapping: String): Either[String, Unit] = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 093ce1932..d30669cca 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -7,6 +7,7 @@ package org.apache.spark.sql import java.net.ConnectException import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration._ @@ -21,7 +22,7 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.metrics.MetricConstants -import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, stopTimer} +import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder @@ -63,6 +64,9 @@ object FlintREPL extends Logging with FlintJobExecutor { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } + private val sessionRunningCount = new AtomicInteger(0) + private val statementRunningCount = new AtomicInteger(0) + def main(args: Array[String]) { val Array(query, resultIndex) = args if (Strings.isNullOrEmpty(resultIndex)) { @@ -81,12 +85,23 @@ object FlintREPL extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val wait = conf.get(FlintSparkConf.JOB_TYPE.key, "continue") - if (wait.equalsIgnoreCase("streaming")) { + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) + logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + if (jobType.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") + val streamingRunningCount = new AtomicInteger(0) val jobOperator = - JobOperator(createSparkSession(conf), query, dataSource, resultIndex, true) + JobOperator( + createSparkSession(conf), + query, + dataSource, + resultIndex, + true, + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. @@ -133,6 +148,8 @@ object FlintREPL extends Logging with FlintJobExecutor { val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) + registerGauge(MetricConstants.REPL_RUNNING_METRIC, sessionRunningCount) + registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) val jobStartTime = currentTimeProvider.currentEpochMillis() // update heart beat every 30 seconds // OpenSearch triggers recovery after 1 minute outdated heart beat @@ -386,9 +403,9 @@ object FlintREPL extends Logging with FlintJobExecutor { FlintInstance.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) - incrementCounter(MetricConstants.REPL_RUNNING_METRIC) logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + sessionRunningCount.incrementAndGet() } def handleSessionError( @@ -400,7 +417,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, - flintSessionContext: Timer.Context): Unit = { + sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" logError(error, e) @@ -409,7 +426,7 @@ object FlintREPL extends Logging with FlintJobExecutor { updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) if (flintInstance.state.equals("fail")) { - recordSessionFailed(flintSessionContext) + recordSessionFailed(sessionTimerContext) } } @@ -588,7 +605,7 @@ object FlintREPL extends Logging with FlintJobExecutor { resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, - statementContext: Timer.Context): Unit = { + statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) if (flintCommand.isRunning() || flintCommand.isWaiting()) { @@ -596,7 +613,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.complete() } updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementContext) + recordStatementStateChange(flintCommand, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -605,7 +622,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logError(error, e) flintCommand.fail() updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementContext) + recordStatementStateChange(flintCommand, statementTimerContext) } } @@ -801,7 +818,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.running() logDebug(s"command running: $flintCommand") updateSessionIndex(flintCommand, flintSessionIndexUpdater) - incrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) + statementRunningCount.incrementAndGet() flintCommand } @@ -855,7 +872,7 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient, sessionIndex: String, sessionId: String, - flintSessionContext: Timer.Context, + sessionTimerContext: Timer.Context, shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { shutdownHookManager.addShutdownHook(() => { @@ -885,7 +902,7 @@ object FlintREPL extends Logging with FlintJobExecutor { getResponse, flintSessionIndexUpdater, sessionId, - flintSessionContext) + sessionTimerContext) } }) } @@ -895,7 +912,7 @@ object FlintREPL extends Logging with FlintJobExecutor { getResponse: GetResponse, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String, - flintSessionContext: Timer.Context): Unit = { + sessionTimerContext: Timer.Context): Unit = { val flintInstance = FlintInstance.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( @@ -905,7 +922,7 @@ object FlintREPL extends Logging with FlintJobExecutor { currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) - recordSessionSuccess(flintSessionContext) + recordSessionSuccess(sessionTimerContext) } /** @@ -953,11 +970,17 @@ object FlintREPL extends Logging with FlintJobExecutor { // Preserve the interrupt status Thread.currentThread().interrupt() logError("HeartBeatUpdater task was interrupted", ie) + incrementCounter( + MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC + ) // Record heartbeat failure metric // maybe due to invalid sequence number or primary term case e: Exception => logWarning( s"""Fail to update the last update time of the flint instance ${sessionId}""", e) + incrementCounter( + MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC + ) // Record heartbeat failure metric } } }, @@ -1068,23 +1091,29 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } - private def recordSessionSuccess(sessionContext: Timer.Context): Unit = { - stopTimer(sessionContext) - decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } incrementCounter(MetricConstants.REPL_SUCCESS_METRIC) } - private def recordSessionFailed(sessionContext: Timer.Context): Unit = { - stopTimer(sessionContext) - decrementCounter(MetricConstants.REPL_RUNNING_METRIC) + private def recordSessionFailed(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } incrementCounter(MetricConstants.REPL_FAILED_METRIC) } private def recordStatementStateChange( flintCommand: FlintCommand, - statementContext: Timer.Context): Unit = { - stopTimer(statementContext) - decrementCounter(MetricConstants.STATEMENT_RUNNING_METRIC) + statementTimerContext: Timer.Context): Unit = { + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } if (flintCommand.isComplete()) { incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) } else if (flintCommand.isFailed()) { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index a2edbe98e..bbaceb15d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -6,11 +6,14 @@ package org.apache.spark.sql import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success, Try} +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.core.storage.OpenSearchUpdater import org.apache.spark.SparkConf @@ -25,11 +28,12 @@ case class JobOperator( query: String, dataSource: String, resultIndex: String, - streaming: Boolean) + streaming: Boolean, + streamingRunningCount: AtomicInteger) extends Logging with FlintJobExecutor { - // jvm shutdown hook + // JVM shutdown hook sys.addShutdownHook(stop()) def start(): Unit = { @@ -37,7 +41,10 @@ case class JobOperator( implicit val executionContext = ExecutionContext.fromExecutor(threadPool) var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + streamingRunningCount.incrementAndGet() + // osClient needs spark session to be created first to get FlintOptions initialized. // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -98,6 +105,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } + recordStreamingCompletionStatus(exceptionThrown) } def stop(): Unit = { @@ -109,4 +117,24 @@ case class JobOperator( case Failure(e) => logError("unexpected error while stopping spark session", e) } } + + /** + * Records the completion of a streaming job by updating the appropriate metrics. This method + * decrements the running metric for streaming jobs and increments either the success or failure + * metric based on whether an exception was thrown. + * + * @param exceptionThrown + * Indicates whether an exception was thrown during the streaming job execution. + */ + private def recordStreamingCompletionStatus(exceptionThrown: Boolean): Unit = { + // Decrement the metric for running streaming jobs as the job is now completing. + if (streamingRunningCount.get() > 0) { + streamingRunningCount.decrementAndGet() + } + + exceptionThrown match { + case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) + case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + } + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index cd784e704..f5e4ec2be 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -9,6 +9,8 @@ import java.io.IOException import java.util.ArrayList import java.util.Locale +import scala.util.{Failure, Success, Try} + import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.action.search.{SearchRequest, SearchResponse} import org.opensearch.client.{RequestOptions, RestHighLevelClient} @@ -18,7 +20,8 @@ import org.opensearch.common.Strings import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.{NamedXContentRegistry, XContentParser, XContentType} import org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions, IRestHighLevelClient} +import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.storage.{FlintReader, OpenSearchQueryReader, OpenSearchScrollReader, OpenSearchUpdater} import org.opensearch.index.query.{AbstractQueryBuilder, MatchAllQueryBuilder, QueryBuilder} import org.opensearch.plugins.SearchPlugin @@ -74,8 +77,13 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { try { client.createIndex(request, RequestOptions.DEFAULT) logInfo(s"create $osIndexName successfully") + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX) } catch { case e: Exception => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX, + e) throw new IllegalStateException(s"Failed to create index $osIndexName", e); } } @@ -109,11 +117,17 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { def getDoc(osIndexName: String, id: String): GetResponse = { using(flintClient.createClient()) { client => - try { - val request = new GetRequest(osIndexName, id) - client.get(request, RequestOptions.DEFAULT) - } catch { - case e: Exception => + val request = new GetRequest(osIndexName, id) + val result = Try(client.get(request, RequestOptions.DEFAULT)) + result match { + case Success(response) => + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX) + response + case Failure(e: Exception) => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX, + e) throw new IllegalStateException( String.format( Locale.ROOT, @@ -146,7 +160,7 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { using(flintClient.createClient()) { client => try { val request = new GetIndexRequest(indexName) - client.isIndexExists(request, RequestOptions.DEFAULT) + client.doesIndexExist(request, RequestOptions.DEFAULT) } catch { case e: Exception => throw new IllegalStateException(s"Failed to check if index $indexName exists", e)