From 7f8f0524fdb7b1135b29f1f933ac09cf62a113ca Mon Sep 17 00:00:00 2001
From: Vamsi Manohar <reddyvam@amazon.com>
Date: Tue, 28 Nov 2023 11:12:13 -0800
Subject: [PATCH] Add flint opensearch metrics

Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
---
 .../DimensionedCloudWatchReporter.java        |  55 ++++++-
 .../metrics/reporter/DimensionedName.java     |   5 +-
 .../metrics/source/FlintMetricSource.scala    |  12 ++
 .../flint/core/FlintClientBuilder.java        |   3 +-
 .../FlintOpensearchClientMetricsWrapper.java  | 138 ++++++++++++++++++
 .../DimensionedCloudWatchReporterTest.java    |  20 ++-
 .../org/apache/spark/sql/FlintREPL.scala      |   8 +-
 7 files changed, 232 insertions(+), 9 deletions(-)
 create mode 100644 flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala
 create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java

diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
index 3ad627a98..70e2ad4e6 100644
--- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
+++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
@@ -13,7 +13,6 @@
 import com.amazonaws.services.cloudwatch.model.PutMetricDataResult;
 import com.amazonaws.services.cloudwatch.model.StandardUnit;
 import com.amazonaws.services.cloudwatch.model.StatisticSet;
-import com.amazonaws.util.StringUtils;
 import com.codahale.metrics.Clock;
 import com.codahale.metrics.Counter;
 import com.codahale.metrics.Counting;
@@ -36,6 +35,7 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Date;
+import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
@@ -47,6 +47,9 @@
 import java.util.stream.Collectors;
 import java.util.stream.LongStream;
 import java.util.stream.Stream;
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.jetbrains.annotations.NotNull;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -83,6 +86,16 @@ public class DimensionedCloudWatchReporter extends ScheduledReporter {
     // Visible for testing
     public static final String DIMENSION_SNAPSHOT_STD_DEV = "snapshot-std-dev";
 
+    public static final String DIMENSION_JOB_ID = "jobId";
+
+    public static final String DIMENSION_APPLICATION_ID = "applicationId";
+
+    public static final String DIMENSION_DOMAIN_ID = "domainId";
+
+    public static final String DIMENSION_INSTANCE_TYPE = "instance";
+
+    public static final String UNKNOWN = "unknown";
+
     /**
      * Amazon CloudWatch rejects values that are either too small or too large.
      * Values must be in the range of 8.515920e-109 to 1.174271e+108 (Base 10) or 2e-360 to 2e360 (Base 2).
@@ -135,7 +148,6 @@ public void report(final SortedMap<String, Gauge> gauges,
         if (builder.withDryRun) {
             LOGGER.warn("** Reporter is running in 'DRY RUN' mode **");
         }
-
         try {
             final List<MetricDatum> metricData = new ArrayList<>(
                     gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size());
@@ -339,6 +351,8 @@ private void stageMetricDatum(final boolean metricConfigured,
         if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) {
             final DimensionedName dimensionedName = DimensionedName.decode(metricName);
             final Set<Dimension> dimensions = new LinkedHashSet<>(builder.globalDimensions);
+            Pair<String, Set<Dimension>> finalNameAndDefaultDimensions = getFinalMetricNameAndDefaultDimensions(dimensionedName);
+            dimensions.addAll(finalNameAndDefaultDimensions.getRight());
             dimensions.addAll(dimensionedName.getDimensions());
             if (shouldAppendDropwizardTypeDimension) {
                 dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue));
@@ -347,12 +361,47 @@ private void stageMetricDatum(final boolean metricConfigured,
             metricData.add(new MetricDatum()
                     .withTimestamp(new Date(builder.clock.getTime()))
                     .withValue(cleanMetricValue(metricValue))
-                    .withMetricName(dimensionedName.getName())
+                    .withMetricName(finalNameAndDefaultDimensions.getLeft())
                     .withDimensions(dimensions)
                     .withUnit(standardUnit));
         }
     }
 
+    @NotNull
+    private Pair<String, Set<Dimension>> getFinalMetricNameAndDefaultDimensions(DimensionedName dimensionedName) {
+        final String jobId = System.getenv().getOrDefault("SERVERLESS_EMR_JOB_ID", UNKNOWN);
+        final String applicationId = System.getenv().getOrDefault("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", UNKNOWN);
+        final String domainId = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN);
+        final Dimension jobDimension = new Dimension().withName(DIMENSION_JOB_ID).withValue(jobId);
+        final Dimension applicationDimension = new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(applicationId);
+        final Dimension domainIdDimension = new Dimension().withName(DIMENSION_DOMAIN_ID).withValue(domainId);
+        Dimension instanceDimension = new Dimension().withName(DIMENSION_INSTANCE_TYPE).withValue(UNKNOWN);
+        String name = dimensionedName.getName();
+        String finalMetricName = name;
+        String[] parts = name.split("\\.");
+        if (doesNameConsistsOfMetricNameSpace(parts, jobId)) {
+            finalMetricName = Stream.of(parts).skip(2).collect(Collectors.joining("."));
+            //For executors only id is added to the metric name, thats why the numeric check.
+            //If it is not numeric then the instance is driver.
+            if (StringUtils.isNumeric(parts[1])) {
+                instanceDimension = new Dimension().withName(DIMENSION_INSTANCE_TYPE).withValue("executor" + parts[1]);
+            }
+            else {
+                instanceDimension = new Dimension().withName(DIMENSION_INSTANCE_TYPE).withValue(parts[1]);
+            }
+        }
+        Set<Dimension> dimensions = new HashSet<>();
+        dimensions.add(jobDimension);
+        dimensions.add(applicationDimension);
+        dimensions.add(instanceDimension);
+        dimensions.add(domainIdDimension);
+        return Pair.of(finalMetricName, dimensions);
+    }
+
+    private boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts, String jobId) {
+        return metricNameParts[0].equals(jobId);
+    }
+
     private void stageMetricDatumWithConvertedSnapshot(final boolean metricConfigured,
                                                        final String metricName,
                                                        final Snapshot snapshot,
diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java
index 839e4c28f..da3a446d4 100644
--- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java
+++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedName.java
@@ -28,7 +28,8 @@ public static DimensionedName decode(final String encodedDimensionedName) {
         if (matcher.find() && matcher.groupCount() == 2) {
             final DimensionedNameBuilder builder = new DimensionedNameBuilder(matcher.group(1).trim());
             for (String t : matcher.group(2).split(",")) {
-                final String[] keyAndValue = t.split(":");
+                //## acts as a distinct separator.
+                final String[] keyAndValue = t.split("##");
                 builder.withDimension(keyAndValue[0].trim(), keyAndValue[1].trim());
             }
             return builder.build();
@@ -59,7 +60,7 @@ public synchronized String encode() {
                 final StringBuilder sb = new StringBuilder(this.name);
                 sb.append('[');
                 sb.append(this.dimensions.values().stream()
-                        .map(dimension -> dimension.getName() + ":" + dimension.getValue())
+                        .map(dimension -> dimension.getName() + "##" + dimension.getValue())
                         .collect(Collectors.joining(",")));
                 sb.append(']');
 
diff --git a/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala b/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala
new file mode 100644
index 000000000..e22a61a51
--- /dev/null
+++ b/flint-core/src/main/scala/org/apache/spark/metrics/source/FlintMetricSource.scala
@@ -0,0 +1,12 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.apache.spark.metrics.source
+
+import com.codahale.metrics.MetricRegistry
+
+class FlintMetricSource(val sourceName: String) extends Source {
+  override val metricRegistry: MetricRegistry = new MetricRegistry
+}
diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java
index a0372a86f..8ca254d9b 100644
--- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java
+++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClientBuilder.java
@@ -5,6 +5,7 @@
 
 package org.opensearch.flint.core;
 
+import org.opensearch.flint.core.metrics.FlintOpensearchClientMetricsWrapper;
 import org.opensearch.flint.core.storage.FlintOpenSearchClient;
 
 /**
@@ -13,6 +14,6 @@
 public class FlintClientBuilder {
 
   public static FlintClient build(FlintOptions options) {
-    return new FlintOpenSearchClient(options);
+    return new FlintOpensearchClientMetricsWrapper(new FlintOpenSearchClient(options));
   }
 }
diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java
new file mode 100644
index 000000000..21fbfff21
--- /dev/null
+++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/FlintOpensearchClientMetricsWrapper.java
@@ -0,0 +1,138 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.core.metrics;
+
+import com.codahale.metrics.Counter;
+import java.util.List;
+import java.util.function.Supplier;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.metrics.source.FlintMetricSource;
+import org.opensearch.client.RestHighLevelClient;
+import org.opensearch.flint.core.FlintClient;
+import org.opensearch.flint.core.metadata.FlintMetadata;
+import org.opensearch.flint.core.metadata.log.OptimisticTransaction;
+import org.opensearch.flint.core.metrics.reporter.DimensionedName;
+import org.opensearch.flint.core.storage.FlintOpenSearchClient;
+import org.opensearch.flint.core.storage.FlintReader;
+import org.opensearch.flint.core.storage.FlintWriter;
+
+/**
+ * This class wraps FlintOpensearchClient and emit spark metrics to FlintMetricSource.
+ */
+public class FlintOpensearchClientMetricsWrapper implements FlintClient {
+
+  private final FlintOpenSearchClient delegate;
+
+  public FlintOpensearchClientMetricsWrapper(FlintOpenSearchClient delegate) {
+    this.delegate = delegate;
+  }
+
+  @Override
+  public <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourceName) {
+    return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName));
+  }
+
+  @Override
+  public <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourceName,
+                                                       boolean forceInit) {
+    return handleExceptions(() -> delegate.startTransaction(indexName, dataSourceName, forceInit));
+  }
+
+  @Override
+  public void createIndex(String indexName, FlintMetadata metadata) {
+    try {
+      delegate.createIndex(indexName, metadata);
+    } catch (Throwable t) {
+      handleThrowable();
+      throw t;
+    }
+  }
+
+  @Override
+  public boolean exists(String indexName) {
+    return handleExceptions(() -> delegate.exists(indexName));
+  }
+
+  @Override
+  public List<FlintMetadata> getAllIndexMetadata(String indexNamePattern) {
+    return handleExceptions(() -> delegate.getAllIndexMetadata(indexNamePattern));
+  }
+
+  @Override
+  public FlintMetadata getIndexMetadata(String indexName) {
+    return handleExceptions(() -> delegate.getIndexMetadata(indexName));
+  }
+
+  @Override
+  public void deleteIndex(String indexName) {
+    try {
+      delegate.deleteIndex(indexName);
+    } catch (Throwable t) {
+      handleThrowable();
+      throw t;
+    }
+  }
+
+  @Override
+  public FlintReader createReader(String indexName, String query) {
+    return handleExceptions(() -> delegate.createReader(indexName, query));
+  }
+
+  @Override
+  public FlintWriter createWriter(String indexName) {
+    return handleExceptions(() -> delegate.createWriter(indexName));
+  }
+
+  @Override
+  public RestHighLevelClient createClient() {
+    return handleExceptions(delegate::createClient);
+  }
+
+  private <T> T handleExceptions(Supplier<T> function) {
+    try {
+      return function.get();
+    } catch (Throwable t) {
+      handleThrowable();
+      throw new RuntimeException(t);
+    }
+  }
+
+  private void handleThrowable(){
+    String clusterName = System.getenv("FLINT_OPENSEARCH_DOMAIN_IDENTIFIER");
+    if (clusterName == null) {
+      clusterName = "unknown";
+    }
+    DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessError")
+        .withDimension("domain_ident", clusterName)
+        .build();
+    publishMetric(metricName);
+  }
+
+  private void handleAccessDeniedException() {
+    String clusterName = System.getenv("FLINT_AUTH_DOMAIN_IDENTIFIER");
+    if (clusterName == null) {
+      clusterName = "unknown";
+    }
+    DimensionedName metricName = DimensionedName.withName("FlintOpenSearchAccessDeniedError")
+        .withDimension("domain_ident", clusterName)
+        .build();
+    publishMetric(metricName);
+  }
+
+  private void publishMetric(DimensionedName metricName) {
+    FlintMetricSource flintMetricSource =
+        (FlintMetricSource) SparkEnv.get().metricsSystem().getSourcesByName("FlintMetricSource");
+    if (flintMetricSource != null) {
+      Counter flintOpenSearchAccessError =
+          flintMetricSource.metricRegistry().getCounters().get(metricName.encode());
+      if (flintOpenSearchAccessError == null) {
+        flintOpenSearchAccessError = flintMetricSource.metricRegistry().counter(metricName.encode());
+      }
+      flintOpenSearchAccessError.inc();
+    }
+  }
+
+}
diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
index 83df15067..0524888dd 100644
--- a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
+++ b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
@@ -47,12 +47,17 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_APPLICATION_ID;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_COUNT;
+import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_DOMAIN_ID;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_GAUGE;
+import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_INSTANCE_TYPE;
+import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_JOB_ID;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_NAME_TYPE;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_MEAN;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_STD_DEV;
 import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_SUMMARY;
+import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.UNKNOWN;
 
 @ExtendWith(MockitoExtension.class)
 @MockitoSettings(strictness = Strictness.LENIENT)
@@ -104,10 +109,12 @@ public void shouldReportWithoutGlobalDimensionsWhenGlobalDimensionsNotConfigured
 
         final List<Dimension> dimensions = firstMetricDatumDimensionsFromCapturedRequest();
 
-        assertThat(dimensions).hasSize(1);
+        assertThat(dimensions).hasSize(5);
         assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT));
+        assertDefaultDimensionsWithUnknownValue(dimensions);
     }
 
+
     @Test
     public void reportedCounterShouldContainExpectedDimension() throws Exception {
         metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc();
@@ -116,6 +123,7 @@ public void reportedCounterShouldContainExpectedDimension() throws Exception {
         final List<Dimension> dimensions = firstMetricDatumDimensionsFromCapturedRequest();
 
         assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT));
+        assertDefaultDimensionsWithUnknownValue(dimensions);
     }
 
     @Test
@@ -126,6 +134,7 @@ public void reportedGaugeShouldContainExpectedDimension() throws Exception {
         final List<Dimension> dimensions = firstMetricDatumDimensionsFromCapturedRequest();
 
         assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_GAUGE));
+        assertDefaultDimensionsWithUnknownValue(dimensions);
     }
 
     @Test
@@ -475,6 +484,15 @@ public void shouldReportExpectedGlobalAndCustomDimensions() throws Exception {
         assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2"));
     }
 
+
+    private void assertDefaultDimensionsWithUnknownValue(List<Dimension> dimensions) {
+        assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN));
+        assertThat(dimensions).contains(new Dimension().withName(DIMENSION_INSTANCE_TYPE).withValue(UNKNOWN));
+        assertThat(dimensions).contains(new Dimension().withName(DIMENSION_DOMAIN_ID).withValue(UNKNOWN));
+        assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN));
+    }
+
+
     private MetricDatum metricDatumByDimensionFromCapturedRequest(final String dimensionValue) {
         final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue();
         final List<MetricDatum> metricData = putMetricDataRequest.getMetricData();
diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
index 99085185c..bb3cdb3c0 100644
--- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
+++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala
@@ -18,10 +18,12 @@ import org.opensearch.action.get.GetResponse
 import org.opensearch.common.Strings
 import org.opensearch.flint.app.{FlintCommand, FlintInstance}
 import org.opensearch.flint.app.FlintInstance.formats
+import org.opensearch.flint.core.metrics.reporter.DimensionedName
 import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
 
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
+import org.apache.spark.metrics.source.FlintMetricSource
 import org.apache.spark.sql.flint.config.FlintSparkConf
 import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
 import org.apache.spark.util.ThreadUtils
@@ -91,10 +93,12 @@ object FlintREPL extends Logging with FlintJobExecutor {
       }
 
       val spark = createSparkSession(conf)
+      //This is the metric source we are going to use at all places including the searchServicesFireFlower Library.
+      val flintMetricSource = new FlintMetricSource("FlintMetricSource")
+      SparkEnv.get.metricsSystem.registerSource(flintMetricSource)
       val osClient = new OSClient(FlintSparkConf().flintOptions())
       val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown")
       val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
-
       // Read the values from the Spark configuration or fall back to the default values
       val inactivityLimitMillis: Long =
         conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS)