diff --git a/build.sbt b/build.sbt index 5d2113f2b..0798092af 100644 --- a/build.sbt +++ b/build.sbt @@ -64,7 +64,15 @@ lazy val flintCore = (project in file("flint-core")) "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", - "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test"), + "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", + "org.mockito" % "mockito-core" % "2.23.0" % "test", + "org.mockito" % "mockito-junit-jupiter" % "3.12.4" % "test", + "org.junit.jupiter" % "junit-jupiter-api" % "5.9.0" % "test", + "org.junit.jupiter" % "junit-jupiter-engine" % "5.9.0" % "test", + "com.google.truth" % "truth" % "1.1.5" % "test", + "net.aichler" % "jupiter-interface" % "0.11.1" % Test + ), + libraryDependencies ++= deps(sparkVersion), publish / skip := true) lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) diff --git a/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java new file mode 100644 index 000000000..293a05d4a --- /dev/null +++ b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.metrics.sink; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.regions.AwsRegionProvider; +import com.amazonaws.regions.DefaultAwsRegionProviderChain; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsync; +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.ScheduledReporter; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.apache.spark.SecurityManager; +import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; +import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; + +/** + * Implementation of the Spark metrics {@link Sink} interface + * for reporting internal Spark metrics into CloudWatch. Spark's metric system uses DropWizard's + * metric library internally, so this class simply wraps the {@link DimensionedCloudWatchReporter} + * with the constructor and methods mandated for Spark metric Sinks. + * + * @see org.apache.spark.metrics.MetricsSystem + * @see ScheduledReporter + * @author kmccaw + */ +public class CloudWatchSink implements Sink { + + private final ScheduledReporter reporter; + + private final long pollingPeriod; + + private final boolean shouldParseInlineDimensions; + + private final boolean shouldAppendDropwizardTypeDimension; + + private final TimeUnit pollingTimeUnit; + + /** + * Constructor with the signature required by Spark, which loads the class through reflection. + * + * @see org.apache.spark.metrics.MetricsSystem + * @param properties Properties for this sink defined in Spark's "metrics.properties" configuration file. + * @param metricRegistry The DropWizard MetricRegistry used by Sparks {@link org.apache.spark.metrics.MetricsSystem} + * @param securityManager Unused argument; required by the Spark sink constructor signature. + */ + public CloudWatchSink( + final Properties properties, + final MetricRegistry metricRegistry, + final SecurityManager securityManager) { + // First extract properties defined in the Spark metrics configuration + + // Extract the required namespace property. This is used as the namespace + // for all metrics reported to CloudWatch + final Optional namespaceProperty = getProperty(properties, PropertyKeys.NAMESPACE); + if (!namespaceProperty.isPresent()) { + final String message = "CloudWatch Spark metrics sink requires '" + + PropertyKeys.NAMESPACE + "' property."; + throw new InvalidMetricsPropertyException(message); + } + + // Extract the optional AWS credentials. If either of the access or secret keys are + // missing in the properties, fall back to using the credentials of the EC2 instance. + final Optional accessKeyProperty = getProperty(properties, PropertyKeys.AWS_ACCESS_KEY_ID); + final Optional secretKeyProperty = getProperty(properties, PropertyKeys.AWS_SECRET_KEY); + final AWSCredentialsProvider awsCredentialsProvider; + if (accessKeyProperty.isPresent() && secretKeyProperty.isPresent()) { + final AWSCredentials awsCredentials = new BasicAWSCredentials( + accessKeyProperty.get(), + secretKeyProperty.get()); + awsCredentialsProvider = new AWSStaticCredentialsProvider(awsCredentials); + } else { + // If the AWS credentials aren't specified in the properties, fall back to using the + // DefaultAWSCredentialsProviderChain, which looks for credentials in the order + // (1) Environment Variables + // (2) Java System Properties + // (3) Credentials file at ~/.aws/credentials + // (4) AWS_CONTAINER_CREDENTIALS_RELATIVE_URI + // (5) EC2 Instance profile credentials + awsCredentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); + } + + // Extract the AWS region CloudWatch metrics should be reported to. + final Optional regionProperty = getProperty(properties, PropertyKeys.AWS_REGION); + final Regions awsRegion; + if (regionProperty.isPresent()) { + try { + awsRegion = Regions.fromName(regionProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + regionProperty.get(), + PropertyKeys.AWS_REGION); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + final AwsRegionProvider regionProvider = new DefaultAwsRegionProviderChain(); + awsRegion = Regions.fromName(regionProvider.getRegion()); + } + + // Extract the polling period, the interval at which metrics are reported. + final Optional pollingPeriodProperty = getProperty(properties, PropertyKeys.POLLING_PERIOD); + if (pollingPeriodProperty.isPresent()) { + try { + final long parsedPollingPeriod = Long.parseLong(pollingPeriodProperty.get()); + // Confirm that the value of this property is a positive number + if (parsedPollingPeriod <= 0) { + final String message = String.format( + "The value (%s) of the \"%s\" CloudWatchSink metrics property is non-positive.", + pollingPeriodProperty.get(), + PropertyKeys.POLLING_PERIOD); + throw new InvalidMetricsPropertyException(message); + } + pollingPeriod = parsedPollingPeriod; + } catch (NumberFormatException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + pollingPeriodProperty.get(), + PropertyKeys.POLLING_PERIOD); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + pollingPeriod = PropertyDefaults.POLLING_PERIOD; + } + + final Optional pollingTimeUnitProperty = getProperty(properties, PropertyKeys.POLLING_TIME_UNIT); + if (pollingTimeUnitProperty.isPresent()) { + try { + pollingTimeUnit = TimeUnit.valueOf(pollingTimeUnitProperty.get().toUpperCase()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + pollingTimeUnitProperty.get(), + PropertyKeys.POLLING_TIME_UNIT); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + pollingTimeUnit = PropertyDefaults.POLLING_PERIOD_TIME_UNIT; + } + + // Extract the inline dimension parsing setting. + final Optional shouldParseInlineDimensionsProperty = getProperty( + properties, + PropertyKeys.SHOULD_PARSE_INLINE_DIMENSIONS); + if (shouldParseInlineDimensionsProperty.isPresent()) { + try { + shouldParseInlineDimensions = Boolean.parseBoolean(shouldParseInlineDimensionsProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + shouldParseInlineDimensionsProperty.get(), + PropertyKeys.SHOULD_PARSE_INLINE_DIMENSIONS); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + shouldParseInlineDimensions = PropertyDefaults.SHOULD_PARSE_INLINE_DIMENSIONS; + } + + // Extract the setting to append dropwizard metrics types as a dimension + final Optional shouldAppendDropwizardTypeDimensionProperty = getProperty( + properties, + PropertyKeys.SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION); + if (shouldAppendDropwizardTypeDimensionProperty.isPresent()) { + try { + shouldAppendDropwizardTypeDimension = Boolean.parseBoolean(shouldAppendDropwizardTypeDimensionProperty.get()); + } catch (IllegalArgumentException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + shouldAppendDropwizardTypeDimensionProperty.get(), + PropertyKeys.SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION); + throw new InvalidMetricsPropertyException(message, e); + } + } else { + shouldAppendDropwizardTypeDimension = PropertyDefaults.SHOULD_PARSE_INLINE_DIMENSIONS; + } + + final AmazonCloudWatchAsync cloudWatchClient = AmazonCloudWatchAsyncClient.asyncBuilder() + .withCredentials(awsCredentialsProvider) + .withRegion(awsRegion) + .build(); + + this.reporter = DimensionedCloudWatchReporter.forRegistry(metricRegistry, cloudWatchClient, namespaceProperty.get()) + .convertRatesTo(TimeUnit.SECONDS) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .filter(MetricFilter.ALL) + .withPercentiles( + DimensionedCloudWatchReporter.Percentile.P50, + DimensionedCloudWatchReporter.Percentile.P75, + DimensionedCloudWatchReporter.Percentile.P99) + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withMeanRate() + .withArithmeticMean() + .withStdDev() + .withStatisticSet() + .withGlobalDimensions() + .withShouldParseDimensionsFromName(shouldParseInlineDimensions) + .withShouldAppendDropwizardTypeDimension(shouldAppendDropwizardTypeDimension) + .build(); + } + + @Override + public void start() { + reporter.start(pollingPeriod, pollingTimeUnit); + } + + @Override + public void stop() { + reporter.stop(); + } + + @Override + public void report() { + reporter.report(); + } + + /** + * Returns the value for specified property key as an Optional. + * @param properties + * @param key + * @return + */ + private static Optional getProperty(Properties properties, final String key) { + return Optional.ofNullable(properties.getProperty(key)); + } + + /** + * The keys used in the metrics properties configuration file. + */ + private static class PropertyKeys { + static final String NAMESPACE = "namespace"; + static final String AWS_ACCESS_KEY_ID = "awsAccessKeyId"; + static final String AWS_SECRET_KEY = "awsSecretKey"; + static final String AWS_REGION = "awsRegion"; + static final String POLLING_PERIOD = "pollingPeriod"; + static final String POLLING_TIME_UNIT = "pollingTimeUnit"; + static final String SHOULD_PARSE_INLINE_DIMENSIONS = "shouldParseInlineDimensions"; + static final String SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION = "shouldAppendDropwizardTypeDimension"; + } + + /** + * The default values for optional properties in the metrics properties configuration file. + */ + private static class PropertyDefaults { + static final long POLLING_PERIOD = 1; + static final TimeUnit POLLING_PERIOD_TIME_UNIT = TimeUnit.MINUTES; + static final boolean SHOULD_PARSE_INLINE_DIMENSIONS = false; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java new file mode 100644 index 000000000..450fe0d0d --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -0,0 +1,819 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsync; +import com.amazonaws.services.cloudwatch.model.Dimension; +import com.amazonaws.services.cloudwatch.model.InvalidParameterValueException; +import com.amazonaws.services.cloudwatch.model.MetricDatum; +import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest; +import com.amazonaws.services.cloudwatch.model.PutMetricDataResult; +import com.amazonaws.services.cloudwatch.model.StandardUnit; +import com.amazonaws.services.cloudwatch.model.StatisticSet; +import com.amazonaws.util.StringUtils; +import com.codahale.metrics.Clock; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Counting; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.Metered; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.ScheduledReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import com.codahale.metrics.jvm.BufferPoolMetricSet; +import com.codahale.metrics.jvm.ClassLoadingGaugeSet; +import com.codahale.metrics.jvm.FileDescriptorRatioGauge; +import com.codahale.metrics.jvm.GarbageCollectorMetricSet; +import com.codahale.metrics.jvm.MemoryUsageGaugeSet; +import com.codahale.metrics.jvm.ThreadStatesGaugeSet; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Date; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reports metrics to Amazon's CloudWatch periodically. + *

+ * Use {@link Builder} to construct instances of this class. The {@link Builder} + * allows to configure what aggregated metrics will be reported as a single {@link MetricDatum} to CloudWatch. + *

+ * There are a bunch of {@code with*} methods that provide a sufficient fine-grained control over what metrics + * should be reported. + * + * Forked from https://github.com/azagniotov/codahale-aggregated-metrics-cloudwatch-reporter. + */ +public class DimensionedCloudWatchReporter extends ScheduledReporter { + + private static final Logger LOGGER = LoggerFactory.getLogger(DimensionedCloudWatchReporter.class); + + // Visible for testing + public static final String DIMENSION_NAME_TYPE = "Type"; + + // Visible for testing + public static final String DIMENSION_GAUGE = "gauge"; + + // Visible for testing + public static final String DIMENSION_COUNT = "count"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_SUMMARY = "snapshot-summary"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_MEAN = "snapshot-mean"; + + // Visible for testing + public static final String DIMENSION_SNAPSHOT_STD_DEV = "snapshot-std-dev"; + + /** + * Amazon CloudWatch rejects values that are either too small or too large. + * Values must be in the range of 8.515920e-109 to 1.174271e+108 (Base 10) or 2e-360 to 2e360 (Base 2). + *

+ * In addition, special values (e.g., NaN, +Infinity, -Infinity) are not supported. + */ + private static final double SMALLEST_SENDABLE_VALUE = 8.515920e-109; + private static final double LARGEST_SENDABLE_VALUE = 1.174271e+108; + + /** + * Each CloudWatch API request may contain at maximum 20 datums + */ + private static final int MAXIMUM_DATUMS_PER_REQUEST = 20; + + /** + * We only submit the difference in counters since the last submission. This way we don't have to reset + * the counters within this application. + */ + private final Map lastPolledCounts; + + private final Builder builder; + private final String namespace; + private final AmazonCloudWatchAsync cloudWatchAsyncClient; + private final StandardUnit rateUnit; + private final StandardUnit durationUnit; + private final boolean shouldParseDimensionsFromName; + private final boolean shouldAppendDropwizardTypeDimension; + + private DimensionedCloudWatchReporter(final Builder builder) { + super(builder.metricRegistry, "coda-hale-metrics-cloud-watch-reporter", builder.metricFilter, builder.rateUnit, builder.durationUnit); + this.builder = builder; + this.namespace = builder.namespace; + this.cloudWatchAsyncClient = builder.cloudWatchAsyncClient; + this.lastPolledCounts = new ConcurrentHashMap<>(); + this.rateUnit = builder.cwRateUnit; + this.durationUnit = builder.cwDurationUnit; + this.shouldParseDimensionsFromName = builder.withShouldParseDimensionsFromName; + this.shouldAppendDropwizardTypeDimension = builder.withShouldAppendDropwizardTypeDimension; + } + + @Override + public void report(final SortedMap gauges, + final SortedMap counters, + final SortedMap histograms, + final SortedMap meters, + final SortedMap timers) { + + if (builder.withDryRun) { + LOGGER.warn("** Reporter is running in 'DRY RUN' mode **"); + } + + try { + final List metricData = new ArrayList<>( + gauges.size() + counters.size() + 10 * histograms.size() + 10 * timers.size()); + + for (final Map.Entry gaugeEntry : gauges.entrySet()) { + processGauge(gaugeEntry.getKey(), gaugeEntry.getValue(), metricData); + } + + for (final Map.Entry counterEntry : counters.entrySet()) { + processCounter(counterEntry.getKey(), counterEntry.getValue(), metricData); + } + + for (final Map.Entry histogramEntry : histograms.entrySet()) { + processCounter(histogramEntry.getKey(), histogramEntry.getValue(), metricData); + processHistogram(histogramEntry.getKey(), histogramEntry.getValue(), metricData); + } + + for (final Map.Entry meterEntry : meters.entrySet()) { + processCounter(meterEntry.getKey(), meterEntry.getValue(), metricData); + processMeter(meterEntry.getKey(), meterEntry.getValue(), metricData); + } + + for (final Map.Entry timerEntry : timers.entrySet()) { + processCounter(timerEntry.getKey(), timerEntry.getValue(), metricData); + processMeter(timerEntry.getKey(), timerEntry.getValue(), metricData); + processTimer(timerEntry.getKey(), timerEntry.getValue(), metricData); + } + + final Collection> metricDataPartitions = partition(metricData, MAXIMUM_DATUMS_PER_REQUEST); + final List> cloudWatchFutures = new ArrayList<>(metricData.size()); + + for (final List partition : metricDataPartitions) { + final PutMetricDataRequest putMetricDataRequest = new PutMetricDataRequest() + .withNamespace(namespace) + .withMetricData(partition); + + if (builder.withDryRun) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dry run - constructed PutMetricDataRequest: {}", putMetricDataRequest); + } + } else { + cloudWatchFutures.add(cloudWatchAsyncClient.putMetricDataAsync(putMetricDataRequest)); + } + } + + for (final Future cloudWatchFuture : cloudWatchFutures) { + try { + cloudWatchFuture.get(); + } catch (final Exception e) { + LOGGER.error("Error reporting metrics to CloudWatch. The data in this CloudWatch API request " + + "may have been discarded, did not make it to CloudWatch.", e); + } + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Sent {} metric datums to CloudWatch. Namespace: {}, metric data {}", metricData.size(), namespace, metricData); + } + + } catch (final RuntimeException e) { + LOGGER.error("Error marshalling CloudWatch metrics.", e); + } + } + + @Override + public void stop() { + try { + super.stop(); + } catch (final Exception e) { + LOGGER.error("Error when stopping the reporter.", e); + } finally { + if (!builder.withDryRun) { + try { + cloudWatchAsyncClient.shutdown(); + } catch (final Exception e) { + LOGGER.error("Error shutting down AmazonCloudWatchAsync", cloudWatchAsyncClient, e); + } + } + } + } + + private void processGauge(final String metricName, final Gauge gauge, final List metricData) { + if (gauge.getValue() instanceof Number) { + final Number number = (Number) gauge.getValue(); + stageMetricDatum(true, metricName, number.doubleValue(), StandardUnit.None, DIMENSION_GAUGE, metricData); + } + } + + private void processCounter(final String metricName, final Counting counter, final List metricData) { + long currentCount = counter.getCount(); + Long lastCount = lastPolledCounts.get(counter); + lastPolledCounts.put(counter, currentCount); + + if (lastCount == null) { + lastCount = 0L; + } + + // Only submit metrics that have changed - let's save some money! + final long delta = currentCount - lastCount; + stageMetricDatum(true, metricName, delta, StandardUnit.Count, DIMENSION_COUNT, metricData); + } + + /** + * The rates of {@link Metered} are reported after being converted using the rate factor, which is deduced from + * the set rate unit + * + * @see Timer#getSnapshot + * @see #getRateUnit + * @see #convertRate(double) + */ + private void processMeter(final String metricName, final Metered meter, final List metricData) { + final String formattedRate = String.format("-rate [per-%s]", getRateUnit()); + stageMetricDatum(builder.withOneMinuteMeanRate, metricName, convertRate(meter.getOneMinuteRate()), rateUnit, "1-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withFiveMinuteMeanRate, metricName, convertRate(meter.getFiveMinuteRate()), rateUnit, "5-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withFifteenMinuteMeanRate, metricName, convertRate(meter.getFifteenMinuteRate()), rateUnit, "15-min-mean" + formattedRate, metricData); + stageMetricDatum(builder.withMeanRate, metricName, convertRate(meter.getMeanRate()), rateUnit, "mean" + formattedRate, metricData); + } + + /** + * The {@link Snapshot} values of {@link Timer} are reported as {@link StatisticSet} after conversion. The + * conversion is done using the duration factor, which is deduced from the set duration unit. + *

+ * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + * + * @see Timer#getSnapshot + * @see #getDurationUnit + * @see #convertDuration(double) + */ + private void processTimer(final String metricName, final Timer timer, final List metricData) { + final Snapshot snapshot = timer.getSnapshot(); + + if (builder.withZeroValuesSubmission || snapshot.size() > 0) { + for (final Percentile percentile : builder.percentiles) { + final double convertedDuration = convertDuration(snapshot.getValue(percentile.getQuantile())); + stageMetricDatum(true, metricName, convertedDuration, durationUnit, percentile.getDesc(), metricData); + } + } + + // prevent empty snapshot from causing InvalidParameterValueException + if (snapshot.size() > 0) { + final String formattedDuration = String.format(" [in-%s]", getDurationUnit()); + stageMetricDatum(builder.withArithmeticMean, metricName, convertDuration(snapshot.getMean()), durationUnit, DIMENSION_SNAPSHOT_MEAN + formattedDuration, metricData); + stageMetricDatum(builder.withStdDev, metricName, convertDuration(snapshot.getStdDev()), durationUnit, DIMENSION_SNAPSHOT_STD_DEV + formattedDuration, metricData); + stageMetricDatumWithConvertedSnapshot(builder.withStatisticSet, metricName, snapshot, durationUnit, metricData); + } + } + + /** + * The {@link Snapshot} values of {@link Histogram} are reported as {@link StatisticSet} raw. In other words, the + * conversion using the duration factor does NOT apply. + *

+ * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + * + * @see Histogram#getSnapshot + */ + private void processHistogram(final String metricName, final Histogram histogram, final List metricData) { + final Snapshot snapshot = histogram.getSnapshot(); + + if (builder.withZeroValuesSubmission || snapshot.size() > 0) { + for (final Percentile percentile : builder.percentiles) { + final double value = snapshot.getValue(percentile.getQuantile()); + stageMetricDatum(true, metricName, value, StandardUnit.None, percentile.getDesc(), metricData); + } + } + + // prevent empty snapshot from causing InvalidParameterValueException + if (snapshot.size() > 0) { + stageMetricDatum(builder.withArithmeticMean, metricName, snapshot.getMean(), StandardUnit.None, DIMENSION_SNAPSHOT_MEAN, metricData); + stageMetricDatum(builder.withStdDev, metricName, snapshot.getStdDev(), StandardUnit.None, DIMENSION_SNAPSHOT_STD_DEV, metricData); + stageMetricDatumWithRawSnapshot(builder.withStatisticSet, metricName, snapshot, StandardUnit.None, metricData); + } + } + + /** + * Please note, the reported values submitted only if they show some data (greater than zero) in order to: + *

+ * 1. save some money + * 2. prevent com.amazonaws.services.cloudwatch.model.InvalidParameterValueException if empty {@link Snapshot} + * is submitted + *

+ * If {@link Builder#withZeroValuesSubmission()} is {@code true}, then all values will be submitted + */ + private void stageMetricDatum(final boolean metricConfigured, + final String metricName, + final double metricValue, + final StandardUnit standardUnit, + final String dimensionValue, + final List metricData) { + // Only submit metrics that show some data, so let's save some money + if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + final String name; + if (shouldParseDimensionsFromName) { + final String[] nameParts = metricName.split(" "); + final StringBuilder nameBuilder = new StringBuilder(nameParts[0]); + int i = 1; + for (; i < nameParts.length; ++i) { + final String[] dimensionParts = nameParts[i].split("="); + if (dimensionParts.length == 2 + && !StringUtils.isNullOrEmpty(dimensionParts[0]) + && !StringUtils.isNullOrEmpty(dimensionParts[1])) { + final Dimension dimension = new Dimension(); + dimension.withName(dimensionParts[0]); + dimension.withValue(dimensionParts[1]); + dimensions.add(dimension); + } else { + nameBuilder.append(" "); + nameBuilder.append(nameParts[i]); + } + } + name = nameBuilder.toString(); + } else { + name = metricName; + } + + if (shouldAppendDropwizardTypeDimension) { + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue)); + } + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withValue(cleanMetricValue(metricValue)) + .withMetricName(name) + .withDimensions(dimensions) + .withUnit(standardUnit)); + } + } + + private void stageMetricDatumWithConvertedSnapshot(final boolean metricConfigured, + final String metricName, + final Snapshot snapshot, + final StandardUnit standardUnit, + final List metricData) { + if (metricConfigured) { + double scaledSum = convertDuration(LongStream.of(snapshot.getValues()).sum()); + final StatisticSet statisticSet = new StatisticSet() + .withSum(scaledSum) + .withSampleCount((double) snapshot.size()) + .withMinimum(convertDuration(snapshot.getMin())) + .withMaximum(convertDuration(snapshot.getMax())); + + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_SUMMARY)); + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withMetricName(metricName) + .withDimensions(dimensions) + .withStatisticValues(statisticSet) + .withUnit(standardUnit)); + } + } + + private void stageMetricDatumWithRawSnapshot(final boolean metricConfigured, + final String metricName, + final Snapshot snapshot, + final StandardUnit standardUnit, + final List metricData) { + if (metricConfigured) { + double total = LongStream.of(snapshot.getValues()).sum(); + final StatisticSet statisticSet = new StatisticSet() + .withSum(total) + .withSampleCount((double) snapshot.size()) + .withMinimum((double) snapshot.getMin()) + .withMaximum((double) snapshot.getMax()); + + final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); + dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_SUMMARY)); + + metricData.add(new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withMetricName(metricName) + .withDimensions(dimensions) + .withStatisticValues(statisticSet) + .withUnit(standardUnit)); + } + } + + private double cleanMetricValue(final double metricValue) { + double absoluteValue = Math.abs(metricValue); + if (absoluteValue < SMALLEST_SENDABLE_VALUE) { + // Allow 0 through untouched, everything else gets rounded to SMALLEST_SENDABLE_VALUE + if (absoluteValue > 0) { + if (metricValue < 0) { + return -SMALLEST_SENDABLE_VALUE; + } else { + return SMALLEST_SENDABLE_VALUE; + } + } + } else if (absoluteValue > LARGEST_SENDABLE_VALUE) { + if (metricValue < 0) { + return -LARGEST_SENDABLE_VALUE; + } else { + return LARGEST_SENDABLE_VALUE; + } + } + return metricValue; + } + + private static Collection> partition(final Collection wholeCollection, final int partitionSize) { + final int[] itemCounter = new int[]{0}; + + return wholeCollection.stream() + .collect(Collectors.groupingBy(item -> itemCounter[0]++ / partitionSize)) + .values(); + } + + /** + * Creates a new {@link Builder} that sends values from the given {@link MetricRegistry} to the given namespace + * using the given CloudWatch client. + * + * @param metricRegistry {@link MetricRegistry} instance + * @param client {@link AmazonCloudWatchAsync} instance + * @param namespace the namespace. Must be non-null and not empty. + * @return {@link Builder} instance + */ + public static Builder forRegistry( + final MetricRegistry metricRegistry, + final AmazonCloudWatchAsync client, + final String namespace) { + return new Builder(metricRegistry, client, namespace); + } + + public enum Percentile { + P50(0.50, "50%"), + P75(0.75, "75%"), + P95(0.95, "95%"), + P98(0.98, "98%"), + P99(0.99, "99%"), + P995(0.995, "99.5%"), + P999(0.999, "99.9%"); + + private final double quantile; + private final String desc; + + Percentile(final double quantile, final String desc) { + this.quantile = quantile; + this.desc = desc; + } + + public double getQuantile() { + return quantile; + } + + public String getDesc() { + return desc; + } + } + + public static class Builder { + + private final String namespace; + private final AmazonCloudWatchAsync cloudWatchAsyncClient; + private final MetricRegistry metricRegistry; + + private Percentile[] percentiles; + private boolean withOneMinuteMeanRate; + private boolean withFiveMinuteMeanRate; + private boolean withFifteenMinuteMeanRate; + private boolean withMeanRate; + private boolean withArithmeticMean; + private boolean withStdDev; + private boolean withDryRun; + private boolean withZeroValuesSubmission; + private boolean withStatisticSet; + private boolean withJvmMetrics; + private boolean withShouldParseDimensionsFromName; + private boolean withShouldAppendDropwizardTypeDimension=true; + private MetricFilter metricFilter; + private TimeUnit rateUnit; + private TimeUnit durationUnit; + private StandardUnit cwRateUnit; + private StandardUnit cwDurationUnit; + private Set globalDimensions; + private final Clock clock; + + private Builder( + final MetricRegistry metricRegistry, + final AmazonCloudWatchAsync cloudWatchAsyncClient, + final String namespace) { + this.metricRegistry = metricRegistry; + this.cloudWatchAsyncClient = cloudWatchAsyncClient; + this.namespace = namespace; + this.percentiles = new Percentile[]{Percentile.P75, Percentile.P95, Percentile.P999}; + this.metricFilter = MetricFilter.ALL; + this.rateUnit = TimeUnit.SECONDS; + this.durationUnit = TimeUnit.MILLISECONDS; + this.globalDimensions = new LinkedHashSet<>(); + this.cwRateUnit = toStandardUnit(rateUnit); + this.cwDurationUnit = toStandardUnit(durationUnit); + this.clock = Clock.defaultClock(); + } + + /** + * Convert rates to the given time unit. + * + * @param rateUnit a unit of time + * @return {@code this} + */ + public Builder convertRatesTo(final TimeUnit rateUnit) { + this.rateUnit = rateUnit; + return this; + } + + /** + * Convert durations to the given time unit. + * + * @param durationUnit a unit of time + * @return {@code this} + */ + public Builder convertDurationsTo(final TimeUnit durationUnit) { + this.durationUnit = durationUnit; + return this; + } + + /** + * Only report metrics which match the given filter. + * + * @param metricFilter a {@link MetricFilter} + * @return {@code this} + */ + public Builder filter(final MetricFilter metricFilter) { + this.metricFilter = metricFilter; + return this; + } + + /** + * If the one minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getOneMinuteRate() + * @see Timer#getOneMinuteRate() + */ + public Builder withOneMinuteMeanRate() { + withOneMinuteMeanRate = true; + return this; + } + + /** + * If the five minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getFiveMinuteRate() + * @see Timer#getFiveMinuteRate() + */ + public Builder withFiveMinuteMeanRate() { + withFiveMinuteMeanRate = true; + return this; + } + + /** + * If the fifteen minute rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getFifteenMinuteRate() + * @see Timer#getFifteenMinuteRate() + */ + public Builder withFifteenMinuteMeanRate() { + withFifteenMinuteMeanRate = true; + return this; + } + + /** + * If the mean rate should be sent for {@link Meter} and {@link Timer}. {@code false} by default. + *

+ * The rate values are converted before reporting based on the rate unit set + * + * @return {@code this} + * @see ScheduledReporter#convertRate(double) + * @see Meter#getMeanRate() + * @see Timer#getMeanRate() + */ + public Builder withMeanRate() { + withMeanRate = true; + return this; + } + + /** + * If the arithmetic mean of {@link Snapshot} values in {@link Histogram} and {@link Timer} should be sent. + * {@code false} by default. + *

+ * The {@link Timer#getSnapshot()} values are converted before reporting based on the duration unit set + * The {@link Histogram#getSnapshot()} values are reported as is + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + * @see Snapshot#getMean() + */ + public Builder withArithmeticMean() { + withArithmeticMean = true; + return this; + } + + /** + * If the standard deviation of {@link Snapshot} values in {@link Histogram} and {@link Timer} should be sent. + * {@code false} by default. + *

+ * The {@link Timer#getSnapshot()} values are converted before reporting based on the duration unit set + * The {@link Histogram#getSnapshot()} values are reported as is + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + * @see Snapshot#getStdDev() + */ + public Builder withStdDev() { + withStdDev = true; + return this; + } + + /** + * If lifetime {@link Snapshot} summary of {@link Histogram} and {@link Timer} should be translated + * to {@link StatisticSet} in the most direct way possible and reported. {@code false} by default. + *

+ * The {@link Snapshot} duration values are converted before reporting based on the duration unit set + * + * @return {@code this} + * @see ScheduledReporter#convertDuration(double) + */ + public Builder withStatisticSet() { + withStatisticSet = true; + return this; + } + + /** + * If JVM statistic should be reported. Supported metrics include: + *

+ * - Run count and elapsed times for all supported garbage collectors + * - Memory usage for all memory pools, including off-heap memory + * - Breakdown of thread states, including deadlocks + * - File descriptor usage + * - Buffer pool sizes and utilization (Java 7 only) + *

+ * {@code false} by default. + * + * @return {@code this} + */ + public Builder withJvmMetrics() { + withJvmMetrics = true; + return this; + } + + /** + * If CloudWatch dimensions should be parsed off the the metric name: + * + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withShouldParseDimensionsFromName(final boolean value) { + withShouldParseDimensionsFromName = value; + return this; + } + + /** + * If the Dropwizard metric type should be reported as a CloudWatch dimension. + * + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withShouldAppendDropwizardTypeDimension(final boolean value) { + withShouldAppendDropwizardTypeDimension = value; + return this; + } + + /** + * Does not actually POST to CloudWatch, logs the {@link PutMetricDataRequest putMetricDataRequest} instead. + * {@code false} by default. + * + * @return {@code this} + */ + public Builder withDryRun() { + withDryRun = true; + return this; + } + + /** + * POSTs to CloudWatch all values. Otherwise, the reporter does not POST values which are zero in order to save + * costs. Also, some users have been experiencing {@link InvalidParameterValueException} when submitting zero + * values. Please refer to: + * https://github.com/azagniotov/codahale-aggregated-metrics-cloudwatch-reporter/issues/4 + *

+ * {@code false} by default. + * + * @return {@code this} + */ + public Builder withZeroValuesSubmission() { + withZeroValuesSubmission = true; + return this; + } + + /** + * The {@link Histogram} and {@link Timer} percentiles to send. If 0.5 is included, it'll be + * reported as median.This defaults to 0.75, 0.95 and 0.999. + *

+ * The {@link Timer#getSnapshot()} percentile values are converted before reporting based on the duration unit + * The {@link Histogram#getSnapshot()} percentile values are reported as is + * + * @param percentiles the percentiles to send. Replaces the default percentiles. + * @return {@code this} + */ + public Builder withPercentiles(final Percentile... percentiles) { + if (percentiles.length > 0) { + this.percentiles = percentiles; + } + return this; + } + + /** + * Global {@link Set} of {@link Dimension} to send with each {@link MetricDatum}. A dimension is a name/value + * pair that helps you to uniquely identify a metric. Every metric has specific characteristics that describe + * it, and you can think of dimensions as categories for those characteristics. + *

+ * Whenever you add a unique name/value pair to one of your metrics, you are creating a new metric. + * Defaults to {@code empty} {@link Set}. + * + * @param dimensions arguments in a form of {@code name=value}. The number of arguments is variable and may be + * zero. The maximum number of arguments is limited by the maximum dimension of a Java array + * as defined by the Java Virtual Machine Specification. Each {@code name=value} string + * will be converted to an instance of {@link Dimension} + * @return {@code this} + */ + public Builder withGlobalDimensions(final String... dimensions) { + for (final String pair : dimensions) { + final List splitted = Stream.of(pair.split("=")).map(String::trim).collect(Collectors.toList()); + this.globalDimensions.add(new Dimension().withName(splitted.get(0)).withValue(splitted.get(1))); + } + return this; + } + + public DimensionedCloudWatchReporter build() { + + if (withJvmMetrics) { + metricRegistry.register("jvm.uptime", (Gauge) () -> ManagementFactory.getRuntimeMXBean().getUptime()); + metricRegistry.register("jvm.current_time", (Gauge) clock::getTime); + metricRegistry.register("jvm.classes", new ClassLoadingGaugeSet()); + metricRegistry.register("jvm.fd_usage", new FileDescriptorRatioGauge()); + metricRegistry.register("jvm.buffers", new BufferPoolMetricSet(ManagementFactory.getPlatformMBeanServer())); + metricRegistry.register("jvm.gc", new GarbageCollectorMetricSet()); + metricRegistry.register("jvm.memory", new MemoryUsageGaugeSet()); + metricRegistry.register("jvm.thread-states", new ThreadStatesGaugeSet()); + } + + cwRateUnit = toStandardUnit(rateUnit); + cwDurationUnit = toStandardUnit(durationUnit); + + return new DimensionedCloudWatchReporter(this); + } + + private StandardUnit toStandardUnit(final TimeUnit timeUnit) { + switch (timeUnit) { + case SECONDS: + return StandardUnit.Seconds; + case MILLISECONDS: + return StandardUnit.Milliseconds; + case MICROSECONDS: + return StandardUnit.Microseconds; + default: + throw new IllegalArgumentException("Unsupported TimeUnit: " + timeUnit); + } + } + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java new file mode 100644 index 000000000..56755d545 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/InvalidMetricsPropertyException.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + + +import java.io.Serializable; + +public class InvalidMetricsPropertyException extends RuntimeException implements Serializable { + + public InvalidMetricsPropertyException(final String message) { + super(message); + } + + public InvalidMetricsPropertyException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index c1f5d78c1..6cdf5187d 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -28,6 +28,18 @@ public interface FlintClient { */ OptimisticTransaction startTransaction(String indexName, String dataSourceName); + /** + * + * Start a new optimistic transaction. + * + * @param indexName index name + * @param dataSourceName TODO: read from elsewhere in future + * @param forceInit forceInit create empty translog if not exist. + * @return transaction handle + */ + OptimisticTransaction startTransaction(String indexName, String dataSourceName, + boolean forceInit); + /** * Create a Flint index with the metadata given. * diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala index fea9974c6..eb93c7fde 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogEntry.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.core.metadata.log import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState +import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} /** * Flint metadata log entry. This is temporary and will merge field in FlintMetadata here and move @@ -92,4 +93,80 @@ object FlintMetadataLogEntry { .getOrElse(IndexState.UNKNOWN) } } + + val QUERY_EXECUTION_REQUEST_MAPPING: String = + """{ + | "dynamic": false, + | "properties": { + | "version": { + | "type": "keyword" + | }, + | "type": { + | "type": "keyword" + | }, + | "state": { + | "type": "keyword" + | }, + | "statementId": { + | "type": "keyword" + | }, + | "applicationId": { + | "type": "keyword" + | }, + | "sessionId": { + | "type": "keyword" + | }, + | "sessionType": { + | "type": "keyword" + | }, + | "error": { + | "type": "text" + | }, + | "lang": { + | "type": "keyword" + | }, + | "query": { + | "type": "text" + | }, + | "dataSourceName": { + | "type": "keyword" + | }, + | "submitTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "jobId": { + | "type": "keyword" + | }, + | "lastUpdateTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "queryId": { + | "type": "keyword" + | }, + | "excludeJobIds": { + | "type": "keyword" + | } + | } + |}""".stripMargin + + val QUERY_EXECUTION_REQUEST_SETTINGS: String = + """{ + | "index": { + | "number_of_shards": "1", + | "auto_expand_replicas": "0-2", + | "number_of_replicas": "0" + | } + |}""".stripMargin + + def failLogEntry(dataSourceName: String, error: String): FlintMetadataLogEntry = + FlintMetadataLogEntry( + "", + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + 0L, + IndexState.FAILED, + dataSourceName, + error) } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 5dd761cb3..e3ac49607 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -46,14 +46,15 @@ import org.opensearch.flint.core.http.RetryableHttpAsyncClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; import org.opensearch.flint.core.metadata.log.OptimisticTransaction; -import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NoOptimisticTransaction; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import scala.Option; +import scala.Some; /** * Flint client implementation for OpenSearch storage. @@ -82,34 +83,48 @@ public FlintOpenSearchClient(FlintOptions options) { } @Override - public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { + public OptimisticTransaction startTransaction(String indexName, String dataSourceName, + boolean forceInit) { LOG.info("Starting transaction on index " + indexName + " and data source " + dataSourceName); String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; - try (RestHighLevelClient client = createClient()) { if (client.indices().exists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { LOG.info("Found metadata log index " + metaLogIndexName); - return new DefaultOptimisticTransaction<>(dataSourceName, - new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName)); } else { - LOG.info("Metadata log index not found " + metaLogIndexName); - return new NoOptimisticTransaction<>(); + if (forceInit) { + createIndex(metaLogIndexName, FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_MAPPING(), + Some.apply(FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_SETTINGS())); + } else { + String errorMsg = "Metadata log index not found " + metaLogIndexName; + LOG.warning(errorMsg); + throw new IllegalStateException(errorMsg); + } } + return new DefaultOptimisticTransaction<>(dataSourceName, + new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName)); } catch (IOException e) { throw new IllegalStateException("Failed to check if index metadata log index exists " + metaLogIndexName, e); } } + @Override + public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { + return startTransaction(indexName, dataSourceName, false); + } + @Override public void createIndex(String indexName, FlintMetadata metadata) { LOG.info("Creating Flint index " + indexName + " with metadata " + metadata); + createIndex(indexName, metadata.getContent(), metadata.indexSettings()); + } + + protected void createIndex(String indexName, String mapping, Option settings) { + LOG.info("Creating Flint index " + indexName); String osIndexName = toLowercase(indexName); try (RestHighLevelClient client = createClient()) { CreateIndexRequest request = new CreateIndexRequest(osIndexName); - request.mapping(metadata.getContent(), XContentType.JSON); - - Option settings = metadata.indexSettings(); + request.mapping(mapping, XContentType.JSON); if (settings.isDefined()) { request.settings(settings.get(), XContentType.JSON); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 07029d608..f51e8a628 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -5,12 +5,6 @@ package org.opensearch.flint.core.storage; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy; - -import java.io.IOException; -import java.util.Base64; -import java.util.Optional; -import java.util.logging.Logger; import org.opensearch.OpenSearchException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.get.GetRequest; @@ -19,11 +13,20 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.metadata.log.FlintMetadataLog; import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; +import java.io.IOException; +import java.util.Base64; +import java.util.Optional; +import java.util.logging.Logger; + +import static java.util.logging.Level.SEVERE; +import static org.opensearch.action.support.WriteRequest.RefreshPolicy; + /** * Flint metadata log in OpenSearch store. For now use single doc instead of maintaining history * of metadata log. @@ -57,6 +60,11 @@ public FlintOpenSearchMetadataLog(FlintClient flintClient, String flintIndexName public FlintMetadataLogEntry add(FlintMetadataLogEntry logEntry) { // TODO: use single doc for now. this will be always append in future. FlintMetadataLogEntry latest; + if (!exists()) { + String errorMsg = "Flint Metadata Log index not found " + metaLogIndexName; + LOG.log(SEVERE, errorMsg); + throw new IllegalStateException(errorMsg); + } if (logEntry.id().isEmpty()) { latest = createLogEntry(logEntry); } else { @@ -108,6 +116,7 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { new IndexRequest() .index(metaLogIndexName) .id(logEntryWithId.id()) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) .source(logEntryWithId.toJson(), XContentType.JSON), RequestOptions.DEFAULT)); } @@ -148,6 +157,15 @@ private FlintMetadataLogEntry writeLogEntry( } } + private boolean exists() { + LOG.info("Checking if Flint index exists " + metaLogIndexName); + try (RestHighLevelClient client = flintClient.createClient()) { + return client.indices().exists(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); + } catch (IOException e) { + throw new IllegalStateException("Failed to check if Flint index exists " + metaLogIndexName, e); + } + } + @FunctionalInterface public interface CheckedFunction { R apply(T t) throws IOException; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index 58963ab74..4a6424512 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -4,14 +4,17 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; import org.opensearch.flint.core.FlintClient; -import org.opensearch.flint.core.FlintClientBuilder; -import org.opensearch.flint.core.FlintOptions; import java.io.IOException; +import java.util.logging.Level; +import java.util.logging.Logger; public class OpenSearchUpdater { + private static final Logger LOG = Logger.getLogger(OpenSearchUpdater.class.getName()); + private final String indexName; private final FlintClient flintClient; @@ -28,6 +31,7 @@ public void upsert(String id, String doc) { // also, failure to close the client causes the job to be stuck in the running state as the client resource // is not released. try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -44,6 +48,7 @@ public void upsert(String id, String doc) { public void update(String id, String doc) { try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -59,6 +64,7 @@ public void update(String id, String doc) { public void updateIf(String id, String doc, long seqNo, long primaryTerm) { try (RestHighLevelClient client = flintClient.createClient()) { + assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id).doc(doc, XContentType.JSON) @@ -73,4 +79,13 @@ public void updateIf(String id, String doc, long seqNo, long primaryTerm) { id), e); } } + + private void assertIndexExist(RestHighLevelClient client, String indexName) throws IOException { + LOG.info("Checking if index exists " + indexName); + if (!client.indices().exists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { + String errorMsg = "Index not found " + indexName; + LOG.log(Level.SEVERE, errorMsg); + throw new IllegalStateException(errorMsg); + } + } } diff --git a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java new file mode 100644 index 000000000..6f87276a8 --- /dev/null +++ b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package apache.spark.metrics.sink; + +import org.apache.spark.SecurityManager; +import com.codahale.metrics.MetricRegistry; +import org.apache.spark.metrics.sink.CloudWatchSink; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import java.util.Properties; +import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; + +class CloudWatchSinkTests { + private final MetricRegistry metricRegistry = Mockito.mock(MetricRegistry.class); + private final SecurityManager securityManager = Mockito.mock(SecurityManager.class); + + @Test + void should_throwException_when_namespacePropertyIsNotSet() { + final Properties properties = getDefaultValidProperties(); + properties.remove("namespace"); + final Executable executable = () -> { + final CloudWatchSink + cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_awsPropertyIsInvalid() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("awsRegion", "someInvalidRegion"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingPeriodPropertyIsNotANumber() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingPeriod", "notANumber"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingPeriodPropertyIsNegative() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingPeriod", "-5"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + @Test + void should_throwException_when_pollingTimeUnitPropertyIsInvalid() { + final Properties properties = getDefaultValidProperties(); + properties.setProperty("pollingTimeUnit", "notATimeUnitValue"); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + } + + private Properties getDefaultValidProperties() { + final Properties properties = new Properties(); + properties.setProperty("namespace", "namespaceValue"); + properties.setProperty("awsAccessKeyId", "awsAccessKeyIdValue"); + properties.setProperty("awsSecretKey", "awsSecretKeyValue"); + properties.setProperty("awsRegion", "us-east-1"); + properties.setProperty("pollingPeriod", "1"); + properties.setProperty("pollingTimeUnit", "MINUTES"); + return properties; + } +} diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java new file mode 100644 index 000000000..991fd78b4 --- /dev/null +++ b/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java @@ -0,0 +1,544 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package opensearch.flint.core.metrics.reporter; + +import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient; +import com.amazonaws.services.cloudwatch.model.Dimension; +import com.amazonaws.services.cloudwatch.model.MetricDatum; +import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest; +import com.amazonaws.services.cloudwatch.model.PutMetricDataResult; +import com.codahale.metrics.EWMA; +import com.codahale.metrics.ExponentialMovingAverages; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SlidingWindowReservoir; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; + +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Count; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Microseconds; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.Milliseconds; +import static com.amazonaws.services.cloudwatch.model.StandardUnit.None; +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_COUNT; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_GAUGE; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_NAME_TYPE; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_MEAN; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_STD_DEV; +import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_SUMMARY; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class DimensionedCloudWatchReporterTest { + + private static final String NAMESPACE = "namespace"; + private static final String ARBITRARY_COUNTER_NAME = "TheCounter"; + private static final String ARBITRARY_METER_NAME = "TheMeter"; + private static final String ARBITRARY_HISTOGRAM_NAME = "TheHistogram"; + private static final String ARBITRARY_TIMER_NAME = "TheTimer"; + private static final String ARBITRARY_GAUGE_NAME = "TheGauge"; + + @Mock + private AmazonCloudWatchAsyncClient mockAmazonCloudWatchAsyncClient; + + @Mock + private Future mockPutMetricDataResultFuture; + + @Captor + private ArgumentCaptor metricDataRequestCaptor; + + private MetricRegistry metricRegistry; + private DimensionedCloudWatchReporter.Builder reporterBuilder; + + @BeforeAll + public static void beforeClass() throws Exception { + reduceExponentialMovingAveragesDefaultTickInterval(); + } + + @BeforeEach + public void setUp() throws Exception { + metricRegistry = new MetricRegistry(); + reporterBuilder = DimensionedCloudWatchReporter.forRegistry(metricRegistry, mockAmazonCloudWatchAsyncClient, NAMESPACE); + when(mockAmazonCloudWatchAsyncClient.putMetricDataAsync(metricDataRequestCaptor.capture())).thenReturn(mockPutMetricDataResultFuture); + } + + @Test + public void shouldNotInvokeCloudWatchClientInDryRunMode() { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withDryRun().build().report(); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportWithoutGlobalDimensionsWhenGlobalDimensionsNotConfigured() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); // When 'withGlobalDimensions' was not called + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).hasSize(1); + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); + } + + @Test + public void reportedCounterShouldContainExpectedDimension() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); + } + + @Test + public void reportedCounterShouldContainDimensionEmbeddedInName() throws Exception { + final String DIMENSION_NAME = "some_dimension"; + final String DIMENSION_VALUE = "some_value"; + + metricRegistry.counter(ARBITRARY_COUNTER_NAME + " " + DIMENSION_NAME + "=" + DIMENSION_VALUE).inc(); + reporterBuilder.withShouldParseDimensionsFromName(true).build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME).withValue(DIMENSION_VALUE)); + } + + @Test + public void reportedGaugeShouldContainExpectedDimension() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 1L); + reporterBuilder.build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_GAUGE)); + } + + @Test + public void shouldNotReportGaugeWhenMetricValueNotOfTypeNumber() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> "bad value type"); + reporterBuilder.build().report(); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void neverReportMetersCountersGaugesWithZeroValues() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 0L); + metricRegistry.meter(ARBITRARY_METER_NAME).mark(0); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(0); + + buildReportWithSleep(reporterBuilder + .withArithmeticMean() + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withMeanRate()); + + verify(mockAmazonCloudWatchAsyncClient, never()).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void reportMetersCountersGaugesWithZeroValuesOnlyWhenConfigured() throws Exception { + metricRegistry.register(ARBITRARY_GAUGE_NAME, (Gauge) () -> 0L); + metricRegistry.meter(ARBITRARY_METER_NAME).mark(0); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(0); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(-1L, TimeUnit.NANOSECONDS); + + buildReportWithSleep(reporterBuilder + .withArithmeticMean() + .withOneMinuteMeanRate() + .withFiveMinuteMeanRate() + .withFifteenMinuteMeanRate() + .withZeroValuesSubmission() + .withMeanRate()); + + verify(mockAmazonCloudWatchAsyncClient, times(1)).putMetricDataAsync(metricDataRequestCaptor.capture()); + + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + for (final MetricDatum metricDatum : metricData) { + assertThat(metricDatum.getValue()).isEqualTo(0.0); + } + } + + @Test + public void reportedMeterShouldContainExpectedOneMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withOneMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("1-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedFiveMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withFiveMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("5-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedFifteenMinuteMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + buildReportWithSleep(reporterBuilder.withFifteenMinuteMeanRate()); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("15-min-mean-rate [per-second]")); + } + + @Test + public void reportedMeterShouldContainExpectedMeanRateDimension() throws Exception { + metricRegistry.meter(ARBITRARY_METER_NAME).mark(1); + reporterBuilder.withMeanRate().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("mean-rate [per-second]")); + } + + @Test + public void reportedHistogramShouldContainExpectedArithmeticMeanDimension() throws Exception { + metricRegistry.histogram(ARBITRARY_HISTOGRAM_NAME).update(1); + reporterBuilder.withArithmeticMean().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_MEAN)); + } + + @Test + public void reportedHistogramShouldContainExpectedStdDevDimension() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + reporterBuilder.withStdDev().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_SNAPSHOT_STD_DEV)); + } + + @Test + public void reportedTimerShouldContainExpectedArithmeticMeanDimension() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.MILLISECONDS); + reporterBuilder.withArithmeticMean().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("snapshot-mean [in-milliseconds]")); + } + + @Test + public void reportedTimerShouldContainExpectedStdDevDimension() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1, TimeUnit.MILLISECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.MILLISECONDS); + reporterBuilder.withStdDev().build().report(); + + final List dimensions = allDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue("snapshot-std-dev [in-milliseconds]")); + } + + @Test + public void shouldReportExpectedSingleGlobalDimension() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); + } + + @Test + public void shouldReportExpectedMultipleGlobalDimensions() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2", "Instance=stage").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); + assertThat(dimensions).contains(new Dimension().withName("Instance").withValue("stage")); + } + + @Test + public void shouldNotReportDuplicateGlobalDimensions() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.withGlobalDimensions("Region=us-west-2", "Region=us-west-2").build().report(); + + final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); + + assertThat(dimensions).containsNoDuplicates(); + } + + @Test + public void shouldReportExpectedCounterValue() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + reporterBuilder.build().report(); + + final MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + + assertThat(metricDatum.getValue()).isWithin(1.0); + assertThat(metricDatum.getUnit()).isEqualTo(Count.toString()); + } + + @Test + public void shouldNotReportUnchangedCounterValue() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + final DimensionedCloudWatchReporter dimensionedCloudWatchReporter = reporterBuilder.build(); + + dimensionedCloudWatchReporter.report(); + MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(1); + metricDataRequestCaptor.getAllValues().clear(); + + dimensionedCloudWatchReporter.report(); + + verify(mockAmazonCloudWatchAsyncClient, times(1)).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportCounterValueDelta() throws Exception { + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + final DimensionedCloudWatchReporter dimensionedCloudWatchReporter = reporterBuilder.build(); + + dimensionedCloudWatchReporter.report(); + MetricDatum metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(2); + metricDataRequestCaptor.getAllValues().clear(); + + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + metricRegistry.counter(ARBITRARY_COUNTER_NAME).inc(); + + dimensionedCloudWatchReporter.report(); + metricDatum = firstMetricDatumFromCapturedRequest(); + assertThat(metricDatum.getValue().intValue()).isEqualTo(6); + + verify(mockAmazonCloudWatchAsyncClient, times(2)).putMetricDataAsync(any(PutMetricDataRequest.class)); + } + + @Test + public void shouldReportArithmeticMeanAfterConversionByDefaultDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1_000_000, TimeUnit.NANOSECONDS); + reporterBuilder.withArithmeticMean().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest("snapshot-mean [in-milliseconds]"); + + assertThat(metricData.getValue().intValue()).isEqualTo(1); + assertThat(metricData.getUnit()).isEqualTo(Milliseconds.toString()); + } + + @Test + public void shouldReportStdDevAfterConversionByDefaultDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(2_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3_000_000, TimeUnit.NANOSECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(30_000_000, TimeUnit.NANOSECONDS); + reporterBuilder.withStdDev().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest("snapshot-std-dev [in-milliseconds]"); + + assertThat(metricData.getValue().intValue()).isEqualTo(12); + assertThat(metricData.getUnit()).isEqualTo(Milliseconds.toString()); + } + + @Test + public void shouldReportSnapshotValuesAfterConversionByCustomDurationWhenReportingTimer() throws Exception { + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(1, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(2, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(3, TimeUnit.SECONDS); + metricRegistry.timer(ARBITRARY_TIMER_NAME).update(30, TimeUnit.SECONDS); + reporterBuilder.withStatisticSet().convertDurationsTo(TimeUnit.MICROSECONDS).build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(36_000_000); + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30_000_000); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1_000_000); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(metricData.getUnit()).isEqualTo(Microseconds.toString()); + } + + @Test + public void shouldReportArithmeticMeanWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + reporterBuilder.withArithmeticMean().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_MEAN); + + assertThat(metricData.getValue().intValue()).isEqualTo(1); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportStdDevWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(3); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(30); + reporterBuilder.withStdDev().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_STD_DEV); + + assertThat(metricData.getValue().intValue()).isEqualTo(12); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportSnapshotValuesWithoutConversionWhenReportingHistogram() throws Exception { + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(1); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(2); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(3); + metricRegistry.histogram(DimensionedCloudWatchReporterTest.ARBITRARY_HISTOGRAM_NAME).update(30); + reporterBuilder.withStatisticSet().build().report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(36); + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + } + + @Test + public void shouldReportHistogramSubsequentSnapshotValues_SumMaxMinValues() throws Exception { + DimensionedCloudWatchReporter reporter = reporterBuilder.withStatisticSet().build(); + + final Histogram slidingWindowHistogram = new Histogram(new SlidingWindowReservoir(4)); + metricRegistry.register("SlidingWindowHistogram", slidingWindowHistogram); + + slidingWindowHistogram.update(1); + slidingWindowHistogram.update(2); + slidingWindowHistogram.update(30); + reporter.report(); + + final MetricDatum metricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(metricData.getStatisticValues().getMaximum().intValue()).isEqualTo(30); + assertThat(metricData.getStatisticValues().getMinimum().intValue()).isEqualTo(1); + assertThat(metricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(3); + assertThat(metricData.getStatisticValues().getSum().intValue()).isEqualTo(33); + assertThat(metricData.getUnit()).isEqualTo(None.toString()); + + slidingWindowHistogram.update(4); + slidingWindowHistogram.update(100); + slidingWindowHistogram.update(5); + slidingWindowHistogram.update(6); + reporter.report(); + + final MetricDatum secondMetricData = metricDatumByDimensionFromCapturedRequest(DIMENSION_SNAPSHOT_SUMMARY); + + assertThat(secondMetricData.getStatisticValues().getMaximum().intValue()).isEqualTo(100); + assertThat(secondMetricData.getStatisticValues().getMinimum().intValue()).isEqualTo(4); + assertThat(secondMetricData.getStatisticValues().getSampleCount().intValue()).isEqualTo(4); + assertThat(secondMetricData.getStatisticValues().getSum().intValue()).isEqualTo(115); + assertThat(secondMetricData.getUnit()).isEqualTo(None.toString()); + + } + + private MetricDatum metricDatumByDimensionFromCapturedRequest(final String dimensionValue) { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + + final Optional metricDatumOptional = + metricData + .stream() + .filter(metricDatum -> metricDatum.getDimensions() + .contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue))) + .findFirst(); + + if (metricDatumOptional.isPresent()) { + return metricDatumOptional.get(); + } + + throw new IllegalStateException("Could not find MetricDatum for Dimension value: " + dimensionValue); + } + + private MetricDatum firstMetricDatumFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + return putMetricDataRequest.getMetricData().get(0); + } + + private List firstMetricDatumDimensionsFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final MetricDatum metricDatum = putMetricDataRequest.getMetricData().get(0); + return metricDatum.getDimensions(); + } + + private List allDimensionsFromCapturedRequest() { + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricData = putMetricDataRequest.getMetricData(); + final List all = new LinkedList<>(); + for (final MetricDatum metricDatum : metricData) { + all.addAll(metricDatum.getDimensions()); + } + return all; + } + + private void buildReportWithSleep(final DimensionedCloudWatchReporter.Builder dimensionedCloudWatchReporterBuilder) throws InterruptedException { + final DimensionedCloudWatchReporter cloudWatchReporter = dimensionedCloudWatchReporterBuilder.build(); + Thread.sleep(10); + cloudWatchReporter.report(); + } + + /** + * This is a very ugly way to fool the {@link EWMA} by reducing the default tick interval + * in {@link ExponentialMovingAverages} from {@code 5} seconds to {@code 1} millisecond in order to ensure that + * exponentially-weighted moving average rates are populated. This helps to verify that all + * the expected {@link Dimension}s are present in {@link MetricDatum}. + * + * @throws NoSuchFieldException + * @throws IllegalAccessException + * @see ExponentialMovingAverages#tickIfNecessary() + * @see MetricDatum#getDimensions() + */ + private static void reduceExponentialMovingAveragesDefaultTickInterval() throws NoSuchFieldException, IllegalAccessException { + setFinalStaticField(ExponentialMovingAverages.class, "TICK_INTERVAL", TimeUnit.MILLISECONDS.toNanos(1)); + } + + private static void setFinalStaticField(final Class clazz, final String fieldName, long value) throws NoSuchFieldException, IllegalAccessException { + final Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + final Field modifiers = field.getClass().getDeclaredField("modifiers"); + modifiers.setAccessible(true); + modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); + field.set(null, value); + } + +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index eba99b809..cf2cd2b6e 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -5,15 +5,33 @@ package org.apache.spark.sql +import java.util.concurrent.ScheduledExecutorService + import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.util.ShutdownHookManager +import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} /** * Flint utility methods that rely on access to private code in Spark SQL package. */ package object flint { + /** + * Create daemon thread pool with the given thread group name and size. + * + * @param threadNamePrefix + * thread group name + * @param numThreads + * thread pool size + * @return + * thread pool executor + */ + def newDaemonThreadPoolScheduledExecutor( + threadNamePrefix: String, + numThreads: Int): ScheduledExecutorService = { + ThreadUtils.newDaemonThreadPoolScheduledExecutor(threadNamePrefix, numThreads) + } + /** * Add shutdown hook to SparkContext with default priority. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index f9e8dd693..5af70b793 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -109,23 +109,50 @@ object FlintInstance { maybeError) } - def serialize(job: FlintInstance, currentTime: Long): String = { - // jobId is only readable by spark, thus we don't override jobId - Serialization.write( - Map( - "type" -> "session", - "sessionId" -> job.sessionId, - "error" -> job.error.getOrElse(""), - "applicationId" -> job.applicationId, - "state" -> job.state, - // update last update time - "lastUpdateTime" -> currentTime, - // Convert a Seq[String] into a comma-separated string, such as "id1,id2". - // This approach is chosen over serializing to an array format (e.g., ["id1", "id2"]) - // because it simplifies client-side processing. With a comma-separated string, - // clients can easily ignore this field if it's not in use, avoiding the need - // for array parsing logic. This makes the serialized data more straightforward to handle. - "excludeJobIds" -> job.excludedJobIds.mkString(","), - "jobStartTime" -> job.jobStartTime)) + /** + * After the initial setup, the 'jobId' is only readable by Spark, and it should not be + * overridden. We use 'jobId' to ensure that only one job can run per session. In the case of a + * new job for the same session, it will override the 'jobId' in the session document. The old + * job will periodically check the 'jobId.' If the read 'jobId' does not match the current + * 'jobId,' the old job will exit early. Therefore, it is crucial that old jobs do not overwrite + * the session store's 'jobId' field after the initial setup. + * + * @param job + * Flint session object + * @param currentTime + * current timestamp in milliseconds + * @param includeJobId + * flag indicating whether to include the "jobId" field in the serialization + * @return + * serialized Flint session + */ + def serialize(job: FlintInstance, currentTime: Long, includeJobId: Boolean = true): String = { + val baseMap = Map( + "type" -> "session", + "sessionId" -> job.sessionId, + "error" -> job.error.getOrElse(""), + "applicationId" -> job.applicationId, + "state" -> job.state, + // update last update time + "lastUpdateTime" -> currentTime, + // Convert a Seq[String] into a comma-separated string, such as "id1,id2". + // This approach is chosen over serializing to an array format (e.g., ["id1", "id2"]) + // because it simplifies client-side processing. With a comma-separated string, + // clients can easily ignore this field if it's not in use, avoiding the need + // for array parsing logic. This makes the serialized data more straightforward to handle. + "excludeJobIds" -> job.excludedJobIds.mkString(","), + "jobStartTime" -> job.jobStartTime) + + val resultMap = if (includeJobId) { + baseMap + ("jobId" -> job.jobId) + } else { + baseMap + } + + Serialization.write(resultMap) + } + + def serializeWithoutJobId(job: FlintInstance, currentTime: Long): String = { + serialize(job, currentTime, includeJobId = false) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index e9331113a..47ade0f87 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -104,7 +104,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { val metadata = index.metadata() try { flintClient - .startTransaction(indexName, dataSourceName) + .startTransaction(indexName, dataSourceName, true) .initialLog(latest => latest.state == EMPTY || latest.state == DELETED) .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index 28e46cb29..5c4c7376c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.util.concurrent.{Executors, ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} import scala.collection.concurrent.{Map, TrieMap} import scala.sys.addShutdownHook @@ -15,6 +15,7 @@ import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor /** * Flint Spark index state monitor. @@ -62,9 +63,8 @@ class FlintSparkIndexMonitor( logInfo("Index monitor task is cancelled") } } catch { - case e: Exception => + case e: Throwable => logError("Failed to update index log entry", e) - throw new IllegalStateException("Failed to update index log entry") } }, 15, // Delay to ensure final logging is complete first, otherwise version conflicts @@ -100,7 +100,8 @@ object FlintSparkIndexMonitor extends Logging { * Thread-safe ExecutorService globally shared by all FlintSpark instance and will be shutdown * in Spark application upon exit. Non-final variable for test convenience. */ - var executor: ScheduledExecutorService = Executors.newScheduledThreadPool(1) + var executor: ScheduledExecutorService = + newDaemonThreadPoolScheduledExecutor("flint-index-heartbeat", 1) /** * Tracker that stores task future handle which is required to cancel the task in future. diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala index 12c2ae5bc..8ece6ba8a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala @@ -39,7 +39,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { 1620000001000L, excludedJobIds) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serialize(instance, currentTime) + val json = FlintInstance.serializeWithoutJobId(instance, currentTime) json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") @@ -80,7 +80,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { Seq.empty[String], Some("Some error occurred")) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serialize(instance, currentTime) + val json = FlintInstance.serializeWithoutJobId(instance, currentTime) json should include(""""error":"Some error occurred"""") } diff --git a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala index 1e7077799..ba9acffd1 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala @@ -14,9 +14,10 @@ import org.opensearch.action.get.GetRequest import org.opensearch.action.index.IndexRequest import org.opensearch.action.update.UpdateRequest import org.opensearch.client.RequestOptions -import org.opensearch.client.indices.CreateIndexRequest +import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest} import org.opensearch.common.xcontent.XContentType import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.{QUERY_EXECUTION_REQUEST_MAPPING, QUERY_EXECUTION_REQUEST_SETTINGS} import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState import org.opensearch.flint.core.storage.FlintOpenSearchClient._ import org.opensearch.flint.spark.FlintSparkSuite @@ -39,13 +40,15 @@ trait OpenSearchTransactionSuite extends FlintSparkSuite { super.beforeEach() openSearchClient .indices() - .create(new CreateIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + .create( + new CreateIndexRequest(testMetaLogIndex) + .mapping(QUERY_EXECUTION_REQUEST_MAPPING, XContentType.JSON) + .settings(QUERY_EXECUTION_REQUEST_SETTINGS, XContentType.JSON), + RequestOptions.DEFAULT) } override def afterEach(): Unit = { - openSearchClient - .indices() - .delete(new DeleteIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + deleteIndex(testMetaLogIndex) super.afterEach() } @@ -71,4 +74,21 @@ trait OpenSearchTransactionSuite extends FlintSparkSuite { .doc(latest.copy(state = newState).toJson, XContentType.JSON), RequestOptions.DEFAULT) } + + def deleteIndex(indexName: String): Unit = { + if (openSearchClient + .indices() + .exists(new GetIndexRequest(indexName), RequestOptions.DEFAULT)) { + openSearchClient + .indices() + .delete(new DeleteIndexRequest(indexName), RequestOptions.DEFAULT) + } + } + + def indexMapping(): String = { + val response = + openSearchClient.indices.get(new GetIndexRequest(testMetaLogIndex), RequestOptions.DEFAULT) + + response.getMappings.get(testMetaLogIndex).source().toString + } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index 9a762d9d6..7da67051d 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -16,7 +16,6 @@ import org.opensearch.client.opensearch.OpenSearchClient import org.opensearch.client.transport.rest_client.RestClientTransport import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NoOptimisticTransaction import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchScrollReader} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -31,9 +30,10 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M behavior of "Flint OpenSearch client" - it should "start no optimistic transaction if metadata log index doesn't exists" in { - val transaction = flintClient.startTransaction("test", "non-exist-index") - transaction shouldBe a[NoOptimisticTransaction[AnyRef]] + it should "throw IllegalStateException if metadata log index doesn't exists" in { + the[IllegalStateException] thrownBy { + flintClient.startTransaction("test", "non-exist-index") + } } it should "create index successfully" in { diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala index a8b5a1fa2..fa072898b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala @@ -9,6 +9,8 @@ import java.util.Base64 import scala.collection.JavaConverters.mapAsJavaMapConverter +import org.json4s.{Formats, NoTypeHints} +import org.json4s.native.{JsonMethods, Serialization} import org.opensearch.flint.OpenSearchTransactionSuite import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ @@ -214,7 +216,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { latestLogEntry(testLatestId) should contain("state" -> "active") } - test("should not necessarily rollback if transaction operation failed but no transient action") { + test( + "should not necessarily rollback if transaction operation failed but no transient action") { // Use create index scenario in this test case the[IllegalStateException] thrownBy { flintClient @@ -227,4 +230,66 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { // Should rollback to initial empty log latestLogEntry(testLatestId) should contain("state" -> "empty") } + + test("forceInit translog, even index is deleted before startTransaction") { + deleteIndex(testMetaLogIndex) + flintClient + .startTransaction(testFlintIndex, testDataSourceName, true) + .initialLog(latest => { + latest.id shouldBe testLatestId + latest.state shouldBe EMPTY + latest.createTime shouldBe 0L + latest.dataSource shouldBe testDataSourceName + latest.error shouldBe "" + true + }) + .finalLog(latest => latest) + .commit(_ => {}) + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + (JsonMethods.parse(indexMapping()) \ "properties" \ "sessionId" \ "type") + .extract[String] should equal("keyword") + } + + test("should fail if index is deleted before initial operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => { + deleteIndex(testMetaLogIndex) + true + }) + .transientLog(latest => latest.copy(state = CREATING)) + .finalLog(latest => latest.copy(state = ACTIVE)) + .commit(_ => {}) + } + } + + test("should fail if index is deleted before transient operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => true) + .transientLog(latest => { + deleteIndex(testMetaLogIndex) + latest.copy(state = CREATING) + }) + .finalLog(latest => latest.copy(state = ACTIVE)) + .commit(_ => {}) + } + } + + test("should fail if index is deleted before final operation") { + the[IllegalStateException] thrownBy { + flintClient + .startTransaction(testFlintIndex, testDataSourceName) + .initialLog(latest => true) + .transientLog(latest => { latest.copy(state = CREATING) }) + .finalLog(latest => { + deleteIndex(testMetaLogIndex) + latest.copy(state = ACTIVE) + }) + .commit(_ => {}) + } + } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala new file mode 100644 index 000000000..3b317a0fe --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.client.RequestOptions +import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.app.FlintInstance +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} +import org.scalatest.matchers.should.Matchers + +class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { + val sessionId = "sessionId" + val timestamp = 1700090926955L + val flintJob = + new FlintInstance( + "applicationId", + "jobId", + sessionId, + "running", + timestamp, + timestamp, + Seq("")) + var flintClient: FlintClient = _ + var updater: OpenSearchUpdater = _ + + override def beforeAll(): Unit = { + super.beforeAll() + flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); + updater = new OpenSearchUpdater( + testMetaLogIndex, + new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + } + + test("upsert flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe timestamp + } + + test("index is deleted when upsert flintJob should throw IllegalStateException") { + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + } + } + + test("update flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + + val newTimestamp = 1700090926956L + updater.update(sessionId, FlintInstance.serialize(flintJob, newTimestamp)) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp + } + + test("index is deleted when update flintJob should throw IllegalStateException") { + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.update(sessionId, FlintInstance.serialize(flintJob, timestamp)) + } + } + + test("updateIf flintJob should success") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + val (resp, latest) = getFlintInstance(sessionId) + + val newTimestamp = 1700090926956L + updater.updateIf( + sessionId, + FlintInstance.serialize(latest, newTimestamp), + resp.getSeqNo, + resp.getPrimaryTerm) + getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp + } + + test("index is deleted when updateIf flintJob should throw IllegalStateException") { + updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + val (resp, latest) = getFlintInstance(sessionId) + + deleteIndex(testMetaLogIndex) + + the[IllegalStateException] thrownBy { + updater.updateIf( + sessionId, + FlintInstance.serialize(latest, timestamp), + resp.getSeqNo, + resp.getPrimaryTerm) + } + } + + def getFlintInstance(docId: String): (GetResponse, FlintInstance) = { + val response = + openSearchClient.get(new GetRequest(testMetaLogIndex, docId), RequestOptions.DEFAULT) + (response, FlintInstance.deserializeFromMap(response.getSourceAsMap)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala index 365aab83d..8df2bc472 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala @@ -29,8 +29,18 @@ class FlintSparkIndexJobITSuite extends OpenSearchTransactionSuite with Matchers } override def afterEach(): Unit = { - super.afterEach() // must clean up metadata log first and then delete - flint.deleteIndex(testIndex) + + /** + * Todo, if state is not valid, will throw IllegalStateException. Should check flint + * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) if + * failed, delete index itself. + */ + try { + flint.deleteIndex(testIndex) + } catch { + case _: IllegalStateException => deleteIndex(testIndex) + } + super.afterEach() } test("recover should exit if index doesn't exist") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala new file mode 100644 index 000000000..4af147939 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.util.Base64 +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{doAnswer, spy} +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest +import org.opensearch.client.RequestOptions +import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.spark.FlintSpark.RefreshMode._ +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor + +class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matchers { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_index_monitor_test" + private val testFlintIndex = getSkippingIndexName(testTable) + private val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + + override def beforeAll(): Unit = { + super.beforeAll() + createPartitionedTable(testTable) + + // Replace mock executor with real one and change its delay + val realExecutor = newDaemonThreadPoolScheduledExecutor("flint-index-heartbeat", 1) + FlintSparkIndexMonitor.executor = spy(realExecutor) + doAnswer(invocation => { + // Delay 5 seconds to wait for refresh index done + realExecutor.scheduleWithFixedDelay(invocation.getArgument(0), 5, 1, TimeUnit.SECONDS) + }).when(FlintSparkIndexMonitor.executor) + .scheduleWithFixedDelay(any[Runnable], any[Long], any[Long], any[TimeUnit]) + } + + override def beforeEach(): Unit = { + super.beforeEach() + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("name") + .options(FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + .create() + flint.refreshIndex(testFlintIndex, INCREMENTAL) + + // Wait for refresh complete and another 5 seconds to make sure monitor thread start + val jobId = spark.streams.active.find(_.name == testFlintIndex).get.id.toString + awaitStreamingComplete(jobId) + Thread.sleep(5000L) + } + + override def afterEach(): Unit = { + // Cancel task to avoid conflict with delete operation since it runs frequently + FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true)) + FlintSparkIndexMonitor.indexMonitorTracker.clear() + + try { + flint.deleteIndex(testFlintIndex) + } catch { + // Index maybe end up with failed state in some test + case _: IllegalStateException => + openSearchClient + .indices() + .delete(new DeleteIndexRequest(testFlintIndex), RequestOptions.DEFAULT) + } finally { + super.afterEach() + } + } + + test("job start time should not change and last update time should keep updated") { + var (prevJobStartTime, prevLastUpdateTime) = getLatestTimestamp + 3 times { (jobStartTime, lastUpdateTime) => + jobStartTime shouldBe prevJobStartTime + lastUpdateTime should be > prevLastUpdateTime + prevLastUpdateTime = lastUpdateTime + } + } + + test("job start time should not change until recover index") { + val (prevJobStartTime, _) = getLatestTimestamp + + // Stop streaming job and wait for monitor task stopped + spark.streams.active.find(_.name == testFlintIndex).get.stop() + waitForMonitorTaskRun() + + // Restart streaming job and monitor task + flint.recoverIndex(testFlintIndex) + waitForMonitorTaskRun() + + val (jobStartTime, _) = getLatestTimestamp + jobStartTime should be > prevJobStartTime + } + + test("monitor task should terminate if streaming job inactive") { + val task = FlintSparkIndexMonitor.indexMonitorTracker(testFlintIndex) + + // Stop streaming job and wait for monitor task stopped + spark.streams.active.find(_.name == testFlintIndex).get.stop() + waitForMonitorTaskRun() + + // Index state transit to failed and task is cancelled + latestLogEntry(testLatestId) should contain("state" -> "failed") + task.isCancelled shouldBe true + } + + test("monitor task should not terminate if any exception") { + // Block write on metadata log index + setWriteBlockOnMetadataLogIndex(true) + waitForMonitorTaskRun() + + // Monitor task should stop working after blocking writes + var (_, prevLastUpdateTime) = getLatestTimestamp + 1 times { (_, lastUpdateTime) => + lastUpdateTime shouldBe prevLastUpdateTime + } + + // Unblock write and wait for monitor task attempt to update again + setWriteBlockOnMetadataLogIndex(false) + waitForMonitorTaskRun() + + // Monitor task continue working after unblocking write + 3 times { (_, lastUpdateTime) => + lastUpdateTime should be > prevLastUpdateTime + prevLastUpdateTime = lastUpdateTime + } + } + + private def getLatestTimestamp: (Long, Long) = { + val latest = latestLogEntry(testLatestId) + (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) + } + + private implicit class intWithTimes(n: Int) { + def times(f: (Long, Long) => Unit): Unit = { + 1 to n foreach { _ => + { + waitForMonitorTaskRun() + + val (jobStartTime, lastUpdateTime) = getLatestTimestamp + f(jobStartTime, lastUpdateTime) + } + } + } + } + + private def waitForMonitorTaskRun(): Unit = { + // Interval longer than monitor schedule to make sure it has finished another run + Thread.sleep(3000L) + } + + private def setWriteBlockOnMetadataLogIndex(isBlock: Boolean): Unit = { + val request = new UpdateSettingsRequest(testMetaLogIndex) + .settings(Map("blocks.write" -> isBlock).asJava) // Blocking write operations + openSearchClient.indices().putSettings(request, RequestOptions.DEFAULT) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index 294449a48..56227533a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -33,8 +33,18 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match } override def afterEach(): Unit = { + + /** + * Todo, if state is not valid, will throw IllegalStateException. Should check flint + * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) if + * failed, delete index itself. + */ + try { + flint.deleteIndex(testFlintIndex) + } catch { + case _: IllegalStateException => deleteIndex(testFlintIndex) + } super.afterEach() - flint.deleteIndex(testFlintIndex) } test("create index") { diff --git a/project/plugins.sbt b/project/plugins.sbt index 0fe5dd1ab..38550667b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,3 +8,5 @@ addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") addSbtPlugin("com.lightbend.sbt" % "sbt-java-formatter" % "0.8.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.0") addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1") +addSbtPlugin("net.aichler" % "sbt-jupiter-interface" % "0.11.1") \ No newline at end of file diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index cb1f5c1ca..750e228ef 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,9 +8,6 @@ package org.apache.spark.sql import java.util.Locale -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} - import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.cluster.metadata.MappingMetadata import org.opensearch.common.settings.Settings @@ -22,9 +19,7 @@ import play.api.libs.json._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, _} -import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -48,51 +43,18 @@ object FlintJob extends Logging with FlintJobExecutor { val conf = createSparkConf() val wait = conf.get("spark.flint.job.type", "continue") val dataSource = conf.get("spark.flint.datasource.name", "") - val spark = createSparkSession(conf) - - val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") - implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - - var dataToWrite: Option[DataFrame] = None - val startTime = System.currentTimeMillis() - // osClient needs spark session to be created first to get FlintOptions initialized. - // Otherwise, we will have connection exception from EMR-S to OS. - val osClient = new OSClient(FlintSparkConf().flintOptions()) - var exceptionThrown = true - try { - val futureMappingCheck = Future { - checkAndCreateIndex(osClient, resultIndex) - } - val data = executeQuery(spark, query, dataSource, "", "") - - val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(mappingCheckResult match { - case Right(_) => data - case Left(error) => - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) - }) - exceptionThrown = false - } catch { - case e: TimeoutException => - val error = s"Getting the mapping of index $resultIndex timed out" - logError(error, e) - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - case e: Exception => - val error = processQueryException(e, spark, dataSource, query, "", "") - dataToWrite = Some( - getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) - } finally { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // Stop SparkSession if streaming job succeeds - if (!exceptionThrown && wait.equalsIgnoreCase("streaming")) { - // wait if any child thread to finish before the main thread terminates - spark.streams.awaitAnyTermination() - } else { - spark.stop() - } - - threadPool.shutdown() - } + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) + + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 6e7dbb926..903bcaa09 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,6 +6,10 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception import org.opensearch.flint.core.FlintClient @@ -14,11 +18,13 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.{createIndex, getFormattedData, isSuperset, logError, logInfo} +import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { this: Logging => diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index edf7d62e6..674c0a75f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -6,14 +6,10 @@ package org.apache.spark.sql import java.net.ConnectException -import java.time.Instant -import java.util.Map import java.util.concurrent.ScheduledExecutorService -import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException} -import scala.concurrent.duration._ -import scala.concurrent.duration.{Duration, MINUTES} +import scala.concurrent.duration.{Duration, MINUTES, _} import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -44,10 +40,11 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 30 * 60 * 1000 + private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) - private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(10, MINUTES) + private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 + val INITIAL_DELAY_MILLIS = 3000L def update(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) @@ -63,30 +60,38 @@ object FlintREPL extends Logging with FlintJobExecutor { val conf: SparkConf = createSparkConf() val dataSource = conf.get("spark.flint.datasource.name", "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ conf.set("spark.sql.defaultCatalog", dataSource) val wait = conf.get("spark.flint.job.type", "continue") - // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) - - if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") - } - if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") - } - - val spark = createSparkSession(conf) - val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") - val result = executeQuery(spark, query, dataSource, "", "") - writeDataFrameToOpensearch(result, resultIndex, osClient) - spark.streams.awaitAnyTermination() + val jobOperator = + JobOperator(conf, query, dataSource, resultIndex, true) + jobOperator.start() } else { + // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. + val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) + val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + + if (sessionIndex.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + } + if (sessionId.isEmpty) { + throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + } + + val spark = createSparkSession(conf) + val osClient = new OSClient(FlintSparkConf().flintOptions()) + val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) @@ -99,7 +104,7 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.getLong("spark.flint.job.queryWaitTimeoutMillis", DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS) val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) - createShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) + addShutdownHook(flintSessionIndexUpdater, osClient, sessionIndex.get, sessionId.get) // 1 thread for updating heart beat val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl-heartbeat", 1) val jobStartTime = currentTimeProvider.currentEpochMillis() @@ -113,7 +118,8 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId.get, threadPool, osClient, - sessionIndex.get) + sessionIndex.get, + INITIAL_DELAY_MILLIS) if (setupFlintJobWithExclusionCheck( conf, @@ -267,7 +273,8 @@ object FlintREPL extends Logging with FlintJobExecutor { var canPickUpNextStatement = true while (currentTimeProvider .currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) { - logDebug(s"""read from ${commandContext.sessionIndex}""") + logInfo( + s"""read from ${commandContext.sessionIndex}, sessionId: $commandContext.sessionId""") val flintReader: FlintReader = createQueryReader( commandContext.osClient, @@ -314,19 +321,26 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionIndex: String, jobStartTime: Long, excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { - val flintJob = - new FlintInstance( - applicationId, - jobId, - sessionId, - "running", - currentTimeProvider.currentEpochMillis(), - jobStartTime, - excludeJobIds) - flintSessionIndexUpdater.upsert( + val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) + val currentTime = currentTimeProvider.currentEpochMillis() + val flintJob = new FlintInstance( + applicationId, + jobId, sessionId, - FlintInstance.serialize(flintJob, currentTimeProvider.currentEpochMillis())) - logDebug( + "running", + currentTime, + jobStartTime, + excludeJobIds) + + val serializedFlintInstance = if (includeJobId) { + FlintInstance.serialize(flintJob, currentTime, true) + } else { + FlintInstance.serializeWithoutJobId(flintJob, currentTime) + } + + flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) + + logInfo( s"""Updated job: {"jobid": ${flintJob.jobId}, "sessionId": ${flintJob.sessionId}} from $sessionIndex""") } @@ -383,7 +397,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val currentTime = currentTimeProvider.currentEpochMillis() flintSessionIndexUpdater.upsert( sessionId, - FlintInstance.serialize(flintInstance, currentTime)) + FlintInstance.serializeWithoutJobId(flintInstance, currentTime)) } /** @@ -511,6 +525,9 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + // todo. it is migration plan to handle https://github + // .com/opensearch-project/sql/issues/2436. Remove sleep after issue fixed in plugin. + Thread.sleep(2000) if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() @@ -678,7 +695,7 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logDebug(s"command complete: $flintCommand") + logInfo(s"command complete: $flintCommand") (dataToWrite, verificationResult) } @@ -767,7 +784,7 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - def createShutdownHook( + def addShutdownHook( flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, @@ -808,7 +825,9 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), + FlintInstance.serializeWithoutJobId( + flintInstance, + currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -825,6 +844,8 @@ object FlintREPL extends Logging with FlintJobExecutor { * the thread pool. * @param osClient * the OpenSearch client. + * @param initialDelayMillis + * the intial delay to start heartbeat */ def createHeartBeatUpdater( currentInterval: Long, @@ -832,7 +853,8 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId: String, threadPool: ScheduledExecutorService, osClient: OSClient, - sessionIndex: String): Unit = { + sessionIndex: String, + initialDelayMillis: Long): Unit = { threadPool.scheduleAtFixedRate( new Runnable { @@ -845,7 +867,9 @@ object FlintREPL extends Logging with FlintJobExecutor { flintInstance.state = "running" flintSessionUpdater.updateIf( sessionId, - FlintInstance.serialize(flintInstance, currentTimeProvider.currentEpochMillis()), + FlintInstance.serializeWithoutJobId( + flintInstance, + currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, getResponse.getPrimaryTerm) } @@ -859,7 +883,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } } }, - 0L, + initialDelayMillis, currentInterval, java.util.concurrent.TimeUnit.MILLISECONDS) } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala new file mode 100644 index 000000000..c60d250ea --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.concurrent.ThreadPoolExecutor + +import scala.concurrent.{ExecutionContext, Future, TimeoutException} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} + +import org.opensearch.flint.core.storage.OpenSearchUpdater + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession +import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, updateFlintInstanceBeforeShutdown} +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils + +case class JobOperator( + sparkConf: SparkConf, + query: String, + dataSource: String, + resultIndex: String, + streaming: Boolean) + extends Logging + with FlintJobExecutor { + private val spark = createSparkSession(sparkConf) + + // jvm shutdown hook + sys.addShutdownHook(stop()) + + def start(): Unit = { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + var dataToWrite: Option[DataFrame] = None + val startTime = System.currentTimeMillis() + // osClient needs spark session to be created first to get FlintOptions initialized. + // Otherwise, we will have connection exception from EMR-S to OS. + val osClient = new OSClient(FlintSparkConf().flintOptions()) + var exceptionThrown = true + try { + val futureMappingCheck = Future { + checkAndCreateIndex(osClient, resultIndex) + } + val data = executeQuery(spark, query, dataSource, "", "") + + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider) + }) + exceptionThrown = false + } catch { + case e: TimeoutException => + val error = s"Getting the mapping of index $resultIndex timed out" + logError(error, e) + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + case e: Exception => + val error = processQueryException(e, spark, dataSource, query, "", "") + dataToWrite = Some( + getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) + } finally { + cleanUpResources(exceptionThrown, threadPool, dataToWrite, resultIndex, osClient) + } + } + + def cleanUpResources( + exceptionThrown: Boolean, + threadPool: ThreadPoolExecutor, + dataToWrite: Option[DataFrame], + resultIndex: String, + osClient: OSClient): Unit = { + try { + dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + } catch { + case e: Exception => logError("fail to write to result index", e) + } + + try { + // Stop SparkSession if streaming job succeeds + if (!exceptionThrown && streaming) { + // wait if any child thread to finish before the main thread terminates + spark.streams.awaitAnyTermination() + } + } catch { + case e: Exception => logError("streaming job failed", e) + } + + try { + threadPool.shutdown() + logInfo("shut down thread threadpool") + } catch { + case e: Exception => logError("Fail to close threadpool", e) + } + } + + def stop(): Unit = { + Try { + spark.stop() + logInfo("stopped spark session") + } match { + case Success(_) => + case Failure(e) => logError("unexpected error while stopping spark session", e) + } + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 704045e8a..7b9fcc140 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -81,7 +81,8 @@ class FlintREPLTest "session1", threadPool, osClient, - "sessionIndex") + "sessionIndex", + 0) // Verifications verify(osClient, atLeastOnce()).getDoc("sessionIndex", "session1") @@ -117,7 +118,7 @@ class FlintREPLTest } // Here, we're injecting our mockShutdownHookManager into the method - FlintREPL.createShutdownHook( + FlintREPL.addShutdownHook( flintSessionIndexUpdater, osClient, sessionIndex,