diff --git a/.gitignore b/.gitignore index 34602292e..bc1705ce3 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ logs/ *.class *.log *.zip + +gen/ diff --git a/README.md b/README.md index b2709e348..83189dcb1 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Version compatibility: |---------------|-------------|---------------|---------------|------------| | 0.1.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ | | 0.2.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ | +| 0.3.0 | 11+ | 3.3.2 | 2.12.14 | 2.6+ | ## Flint Extension Usage @@ -50,7 +51,7 @@ sbt clean standaloneCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark_2.12:0.2.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark_2.12:0.3.0-SNAPSHOT" ``` ### PPL Build & Run @@ -62,7 +63,7 @@ sbt clean sparkPPLCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.2.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.3.0-SNAPSHOT" ``` ## Code of Conduct diff --git a/build.sbt b/build.sbt index 48e4bca5b..95324fc99 100644 --- a/build.sbt +++ b/build.sbt @@ -10,7 +10,7 @@ lazy val opensearchVersion = "2.6.0" ThisBuild / organization := "org.opensearch" -ThisBuild / version := "0.2.0-SNAPSHOT" +ThisBuild / version := "0.3.0-SNAPSHOT" ThisBuild / scalaVersion := scala212 @@ -42,8 +42,9 @@ lazy val commonSettings = Seq( testScalastyle := (Test / scalastyle).toTask("").value, Test / test := ((Test / test) dependsOn testScalastyle).value) +// running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication) + .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -159,7 +160,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test" ) + .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", @@ -175,7 +176,9 @@ lazy val integtest = (project in file("integ-test")) "org.opensearch.client" % "opensearch-java" % "2.6.0" % "test" exclude ("com.fasterxml.jackson.core", "jackson-databind")), libraryDependencies ++= deps(sparkVersion), - Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value)) + Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value, + (sparkSqlApplication / assembly).value + )) lazy val standaloneCosmetic = project .settings( diff --git a/docs/PPL-on-Spark.md b/docs/PPL-on-Spark.md index f421d5679..a7e539d23 100644 --- a/docs/PPL-on-Spark.md +++ b/docs/PPL-on-Spark.md @@ -34,7 +34,7 @@ sbt clean sparkPPLCosmetic/publishM2 ``` then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, ``` -bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.2.0-SNAPSHOT" +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.3.0-SNAPSHOT" ``` ### PPL Extension Usage @@ -46,7 +46,7 @@ spark-sql --conf "spark.sql.extensions=org.opensearch.flint.FlintPPLSparkExtensi ``` ### Running With both Flint & PPL Extensions -In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.2.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.2.0-SNAPSHOT`) to the cluster's +In order to make use of both flint and ppl extension, one can simply add both jars (`org.opensearch:opensearch-spark-ppl_2.12:0.3.0-SNAPSHOT`,`org.opensearch:opensearch-spark_2.12:0.3.0-SNAPSHOT`) to the cluster's classpath. Next need to configure both extensions : diff --git a/docs/index.md b/docs/index.md index 756e4bba4..31147aed4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,7 +54,7 @@ Currently, Flint metadata is only static configuration without version control a ```json { - "version": "0.2.0", + "version": "0.3.0", "name": "...", "kind": "skipping", "source": "...", @@ -570,7 +570,7 @@ For now, only single or conjunct conditions (conditions connected by AND) in WHE ### AWS EMR Spark Integration - Using execution role Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/DefaultAWSCredentialsProviderChain.html). When running in EMR Spark, Flint use executionRole credentials ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.2.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.3.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ @@ -612,7 +612,7 @@ Flint use [DefaultAWSCredentialsProviderChain](https://docs.aws.amazon.com/AWSJa ``` 3. Set the spark.datasource.flint.customAWSCredentialsProvider property with value as com.amazonaws.emr.AssumeRoleAWSCredentialsProvider. Set the environment variable ASSUME_ROLE_CREDENTIALS_ROLE_ARN with the ARN value of CrossAccountRoleB. ``` ---conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.2.0-SNAPSHOT \ +--conf spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.3.0-SNAPSHOT \ --conf spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots \ --conf spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64 \ diff --git a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java index ce35c34e8..23205fe99 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core; +import org.opensearch.OpenSearchException; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteRequest; @@ -26,6 +27,7 @@ import org.opensearch.client.indices.GetIndexResponse; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.client.RequestOptions; +import org.opensearch.flint.core.metrics.MetricsUtil; import java.io.Closeable; import java.io.IOException; @@ -52,11 +54,62 @@ public interface IRestHighLevelClient extends Closeable { IndexResponse index(IndexRequest indexRequest, RequestOptions options) throws IOException; - Boolean isIndexExists(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException; + Boolean doesIndexExist(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException; SearchResponse search(SearchRequest searchRequest, RequestOptions options) throws IOException; SearchResponse scroll(SearchScrollRequest searchScrollRequest, RequestOptions options) throws IOException; DocWriteResponse update(UpdateRequest updateRequest, RequestOptions options) throws IOException; + + + /** + * Records the success of an OpenSearch operation by incrementing the corresponding metric counter. + * This method constructs the metric name by appending ".200.count" to the provided metric name prefix. + * The metric name is then used to increment the counter, indicating a successful operation. + * + * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for success + */ + static void recordOperationSuccess(String metricNamePrefix) { + String successMetricName = metricNamePrefix + ".2xx.count"; + MetricsUtil.incrementCounter(successMetricName); + } + + /** + * Records the failure of an OpenSearch operation by incrementing the corresponding metric counter. + * If the exception is an OpenSearchException with a specific status code (e.g., 403), + * it increments a metric specifically for that status code. + * Otherwise, it increments a general failure metric counter based on the status code category (e.g., 4xx, 5xx). + * + * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for failure + * @param e the exception encountered during the operation, used to determine the type of failure + */ + static void recordOperationFailure(String metricNamePrefix, Exception e) { + OpenSearchException openSearchException = extractOpenSearchException(e); + int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500; + + if (statusCode == 403) { + String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; + MetricsUtil.incrementCounter(forbiddenErrorMetricName); + } + + String failureMetricName = metricNamePrefix + "." + (statusCode / 100) + "xx.count"; + MetricsUtil.incrementCounter(failureMetricName); + } + + /** + * Extracts an OpenSearchException from the given Throwable. + * Checks if the Throwable is an instance of OpenSearchException or caused by one. + * + * @param ex the exception to be checked + * @return the extracted OpenSearchException, or null if not found + */ + private static OpenSearchException extractOpenSearchException(Throwable ex) { + if (ex instanceof OpenSearchException) { + return (OpenSearchException) ex; + } else if (ex.getCause() instanceof OpenSearchException) { + return (OpenSearchException) ex.getCause(); + } + return null; + } } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java index 3556c7e24..bf48af52d 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java @@ -20,7 +20,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.OpenSearchException; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; @@ -28,7 +27,6 @@ import org.opensearch.client.indices.CreateIndexResponse; import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.client.indices.GetIndexResponse; -import org.opensearch.flint.core.metrics.MetricsUtil; import java.io.IOException; @@ -91,7 +89,7 @@ public IndexResponse index(IndexRequest indexRequest, RequestOptions options) th } @Override - public Boolean isIndexExists(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { + public Boolean doesIndexExist(GetIndexRequest getIndexRequest, RequestOptions options) throws IOException { return execute(OS_READ_OP_METRIC_PREFIX, () -> client.indices().exists(getIndexRequest, options)); } @@ -122,64 +120,14 @@ public UpdateResponse update(UpdateRequest updateRequest, RequestOptions options private T execute(String metricNamePrefix, IOCallable operation) throws IOException { try { T result = operation.call(); - recordOperationSuccess(metricNamePrefix); + IRestHighLevelClient.recordOperationSuccess(metricNamePrefix); return result; } catch (Exception e) { - recordOperationFailure(metricNamePrefix, e); + IRestHighLevelClient.recordOperationFailure(metricNamePrefix, e); throw e; } } - /** - * Records the success of an OpenSearch operation by incrementing the corresponding metric counter. - * This method constructs the metric name by appending ".200.count" to the provided metric name prefix. - * The metric name is then used to increment the counter, indicating a successful operation. - * - * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for success - */ - private void recordOperationSuccess(String metricNamePrefix) { - String successMetricName = metricNamePrefix + ".2xx.count"; - MetricsUtil.incrementCounter(successMetricName); - } - - /** - * Records the failure of an OpenSearch operation by incrementing the corresponding metric counter. - * If the exception is an OpenSearchException with a specific status code (e.g., 403), - * it increments a metric specifically for that status code. - * Otherwise, it increments a general failure metric counter based on the status code category (e.g., 4xx, 5xx). - * - * @param metricNamePrefix the prefix for the metric name which is used to construct the full metric name for failure - * @param e the exception encountered during the operation, used to determine the type of failure - */ - private void recordOperationFailure(String metricNamePrefix, Exception e) { - OpenSearchException openSearchException = extractOpenSearchException(e); - int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500; - - if (statusCode == 403) { - String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; - MetricsUtil.incrementCounter(forbiddenErrorMetricName); - } - - String failureMetricName = metricNamePrefix + "." + (statusCode / 100) + "xx.count"; - MetricsUtil.incrementCounter(failureMetricName); - } - - /** - * Extracts an OpenSearchException from the given Throwable. - * Checks if the Throwable is an instance of OpenSearchException or caused by one. - * - * @param ex the exception to be checked - * @return the extracted OpenSearchException, or null if not found - */ - private OpenSearchException extractOpenSearchException(Throwable ex) { - if (ex instanceof OpenSearchException) { - return (OpenSearchException) ex; - } else if (ex.getCause() instanceof OpenSearchException) { - return (OpenSearchException) ex.getCause(); - } - return null; - } - /** * Functional interface for operations that can throw IOException. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index d34a3705d..6a081a740 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -8,7 +8,7 @@ /** * This class defines custom metric constants used for monitoring flint operations. */ -public class MetricConstants { +public final class MetricConstants { /** * The prefix for all read-related metrics in OpenSearch. @@ -21,4 +21,93 @@ public class MetricConstants { * Similar to OS_READ_METRIC_PREFIX, this constant is used for categorizing and identifying metrics that pertain to write operations. */ public static final String OS_WRITE_OP_METRIC_PREFIX = "opensearch.write"; + + /** + * Metric name for counting the errors encountered with Amazon S3 operations. + */ + public static final String S3_ERR_CNT_METRIC = "s3.error.count"; + + /** + * Metric name for counting the number of sessions currently running. + */ + public static final String REPL_RUNNING_METRIC = "session.running.count"; + + /** + * Metric name for counting the number of sessions that have failed. + */ + public static final String REPL_FAILED_METRIC = "session.failed.count"; + + /** + * Metric name for counting the number of sessions that have successfully completed. + */ + public static final String REPL_SUCCESS_METRIC = "session.success.count"; + + /** + * Metric name for tracking the processing time of sessions. + */ + public static final String REPL_PROCESSING_TIME_METRIC = "session.processingTime"; + + /** + * Prefix for metrics related to the request metadata read operations. + */ + public static final String REQUEST_METADATA_READ_METRIC_PREFIX = "request.metadata.read"; + + /** + * Prefix for metrics related to the request metadata write operations. + */ + public static final String REQUEST_METADATA_WRITE_METRIC_PREFIX = "request.metadata.write"; + + /** + * Metric name for counting failed heartbeat operations on request metadata. + */ + public static final String REQUEST_METADATA_HEARTBEAT_FAILED_METRIC = "request.metadata.heartbeat.failed.count"; + + /** + * Prefix for metrics related to the result metadata write operations. + */ + public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write"; + + /** + * Metric name for counting the number of statements currently running. + */ + public static final String STATEMENT_RUNNING_METRIC = "statement.running.count"; + + /** + * Metric name for counting the number of statements that have failed. + */ + public static final String STATEMENT_FAILED_METRIC = "statement.failed.count"; + + /** + * Metric name for counting the number of statements that have successfully completed. + */ + public static final String STATEMENT_SUCCESS_METRIC = "statement.success.count"; + + /** + * Metric name for tracking the processing time of statements. + */ + public static final String STATEMENT_PROCESSING_TIME_METRIC = "statement.processingTime"; + + /** + * Metric for tracking the count of currently running streaming jobs. + */ + public static final String STREAMING_RUNNING_METRIC = "streaming.running.count"; + + /** + * Metric for tracking the count of streaming jobs that have failed. + */ + public static final String STREAMING_FAILED_METRIC = "streaming.failed.count"; + + /** + * Metric for tracking the count of streaming jobs that have completed successfully. + */ + public static final String STREAMING_SUCCESS_METRIC = "streaming.success.count"; + + /** + * Metric for tracking the count of failed heartbeat signals in streaming jobs. + */ + public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + + private MetricConstants() { + // Private constructor to prevent instantiation + } } \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 0edce3e36..8e63992f5 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -6,11 +6,15 @@ package org.opensearch.flint.core.metrics; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.Source; import scala.collection.Seq; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; /** @@ -20,8 +24,8 @@ public final class MetricsUtil { private static final Logger LOG = Logger.getLogger(MetricsUtil.class.getName()); - // Private constructor to prevent instantiation private MetricsUtil() { + // Private constructor to prevent instantiation } /** @@ -38,20 +42,79 @@ public static void incrementCounter(String metricName) { } } + /** + * Decrements the value of the specified metric counter by one, if the counter exists and its current count is greater than zero. + * + * @param metricName The name of the metric counter to be decremented. + */ + public static void decrementCounter(String metricName) { + Counter counter = getOrCreateCounter(metricName); + if (counter != null && counter.getCount() > 0) { + counter.dec(); + } + } + + /** + * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. + * This context can be used to measure the duration of a particular operation or event. + * + * @param metricName The name of the metric timer to retrieve the context for. + * @return A {@link Timer.Context} instance for timing operations, or {@code null} if the timer could not be created or retrieved. + */ + public static Timer.Context getTimerContext(String metricName) { + Timer timer = getOrCreateTimer(metricName); + return timer != null ? timer.time() : null; + } + + /** + * Stops the timer associated with the given {@link Timer.Context}, effectively recording the elapsed time since the timer was started + * and returning the duration. If the context is {@code null}, this method does nothing and returns {@code null}. + * + * @param context The {@link Timer.Context} to stop. May be {@code null}, in which case this method has no effect and returns {@code null}. + * @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}. + */ + public static Long stopTimer(Timer.Context context) { + return context != null ? context.stop() : null; + } + + /** + * Registers a gauge metric with the provided name and value. + * The gauge will reflect the current value of the AtomicInteger provided. + * + * @param metricName The name of the gauge metric to register. + * @param value The AtomicInteger whose current value should be reflected by the gauge. + */ + public static void registerGauge(String metricName, final AtomicInteger value) { + MetricRegistry metricRegistry = getMetricRegistry(); + if (metricRegistry == null) { + LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName); + return; + } + metricRegistry.register(metricName, (Gauge) value::get); + } + // Retrieves or creates a new counter for the given metric name private static Counter getOrCreateCounter(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.counter(metricName) : null; + } + + // Retrieves or creates a new Timer for the given metric name + private static Timer getOrCreateTimer(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(); + return metricRegistry != null ? metricRegistry.timer(metricName) : null; + } + + // Retrieves the MetricRegistry from the current Spark environment. + private static MetricRegistry getMetricRegistry() { SparkEnv sparkEnv = SparkEnv.get(); if (sparkEnv == null) { - LOG.warning("Spark environment not available, cannot instrument metric: " + metricName); + LOG.warning("Spark environment not available, cannot access MetricRegistry."); return null; } FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - Counter counter = flintMetricSource.metricRegistry().getCounters().get(metricName); - if (counter == null) { - counter = flintMetricSource.metricRegistry().counter(metricName); - } - return counter; + return flintMetricSource.metricRegistry(); } // Gets or initializes the FlintMetricSource diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java index ce7136507..6e3e90916 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java @@ -7,9 +7,13 @@ import java.util.Map; import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + import org.apache.commons.lang.StringUtils; import com.amazonaws.services.cloudwatch.model.Dimension; +import org.apache.spark.SparkEnv; /** * Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint. @@ -18,7 +22,9 @@ * application ID, and more. */ public class DimensionUtils { + private static final Logger LOG = Logger.getLogger(DimensionUtils.class.getName()); private static final String DIMENSION_JOB_ID = "jobId"; + private static final String DIMENSION_JOB_TYPE = "jobType"; private static final String DIMENSION_APPLICATION_ID = "applicationId"; private static final String DIMENSION_APPLICATION_NAME = "applicationName"; private static final String DIMENSION_DOMAIN_ID = "domainId"; @@ -29,6 +35,8 @@ public class DimensionUtils { private static final Map> dimensionBuilders = Map.of( DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension, DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID), + // TODO: Move FlintSparkConf into the core to prevent circular dependencies + DIMENSION_JOB_TYPE, ignored -> constructDimensionFromSparkConf(DIMENSION_JOB_TYPE, "spark.flint.job.type", UNKNOWN), DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID), DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME), DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID) @@ -39,7 +47,7 @@ public class DimensionUtils { * builder exists for the dimension name, it is used; otherwise, a default dimension is constructed. * * @param dimensionName The name of the dimension to construct. - * @param parts Additional information that might be required by specific dimension builders. + * @param metricNameParts Additional information that might be required by specific dimension builders. * @return A CloudWatch Dimension object. */ public static Dimension constructDimension(String dimensionName, String[] metricNameParts) { @@ -50,6 +58,30 @@ public static Dimension constructDimension(String dimensionName, String[] metric .apply(metricNameParts); } + /** + * Constructs a CloudWatch Dimension object using a specified Spark configuration key. + * + * @param dimensionName The name of the dimension to construct. + * @param sparkConfKey the Spark configuration key used to look up the value for the dimension. + * @param defaultValue the default value to use for the dimension if the Spark configuration key is not found or if the Spark environment is not available. + * @return A CloudWatch Dimension object. + * @throws Exception if an error occurs while accessing the Spark configuration. The exception is logged and then rethrown. + */ + public static Dimension constructDimensionFromSparkConf(String dimensionName, String sparkConfKey, String defaultValue) { + String propertyValue = defaultValue; + try { + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + propertyValue = SparkEnv.get().conf().get(sparkConfKey, defaultValue); + } else { + LOG.warning("Spark environment or configuration is not available, defaulting to provided default value."); + } + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error accessing Spark configuration with key: " + sparkConfKey + ", defaulting to provided default value.", e); + throw e; + } + return new Dimension().withName(dimensionName).withValue(propertyValue); + } + // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137 // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName. public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) { diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index c1c5491ed..410d896d2 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -77,6 +77,8 @@ public class FlintOptions implements Serializable { public static final int DEFAULT_SOCKET_TIMEOUT_MILLIS = 60000; + public static final int DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000; + public FlintOptions(Map options) { this.options = options; this.retryOptions = new FlintRetryOptions(options); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala b/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala index 2ff83d4b4..d226df24d 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintVersion.scala @@ -16,6 +16,7 @@ case class FlintVersion(version: String) { object FlintVersion { val V_0_1_0: FlintVersion = FlintVersion("0.1.0") val V_0_2_0: FlintVersion = FlintVersion("0.2.0") + val V_0_3_0: FlintVersion = FlintVersion("0.3.0") - def current(): FlintVersion = V_0_2_0 + def current(): FlintVersion = V_0_3_0 } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 29ebad206..45aedbaa6 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -99,7 +99,7 @@ public OptimisticTransaction startTransaction(String indexName, String da String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; try (IRestHighLevelClient client = createClient()) { - if (client.isIndexExists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { + if (client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { LOG.info("Found metadata log index " + metaLogIndexName); } else { if (forceInit) { @@ -149,7 +149,7 @@ public boolean exists(String indexName) { LOG.info("Checking if Flint index exists " + indexName); String osIndexName = sanitizeIndexName(indexName); try (IRestHighLevelClient client = createClient()) { - return client.isIndexExists(new GetIndexRequest(osIndexName), RequestOptions.DEFAULT); + return client.doesIndexExist(new GetIndexRequest(osIndexName), RequestOptions.DEFAULT); } catch (IOException e) { throw new IllegalStateException("Failed to check if Flint index exists " + osIndexName, e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 9c1502b29..7195ae177 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -175,7 +175,7 @@ private FlintMetadataLogEntry writeLogEntry( private boolean exists() { LOG.info("Checking if Flint index exists " + metaLogIndexName); try (IRestHighLevelClient client = flintClient.createClient()) { - return client.isIndexExists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); + return client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); } catch (IOException e) { throw new IllegalStateException("Failed to check if Flint index exists " + metaLogIndexName, e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java new file mode 100644 index 000000000..19ce6ce8b --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.ClearScrollRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.Strings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX; + +/** + * {@link OpenSearchReader} using search. https://opensearch.org/docs/latest/api-reference/search/ + */ +public class OpenSearchQueryReader extends OpenSearchReader { + + private static final Logger LOG = Logger.getLogger(OpenSearchQueryReader.class.getName()); + + public OpenSearchQueryReader(IRestHighLevelClient client, String indexName, SearchSourceBuilder searchSourceBuilder) { + super(client, new SearchRequest().indices(indexName).source(searchSourceBuilder)); + } + + /** + * search. + */ + Optional search(SearchRequest request) { + Optional response = Optional.empty(); + try { + response = Optional.of(client.search(request, RequestOptions.DEFAULT)); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_READ_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_READ_METRIC_PREFIX, e); + } + return response; + } + + /** + * nothing to clean + */ + void clean() throws IOException {} +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index c70d327fe..e2e831bd0 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core.storage; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.flint.core.IRestHighLevelClient; @@ -48,6 +49,13 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest iterator = searchHits.iterator(); } return iterator.hasNext(); + } catch (OpenSearchStatusException e) { + // e.g., org.opensearch.OpenSearchStatusException: OpenSearch exception [type=index_not_found_exception, reason=no such index [query_results2]] + if (e.getMessage() != null && (e.getMessage().contains("index_not_found_exception"))) { + return false; + } else { + throw e; + } } catch (IOException e) { // todo. log error. throw new RuntimeException(e); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index ffe771b15..0d84b4956 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -12,80 +12,93 @@ import java.util.logging.Level; import java.util.logging.Logger; +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX; +import static org.opensearch.flint.core.metrics.MetricConstants.REQUEST_METADATA_WRITE_METRIC_PREFIX; + +/** + * Provides functionality for updating and upserting documents in an OpenSearch index. + * This class utilizes FlintClient for managing connections to OpenSearch and performs + * document updates and upserts with optional optimistic concurrency control. + */ public class OpenSearchUpdater { private static final Logger LOG = Logger.getLogger(OpenSearchUpdater.class.getName()); private final String indexName; - private final FlintClient flintClient; - public OpenSearchUpdater(String indexName, FlintClient flintClient) { this.indexName = indexName; this.flintClient = flintClient; } public void upsert(String id, String doc) { - // we might need to keep the updater for a long time. Reusing the client may not work as the temporary - // credentials may expire. - // also, failure to close the client causes the job to be stuck in the running state as the client resource - // is not released. - try (IRestHighLevelClient client = flintClient.createClient()) { - assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) - .docAsUpsert(true); - client.update(updateRequest, RequestOptions.DEFAULT); - } catch (IOException e) { - throw new RuntimeException(String.format( - "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); - } + updateDocument(id, doc, true, -1, -1); } public void update(String id, String doc) { - try (IRestHighLevelClient client = flintClient.createClient()) { - assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - client.update(updateRequest, RequestOptions.DEFAULT); - } catch (IOException e) { - throw new RuntimeException(String.format( - "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); - } + updateDocument(id, doc, false, -1, -1); } public void updateIf(String id, String doc, long seqNo, long primaryTerm) { + updateDocument(id, doc, false, seqNo, primaryTerm); + } + + /** + * Internal method for updating or upserting a document with optional optimistic concurrency control. + * + * @param id The document ID. + * @param doc The document content in JSON format. + * @param upsert Flag indicating whether to upsert the document. + * @param seqNo The sequence number for optimistic concurrency control. + * @param primaryTerm The primary term for optimistic concurrency control. + */ + private void updateDocument(String id, String doc, boolean upsert, long seqNo, long primaryTerm) { + // we might need to keep the updater for a long time. Reusing the client may not work as the temporary + // credentials may expire. + // also, failure to close the client causes the job to be stuck in the running state as the client resource + // is not released. try (IRestHighLevelClient client = flintClient.createClient()) { assertIndexExist(client, indexName); - UpdateRequest - updateRequest = - new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm); - client.update(updateRequest, RequestOptions.DEFAULT); + UpdateRequest updateRequest = new UpdateRequest(indexName, id) + .doc(doc, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + + if (upsert) { + updateRequest.docAsUpsert(true); + } + if (seqNo >= 0 && primaryTerm >= 0) { + updateRequest.setIfSeqNo(seqNo).setIfPrimaryTerm(primaryTerm); + } + + try { + client.update(updateRequest, RequestOptions.DEFAULT); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_WRITE_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_WRITE_METRIC_PREFIX, e); + } } catch (IOException e) { throw new RuntimeException(String.format( "Failed to execute update request on index: %s, id: %s", - indexName, - id), e); + indexName, id), e); } } private void assertIndexExist(IRestHighLevelClient client, String indexName) throws IOException { - LOG.info("Checking if index exists " + indexName); - if (!client.isIndexExists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { - String errorMsg = "Index not found " + indexName; + LOG.info("Checking if index exists: " + indexName); + boolean exists; + try { + exists = client.doesIndexExist(new GetIndexRequest(indexName), RequestOptions.DEFAULT); + IRestHighLevelClient.recordOperationSuccess(REQUEST_METADATA_READ_METRIC_PREFIX); + } catch (Exception e) { + IRestHighLevelClient.recordOperationFailure(REQUEST_METADATA_READ_METRIC_PREFIX, e); + throw e; + } + + if (!exists) { + String errorMsg = "Index not found: " + indexName; LOG.log(Level.SEVERE, errorMsg); throw new IllegalStateException(errorMsg); } } } + diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index 8e646f446..3b8940536 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -1,5 +1,8 @@ package org.opensearch.flint.core.metrics; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.junit.Test; @@ -7,6 +10,10 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; @@ -18,7 +25,34 @@ public class MetricsUtilTest { @Test - public void incOpenSearchMetric() { + public void testIncrementDecrementCounter() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Test the methods + String testMetric = "testPrefix.2xx.count"; + MetricsUtil.incrementCounter(testMetric); + MetricsUtil.incrementCounter(testMetric); + MetricsUtil.decrementCounter(testMetric); + + // Verify interactions + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(3)).metricRegistry(); + Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric); + Assertions.assertNotNull(counter); + Assertions.assertEquals(counter.getCount(), 1); + } + } + + @Test + public void testStartStopTimer() { try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { // Mock SparkEnv SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); @@ -29,14 +63,52 @@ public void incOpenSearchMetric() { when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) .thenReturn(flintMetricSource); - // Test the method - MetricsUtil.incrementCounter("testPrefix.2xx.count"); + // Test the methods + String testMetric = "testPrefix.processingTime"; + Timer.Context context = MetricsUtil.getTimerContext(testMetric); + TimeUnit.MILLISECONDS.sleep(500); + MetricsUtil.stopTimer(context); // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(2)).metricRegistry(); - Assertions.assertNotNull( - flintMetricSource.metricRegistry().getCounters().get("testPrefix.2xx.count")); + verify(flintMetricSource, times(1)).metricRegistry(); + Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric); + Assertions.assertNotNull(timer); + Assertions.assertEquals(timer.getCount(), 1L); + assertEquals(1.9, timer.getMeanRate(), 0.1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + public void testRegisterGaugeWhenMetricRegistryIsAvailable() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Setup gauge + AtomicInteger testValue = new AtomicInteger(1); + String gaugeName = "test.gauge"; + MetricsUtil.registerGauge(gaugeName, testValue); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(1)).metricRegistry(); + + Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + Assertions.assertEquals(gauge.getValue(), 1); + + testValue.incrementAndGet(); + testValue.incrementAndGet(); + testValue.decrementAndGet(); + Assertions.assertEquals(gauge.getValue(), 2); } } } \ No newline at end of file diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java index 7fab8c346..2178c3c22 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java @@ -5,12 +5,21 @@ package org.opensearch.flint.core.metrics.reporter; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.metrics.source.FlintMetricSource; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; import com.amazonaws.services.cloudwatch.model.Dimension; import org.junit.jupiter.api.function.Executable; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import java.lang.reflect.Field; import java.util.Map; @@ -57,13 +66,37 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega Field field = classOfMap.getDeclaredField("m"); field.setAccessible(true); Map writeableEnvironmentVariables = (Map)field.get(System.getenv()); - writeableEnvironmentVariables.put("TEST_VAR", "dummy1"); - writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2"); - Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts); - assertEquals("TEST_VAR", result1.getName()); - assertEquals("dummy1", result1.getValue()); - Dimension result2 = DimensionUtils.constructDimension("jobId", parts); - assertEquals("jobId", result2.getName()); - assertEquals("dummy2", result2.getValue()); + try { + writeableEnvironmentVariables.put("TEST_VAR", "dummy1"); + writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2"); + Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts); + assertEquals("TEST_VAR", result1.getName()); + assertEquals("dummy1", result1.getValue()); + Dimension result2 = DimensionUtils.constructDimension("jobId", parts); + assertEquals("jobId", result2.getName()); + assertEquals("dummy2", result2.getValue()); + } finally { + // since system environment is shared by other tests. Make sure to remove them before exiting. + writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID"); + writeableEnvironmentVariables.remove("TEST_VAR"); + } + } + + @Test + public void testConstructDimensionFromSparkConfWithAvailableConfig() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + SparkConf sparkConf = new SparkConf().set("spark.test.key", "testValue"); + when(sparkEnv.get().conf()).thenReturn(sparkConf); + + Dimension result = DimensionUtils.constructDimensionFromSparkConf("testDimension", "spark.test.key", "defaultValue"); + // Assertions + assertEquals("testDimension", result.getName()); + assertEquals("testValue", result.getValue()); + + // Reset SparkEnv mock to not affect other tests + Mockito.reset(SparkEnv.get()); + } } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index fd998d46d..359994c56 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -146,7 +146,30 @@ object FlintSparkConf { .datasourceOption() .doc("socket duration in milliseconds") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_SOCKET_TIMEOUT_MILLIS)) - + val DATA_SOURCE_NAME = + FlintConfig(s"spark.flint.datasource.name") + .doc("data source name") + .createOptional() + val JOB_TYPE = + FlintConfig(s"spark.flint.job.type") + .doc("Flint job type. Including interactive and streaming") + .createWithDefault("interactive") + val SESSION_ID = + FlintConfig(s"spark.flint.job.sessionId") + .doc("Flint session id") + .createOptional() + val REQUEST_INDEX = + FlintConfig(s"spark.flint.job.requestIndex") + .doc("Request index") + .createOptional() + val EXCLUDE_JOB_IDS = + FlintConfig(s"spark.flint.deployment.excludeJobs") + .doc("Exclude job ids") + .createOptional() + val REPL_INACTIVITY_TIMEOUT_MILLIS = + FlintConfig(s"spark.flint.job.inactivityLimitMillis") + .doc("inactivity timeout") + .createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS)) } /** @@ -196,11 +219,18 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable CUSTOM_AWS_CREDENTIALS_PROVIDER, USERNAME, PASSWORD, - SOCKET_TIMEOUT_MILLIS) + SOCKET_TIMEOUT_MILLIS, + JOB_TYPE, + REPL_INACTIVITY_TIMEOUT_MILLIS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .toMap - val optionsWithoutDefault = Seq(RETRYABLE_EXCEPTION_CLASS_NAMES) + val optionsWithoutDefault = Seq( + RETRYABLE_EXCEPTION_CLASS_NAMES, + DATA_SOURCE_NAME, + SESSION_ID, + REQUEST_INDEX, + EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { case (_, None) => None diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index 5af70b793..9911a3b6c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -25,7 +25,14 @@ class FlintInstance( val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) {} + val error: Option[String] = None) { + override def toString: String = { + val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") + val errorStr = error.getOrElse("None") + s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" + } +} object FlintInstance { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 5c4c7376c..2f44a28f4 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -12,6 +12,7 @@ import scala.sys.addShutdownHook import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -65,6 +66,7 @@ class FlintSparkIndexMonitor( } catch { case e: Throwable => logError("Failed to update index log entry", e) + MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC) } }, 15, // Delay to ensure final logging is complete first, otherwise version conflicts diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala new file mode 100644 index 000000000..86bf567f5 --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -0,0 +1,263 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success} +import scala.util.control.Breaks.{break, breakable} + +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.spark.FlintSparkSuite +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.must.Matchers.{defined, have} +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} + +import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +class FlintJobITSuite extends FlintSparkSuite with JobTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.skipping_sql_test" + private val testIndex = getSkippingIndexName(testTable) + val resultIndex = "query_results2" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + var osClient: OSClient = _ + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + // initialized after the container is started + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + createPartitionedMultiRowTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + + deleteTestIndex(testIndex) + + waitJobStop(threadLocalFuture.get()) + + threadLocalFuture.remove() + } + + def waitJobStop(future: Future[Unit]): Unit = { + try { + val activeJob = spark.streams.active.find(_.name == testIndex) + if (activeJob.isDefined) { + activeJob.get.stop() + } + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for job to finish") + } + } + + def startJob(query: String, jobRunId: String): Future[Unit] = { + val prefix = "flint-job-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val streamingRunningCount = new AtomicInteger(0) + + val futureResult = Future { + val job = + JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount) + job.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + + job.start() + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + test("create skipping index with auto refresh") { + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + + val activeJob = spark.streams.active.find(_.name == testIndex) + activeJob shouldBe defined + failAfter(streamingTimeout) { + activeJob.get.processAllAvailable() + } + val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex) + flint.describeIndex(testIndex) shouldBe defined + indexData.count() shouldBe 2 + } + + test("create skipping index with non-existent table") { + val query = + s""" + | CREATE SKIPPING INDEX ON testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080r" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "FAILED", s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got ${result.error}") + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + test("describe skipping index") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year") + .addValueSet("name") + .addMinMax("age") + .create() + + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080s" + val query = s"DESC SKIPPING INDEX ON $testTable" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 3, + s"expected result size is 3, but got ${result.results.size}") + val expectedResult0 = + "{'indexed_col_name':'year','data_type':'int','skip_type':'PARTITION'}" + assert( + result.results(0) == expectedResult0, + s"expected result size is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = + "{'indexed_col_name':'name','data_type':'string','skip_type':'VALUE_SET'}" + assert( + result.results(1) == expectedResult1, + s"expected result size is $expectedResult1, but got ${result.results(1)}") + val expectedResult2 = "{'indexed_col_name':'age','data_type':'int','skip_type':'MIN_MAX'}" + assert( + result.results(2) == expectedResult2, + s"expected result size is $expectedResult2, but got ${result.results(2)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'indexed_col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected 0th field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected 1st field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'skip_type','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected 2nd field is $expectedSecondSchema, but got ${result.schemas(2)}") + + assert(result.status == "SUCCESS", s"expected status is FAILED, but got ${result.status}") + assert(result.error.isEmpty, s"we expect error, but got ${result.error}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + def commonAssert( + result: REPLResult, + jobRunId: String, + query: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId == jobRunId, + s"expected jobRunId is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId == appId, + s"expected applicationId is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName == dataSourceName, + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + val actualQueryText = normalizeString(result.queryText) + val expectedQueryText = normalizeString(query) + assert( + actualQueryText == expectedQueryText, + s"expected query is $expectedQueryText, but got $actualQueryText") + assert(result.sessionId.isEmpty, s"we don't expect session id, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = { + pollForResultAndAssert( + osClient, + expected, + "jobRunId", + jobId, + streamingTimeout.toMillis, + resultIndex) + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala new file mode 100644 index 000000000..9a2afc71e --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -0,0 +1,573 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} +import scala.util.control.Breaks.{break, breakable} + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.app.{FlintCommand, FlintInstance} +import org.opensearch.flint.core.{FlintClient, FlintOptions} +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { + + var flintClient: FlintClient = _ + var osClient: OSClient = _ + var updater: OpenSearchUpdater = _ + val requestIndex = "flint_ql_sessions" + val resultIndex = "query_results2" + val jobRunId = "00ff4o3b5091080q" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + val sessionId = "10" + val requestIndexMapping = + """ { + | "properties": { + | "applicationId": { + | "type": "keyword" + | }, + | "dataSourceName": { + | "type": "keyword" + | }, + | "error": { + | "type": "text" + | }, + | "excludeJobIds": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "if_primary_term": { + | "type": "long" + | }, + | "if_seq_no": { + | "type": "long" + | }, + | "jobId": { + | "type": "keyword" + | }, + | "jobStartTime": { + | "type": "long" + | }, + | "lang": { + | "type": "keyword" + | }, + | "lastUpdateTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "query": { + | "type": "text" + | }, + | "queryId": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "sessionId": { + | "type": "keyword" + | }, + | "state": { + | "type": "keyword" + | }, + | "statementId": { + | "type": "keyword" + | }, + | "submitTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "type": { + | "type": "keyword" + | } + | } + | } + |""".stripMargin + val testTable = dataSourceName + ".default.flint_sql_test" + + // use a thread-local variable to store and manage the future in beforeEach and afterEach + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + + flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + updater = new OpenSearchUpdater( + requestIndex, + new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + + } + + override def afterEach(): Unit = { + flintClient.deleteIndex(requestIndex) + super.afterEach() + } + + def createSession(jobId: String, excludeJobId: String): Unit = { + val docs = Seq(s"""{ + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) + index(requestIndex, oneNodeSetting, requestIndexMapping, docs) + } + + def startREPL(): Future[Unit] = { + val prefix = "flint-repl-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + val futureResult = Future { + // SparkConf's constructor creates a SparkConf that loads defaults from system properties and the classpath. + // Read SparkConf.getSystemProperties + System.setProperty(DATA_SOURCE_NAME.key, "my_glue1") + System.setProperty(JOB_TYPE.key, "interactive") + System.setProperty(SESSION_ID.key, sessionId) + System.setProperty(REQUEST_INDEX.key, requestIndex) + System.setProperty(EXCLUDE_JOB_IDS.key, "00fer5qo32fa080q") + System.setProperty(REPL_INACTIVITY_TIMEOUT_MILLIS.key, "5000") + System.setProperty( + s"spark.sql.catalog.my_glue1", + "org.opensearch.sql.FlintDelegatingSessionCatalog") + System.setProperty("spark.master", "local") + System.setProperty(HOST_ENDPOINT.key, openSearchHost) + System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) + System.setProperty(REFRESH_POLICY.key, "true") + + FlintREPL.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + FlintREPL.enableHiveSupport = false + FlintREPL.terminateJVM = false + FlintREPL.main(Array("select 1", resultIndex)) + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + def waitREPLStop(future: Future[Unit]): Unit = { + try { + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for REPL to finish") + } + } + + def submitQuery(query: String, queryId: String): String = { + submitQuery(query, queryId, System.currentTimeMillis()) + } + + def submitQuery(query: String, queryId: String, submitTime: Long): String = { + val statementId = UUID.randomUUID().toString + + updater.upsert( + statementId, + s"""{ + | "sessionId": "${sessionId}", + | "query": "${query}", + | "applicationId": "00fd775baqpu4g0p", + | "state": "waiting", + | "submitTime": $submitTime, + | "type": "statement", + | "statementId": "${statementId}", + | "queryId": "${queryId}", + | "dataSourceName": "${dataSourceName}" + |}""".stripMargin) + statementId + } + + test("sanity") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL()) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) + failureValidation(result) + true + } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + + /** + * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or + * removed when inside a JSON string. The same goes for tab characters, which should be + * represented as \\t. + * + * Here, I replace the newlines with spaces and escape tab characters that is being included in + * the JSON. + * + * @param sqlQuery + * @return + */ + def makeJsonCompliant(sqlQuery: String): String = { + sqlQuery.replaceAll("\n", " ").replaceAll("\t", "\\\\t") + } + + def commonValidation( + result: REPLResult, + expectedQueryId: String, + expectedStatement: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId.equals(jobRunId), + s"expected job id is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId.equals(appId), + s"expected app id is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName.equals(dataSourceName), + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + assert( + result.queryId.equals(expectedQueryId), + s"expected query id is $expectedQueryId, but got ${result.queryId}") + assert( + result.queryText.equals(expectedStatement), + s"expected query is $expectedStatement, but got ${result.queryText}") + assert( + result.sessionId.equals(sessionId), + s"expected session id is $sessionId, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + } + + def successValidation(result: REPLResult): Unit = { + assert( + result.status.equals("SUCCESS"), + s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + } + + def failureValidation(result: REPLResult): Unit = { + assert(result.status.equals("FAILED"), s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got nothing") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, queryId: String): Unit = { + pollForResultAndAssert(osClient, expected, "queryId", queryId, 60000, resultIndex) + } + + /** + * Repeatedly polls a resource until a specified condition is met or a timeout occurs. + * + * This method continuously checks a resource for a specific condition. If the condition is met + * within the timeout period, the polling stops. If the timeout period is exceeded without the + * condition being met, an assertion error is thrown. + * + * @param osClient + * The OSClient used to poll the resource. + * @param condition + * A function that takes an instance of type T and returns a Boolean. This function defines + * the condition to be met. + * @param id + * The unique identifier of the resource to be polled. + * @param timeoutMillis + * The maximum amount of time (in milliseconds) to wait for the condition to be met. + * @param index + * The index in which the resource resides. + * @param deserialize + * A function that deserializes a String into an instance of type T. + * @param logType + * A descriptive string for logging purposes, indicating the type of resource being polled. + * @return + * whether timeout happened + * @throws OpenSearchStatusException + * if there's an issue fetching the resource. + */ + def awaitConditionOrTimeout[T]( + osClient: OSClient, + expected: T => Boolean, + id: String, + timeoutMillis: Long, + index: String, + deserialize: String => T, + logType: String): Boolean = { + val getResponse = osClient.getDoc(index, id) + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check $logType for $id") + try { + if (getResponse.isExists()) { + val instance = deserialize(getResponse.getSourceAsString) + logInfo(s"$logType $id: $instance") + if (expected(instance)) { + break + } + } + } catch { + case e: OpenSearchStatusException => logError(s"Exception while fetching $logType", e) + } + Thread.sleep(2000) // 2 seconds + } + } + System.currentTimeMillis() - startTime >= timeoutMillis + } + + def awaitConditionForStatementOrTimeout( + expected: FlintCommand => Boolean, + statementId: String): Boolean = { + awaitConditionOrTimeout[FlintCommand]( + osClient, + expected, + statementId, + 10000, + requestIndex, + FlintCommand.deserialize, + "statement") + } + + def awaitConditionForSessionOrTimeout( + expected: FlintInstance => Boolean, + sessionId: String): Boolean = { + awaitConditionOrTimeout[FlintInstance]( + osClient, + expected, + sessionId, + 10000, + requestIndex, + FlintInstance.deserialize, + "session") + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala new file mode 100644 index 000000000..563997b7f --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} +import scala.util.control.Breaks._ + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.core.FlintOptions +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging + +/** + * We use a self-type annotation (self: OpenSearchSuite =>) to specify that it must be mixed into + * a class that also mixes in OpenSearchSuite. This way, JobTest can still use the + * openSearchOptions field, + */ +trait JobTest extends Logging { self: OpenSearchSuite => + + def pollForResultAndAssert( + osClient: OSClient, + expected: REPLResult => Boolean, + idField: String, + idValue: String, + timeoutMillis: Long, + resultIndex: String): Unit = { + val query = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "$idField": "$idValue" + | } + | } + | ] + | } + |}""".stripMargin + val resultReader = osClient.createQueryReader(resultIndex, query, "updateTime", SortOrder.ASC) + + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check result for $idValue") + try { + if (resultReader.hasNext()) { + REPLResult.deserialize(resultReader.next()) match { + case Success(replResult) => + logInfo(s"repl result: $replResult") + assert(expected(replResult), s"{$query} failed.") + case Failure(exception) => + assert(false, "Failed to deserialize: " + exception.getMessage) + } + break + } + } catch { + case e: OpenSearchStatusException => logError("Exception while querying for result", e) + } + + Thread.sleep(2000) // 2 seconds + } + if (System.currentTimeMillis() - startTime >= timeoutMillis) { + assert( + false, + s"Timeout occurred after $timeoutMillis milliseconds waiting for query result.") + } + } + } + + /** + * Used to preprocess multi-line queries before comparing them as serialized and deserialized + * queries might have different characters. + * @param s + * input + * @return + * normalized input by replacing all space, tab, ane newlines with single spaces. + */ + def normalizeString(s: String): String = { + // \\s+ is a regular expression that matches one or more whitespace characters, including spaces, tabs, and newlines. + s.replaceAll("\\s+", " ") + } // Replace all whitespace characters with empty string +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala new file mode 100644 index 000000000..34dc2595c --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.Try + +import org.json4s.{DefaultFormats, Formats} +import org.json4s.native.JsonMethods.parse + +class REPLResult( + val results: Seq[String], + val schemas: Seq[String], + val jobRunId: String, + val applicationId: String, + val dataSourceName: String, + val status: String, + val error: String, + val queryId: String, + val queryText: String, + val sessionId: String, + val updateTime: Long, + val queryRunTime: Long) { + override def toString: String = { + s"REPLResult(results=$results, schemas=$schemas, jobRunId=$jobRunId, applicationId=$applicationId, " + + s"dataSourceName=$dataSourceName, status=$status, error=$error, queryId=$queryId, queryText=$queryText, " + + s"sessionId=$sessionId, updateTime=$updateTime, queryRunTime=$queryRunTime)" + } +} + +object REPLResult { + implicit val formats: Formats = DefaultFormats + + def deserialize(jsonString: String): Try[REPLResult] = Try { + val json = parse(jsonString) + + new REPLResult( + results = (json \ "result").extract[Seq[String]], + schemas = (json \ "schema").extract[Seq[String]], + jobRunId = (json \ "jobRunId").extract[String], + applicationId = (json \ "applicationId").extract[String], + dataSourceName = (json \ "dataSourceName").extract[String], + status = (json \ "status").extract[String], + error = (json \ "error").extract[String], + queryId = (json \ "queryId").extract[String], + queryText = (json \ "queryText").extract[String], + sessionId = (json \ "sessionId").extract[String], + updateTime = (json \ "updateTime").extract[Long], + queryRunTime = (json \ "queryRunTime").extract[Long]) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7af1c2639..4ab3a983b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -51,6 +51,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit protected def deleteTestIndex(testIndexNames: String*): Unit = { testIndexNames.foreach(testIndex => { + /** * Todo, if state is not valid, will throw IllegalStateException. Should check flint * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index 61564546e..575f09362 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -142,7 +142,8 @@ class FlintSparkPPLCorrelationITSuite assert( thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") } - test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { + test( + "create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { val thrown = intercept[IllegalStateException] { val frame = sql(s""" | source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 62ff50fb6..32c1baa0a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -331,6 +331,7 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + /** * | age_span | country | average_age | * |:---------|:--------|:------------| diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 750e228ef..32747e20f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -7,6 +7,7 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata @@ -14,11 +15,14 @@ import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.XContentType import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} /** @@ -41,7 +45,10 @@ object FlintJob extends Logging with FlintJobExecutor { val Array(query, resultIndex) = args val conf = createSparkConf() - val wait = conf.get("spark.flint.job.type", "continue") + val jobType = conf.get("spark.flint.job.type", "batch") + logInfo(s"""Job type is: ${jobType}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + val dataSource = conf.get("spark.flint.datasource.name", "") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* @@ -52,9 +59,16 @@ object FlintJob extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - + val streamingRunningCount = new AtomicInteger(0) val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + JobOperator( + createSparkSession(conf), + query, + dataSource, + resultIndex, + jobType.equalsIgnoreCase("streaming"), + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index a44e70401..1814a8d8e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -11,18 +11,20 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception -import org.opensearch.flint.core.FlintClient +import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient} import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue} import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} +import org.apache.spark.sql.FlintREPL.envinromentProvider import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} -import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider} import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { @@ -30,6 +32,8 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() + var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var enableHiveSupport: Boolean = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val resultIndexMapping = @@ -87,14 +91,27 @@ trait FlintJobExecutor { } def createSparkSession(conf: SparkConf): SparkSession = { - SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + val builder = SparkSession.builder().config(conf) + if (enableHiveSupport) { + builder.enableHiveSupport() + } + builder.getOrCreate() } private def writeData(resultData: DataFrame, resultIndex: String): Unit = { - resultData.write - .format("flint") - .mode("append") - .save(resultIndex) + try { + resultData.write + .format("flint") + .mode("append") + .save(resultIndex) + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX) + } catch { + case e: Exception => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX, + e) + } } /** @@ -114,7 +131,7 @@ trait FlintJobExecutor { if (osClient.doesIndexExist(resultIndex)) { writeData(resultData, resultIndex) } else { - createIndex(osClient, resultIndex, resultIndexMapping) + createResultIndex(osClient, resultIndex, resultIndexMapping) writeData(resultData, resultIndex) } } @@ -177,8 +194,8 @@ trait FlintJobExecutor { ( resultToSave, resultSchemaToSave, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "SUCCESS", "", @@ -226,8 +243,8 @@ trait FlintJobExecutor { ( null, null, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "FAILED", error, @@ -310,8 +327,9 @@ trait FlintJobExecutor { } } catch { case e: IllegalStateException - if e.getCause().getMessage().contains("index_not_found_exception") => - createIndex(osClient, resultIndex, resultIndexMapping) + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => + createResultIndex(osClient, resultIndex, resultIndexMapping) case e: InterruptedException => val error = s"Interrupted by the main thread: ${e.getMessage}" Thread.currentThread().interrupt() // Preserve the interrupt status @@ -324,7 +342,7 @@ trait FlintJobExecutor { } } - def createIndex( + def createResultIndex( osClient: OSClient, resultIndex: String, mapping: String): Either[String, Unit] = { @@ -393,6 +411,7 @@ trait FlintJobExecutor { case r: ParseException => handleQueryException(r, "Syntax error", spark, dataSource, query, queryId, sessionId) case r: AmazonS3Exception => + incrementCounter(MetricConstants.S3_ERR_CNT_METRIC) handleQueryException( r, "Fail to read data from S3. Cause", diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6c3fd957d..d30669cca 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -7,22 +7,30 @@ package org.apache.spark.sql import java.net.ConnectException import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal +import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -42,17 +50,23 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + @volatile var earlyExitFlag: Boolean = false + // termiante JVM in the presence non-deamon thread before exiting + var terminateJVM = true + def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } + private val sessionRunningCount = new AtomicInteger(0) + private val statementRunningCount = new AtomicInteger(0) + def main(args: Array[String]) { val Array(query, resultIndex) = args if (Strings.isNullOrEmpty(resultIndex)) { @@ -61,7 +75,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() - val dataSource = conf.get("spark.flint.datasource.name", "unknown") + val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -71,33 +85,47 @@ object FlintREPL extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val wait = conf.get("spark.flint.job.type", "continue") - if (wait.equalsIgnoreCase("streaming")) { + val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) + logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") + conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + if (jobType.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") + val streamingRunningCount = new AtomicInteger(0) val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, true) + JobOperator( + createSparkSession(conf), + query, + dataSource, + resultIndex, + true, + streamingRunningCount) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) + val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + throw new IllegalArgumentException(FlintSparkConf.REQUEST_INDEX.key + " is not set") } if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + throw new IllegalArgumentException(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = - conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) val queryExecutionTimeoutSecs: Duration = Duration( conf.getLong( "spark.flint.job.queryExecutionTimeoutSec", @@ -107,12 +135,21 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) + + addShutdownHook( + flintSessionIndexUpdater, + osClient, + sessionIndex.get, + sessionId.get, + sessionTimerContext) // 1 thread for updating heart beat val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) + registerGauge(MetricConstants.REPL_RUNNING_METRIC, sessionRunningCount) + registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) val jobStartTime = currentTimeProvider.currentEpochMillis() // update heart beat every 30 seconds // OpenSearch triggers recovery after 1 minute outdated heart beat @@ -136,6 +173,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId, flintSessionIndexUpdater, jobStartTime)) { + earlyExitFlag = true return } @@ -151,10 +189,10 @@ object FlintREPL extends Logging with FlintJobExecutor { queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) - exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } + recordSessionSuccess(sessionTimerContext) } catch { case e: Exception => handleSessionError( @@ -165,24 +203,25 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime, flintSessionIndexUpdater, osClient, - sessionIndex.get) + sessionIndex.get, + sessionTimerContext) } finally { if (threadPool != null) { heartBeatFuture.cancel(true) // Pass `true` to interrupt if running threadPoolFactory.shutdownThreadPool(threadPool) } - + stopTimer(sessionTimerContext) spark.stop() // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, // which may be due to unresolved bugs in dependencies or threads not being properly shut down. - if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { logInfo("A non-daemon thread in the driver is seen.") // Exit the JVM to prevent resource leaks and potential emr-s job hung. // A zero status code is used for a graceful shutdown without indicating an error. // If exiting with non-zero status, emr-s job will fail. - // This is a part of the fault tolerance mechanism to handle such scenarios gracefully. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully System.exit(0) } } @@ -232,7 +271,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, flintSessionIndexUpdater: OpenSearchUpdater, jobStartTime: Long): Boolean = { - val confExcludeJobsOpt = conf.getOption("spark.flint.deployment.excludeJobs") + val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => @@ -366,6 +405,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") + sessionRunningCount.incrementAndGet() } def handleSessionError( @@ -376,7 +416,8 @@ object FlintREPL extends Logging with FlintJobExecutor { jobStartTime: Long, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, - sessionIndex: String): Unit = { + sessionIndex: String, + sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" logError(error, e) @@ -384,6 +425,9 @@ object FlintREPL extends Logging with FlintJobExecutor { .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) updateFlintInstance(flintInstance, flintSessionIndexUpdater, sessionId) + if (flintInstance.state.equals("fail")) { + recordSessionFailed(sessionTimerContext) + } } private def getExistingFlintInstance( @@ -505,10 +549,13 @@ object FlintREPL extends Logging with FlintJobExecutor { } if (!canPickNextStatementResult) { + earlyExitFlag = true canProceed = false } else if (!flintReader.hasNext) { canProceed = false } else { + val statementTimerContext = getTimerContext( + MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( @@ -529,7 +576,8 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand, resultIndex, flintSessionIndexUpdater, - osClient) + osClient, + statementTimerContext) // last query finish time is last activity time lastActivityTime = currentTimeProvider.currentEpochMillis() } @@ -556,17 +604,16 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand: FlintCommand, resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, - osClient: OSClient): Unit = { + osClient: OSClient, + statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // todo. it is migration plan to handle https://github - // .com/opensearch-project/sql/issues/2436. Remove sleep after issue fixed in plugin. - Thread.sleep(2000) if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() } updateSessionIndex(flintCommand, flintSessionIndexUpdater) + recordStatementStateChange(flintCommand, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) @@ -575,6 +622,7 @@ object FlintREPL extends Logging with FlintJobExecutor { logError(error, e) flintCommand.fail() updateSessionIndex(flintCommand, flintSessionIndexUpdater) + recordStatementStateChange(flintCommand, statementTimerContext) } } @@ -770,6 +818,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintCommand.running() logDebug(s"command running: $flintCommand") updateSessionIndex(flintCommand, flintSessionIndexUpdater) + statementRunningCount.incrementAndGet() flintCommand } @@ -814,7 +863,7 @@ object FlintREPL extends Logging with FlintJobExecutor { | } |}""".stripMargin - val flintReader = osClient.createReader(sessionIndex, dsl, "submitTime") + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } @@ -823,6 +872,7 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient, sessionIndex: String, sessionId: String, + sessionTimerContext: Timer.Context, shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { shutdownHookManager.addShutdownHook(() => { @@ -838,12 +888,21 @@ object FlintREPL extends Logging with FlintJobExecutor { } val state = Option(source.get("state")).map(_.asInstanceOf[String]) - if (state.isDefined && state.get != "dead" && state.get != "fail") { + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { updateFlintInstanceBeforeShutdown( source, getResponse, flintSessionIndexUpdater, - sessionId) + sessionId, + sessionTimerContext) } }) } @@ -852,7 +911,8 @@ object FlintREPL extends Logging with FlintJobExecutor { source: java.util.Map[String, AnyRef], getResponse: GetResponse, flintSessionIndexUpdater: OpenSearchUpdater, - sessionId: String): Unit = { + sessionId: String, + sessionTimerContext: Timer.Context): Unit = { val flintInstance = FlintInstance.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( @@ -862,6 +922,7 @@ object FlintREPL extends Logging with FlintJobExecutor { currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) + recordSessionSuccess(sessionTimerContext) } /** @@ -909,11 +970,17 @@ object FlintREPL extends Logging with FlintJobExecutor { // Preserve the interrupt status Thread.currentThread().interrupt() logError("HeartBeatUpdater task was interrupted", ie) + incrementCounter( + MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC + ) // Record heartbeat failure metric // maybe due to invalid sequence number or primary term case e: Exception => logWarning( s"""Fail to update the last update time of the flint instance ${sessionId}""", e) + incrementCounter( + MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC + ) // Record heartbeat failure metric } } }, @@ -1023,4 +1090,34 @@ object FlintREPL extends Logging with FlintJobExecutor { result.getOrElse(throw new RuntimeException("Failed after retries")) } + + private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } + incrementCounter(MetricConstants.REPL_SUCCESS_METRIC) + } + + private def recordSessionFailed(sessionTimerContext: Timer.Context): Unit = { + stopTimer(sessionTimerContext) + if (sessionRunningCount.get() > 0) { + sessionRunningCount.decrementAndGet() + } + incrementCounter(MetricConstants.REPL_FAILED_METRIC) + } + + private def recordStatementStateChange( + flintCommand: FlintCommand, + statementTimerContext: Timer.Context): Unit = { + stopTimer(statementTimerContext) + if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() + } + if (flintCommand.isComplete()) { + incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) + } else if (flintCommand.isFailed()) { + incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) + } + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index a702d2c64..bbaceb15d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -6,11 +6,14 @@ package org.apache.spark.sql import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import scala.util.{Failure, Success, Try} +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.core.storage.OpenSearchUpdater import org.apache.spark.SparkConf @@ -21,16 +24,16 @@ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils case class JobOperator( - sparkConf: SparkConf, + spark: SparkSession, query: String, dataSource: String, resultIndex: String, - streaming: Boolean) + streaming: Boolean, + streamingRunningCount: AtomicInteger) extends Logging with FlintJobExecutor { - private val spark = createSparkSession(sparkConf) - // jvm shutdown hook + // JVM shutdown hook sys.addShutdownHook(stop()) def start(): Unit = { @@ -38,7 +41,10 @@ case class JobOperator( implicit val executionContext = ExecutionContext.fromExecutor(threadPool) var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + streamingRunningCount.incrementAndGet() + // osClient needs spark session to be created first to get FlintOptions initialized. // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -99,6 +105,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } + recordStreamingCompletionStatus(exceptionThrown) } def stop(): Unit = { @@ -110,4 +117,24 @@ case class JobOperator( case Failure(e) => logError("unexpected error while stopping spark session", e) } } + + /** + * Records the completion of a streaming job by updating the appropriate metrics. This method + * decrements the running metric for streaming jobs and increments either the success or failure + * metric based on whether an exception was thrown. + * + * @param exceptionThrown + * Indicates whether an exception was thrown during the streaming job execution. + */ + private def recordStreamingCompletionStatus(exceptionThrown: Boolean): Unit = { + // Decrement the metric for running streaming jobs as the job is now completing. + if (streamingRunningCount.get() > 0) { + streamingRunningCount.decrementAndGet() + } + + exceptionThrown match { + case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) + case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + } + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index e2e44bddd..f5e4ec2be 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -9,7 +9,10 @@ import java.io.IOException import java.util.ArrayList import java.util.Locale +import scala.util.{Failure, Success, Try} + import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.action.search.{SearchRequest, SearchResponse} import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} import org.opensearch.client.indices.CreateIndexRequest @@ -17,8 +20,9 @@ import org.opensearch.common.Strings import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.{NamedXContentRegistry, XContentParser, XContentType} import org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchScrollReader, OpenSearchUpdater} +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions, IRestHighLevelClient} +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchQueryReader, OpenSearchScrollReader, OpenSearchUpdater} import org.opensearch.index.query.{AbstractQueryBuilder, MatchAllQueryBuilder, QueryBuilder} import org.opensearch.plugins.SearchPlugin import org.opensearch.search.SearchModule @@ -73,8 +77,13 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { try { client.createIndex(request, RequestOptions.DEFAULT) logInfo(s"create $osIndexName successfully") + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX) } catch { case e: Exception => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.RESULT_METADATA_WRITE_METRIC_PREFIX, + e) throw new IllegalStateException(s"Failed to create index $osIndexName", e); } } @@ -108,23 +117,29 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { def getDoc(osIndexName: String, id: String): GetResponse = { using(flintClient.createClient()) { client => - try { - val request = new GetRequest(osIndexName, id) - client.get(request, RequestOptions.DEFAULT) - } catch { - case e: Exception => + val request = new GetRequest(osIndexName, id) + val result = Try(client.get(request, RequestOptions.DEFAULT)) + result match { + case Success(response) => + IRestHighLevelClient.recordOperationSuccess( + MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX) + response + case Failure(e: Exception) => + IRestHighLevelClient.recordOperationFailure( + MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX, + e) throw new IllegalStateException( String.format( Locale.ROOT, "Failed to retrieve doc %s from index %s", - osIndexName, - id), + id, + osIndexName), e) } } } - def createReader(indexName: String, query: String, sort: String): FlintReader = try { + def createScrollReader(indexName: String, query: String, sort: String): FlintReader = try { var queryBuilder: QueryBuilder = new MatchAllQueryBuilder if (!Strings.isNullOrEmpty(query)) { val parser = @@ -145,11 +160,31 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { using(flintClient.createClient()) { client => try { val request = new GetIndexRequest(indexName) - client.isIndexExists(request, RequestOptions.DEFAULT) + client.doesIndexExist(request, RequestOptions.DEFAULT) } catch { case e: Exception => throw new IllegalStateException(s"Failed to check if index $indexName exists", e) } } } + + def createQueryReader( + indexName: String, + query: String, + sort: String, + sortOrder: SortOrder): FlintReader = try { + var queryBuilder: QueryBuilder = new MatchAllQueryBuilder + if (!Strings.isNullOrEmpty(query)) { + val parser = + XContentType.JSON.xContent.createParser(xContentRegistry, IGNORE_DEPRECATIONS, query) + queryBuilder = AbstractQueryBuilder.parseInnerQueryBuilder(parser) + } + new OpenSearchQueryReader( + flintClient.createClient(), + indexName, + new SearchSourceBuilder().query(queryBuilder).sort(sort, sortOrder)) + } catch { + case e: IOException => + throw new RuntimeException(e) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala new file mode 100644 index 000000000..5b1c4e2df --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * Trait defining an interface for fetching environment variables. + */ +trait EnvironmentProvider { + + /** + * Retrieves the value of an environment variable. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set. + * @return + * The value of the environment variable if it exists, otherwise the default value. + */ + def getEnvVar(name: String, default: String): String +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala new file mode 100644 index 000000000..bf5eafce5 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * An implementation of `EnvironmentProvider` that fetches actual environment variables from the + * system. + */ +class RealEnvironment extends EnvironmentProvider { + + /** + * Retrieves the value of an environment variable from the system or returns a default value if + * not present. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set in the system. + * @return + * The value of the environment variable if it exists in the system, otherwise the default + * value. + */ + def getEnvVar(name: String, default: String): String = sys.env.getOrElse(name, default) +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c3d027102..abae546b6 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -15,6 +15,8 @@ import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag +import com.codahale.metrics.Timer +import org.mockito.ArgumentMatchers.{eq => eqTo, _} import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -22,12 +24,13 @@ import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.{ArrayType, LongType, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -80,6 +83,7 @@ class FlintREPLTest val getResponse = mock[GetResponse] val sessionIndex = "testIndex" val sessionId = "testSessionId" + val flintSessionContext = mock[Timer.Context] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) @@ -108,6 +112,7 @@ class FlintREPLTest osClient, sessionIndex, sessionId, + flintSessionContext, mockShutdownHookManager) verify(flintSessionIndexUpdater).updateIf(*, *, *, *) @@ -411,7 +416,8 @@ class FlintREPLTest new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) val maxRetries = 1 @@ -686,7 +692,8 @@ class FlintREPLTest test("queryLoop continue until inactivity limit is reached") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -736,7 +743,8 @@ class FlintREPLTest test("queryLoop should stop when canPickUpNextStatement is false") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) val resultIndex = "testResultIndex" @@ -790,7 +798,8 @@ class FlintREPLTest test("queryLoop should properly shut down the thread pool after execution") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -838,7 +847,8 @@ class FlintREPLTest test("queryLoop handle exceptions within the loop gracefully") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) // Simulate an exception thrown when hasNext is called when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) @@ -889,7 +899,8 @@ class FlintREPLTest test("queryLoop should correctly update loop control variables") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) @@ -958,7 +969,8 @@ class FlintREPLTest test("queryLoop should execute loop without processing any commands") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala new file mode 100644 index 000000000..b6f3e3c97 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +/** + * A mock implementation of `EnvironmentProvider` for use in tests, where environment variables + * can be predefined. + * + * @param inputMap + * A map representing the environment variables (name -> value). + */ +class MockEnvironment(inputMap: Map[String, String]) extends EnvironmentProvider { + + /** + * Retrieves the value of an environment variable from the input map or returns a default value + * if not present. + * + * @param name + * The name of the environment variable. + * @param default + * The default value to return if the environment variable is not set in the input map. + * @return + * The value of the environment variable from the input map if it exists, otherwise the + * default value. + */ + def getEnvVar(name: String, default: String): String = inputMap.getOrElse(name, default) +}