Skip to content

Commit

Permalink
Refactor and add metrics for streaming job
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Feb 18, 2024
1 parent 9c34a1d commit 01ecd3a
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/**
* This class defines custom metric constants used for monitoring flint operations.
*/
public class MetricConstants {
public final class MetricConstants {

/**
* The prefix for all read-related metrics in OpenSearch.
Expand Down Expand Up @@ -65,5 +65,29 @@ public class MetricConstants {
/**
* Metric name for tracking the processing time of statements.
*/
public static final String STATEMENT_PROCESSING_TIME_METRIC = "STATEMENT.processingTime";
public static final String STATEMENT_PROCESSING_TIME_METRIC = "statement.processingTime";

/**
* Metric for tracking the count of currently running streaming jobs.
*/
public static final String STREAMING_RUNNING_METRIC = "streaming.running.count";

/**
* Metric for tracking the count of streaming jobs that have failed.
*/
public static final String STREAMING_FAILED_METRIC = "streaming.failed.count";

/**
* Metric for tracking the count of streaming jobs that have completed successfully.
*/
public static final String STREAMING_SUCCESS_METRIC = "streaming.success.count";

/**
* Metric for tracking the count of failed heartbeat signals in streaming jobs.
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

private MetricConstants() {
// Private constructor to prevent instantiation
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
package org.opensearch.flint.core.metrics;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.apache.spark.metrics.source.Source;
import scala.collection.Seq;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;

/**
Expand All @@ -21,8 +24,8 @@ public final class MetricsUtil {

private static final Logger LOG = Logger.getLogger(MetricsUtil.class.getName());

// Private constructor to prevent instantiation
private MetricsUtil() {
// Private constructor to prevent instantiation
}

/**
Expand Down Expand Up @@ -60,10 +63,7 @@ public static void decrementCounter(String metricName) {
*/
public static Timer.Context getTimerContext(String metricName) {
Timer timer = getOrCreateTimer(metricName);
if (timer != null) {
return timer.time();
}
return null;
return timer != null ? timer.time() : null;
}

/**
Expand All @@ -74,42 +74,47 @@ public static Timer.Context getTimerContext(String metricName) {
* @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}.
*/
public static Long stopTimer(Timer.Context context) {
if (context != null) {
return context.stop();
return context != null ? context.stop() : null;
}

/**
* Registers a gauge metric with the provided name and value.
* The gauge will reflect the current value of the AtomicInteger provided.
*
* @param metricName The name of the gauge metric to register.
* @param value The AtomicInteger whose current value should be reflected by the gauge.
*/
public static void registerGauge(String metricName, final AtomicInteger value) {
MetricRegistry metricRegistry = getMetricRegistry();
if (metricRegistry == null) {
LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName);
return;
}
return null;
metricRegistry.register(metricName, (Gauge<Integer>) value::get);
}

// Retrieves or creates a new counter for the given metric name
private static Counter getOrCreateCounter(String metricName) {
SparkEnv sparkEnv = SparkEnv.get();
if (sparkEnv == null) {
LOG.warning("Spark environment not available, cannot instrument metric: " + metricName);
return null;
}

FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv);
Counter counter = flintMetricSource.metricRegistry().getCounters().get(metricName);
if (counter == null) {
counter = flintMetricSource.metricRegistry().counter(metricName);
}
return counter;
MetricRegistry metricRegistry = getMetricRegistry();
return metricRegistry != null ? metricRegistry.counter(metricName) : null;
}

// Retrieves or creates a new Timer for the given metric name
private static Timer getOrCreateTimer(String metricName) {
MetricRegistry metricRegistry = getMetricRegistry();
return metricRegistry != null ? metricRegistry.timer(metricName) : null;
}

// Retrieves the MetricRegistry from the current Spark environment.
private static MetricRegistry getMetricRegistry() {
SparkEnv sparkEnv = SparkEnv.get();
if (sparkEnv == null) {
LOG.warning("Spark environment not available, cannot instrument metric: " + metricName);
LOG.warning("Spark environment not available, cannot access MetricRegistry.");
return null;
}

FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv);
Timer timer = flintMetricSource.metricRegistry().getTimers().get(metricName);
if (timer == null) {
timer = flintMetricSource.metricRegistry().timer(metricName);
}
return timer;
return flintMetricSource.metricRegistry();
}

// Gets or initializes the FlintMetricSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import java.util.Map;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.lang.StringUtils;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.apache.spark.SparkEnv;

/**
* Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint.
Expand All @@ -18,7 +22,9 @@
* application ID, and more.
*/
public class DimensionUtils {
private static final Logger LOG = Logger.getLogger(DimensionUtils.class.getName());
private static final String DIMENSION_JOB_ID = "jobId";
private static final String DIMENSION_JOB_TYPE = "jobType";
private static final String DIMENSION_APPLICATION_ID = "applicationId";
private static final String DIMENSION_APPLICATION_NAME = "applicationName";
private static final String DIMENSION_DOMAIN_ID = "domainId";
Expand All @@ -29,6 +35,8 @@ public class DimensionUtils {
private static final Map<String, Function<String[], Dimension>> dimensionBuilders = Map.of(
DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension,
DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID),
// TODO: Move FlintSparkConf into the core to prevent circular dependencies
DIMENSION_JOB_TYPE, ignored -> constructDimensionFromSparkConf(DIMENSION_JOB_TYPE, "spark.flint.job.type", UNKNOWN),
DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID),
DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME),
DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID)
Expand All @@ -39,7 +47,7 @@ public class DimensionUtils {
* builder exists for the dimension name, it is used; otherwise, a default dimension is constructed.
*
* @param dimensionName The name of the dimension to construct.
* @param parts Additional information that might be required by specific dimension builders.
* @param metricNameParts Additional information that might be required by specific dimension builders.
* @return A CloudWatch Dimension object.
*/
public static Dimension constructDimension(String dimensionName, String[] metricNameParts) {
Expand All @@ -50,6 +58,30 @@ public static Dimension constructDimension(String dimensionName, String[] metric
.apply(metricNameParts);
}

/**
* Constructs a CloudWatch Dimension object using a specified Spark configuration key.
*
* @param dimensionName The name of the dimension to construct.
* @param sparkConfKey the Spark configuration key used to look up the value for the dimension.
* @param defaultValue the default value to use for the dimension if the Spark configuration key is not found or if the Spark environment is not available.
* @return A CloudWatch Dimension object.
* @throws Exception if an error occurs while accessing the Spark configuration. The exception is logged and then rethrown.
*/
public static Dimension constructDimensionFromSparkConf(String dimensionName, String sparkConfKey, String defaultValue) {
String propertyValue = defaultValue;
try {
if (SparkEnv.get() != null && SparkEnv.get().conf() != null) {
propertyValue = SparkEnv.get().conf().get(sparkConfKey, defaultValue);
} else {
LOG.warning("Spark environment or configuration is not available, defaulting to provided default value.");
}
} catch (Exception e) {
LOG.log(Level.SEVERE, "Error accessing Spark configuration with key: " + sparkConfKey + ", defaulting to provided default value.", e);
throw e;
}
return new Dimension().withName(dimensionName).withValue(propertyValue);
}

// This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137
// Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName.
public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.opensearch.flint.core.metrics;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.Timer;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
Expand All @@ -10,6 +11,7 @@
import org.mockito.Mockito;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -42,7 +44,7 @@ public void testIncrementDecrementCounter() {

// Verify interactions
verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(4)).metricRegistry();
verify(flintMetricSource, times(3)).metricRegistry();
Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric);
Assertions.assertNotNull(counter);
Assertions.assertEquals(counter.getCount(), 1);
Expand All @@ -69,7 +71,7 @@ public void testStartStopTimer() {

// Verify interactions
verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(2)).metricRegistry();
verify(flintMetricSource, times(1)).metricRegistry();
Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric);
Assertions.assertNotNull(timer);
Assertions.assertEquals(timer.getCount(), 1L);
Expand All @@ -78,4 +80,35 @@ public void testStartStopTimer() {
throw new RuntimeException(e);
}
}

@Test
public void testRegisterGaugeWhenMetricRegistryIsAvailable() {
try (MockedStatic<SparkEnv> sparkEnvMock = mockStatic(SparkEnv.class)) {
// Mock SparkEnv
SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS);
sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv);

// Mock FlintMetricSource
FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource());
when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head())
.thenReturn(flintMetricSource);

// Setup gauge
AtomicInteger testValue = new AtomicInteger(1);
String gaugeName = "test.gauge";
MetricsUtil.registerGauge(gaugeName, testValue);

verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(flintMetricSource, times(1)).metricRegistry();

Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName);
Assertions.assertNotNull(gauge);
Assertions.assertEquals(gauge.getValue(), 1);

testValue.incrementAndGet();
testValue.incrementAndGet();
testValue.decrementAndGet();
Assertions.assertEquals(gauge.getValue(), 2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

package org.opensearch.flint.core.metrics.reporter;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.junit.jupiter.api.function.Executable;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import java.lang.reflect.Field;
import java.util.Map;
Expand Down Expand Up @@ -71,6 +80,23 @@ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, Illega
writeableEnvironmentVariables.remove("SERVERLESS_EMR_JOB_ID");
writeableEnvironmentVariables.remove("TEST_VAR");
}
}

@Test
public void testConstructDimensionFromSparkConfWithAvailableConfig() {
try (MockedStatic<SparkEnv> sparkEnvMock = mockStatic(SparkEnv.class)) {
SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS);
sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv);
SparkConf sparkConf = new SparkConf().set("spark.test.key", "testValue");
when(sparkEnv.get().conf()).thenReturn(sparkConf);

Dimension result = DimensionUtils.constructDimensionFromSparkConf("testDimension", "spark.test.key", "defaultValue");
// Assertions
assertEquals("testDimension", result.getName());
assertEquals("testValue", result.getValue());

// Reset SparkEnv mock to not affect other tests
Mockito.reset(SparkEnv.get());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import scala.sys.addShutdownHook

import org.opensearch.flint.core.FlintClient
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING}
import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -65,6 +66,7 @@ class FlintSparkIndexMonitor(
} catch {
case e: Throwable =>
logError("Failed to update index log entry", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)
}
},
15, // Delay to ensure final logging is complete first, otherwise version conflicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{Duration, MINUTES}
Expand Down Expand Up @@ -67,10 +69,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val prefix = "flint-job-test"
val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
val streamingRunningCount = new AtomicInteger(0)

val futureResult = Future {
val job =
JobOperator(spark, query, dataSourceName, resultIndex, true)
JobOperator(spark, query, dataSourceName, resultIndex, true, streamingRunningCount)
job.envinromentProvider = new MockEnvironment(
Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId))

Expand Down
Loading

0 comments on commit 01ecd3a

Please sign in to comment.