Skip to content

Commit

Permalink
Add read/write bytes metrics (opensearch-project#803)
Browse files Browse the repository at this point in the history
* Add read/write bytes metrics

Signed-off-by: Tomoyuki Morita <[email protected]>

* Add unit test

Signed-off-by: Tomoyuki Morita <[email protected]>

* Address comments

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored and 14yapkc1 committed Dec 11, 2024
1 parent 6c9aa09 commit d25c8a8
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 23 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ lazy val flintCore = (project in file("flint-core"))
"com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"software.amazon.awssdk" % "auth-crt" % "2.28.10",
"org.projectlombok" % "lombok" % "1.18.30" % "provided",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.metrics;

import com.codahale.metrics.Gauge;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Value;

/**
* Gauge which stores historic data points with timestamps.
* This is used for emitting separate data points per request, instead of single aggregated metrics.
*/
public class HistoricGauge implements Gauge<Long> {
@AllArgsConstructor
@Value
public static class DataPoint {
Long value;
long timestamp;
}

private final List<DataPoint> dataPoints = Collections.synchronizedList(new LinkedList<>());

/**
* This method will just return first value.
* @return first value
*/
@Override
public Long getValue() {
if (!dataPoints.isEmpty()) {
return dataPoints.get(0).value;
} else {
return null;
}
}

/**
* Add new data point. Current time stamp will be attached to the data point.
* @param value metric value
*/
public void addDataPoint(Long value) {
dataPoints.add(new DataPoint(value, System.currentTimeMillis()));
}

/**
* Return copy of dataPoints and remove them from internal list
* @return copy of the data points
*/
public List<DataPoint> pollDataPoints() {
int size = dataPoints.size();
List<DataPoint> result = new ArrayList<>(dataPoints.subList(0, size));
if (size > 0) {
dataPoints.subList(0, size).clear();
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ public final class MetricConstants {
*/
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";

/**
* Metric for tracking the total bytes read from input
*/
public static final String INPUT_TOTAL_BYTES_READ = "input.totalBytesRead.count";

/**
* Metric for tracking the total records read from input
*/
public static final String INPUT_TOTAL_RECORDS_READ = "input.totalRecordsRead.count";

/**
* Metric for tracking the total bytes written to output
*/
public static final String OUTPUT_TOTAL_BYTES_WRITTEN = "output.totalBytesWritten.count";

/**
* Metric for tracking the total records written to output
*/
public static final String OUTPUT_TOTAL_RECORDS_WRITTEN = "output.totalRecordsWritten.count";

private MetricConstants() {
// Private constructor to prevent instantiation
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ public static void decrementCounter(String metricName, boolean isIndexMetric) {
}
}

public static void setCounter(String metricName, boolean isIndexMetric, long n) {
Counter counter = getOrCreateCounter(metricName, isIndexMetric);
if (counter != null) {
counter.dec(counter.getCount());
counter.inc(n);
LOG.info("counter: " + counter.getCount());
}
}

/**
* Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist.
*
Expand Down Expand Up @@ -111,6 +120,24 @@ public static Timer getTimer(String metricName, boolean isIndexMetric) {
return getOrCreateTimer(metricName, isIndexMetric);
}

/**
* Registers a HistoricGauge metric with the provided name and value.
*
* @param metricName The name of the HistoricGauge metric to register.
* @param value The value to be stored
*/
public static void addHistoricGauge(String metricName, final long value) {
HistoricGauge historicGauge = getOrCreateHistoricGauge(metricName);
if (historicGauge != null) {
historicGauge.addDataPoint(value);
}
}

private static HistoricGauge getOrCreateHistoricGauge(String metricName) {
MetricRegistry metricRegistry = getMetricRegistry(false);
return metricRegistry != null ? metricRegistry.gauge(metricName, HistoricGauge::new) : null;
}

/**
* Registers a gauge metric with the provided name and value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups;
import org.opensearch.flint.core.metrics.HistoricGauge;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -145,7 +146,11 @@ public void report(final SortedMap<String, Gauge> gauges,
gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size());

for (final Map.Entry<String, Gauge> gaugeEntry : gauges.entrySet()) {
processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData);
if (gaugeEntry.getValue() instanceof HistoricGauge) {
processHistoricGauge(gaugeEntry.getKey(), (HistoricGauge) gaugeEntry.getValue(), metricData);
} else {
processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData);
}
}

for (final Map.Entry<String, Counter> counterEntry : counters.entrySet()) {
Expand Down Expand Up @@ -227,6 +232,13 @@ private void processGauge(final String metricName, final Gauge gauge, final List
}
}

private void processHistoricGauge(final String metricName, final HistoricGauge gauge, final List<MetricDatum> metricData) {
for (HistoricGauge.DataPoint dataPoint: gauge.pollDataPoints()) {
stageMetricDatum(true, metricName, dataPoint.getValue().doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData,
dataPoint.getTimestamp());
}
}

private void processCounter(final String metricName, final Counting counter, final List<MetricDatum> metricData) {
long currentCount = counter.getCount();
Long lastCount = lastPolledCounts.get(counter);
Expand Down Expand Up @@ -333,12 +345,25 @@ private void processHistogram(final String metricName, final Histogram histogram
* <p>
* If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted
*/
private void stageMetricDatum(final boolean metricConfigured,
final String metricName,
final double metricValue,
final StandardUnit standardUnit,
final String dimensionValue,
final List<MetricDatum> metricData
) {
stageMetricDatum(metricConfigured, metricName, metricValue, standardUnit,
dimensionValue, metricData, builder.clock.getTime());
}

private void stageMetricDatum(final boolean metricConfigured,
final String metricName,
final double metricValue,
final StandardUnit standardUnit,
final String dimensionValue,
final List<MetricDatum> metricData) {
final List<MetricDatum> metricData,
final Long timestamp
) {
// Only submit metrics that show some data, so let's save some money
if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) {
final DimensionedName dimensionedName = DimensionedName.decode(metricName);
Expand All @@ -351,7 +376,7 @@ private void stageMetricDatum(final boolean metricConfigured,
MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions);
for (Set<Dimension> dimensionSet : metricInfo.getDimensionSets()) {
MetricDatum datum = new MetricDatum()
.withTimestamp(new Date(builder.clock.getTime()))
.withTimestamp(new Date(timestamp))
.withValue(cleanMetricValue(metricValue))
.withMetricName(metricInfo.getMetricName())
.withDimensions(dimensionSet)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.metrics

import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.SparkSession

/**
* Collect and emit bytesRead/Written and recordsRead/Written metrics
*/
class ReadWriteBytesSparkListener extends SparkListener with Logging {
var bytesRead: Long = 0
var recordsRead: Long = 0
var bytesWritten: Long = 0
var recordsWritten: Long = 0

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val inputMetrics = taskEnd.taskMetrics.inputMetrics
val outputMetrics = taskEnd.taskMetrics.outputMetrics
val ids = s"(${taskEnd.taskInfo.taskId}, ${taskEnd.taskInfo.partitionId})"
logInfo(
s"${ids} Input: bytesRead=${inputMetrics.bytesRead}, recordsRead=${inputMetrics.recordsRead}")
logInfo(
s"${ids} Output: bytesWritten=${outputMetrics.bytesWritten}, recordsWritten=${outputMetrics.recordsWritten}")

bytesRead += inputMetrics.bytesRead
recordsRead += inputMetrics.recordsRead
bytesWritten += outputMetrics.bytesWritten
recordsWritten += outputMetrics.recordsWritten
}

def emitMetrics(): Unit = {
logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}")
logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}")
MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead)
MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead)
MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten)
MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten)
}
}

object ReadWriteBytesSparkListener {
def withMetrics[T](spark: SparkSession, lambda: () => T): T = {
val listener = new ReadWriteBytesSparkListener()
spark.sparkContext.addSparkListener(listener)

val result = lambda()

spark.sparkContext.removeSparkListener(listener)
listener.emitMetrics()

result
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.metrics;

import org.junit.Test;
import static org.junit.Assert.*;
import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint;

import java.util.List;

public class HistoricGaugeTest {

@Test
public void testGetValue_EmptyGauge_ShouldReturnNull() {
HistoricGauge gauge= new HistoricGauge();
assertNull(gauge.getValue());
}

@Test
public void testGetValue_WithSingleDataPoint_ShouldReturnFirstValue() {
HistoricGauge gauge= new HistoricGauge();
Long value = 100L;
gauge.addDataPoint(value);

assertEquals(value, gauge.getValue());
}

@Test
public void testGetValue_WithMultipleDataPoints_ShouldReturnFirstValue() {
HistoricGauge gauge= new HistoricGauge();
Long firstValue = 100L;
Long secondValue = 200L;
gauge.addDataPoint(firstValue);
gauge.addDataPoint(secondValue);

assertEquals(firstValue, gauge.getValue());
}

@Test
public void testPollDataPoints_WithMultipleDataPoints_ShouldReturnAndClearDataPoints() {
HistoricGauge gauge= new HistoricGauge();
gauge.addDataPoint(100L);
gauge.addDataPoint(200L);
gauge.addDataPoint(300L);

List<DataPoint> dataPoints = gauge.pollDataPoints();

assertEquals(3, dataPoints.size());
assertEquals(Long.valueOf(100L), dataPoints.get(0).getValue());
assertEquals(Long.valueOf(200L), dataPoints.get(1).getValue());
assertEquals(Long.valueOf(300L), dataPoints.get(2).getValue());

assertTrue(gauge.pollDataPoints().isEmpty());
}

@Test
public void testAddDataPoint_ShouldAddDataPointWithCorrectValueAndTimestamp() {
HistoricGauge gauge= new HistoricGauge();
Long value = 100L;
gauge.addDataPoint(value);

List<DataPoint> dataPoints = gauge.pollDataPoints();

assertEquals(1, dataPoints.size());
assertEquals(value, dataPoints.get(0).getValue());
assertTrue(dataPoints.get(0).getTimestamp() > 0);
}

@Test
public void testPollDataPoints_EmptyGauge_ShouldReturnEmptyList() {
HistoricGauge gauge= new HistoricGauge();
List<DataPoint> dataPoints = gauge.pollDataPoints();

assertTrue(dataPoints.isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.codahale.metrics.Gauge;
import com.codahale.metrics.Timer;
import java.time.Duration;
import java.time.temporal.TemporalUnit;
import java.util.List;
import org.apache.spark.SparkEnv;
import org.apache.spark.metrics.source.FlintMetricSource;
import org.apache.spark.metrics.source.FlintIndexMetricSource;
Expand All @@ -16,6 +16,7 @@

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -199,4 +200,31 @@ public void testDefaultBehavior() {
Assertions.assertNotNull(flintMetricSource.metricRegistry().getGauges().get(testGaugeMetric));
}
}

@Test
public void testAddHistoricGauge() {
try (MockedStatic<SparkEnv> sparkEnvMock = mockStatic(SparkEnv.class)) {
SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS);
sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv);

String sourceName = FlintMetricSource.FLINT_METRIC_SOURCE_NAME();
Source metricSource = Mockito.spy(new FlintMetricSource());
when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource);

long value1 = 100L;
long value2 = 200L;
String gaugeName = "test.gauge";
MetricsUtil.addHistoricGauge(gaugeName, value1);
MetricsUtil.addHistoricGauge(gaugeName, value2);

verify(sparkEnv.metricsSystem(), times(0)).registerSource(any());
verify(metricSource, times(2)).metricRegistry();

HistoricGauge gauge = (HistoricGauge)metricSource.metricRegistry().getGauges().get(gaugeName);
Assertions.assertNotNull(gauge);
List<DataPoint> dataPoints = gauge.pollDataPoints();
Assertions.assertEquals(value1, dataPoints.get(0).getValue());
Assertions.assertEquals(value2, dataPoints.get(1).getValue());
}
}
}
Loading

0 comments on commit d25c8a8

Please sign in to comment.