diff --git a/build.sbt b/build.sbt index f7653c50c..30858e8d6 100644 --- a/build.sbt +++ b/build.sbt @@ -89,6 +89,7 @@ lazy val flintCore = (project in file("flint-core")) "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), "software.amazon.awssdk" % "auth-crt" % "2.28.10", + "org.projectlombok" % "lombok" % "1.18.30" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java new file mode 100644 index 000000000..181bf8575 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/HistoricGauge.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import com.codahale.metrics.Gauge; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Value; + +/** + * Gauge which stores historic data points with timestamps. + * This is used for emitting separate data points per request, instead of single aggregated metrics. + */ +public class HistoricGauge implements Gauge { + @AllArgsConstructor + @Value + public static class DataPoint { + Long value; + long timestamp; + } + + private final List dataPoints = Collections.synchronizedList(new LinkedList<>()); + + /** + * This method will just return first value. + * @return first value + */ + @Override + public Long getValue() { + if (!dataPoints.isEmpty()) { + return dataPoints.get(0).value; + } else { + return null; + } + } + + /** + * Add new data point. Current time stamp will be attached to the data point. + * @param value metric value + */ + public void addDataPoint(Long value) { + dataPoints.add(new DataPoint(value, System.currentTimeMillis())); + } + + /** + * Return copy of dataPoints and remove them from internal list + * @return copy of the data points + */ + public List pollDataPoints() { + int size = dataPoints.size(); + List result = new ArrayList<>(dataPoints.subList(0, size)); + if (size > 0) { + dataPoints.subList(0, size).clear(); + } + return result; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 3a72c1d5a..427fab9fe 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -117,6 +117,26 @@ public final class MetricConstants { */ public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + /** + * Metric for tracking the total bytes read from input + */ + public static final String INPUT_TOTAL_BYTES_READ = "input.totalBytesRead.count"; + + /** + * Metric for tracking the total records read from input + */ + public static final String INPUT_TOTAL_RECORDS_READ = "input.totalRecordsRead.count"; + + /** + * Metric for tracking the total bytes written to output + */ + public static final String OUTPUT_TOTAL_BYTES_WRITTEN = "output.totalBytesWritten.count"; + + /** + * Metric for tracking the total records written to output + */ + public static final String OUTPUT_TOTAL_RECORDS_WRITTEN = "output.totalRecordsWritten.count"; + private MetricConstants() { // Private constructor to prevent instantiation } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index ab1207ccc..511c18664 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -75,6 +75,15 @@ public static void decrementCounter(String metricName, boolean isIndexMetric) { } } + public static void setCounter(String metricName, boolean isIndexMetric, long n) { + Counter counter = getOrCreateCounter(metricName, isIndexMetric); + if (counter != null) { + counter.dec(counter.getCount()); + counter.inc(n); + LOG.info("counter: " + counter.getCount()); + } + } + /** * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. * @@ -111,6 +120,24 @@ public static Timer getTimer(String metricName, boolean isIndexMetric) { return getOrCreateTimer(metricName, isIndexMetric); } + /** + * Registers a HistoricGauge metric with the provided name and value. + * + * @param metricName The name of the HistoricGauge metric to register. + * @param value The value to be stored + */ + public static void addHistoricGauge(String metricName, final long value) { + HistoricGauge historicGauge = getOrCreateHistoricGauge(metricName); + if (historicGauge != null) { + historicGauge.addDataPoint(value); + } + } + + private static HistoricGauge getOrCreateHistoricGauge(String metricName) { + MetricRegistry metricRegistry = getMetricRegistry(false); + return metricRegistry != null ? metricRegistry.gauge(metricName, HistoricGauge::new) : null; + } + /** * Registers a gauge metric with the provided name and value. * diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java index a5ea190c5..9104e1b34 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -47,6 +47,7 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups; +import org.opensearch.flint.core.metrics.HistoricGauge; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -145,7 +146,11 @@ public void report(final SortedMap gauges, gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size()); for (final Map.Entry gaugeEntry : gauges.entrySet()) { - processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + if (gaugeEntry.getValue() instanceof HistoricGauge) { + processHistoricGauge(gaugeEntry.getKey(), (HistoricGauge) gaugeEntry.getValue(), metricData); + } else { + processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + } } for (final Map.Entry counterEntry : counters.entrySet()) { @@ -227,6 +232,13 @@ private void processGauge(final String metricName, final Gauge gauge, final List } } + private void processHistoricGauge(final String metricName, final HistoricGauge gauge, final List metricData) { + for (HistoricGauge.DataPoint dataPoint: gauge.pollDataPoints()) { + stageMetricDatum(true, metricName, dataPoint.getValue().doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData, + dataPoint.getTimestamp()); + } + } + private void processCounter(final String metricName, final Counting counter, final List metricData) { long currentCount = counter.getCount(); Long lastCount = lastPolledCounts.get(counter); @@ -333,12 +345,25 @@ private void processHistogram(final String metricName, final Histogram histogram *

* If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted */ + private void stageMetricDatum(final boolean metricConfigured, + final String metricName, + final double metricValue, + final StandardUnit standardUnit, + final String dimensionValue, + final List metricData + ) { + stageMetricDatum(metricConfigured, metricName, metricValue, standardUnit, + dimensionValue, metricData, builder.clock.getTime()); + } + private void stageMetricDatum(final boolean metricConfigured, final String metricName, final double metricValue, final StandardUnit standardUnit, final String dimensionValue, - final List metricData) { + final List metricData, + final Long timestamp + ) { // Only submit metrics that show some data, so let's save some money if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { final DimensionedName dimensionedName = DimensionedName.decode(metricName); @@ -351,7 +376,7 @@ private void stageMetricDatum(final boolean metricConfigured, MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions); for (Set dimensionSet : metricInfo.getDimensionSets()) { MetricDatum datum = new MetricDatum() - .withTimestamp(new Date(builder.clock.getTime())) + .withTimestamp(new Date(timestamp)) .withValue(cleanMetricValue(metricValue)) .withMetricName(metricInfo.getMetricName()) .withDimensions(dimensionSet) diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala new file mode 100644 index 000000000..bfafd3eb3 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metrics/ReadWriteBytesSparkListener.scala @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.SparkSession + +/** + * Collect and emit bytesRead/Written and recordsRead/Written metrics + */ +class ReadWriteBytesSparkListener extends SparkListener with Logging { + var bytesRead: Long = 0 + var recordsRead: Long = 0 + var bytesWritten: Long = 0 + var recordsWritten: Long = 0 + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val inputMetrics = taskEnd.taskMetrics.inputMetrics + val outputMetrics = taskEnd.taskMetrics.outputMetrics + val ids = s"(${taskEnd.taskInfo.taskId}, ${taskEnd.taskInfo.partitionId})" + logInfo( + s"${ids} Input: bytesRead=${inputMetrics.bytesRead}, recordsRead=${inputMetrics.recordsRead}") + logInfo( + s"${ids} Output: bytesWritten=${outputMetrics.bytesWritten}, recordsWritten=${outputMetrics.recordsWritten}") + + bytesRead += inputMetrics.bytesRead + recordsRead += inputMetrics.recordsRead + bytesWritten += outputMetrics.bytesWritten + recordsWritten += outputMetrics.recordsWritten + } + + def emitMetrics(): Unit = { + logInfo(s"Input: totalBytesRead=${bytesRead}, totalRecordsRead=${recordsRead}") + logInfo(s"Output: totalBytesWritten=${bytesWritten}, totalRecordsWritten=${recordsWritten}") + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_BYTES_READ, bytesRead) + MetricsUtil.addHistoricGauge(MetricConstants.INPUT_TOTAL_RECORDS_READ, recordsRead) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_BYTES_WRITTEN, bytesWritten) + MetricsUtil.addHistoricGauge(MetricConstants.OUTPUT_TOTAL_RECORDS_WRITTEN, recordsWritten) + } +} + +object ReadWriteBytesSparkListener { + def withMetrics[T](spark: SparkSession, lambda: () => T): T = { + val listener = new ReadWriteBytesSparkListener() + spark.sparkContext.addSparkListener(listener) + + val result = lambda() + + spark.sparkContext.removeSparkListener(listener) + listener.emitMetrics() + + result + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java new file mode 100644 index 000000000..f3d842af2 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/HistoricGaugeTest.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics; + +import org.junit.Test; +import static org.junit.Assert.*; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; + +import java.util.List; + +public class HistoricGaugeTest { + + @Test + public void testGetValue_EmptyGauge_ShouldReturnNull() { + HistoricGauge gauge= new HistoricGauge(); + assertNull(gauge.getValue()); + } + + @Test + public void testGetValue_WithSingleDataPoint_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + assertEquals(value, gauge.getValue()); + } + + @Test + public void testGetValue_WithMultipleDataPoints_ShouldReturnFirstValue() { + HistoricGauge gauge= new HistoricGauge(); + Long firstValue = 100L; + Long secondValue = 200L; + gauge.addDataPoint(firstValue); + gauge.addDataPoint(secondValue); + + assertEquals(firstValue, gauge.getValue()); + } + + @Test + public void testPollDataPoints_WithMultipleDataPoints_ShouldReturnAndClearDataPoints() { + HistoricGauge gauge= new HistoricGauge(); + gauge.addDataPoint(100L); + gauge.addDataPoint(200L); + gauge.addDataPoint(300L); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(3, dataPoints.size()); + assertEquals(Long.valueOf(100L), dataPoints.get(0).getValue()); + assertEquals(Long.valueOf(200L), dataPoints.get(1).getValue()); + assertEquals(Long.valueOf(300L), dataPoints.get(2).getValue()); + + assertTrue(gauge.pollDataPoints().isEmpty()); + } + + @Test + public void testAddDataPoint_ShouldAddDataPointWithCorrectValueAndTimestamp() { + HistoricGauge gauge= new HistoricGauge(); + Long value = 100L; + gauge.addDataPoint(value); + + List dataPoints = gauge.pollDataPoints(); + + assertEquals(1, dataPoints.size()); + assertEquals(value, dataPoints.get(0).getValue()); + assertTrue(dataPoints.get(0).getTimestamp() > 0); + } + + @Test + public void testPollDataPoints_EmptyGauge_ShouldReturnEmptyList() { + HistoricGauge gauge= new HistoricGauge(); + List dataPoints = gauge.pollDataPoints(); + + assertTrue(dataPoints.isEmpty()); + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index b5470b6be..70b51ed63 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -4,7 +4,7 @@ import com.codahale.metrics.Gauge; import com.codahale.metrics.Timer; import java.time.Duration; -import java.time.temporal.TemporalUnit; +import java.util.List; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; import org.apache.spark.metrics.source.FlintIndexMetricSource; @@ -16,6 +16,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.opensearch.flint.core.metrics.HistoricGauge.DataPoint; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -199,4 +200,31 @@ public void testDefaultBehavior() { Assertions.assertNotNull(flintMetricSource.metricRegistry().getGauges().get(testGaugeMetric)); } } + + @Test + public void testAddHistoricGauge() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + String sourceName = FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); + + long value1 = 100L; + long value2 = 200L; + String gaugeName = "test.gauge"; + MetricsUtil.addHistoricGauge(gaugeName, value1); + MetricsUtil.addHistoricGauge(gaugeName, value2); + + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(metricSource, times(2)).metricRegistry(); + + HistoricGauge gauge = (HistoricGauge)metricSource.metricRegistry().getGauges().get(gaugeName); + Assertions.assertNotNull(gauge); + List dataPoints = gauge.pollDataPoints(); + Assertions.assertEquals(value1, dataPoints.get(0).getValue()); + Assertions.assertEquals(value2, dataPoints.get(1).getValue()); + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala index d343fd999..bedeeba54 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.refresh import java.util.Collections +import org.opensearch.flint.core.metrics.ReadWriteBytesSparkListener import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper} import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode} @@ -67,15 +68,17 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) // Flint index has specialized logic and capability for incremental refresh case refresh: StreamingRefresh => logInfo("Start refreshing index in streaming style") - val job = - refresh - .buildStream(spark) - .writeStream - .queryName(indexName) - .format(FLINT_DATASOURCE) - .options(flintSparkConf.properties) - .addSinkOptions(options, flintSparkConf) - .start(indexName) + val job = ReadWriteBytesSparkListener.withMetrics( + spark, + () => + refresh + .buildStream(spark) + .writeStream + .queryName(indexName) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .addSinkOptions(options, flintSparkConf) + .start(indexName)) Some(job.id.toString) // Otherwise, fall back to foreachBatch + batch refresh diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index cdeebe663..0978e6898 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -17,7 +17,7 @@ import com.codahale.metrics.Timer import org.opensearch.flint.common.model.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging -import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.{MetricConstants, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.apache.spark.SparkConf @@ -525,12 +525,16 @@ object FlintREPL extends Logging with FlintJobExecutor { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) val (dataToWrite, returnedVerificationResult) = - processStatementOnVerification( - statementExecutionManager, - queryResultWriter, - flintStatement, - state, - context) + ReadWriteBytesSparkListener.withMetrics( + spark, + () => { + processStatementOnVerification( + statementExecutionManager, + queryResultWriter, + flintStatement, + state, + context) + }) verificationResult = returnedVerificationResult finalizeCommand( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 01d8cb05c..6cdbdb16d 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,7 +14,7 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.common.scheduler.model.LangType -import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} +import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil, ReadWriteBytesSparkListener} import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark @@ -70,6 +70,9 @@ case class JobOperator( val statementExecutionManager = instantiateStatementExecutionManager(commandContext, resultIndex, osClient) + val readWriteBytesSparkListener = new ReadWriteBytesSparkListener() + sparkSession.sparkContext.addSparkListener(readWriteBytesSparkListener) + val statement = new FlintStatement( "running", @@ -137,6 +140,8 @@ case class JobOperator( startTime)) } finally { emitQueryExecutionTimeMetric(startTime) + readWriteBytesSparkListener.emitMetrics() + sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) @@ -202,8 +207,9 @@ case class JobOperator( private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { MetricsUtil - .getTimer(MetricConstants.QUERY_EXECUTION_TIME_METRIC, false) - .update(System.currentTimeMillis() - startTime, TimeUnit.MILLISECONDS); + .addHistoricGauge( + MetricConstants.QUERY_EXECUTION_TIME_METRIC, + System.currentTimeMillis() - startTime) } def stop(): Unit = {