Skip to content

Commit

Permalink
Refactor and add metrics for streaming job
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Feb 18, 2024
1 parent 9c34a1d commit dd1a0d6
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/**
* This class defines custom metric constants used for monitoring flint operations.
*/
public class MetricConstants {
public final class MetricConstants {

/**
* The prefix for all read-related metrics in OpenSearch.
Expand Down Expand Up @@ -65,5 +65,29 @@ public class MetricConstants {
/**
* Metric name for tracking the processing time of statements.
*/
public static final String STATEMENT_PROCESSING_TIME_METRIC = "STATEMENT.processingTime";
public static final String STATEMENT_PROCESSING_TIME_METRIC = "statement.processingTime";

/**
* Metric for tracking the count of currently running streaming jobs.
*/
public static final String STREAMING_RUNNING_METRIC = "streaming.running.count";

/**
* Metric for tracking the count of streaming jobs that have failed.
*/
public static final String STREAMING_FAILED_METRIC = "streaming.failed.count";

/**
* Metric for tracking the count of streaming jobs that have completed successfully.
*/
public static final String STREAMING_SUCCESS_METRIC = "streaming.success.count";

/**
* Metric for tracking the count of failed heartbeat signals in streaming jobs.
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

private MetricConstants() {
// Private constructor to prevent instantiation
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
package org.opensearch.flint.core.metrics;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.apache.spark.metrics.source.Source;
import scala.collection.Seq;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;

/**
Expand All @@ -21,8 +24,8 @@ public final class MetricsUtil {

private static final Logger LOG = Logger.getLogger(MetricsUtil.class.getName());

// Private constructor to prevent instantiation
private MetricsUtil() {
// Private constructor to prevent instantiation
}

/**
Expand Down Expand Up @@ -60,10 +63,7 @@ public static void decrementCounter(String metricName) {
*/
public static Timer.Context getTimerContext(String metricName) {
Timer timer = getOrCreateTimer(metricName);
if (timer != null) {
return timer.time();
}
return null;
return timer != null ? timer.time() : null;
}

/**
Expand All @@ -74,42 +74,47 @@ public static Timer.Context getTimerContext(String metricName) {
* @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}.
*/
public static Long stopTimer(Timer.Context context) {
if (context != null) {
return context.stop();
return context != null ? context.stop() : null;
}

/**
* Registers a gauge metric with the provided name and value.
* The gauge will reflect the current value of the AtomicInteger provided.
*
* @param metricName The name of the gauge metric to register.
* @param value The AtomicInteger whose current value should be reflected by the gauge.
*/
public static void registerGauge(String metricName, final AtomicInteger value) {
MetricRegistry metricRegistry = getMetricRegistry();
if (metricRegistry == null) {
LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName);
return;
}
return null;
metricRegistry.register(metricName, (Gauge<Integer>) value::get);
}

// Retrieves or creates a new counter for the given metric name
private static Counter getOrCreateCounter(String metricName) {
SparkEnv sparkEnv = SparkEnv.get();
if (sparkEnv == null) {
LOG.warning("Spark environment not available, cannot instrument metric: " + metricName);
return null;
}

FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv);
Counter counter = flintMetricSource.metricRegistry().getCounters().get(metricName);
if (counter == null) {
counter = flintMetricSource.metricRegistry().counter(metricName);
}
return counter;
MetricRegistry metricRegistry = getMetricRegistry();
return metricRegistry != null ? metricRegistry.counter(metricName) : null;
}

// Retrieves or creates a new Timer for the given metric name
private static Timer getOrCreateTimer(String metricName) {
MetricRegistry metricRegistry = getMetricRegistry();
return metricRegistry != null ? metricRegistry.timer(metricName) : null;
}

// Retrieves the MetricRegistry from the current Spark environment.
private static MetricRegistry getMetricRegistry() {
SparkEnv sparkEnv = SparkEnv.get();
if (sparkEnv == null) {
LOG.warning("Spark environment not available, cannot instrument metric: " + metricName);
LOG.warning("Spark environment not available, cannot access MetricRegistry.");
return null;
}

FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv);
Timer timer = flintMetricSource.metricRegistry().getTimers().get(metricName);
if (timer == null) {
timer = flintMetricSource.metricRegistry().timer(metricName);
}
return timer;
return flintMetricSource.metricRegistry();
}

// Gets or initializes the FlintMetricSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import java.util.Map;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.lang.StringUtils;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.apache.spark.SparkEnv;

/**
* Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint.
Expand All @@ -18,7 +22,9 @@
* application ID, and more.
*/
public class DimensionUtils {
private static final Logger LOG = Logger.getLogger(DimensionUtils.class.getName());
private static final String DIMENSION_JOB_ID = "jobId";
private static final String DIMENSION_JOB_TYPE = "jobType";
private static final String DIMENSION_APPLICATION_ID = "applicationId";
private static final String DIMENSION_APPLICATION_NAME = "applicationName";
private static final String DIMENSION_DOMAIN_ID = "domainId";
Expand All @@ -29,6 +35,8 @@ public class DimensionUtils {
private static final Map<String, Function<String[], Dimension>> dimensionBuilders = Map.of(
DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension,
DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID),
// TODO: Move FlintSparkConf into the core to prevent circular dependencies
DIMENSION_JOB_TYPE, ignored -> constructDimensionFromSparkConf(DIMENSION_JOB_TYPE, "spark.flint.job.type", UNKNOWN),
DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID),
DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME),
DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID)
Expand All @@ -39,7 +47,7 @@ public class DimensionUtils {
* builder exists for the dimension name, it is used; otherwise, a default dimension is constructed.
*
* @param dimensionName The name of the dimension to construct.
* @param parts Additional information that might be required by specific dimension builders.
* @param metricNameParts Additional information that might be required by specific dimension builders.
* @return A CloudWatch Dimension object.
*/
public static Dimension constructDimension(String dimensionName, String[] metricNameParts) {
Expand All @@ -50,6 +58,30 @@ public static Dimension constructDimension(String dimensionName, String[] metric
.apply(metricNameParts);
}

/**
* Constructs a CloudWatch Dimension object using a specified Spark configuration key.
*
* @param dimensionName The name of the dimension to construct.
* @param sparkConfKey the Spark configuration key used to look up the value for the dimension.
* @param defaultValue the default value to use for the dimension if the Spark configuration key is not found or if the Spark environment is not available.
* @return A CloudWatch Dimension object.
* @throws Exception if an error occurs while accessing the Spark configuration. The exception is logged and then rethrown.
*/
public static Dimension constructDimensionFromSparkConf(String dimensionName, String sparkConfKey, String defaultValue) {
String propertyValue = defaultValue;
try {
if (SparkEnv.get() != null && SparkEnv.get().conf() != null) {
propertyValue = SparkEnv.get().conf().get(sparkConfKey, defaultValue);
} else {
LOG.warning("Spark environment or configuration is not available, defaulting to provided default value.");
}
} catch (Exception e) {
LOG.log(Level.SEVERE, "Error accessing Spark configuration with key: " + sparkConfKey + ", defaulting to provided default value.", e);
throw e;
}
return new Dimension().withName(dimensionName).withValue(propertyValue);
}

// This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137
// Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName.
public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.opensearch.flint.core.metrics;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.Timer;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
Expand All @@ -10,6 +11,7 @@
import org.mockito.Mockito;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -42,7 +44,7 @@ public void testIncrementDecrementCounter() {

// Verify interactions
verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(4)).metricRegistry();
verify(flintMetricSource, times(3)).metricRegistry();
Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric);
Assertions.assertNotNull(counter);
Assertions.assertEquals(counter.getCount(), 1);
Expand All @@ -69,7 +71,7 @@ public void testStartStopTimer() {

// Verify interactions
verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(2)).metricRegistry();
verify(flintMetricSource, times(1)).metricRegistry();
Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric);
Assertions.assertNotNull(timer);
Assertions.assertEquals(timer.getCount(), 1L);
Expand All @@ -78,4 +80,35 @@ public void testStartStopTimer() {
throw new RuntimeException(e);
}
}

@Test
public void testRegisterGaugeWhenMetricRegistryIsAvailable() {
try (MockedStatic<SparkEnv> sparkEnvMock = mockStatic(SparkEnv.class)) {
// Mock SparkEnv
SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS);
sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv);

// Mock FlintMetricSource
FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource());
when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head())
.thenReturn(flintMetricSource);

// Setup gauge
AtomicInteger testValue = new AtomicInteger(1);
String gaugeName = "test.gauge";
MetricsUtil.registerGauge(gaugeName, testValue);

verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(1)).metricRegistry();

Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName);
Assertions.assertNotNull(gauge);
Assertions.assertEquals(gauge.getValue(), 1);

testValue.incrementAndGet();
testValue.incrementAndGet();
testValue.decrementAndGet();
Assertions.assertEquals(gauge.getValue(), 2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

package org.opensearch.flint.core.metrics.reporter;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.junit.jupiter.api.function.Executable;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import java.lang.reflect.Field;
import java.util.Map;
Expand Down Expand Up @@ -71,6 +80,23 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega
writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID");
writeableEnvironmentVariables.remove("TEST_VAR");
}
}

@Test
public void testConstructDimensionFromSparkConfWithAvailableConfig() {
try (MockedStatic<SparkEnv> sparkEnvMock = mockStatic(SparkEnv.class)) {
SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS);
sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv);
SparkConf sparkConf = new SparkConf().set("spark.test.key", "testValue");
when(sparkEnv.get().conf()).thenReturn(sparkConf);

Dimension result = DimensionUtils.constructDimensionFromSparkConf("testDimension", "spark.test.key", "defaultValue");
// Assertions
assertEquals("testDimension", result.getName());
assertEquals("testValue", result.getValue());

// Reset SparkEnv mock to not affect other tests
Mockito.reset(SparkEnv.get());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import scala.sys.addShutdownHook

import org.opensearch.flint.core.FlintClient
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING}
import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -65,6 +66,7 @@ class FlintSparkIndexMonitor(
} catch {
case e: Throwable =>
logError("Failed to update index log entry", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)
}
},
15, // Delay to ensure final logging is complete first, otherwise version conflicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import play.api.libs.json._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}

/**
Expand All @@ -41,7 +42,10 @@ object FlintJob extends Logging with FlintJobExecutor {
val Array(query, resultIndex) = args

val conf = createSparkConf()
val wait = conf.get("spark.flint.job.type", "continue")
val jobType = conf.get("spark.flint.job.type", "batch")
logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
Expand All @@ -58,7 +62,7 @@ object FlintJob extends Logging with FlintJobExecutor {
query,
dataSource,
resultIndex,
wait.equalsIgnoreCase("streaming"))
jobType.equalsIgnoreCase("streaming"))
jobOperator.start()
}
}
Loading

0 comments on commit dd1a0d6

Please sign in to comment.