> dimensionBuilders = Map.of(
+ DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension,
+ DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID),
+ DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID),
+ DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME),
+ DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID)
+ );
+
+ /**
+ * Constructs a CloudWatch Dimension object based on the provided dimension name. If a specific
+ * builder exists for the dimension name, it is used; otherwise, a default dimension is constructed.
+ *
+ * @param dimensionName The name of the dimension to construct.
+ * @param parts Additional information that might be required by specific dimension builders.
+ * @return A CloudWatch Dimension object.
+ */
+ public static Dimension constructDimension(String dimensionName, String[] metricNameParts) {
+ if (!doesNameConsistsOfMetricNameSpace(metricNameParts)) {
+ throw new IllegalArgumentException("The provided metric name parts do not consist of a valid metric namespace.");
+ }
+ return dimensionBuilders.getOrDefault(dimensionName, ignored -> getDefaultDimension(dimensionName))
+ .apply(metricNameParts);
+ }
+
+ // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137
+ // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName.
+ public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) {
+ return metricNameParts.length >= 3
+ && (metricNameParts[1].equals("driver") || StringUtils.isNumeric(metricNameParts[1]));
+ }
+
+ /**
+ * Generates a Dimension object representing the instance role (either executor or driver) based on the
+ * metric name parts provided.
+ *
+ * @param parts An array where the second element indicates the role by being numeric (executor) or not (driver).
+ * @return A Dimension object with the instance role.
+ */
+ private static Dimension getInstanceRoleDimension(String[] parts) {
+ String value = StringUtils.isNumeric(parts[1]) ? "executor" : parts[1];
+ return new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue(value);
+ }
+
+ /**
+ * Constructs a Dimension object using a system environment variable. If the environment variable is not found,
+ * it uses a predefined "UNKNOWN" value.
+ *
+ * @param envVarName The name of the environment variable to use for the dimension's value.
+ * @param dimensionName The name of the dimension.
+ * @return A Dimension object populated with the appropriate name and value.
+ */
+ private static Dimension getEnvironmentVariableDimension(String envVarName, String dimensionName) {
+ String value = System.getenv().getOrDefault(envVarName, UNKNOWN);
+ return new Dimension().withName(dimensionName).withValue(value);
+ }
+
+ /**
+ * Provides a generic mechanism to construct a Dimension object with an environment variable value
+ * or a default value if the environment variable is not set.
+ *
+ * @param dimensionName The name of the dimension for which to retrieve the value.
+ * @return A Dimension object populated with the dimension name and its corresponding value.
+ */
+ private static Dimension getDefaultDimension(String dimensionName) {
+ return getEnvironmentVariableDimension(dimensionName, dimensionName);
+ }
+}
\ No newline at end of file
diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
index a47fa70ce..e16eb0021 100644
--- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
+++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java
@@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
-import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -48,9 +47,12 @@
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
+import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import static org.opensearch.flint.core.metrics.reporter.DimensionUtils.constructDimension;
+
/**
* Reports metrics to Amazon's CloudWatch periodically.
*
@@ -84,16 +86,6 @@ public class DimensionedCloudWatchReporter extends ScheduledReporter {
// Visible for testing
public static final String DIMENSION_SNAPSHOT_STD_DEV = "snapshot-std-dev";
- public static final String DIMENSION_JOB_ID = "jobId";
-
- public static final String DIMENSION_APPLICATION_ID = "applicationId";
-
- public static final String DIMENSION_DOMAIN_ID = "domainId";
-
- public static final String DIMENSION_INSTANCE_ROLE = "instanceRole";
-
- public static final String UNKNOWN = "unknown";
-
/**
* Amazon CloudWatch rejects values that are either too small or too large.
* Values must be in the range of 8.515920e-109 to 1.174271e+108 (Base 10) or 2e-360 to 2e360 (Base 2).
@@ -103,6 +95,8 @@ public class DimensionedCloudWatchReporter extends ScheduledReporter {
private static final double SMALLEST_SENDABLE_VALUE = 8.515920e-109;
private static final double LARGEST_SENDABLE_VALUE = 1.174271e+108;
+ private static Map constructedDimensions;
+
/**
* Each CloudWatch API request may contain at maximum 20 datums
*/
@@ -133,6 +127,7 @@ private DimensionedCloudWatchReporter(final Builder builder) {
this.durationUnit = builder.cwDurationUnit;
this.shouldParseDimensionsFromName = builder.withShouldParseDimensionsFromName;
this.shouldAppendDropwizardTypeDimension = builder.withShouldAppendDropwizardTypeDimension;
+ this.constructedDimensions = new ConcurrentHashMap<>();
this.filter = MetricFilter.ALL;
}
@@ -349,34 +344,89 @@ private void stageMetricDatum(final boolean metricConfigured,
// Only submit metrics that show some data, so let's save some money
if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) {
final DimensionedName dimensionedName = DimensionedName.decode(metricName);
+ // Add global dimensions for all metrics
final Set dimensions = new LinkedHashSet<>(builder.globalDimensions);
- MetricInfo metricInfo = getMetricInfo(dimensionedName);
- dimensions.addAll(metricInfo.getDimensions());
if (shouldAppendDropwizardTypeDimension) {
dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue));
}
- metricData.add(new MetricDatum()
- .withTimestamp(new Date(builder.clock.getTime()))
- .withValue(cleanMetricValue(metricValue))
- .withMetricName(metricInfo.getMetricName())
- .withDimensions(dimensions)
- .withUnit(standardUnit));
+ MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions);
+ for (Set dimensionSet : metricInfo.getDimensionSets()) {
+ MetricDatum datum = new MetricDatum()
+ .withTimestamp(new Date(builder.clock.getTime()))
+ .withValue(cleanMetricValue(metricValue))
+ .withMetricName(metricInfo.getMetricName())
+ .withDimensions(dimensionSet)
+ .withUnit(standardUnit);
+ metricData.add(datum);
+ }
}
}
- public MetricInfo getMetricInfo(DimensionedName dimensionedName) {
+ /**
+ * Constructs a {@link MetricInfo} object based on the provided {@link DimensionedName} and a set of additional dimensions.
+ * This method processes the metric name contained within {@code dimensionedName} to potentially modify it based on naming conventions
+ * and extracts or generates additional dimension sets for detailed metrics reporting.
+ *
+ * If no specific naming convention is detected, the original set of dimensions is used as is. The method finally encapsulates the metric name
+ * and the collection of dimension sets in a {@link MetricInfo} object and returns it.
+ *
+ * @param dimensionedName An instance of {@link DimensionedName} containing the original metric name and any directly associated dimensions.
+ * @param dimensions A set of {@link Dimension} objects provided externally that should be associated with the metric.
+ * @return A {@link MetricInfo} object containing the processed metric name and a list of dimension sets for metrics reporting.
+ */
+ private MetricInfo getMetricInfo(DimensionedName dimensionedName, Set dimensions) {
+ // Add dimensions from dimensionedName
+ dimensions.addAll(dimensionedName.getDimensions());
+
String metricName = dimensionedName.getName();
String[] parts = metricName.split("\\.");
- Set dimensions = new HashSet<>();
- if (doesNameConsistsOfMetricNameSpace(parts)) {
+ List> dimensionSets = new ArrayList<>();
+ if (DimensionUtils.doesNameConsistsOfMetricNameSpace(parts)) {
metricName = constructMetricName(parts);
- addInstanceRoleDimension(dimensions, parts);
+ // Get dimension sets corresponding to a specific metric source
+ constructDimensionSets(dimensionSets, parts);
+ // Add dimensions constructed above into each of the dimensionSets
+ for (Set dimensionSet : dimensionSets) {
+ // Create a copy of each set and add the additional dimensions
+ dimensionSet.addAll(dimensions);
+ }
+ }
+
+ if (dimensionSets.isEmpty()) {
+ dimensionSets.add(dimensions);
+ }
+ return new MetricInfo(metricName, dimensionSets);
+ }
+
+ /**
+ * Populates a list of dimension sets based on the metric source name extracted from the metric's parts
+ * and predefined dimension groupings. This method aims to create detailed and structured dimension
+ * sets for metrics, enhancing the granularity and relevance of metric reporting.
+ *
+ * If no predefined dimension groups exist for the metric source, or if the dimension name groups are
+ * not initialized, the method exits without modifying the dimension sets list.
+ *
+ * @param dimensionSets A list to be populated with sets of {@link Dimension} objects, each representing
+ * a group of dimensions relevant to the metric's source.
+ * @param parts An array of strings derived from splitting the metric's name, used to extract information
+ * like the metric source name and to construct dimensions based on naming conventions.
+ */
+ private void constructDimensionSets(List> dimensionSets, String[] parts) {
+ String metricSourceName = parts[2];
+ if (builder.dimensionNameGroups == null || builder.dimensionNameGroups.getDimensionGroups() == null || !builder.dimensionNameGroups.getDimensionGroups().containsKey(metricSourceName)) {
+ return;
+ }
+
+ for (List dimensionNames: builder.dimensionNameGroups.getDimensionGroups().get(metricSourceName)) {
+ Set dimensions = new LinkedHashSet<>();
+ for (String dimensionName: dimensionNames) {
+ constructedDimensions.putIfAbsent(dimensionName, constructDimension(dimensionName, parts));
+ dimensions.add(constructedDimensions.get(dimensionName));
+ }
+ dimensionSets.add(dimensions);
}
- addDefaultDimensionsForSparkJobMetrics(dimensions);
- dimensions.addAll(dimensionedName.getDimensions());
- return new MetricInfo(metricName, dimensions);
}
/**
@@ -393,31 +443,6 @@ private String constructMetricName(String[] metricNameParts) {
return Stream.of(metricNameParts).skip(partsToSkip).collect(Collectors.joining("."));
}
- // These dimensions are for all metrics
- // TODO: Remove EMR-S specific env vars https://github.com/opensearch-project/opensearch-spark/issues/231
- private static void addDefaultDimensionsForSparkJobMetrics(Set dimensions) {
- final String jobId = System.getenv().getOrDefault("SERVERLESS_EMR_JOB_ID", UNKNOWN);
- final String applicationId = System.getenv().getOrDefault("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", UNKNOWN);
- dimensions.add(new Dimension().withName(DIMENSION_JOB_ID).withValue(jobId));
- dimensions.add(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(applicationId));
- }
-
- private static void addInstanceRoleDimension(Set dimensions, String[] parts) {
- Dimension instanceRoleDimension;
- if (StringUtils.isNumeric(parts[1])) {
- instanceRoleDimension = new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue("executor");
- } else {
- instanceRoleDimension = new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue(parts[1]);
- }
- dimensions.add(instanceRoleDimension);
- }
- // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137
- // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName.
- private boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) {
- return metricNameParts.length >= 3
- && (metricNameParts[1].equals("driver") || StringUtils.isNumeric(metricNameParts[1]));
- }
-
private void stageMetricDatumWithConvertedSnapshot(final boolean metricConfigured,
final String metricName,
final Snapshot snapshot,
@@ -545,19 +570,19 @@ public String getDesc() {
public static class MetricInfo {
private String metricName;
- private Set dimensions;
+ private List> dimensionSets;
- public MetricInfo(String metricName, Set dimensions) {
+ public MetricInfo(String metricName, List> dimensionSets) {
this.metricName = metricName;
- this.dimensions = dimensions;
+ this.dimensionSets = dimensionSets;
}
public String getMetricName() {
return metricName;
}
- public Set getDimensions() {
- return dimensions;
+ public List> getDimensionSets() {
+ return dimensionSets;
}
}
@@ -587,6 +612,7 @@ public static class Builder {
private StandardUnit cwRateUnit;
private StandardUnit cwDurationUnit;
private Set globalDimensions;
+ private DimensionNameGroups dimensionNameGroups;
private final Clock clock;
private Builder(
@@ -787,6 +813,11 @@ public Builder withShouldAppendDropwizardTypeDimension(final boolean value) {
return this;
}
+ public Builder withDimensionNameGroups(final DimensionNameGroups dimensionNameGroups) {
+ this.dimensionNameGroups = dimensionNameGroups;
+ return this;
+ }
+
/**
* Does not actually POST to CloudWatch, logs the {@link PutMetricDataRequest putMetricDataRequest} instead.
* {@code false} by default.
diff --git a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java
similarity index 62%
rename from flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java
rename to flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java
index 6f87276a8..db2948858 100644
--- a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java
+++ b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java
@@ -16,7 +16,10 @@
import java.util.Properties;
import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException;
-class CloudWatchSinkTests {
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+
+class CloudWatchSinkTest {
private final MetricRegistry metricRegistry = Mockito.mock(MetricRegistry.class);
private final SecurityManager securityManager = Mockito.mock(SecurityManager.class);
@@ -71,6 +74,44 @@ void should_throwException_when_pollingTimeUnitPropertyIsInvalid() {
Assertions.assertThrows(InvalidMetricsPropertyException.class, executable);
}
+ @Test
+ void should_throwException_when_DimensionGroupsPropertyIsInvalid() {
+ final Properties properties = getDefaultValidProperties();
+ String jsonString = "{\"dimensionGroups\":[{\"MetricSource1\":{}}, [\"Dimension1\",\"Dimension2\",\"Dimension3\"]]}]}";
+ properties.setProperty("dimensionGroups", jsonString);
+ final Executable executable = () -> {
+ final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager);
+ };
+ InvalidMetricsPropertyException exception = Assertions.assertThrows(InvalidMetricsPropertyException.class, executable);
+ StringBuilder expectedMessageBuilder = new StringBuilder();
+ expectedMessageBuilder.append("Unable to parse value (")
+ .append(jsonString)
+ .append(") for the \"dimensionGroups\" CloudWatchSink metrics property.");
+ Assertions.assertEquals(expectedMessageBuilder.toString(), exception.getMessage());
+ }
+
+ @Test
+ public void should_CreateCloudWatchSink_When_dimensionGroupsPropertyIsValid() {
+ final Properties properties = getDefaultValidProperties();
+ String jsonString = "{"
+ + "\"dimensionGroups\": {"
+ + "\"MetricSource1\": [[\"DimensionA1\", \"DimensionA2\"], [\"DimensionA1\"]],"
+ + "\"MetricSource2\": [[\"DimensionB1\"], [\"DimensionB2\", \"DimensionB3\", \"DimensionB4\"]],"
+ + "\"MetricSource3\": [[\"DimensionC1\", \"DimensionC2\", \"DimensionC3\"], [\"DimensionC4\"], [\"DimensionC5\", \"DimensionC6\"]]"
+ + "}"
+ + "}";
+ properties.setProperty("dimensionGroups", jsonString);
+
+ CloudWatchSink cloudWatchSink = null;
+ try {
+ cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager);
+ } catch (Exception e) {
+ fail("Should not have thrown any exception, but threw: " + e.getMessage());
+ }
+
+ assertNotNull("CloudWatchSink should be created", cloudWatchSink);
+ }
+
private Properties getDefaultValidProperties() {
final Properties properties = new Properties();
properties.setProperty("namespace", "namespaceValue");
diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java
new file mode 100644
index 000000000..7fab8c346
--- /dev/null
+++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.core.metrics.reporter;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import com.amazonaws.services.cloudwatch.model.Dimension;
+import org.junit.jupiter.api.function.Executable;
+
+import java.lang.reflect.Field;
+import java.util.Map;
+
+public class DimensionUtilsTest {
+ private static final String[] parts = {"someMetric", "123", "dummySource"};
+
+ @Test
+ void testConstructDimensionThrowsIllegalArgumentException() {
+ String dimensionName = "InvalidDimension";
+ String[] metricNameParts = {};
+
+ final Executable executable = () -> {
+ DimensionUtils.constructDimension(dimensionName, metricNameParts);
+ };
+ IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, executable);
+ Assertions.assertEquals("The provided metric name parts do not consist of a valid metric namespace.", exception.getMessage());
+ }
+ @Test
+ public void testGetInstanceRoleDimensionWithExecutor() {
+ Dimension result = DimensionUtils.constructDimension("instanceRole", parts);
+ assertEquals("instanceRole", result.getName());
+ assertEquals("executor", result.getValue());
+ }
+
+ @Test
+ public void testGetInstanceRoleDimensionWithRoleName() {
+ String[] parts = {"someMetric", "driver", "dummySource"};
+ Dimension result = DimensionUtils.constructDimension("instanceRole", parts);
+ assertEquals("instanceRole", result.getName());
+ assertEquals("driver", result.getValue());
+ }
+
+ @Test
+ public void testGetDefaultDimensionWithUnknown() {
+ Dimension result = DimensionUtils.constructDimension("nonExistentDimension", parts);
+ assertEquals("nonExistentDimension", result.getName());
+ assertEquals("UNKNOWN", result.getValue());
+ }
+
+ @Test
+ public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, IllegalAccessException {
+ Class> classOfMap = System.getenv().getClass();
+ Field field = classOfMap.getDeclaredField("m");
+ field.setAccessible(true);
+ Map writeableEnvironmentVariables = (Map)field.get(System.getenv());
+ writeableEnvironmentVariables.put("TEST_VAR", "dummy1");
+ writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2");
+ Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts);
+ assertEquals("TEST_VAR", result1.getName());
+ assertEquals("dummy1", result1.getValue());
+ Dimension result2 = DimensionUtils.constructDimension("jobId", parts);
+ assertEquals("jobId", result2.getName());
+ assertEquals("dummy2", result2.getValue());
+ }
+}
diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
similarity index 92%
rename from flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
rename to flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
index 4774bcc0b..db58993ef 100644
--- a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
+++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java
@@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
-package opensearch.flint.core.metrics.reporter;
+package org.opensearch.flint.core.metrics.reporter;
import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient;
import com.amazonaws.services.cloudwatch.model.Dimension;
@@ -16,8 +16,15 @@
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SlidingWindowReservoir;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.HashSet;
+import java.util.Map;
import java.util.Set;
+
+import org.apache.spark.metrics.sink.CloudWatchSink;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -36,8 +43,6 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
-import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter;
-import org.opensearch.flint.core.metrics.reporter.DimensionedName;
import static com.amazonaws.services.cloudwatch.model.StandardUnit.Count;
import static com.amazonaws.services.cloudwatch.model.StandardUnit.Microseconds;
@@ -49,16 +54,12 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
-import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_APPLICATION_ID;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_COUNT;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_GAUGE;
-import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_INSTANCE_ROLE;
-import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_JOB_ID;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_NAME_TYPE;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_MEAN;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_STD_DEV;
import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_SUMMARY;
-import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.UNKNOWN;
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
@@ -110,9 +111,8 @@ public void shouldReportWithoutGlobalDimensionsWhenGlobalDimensionsNotConfigured
final List dimensions = firstMetricDatumDimensionsFromCapturedRequest();
- assertThat(dimensions).hasSize(3);
+ assertThat(dimensions).hasSize(1);
assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT));
- assertDefaultDimensionsWithUnknownValue(dimensions);
}
@@ -124,7 +124,6 @@ public void reportedCounterShouldContainExpectedDimension() throws Exception {
final List dimensions = firstMetricDatumDimensionsFromCapturedRequest();
assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT));
- assertDefaultDimensionsWithUnknownValue(dimensions);
}
@Test
@@ -135,7 +134,6 @@ public void reportedGaugeShouldContainExpectedDimension() throws Exception {
final List dimensions = firstMetricDatumDimensionsFromCapturedRequest();
assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_GAUGE));
- assertDefaultDimensionsWithUnknownValue(dimensions);
}
@Test
@@ -483,7 +481,6 @@ public void shouldReportExpectedGlobalAndCustomDimensions() throws Exception {
assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2"));
assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1"));
assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2"));
- assertDefaultDimensionsWithUnknownValue(dimensions);
}
@Test
@@ -495,16 +492,14 @@ public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceDriverMetric()
.build().encode()).inc();
reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report();
final DimensionedCloudWatchReporter.MetricInfo metricInfo = firstMetricDatumInfoFromCapturedRequest();
- Set dimensions = metricInfo.getDimensions();
+ List> dimensionSets = metricInfo.getDimensionSets();
+ Set dimensions = dimensionSets.get(0);
assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2"));
assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1"));
assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2"));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue("driver"));
assertThat(metricInfo.getMetricName()).isEqualTo("LiveListenerBus.listenerProcessingTime.org.apache.spark.HeartbeatReceiver");
}
- @Test
+ @Test
public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceExecutorMetric() throws Exception {
//setting jobId as unknown to invoke name parsing.
metricRegistry.counter(DimensionedName.withName("unknown.1.NettyBlockTransfer.shuffle-client.usedDirectMemory")
@@ -514,23 +509,44 @@ public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceExecutorMetric(
reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report();
final DimensionedCloudWatchReporter.MetricInfo metricInfo = firstMetricDatumInfoFromCapturedRequest();
- Set dimensions = metricInfo.getDimensions();
+ Set dimensions = metricInfo.getDimensionSets().get(0);
assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2"));
assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1"));
assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2"));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue( "executor"));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN));
assertThat(metricInfo.getMetricName()).isEqualTo("NettyBlockTransfer.shuffle-client.usedDirectMemory");
}
+ @Test
+ public void shouldConsumeMultipleMetricDatumWithDimensionGroups() throws Exception {
+ // Setup
+ String metricSourceName = "TestSource";
+ Map>> dimensionGroups = new HashMap<>();
+ dimensionGroups.put(metricSourceName, Arrays.asList(
+ Arrays.asList("appName", "instanceRole"),
+ Arrays.asList("appName")
+ ));
+
+ metricRegistry.counter(DimensionedName.withName("unknown.1.TestSource.shuffle-client.usedDirectMemory")
+ .build().encode()).inc();
+
+ CloudWatchSink.DimensionNameGroups dimensionNameGroups = new CloudWatchSink.DimensionNameGroups();
+ dimensionNameGroups.setDimensionGroups(dimensionGroups);
+ reporterBuilder.withDimensionNameGroups(dimensionNameGroups).build().report();
+ final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue();
+ final List metricDatums = putMetricDataRequest.getMetricData();
+ assertThat(metricDatums).hasSize(2);
- private void assertDefaultDimensionsWithUnknownValue(List dimensions) {
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN));
- assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN));
- }
+ MetricDatum metricDatum1 = metricDatums.get(0);
+ Set dimensions1 = new HashSet(metricDatum1.getDimensions());
+ assertThat(dimensions1).contains(new Dimension().withName("appName").withValue("UNKNOWN"));
+ assertThat(dimensions1).contains(new Dimension().withName("instanceRole").withValue("executor"));
+ MetricDatum metricDatum2 = metricDatums.get(1);
+ Set dimensions2 = new HashSet(metricDatum2.getDimensions());
+ assertThat(dimensions2).contains(new Dimension().withName("appName").withValue("UNKNOWN"));
+ assertThat(dimensions2).doesNotContain(new Dimension().withName("instanceRole").withValue("executor"));
+ }
private MetricDatum metricDatumByDimensionFromCapturedRequest(final String dimensionValue) {
final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue();
@@ -564,7 +580,10 @@ private List firstMetricDatumDimensionsFromCapturedRequest() {
private DimensionedCloudWatchReporter.MetricInfo firstMetricDatumInfoFromCapturedRequest() {
final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue();
final MetricDatum metricDatum = putMetricDataRequest.getMetricData().get(0);
- return new DimensionedCloudWatchReporter.MetricInfo(metricDatum.getMetricName(), new HashSet<>(metricDatum.getDimensions()));
+ Set dimensions = new HashSet(metricDatum.getDimensions());
+ List> dimensionSet = new ArrayList<>();
+ dimensionSet.add(dimensions);
+ return new DimensionedCloudWatchReporter.MetricInfo(metricDatum.getMetricName(), dimensionSet);
}
private List allDimensionsFromCapturedRequest() {
diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java
similarity index 97%
rename from flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java
rename to flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java
index d6145545d..6bc6a9c2d 100644
--- a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java
+++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java
@@ -1,4 +1,4 @@
-package opensearch.flint.core.metrics.reporter;
+package org.opensearch.flint.core.metrics.reporter;
import static org.hamcrest.CoreMatchers.hasItems;
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala
index dc85affb1..fba818a0f 100644
--- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala
@@ -12,20 +12,19 @@ import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._
import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NO_LOG_ENTRY
-import org.opensearch.flint.spark.FlintSpark.RefreshMode.{AUTO, MANUAL, RefreshMode}
-import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, ID_COLUMN, StreamingRefresh}
+import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.AUTO
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, Row, SparkSession}
-import org.apache.spark.sql.SaveMode._
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
-import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
-import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}
+import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
/**
* Flint Spark integration API entrypoint.
@@ -130,8 +129,6 @@ class FlintSpark(val spark: SparkSession) extends Logging {
*
* @param indexName
* index name
- * @param mode
- * refresh mode
* @return
* refreshing job ID (empty if batch job for now)
*/
@@ -139,7 +136,7 @@ class FlintSpark(val spark: SparkSession) extends Logging {
logInfo(s"Refreshing Flint index $indexName")
val index = describeIndex(indexName)
.getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist"))
- val mode = if (index.options.autoRefresh()) AUTO else MANUAL
+ val indexRefresh = FlintSparkIndexRefresh.create(indexName, index)
try {
flintClient
@@ -149,17 +146,16 @@ class FlintSpark(val spark: SparkSession) extends Logging {
latest.copy(state = REFRESHING, createTime = System.currentTimeMillis()))
.finalLog(latest => {
// Change state to active if full, otherwise update index state regularly
- if (mode == MANUAL) {
- logInfo("Updating index state to active")
- latest.copy(state = ACTIVE)
- } else {
- // Schedule regular update and return log entry as refreshing state
+ if (indexRefresh.refreshMode == AUTO) {
logInfo("Scheduling index state monitor")
flintIndexMonitor.startMonitor(indexName)
latest
+ } else {
+ logInfo("Updating index state to active")
+ latest.copy(state = ACTIVE)
}
})
- .commit(_ => doRefreshIndex(index, indexName, mode))
+ .commit(_ => indexRefresh.start(spark, flintSparkConf))
} catch {
case e: Exception =>
logError("Failed to refresh Flint index", e)
@@ -292,7 +288,10 @@ class FlintSpark(val spark: SparkSession) extends Logging {
flintIndexMonitor.startMonitor(indexName)
latest.copy(state = REFRESHING)
})
- .commit(_ => doRefreshIndex(index.get, indexName, AUTO))
+ .commit(_ =>
+ FlintSparkIndexRefresh
+ .create(indexName, index.get)
+ .start(spark, flintSparkConf))
logInfo("Recovery complete")
true
@@ -333,67 +332,6 @@ class FlintSpark(val spark: SparkSession) extends Logging {
spark.read.format(FLINT_DATASOURCE).load(indexName)
}
- // TODO: move to separate class
- private def doRefreshIndex(
- index: FlintSparkIndex,
- indexName: String,
- mode: RefreshMode): Option[String] = {
- logInfo(s"Refreshing index $indexName in $mode mode")
- val options = index.options
- val tableName = index.metadata().source
-
- // Batch refresh Flint index from the given source data frame
- def batchRefresh(df: Option[DataFrame] = None): Unit = {
- index
- .build(spark, df)
- .write
- .format(FLINT_DATASOURCE)
- .options(flintSparkConf.properties)
- .mode(Overwrite)
- .save(indexName)
- }
-
- val jobId = mode match {
- case MANUAL =>
- logInfo("Start refreshing index in batch style")
- batchRefresh()
- None
-
- // Flint index has specialized logic and capability for incremental refresh
- case AUTO if index.isInstanceOf[StreamingRefresh] =>
- logInfo("Start refreshing index in streaming style")
- val job =
- index
- .asInstanceOf[StreamingRefresh]
- .buildStream(spark)
- .writeStream
- .queryName(indexName)
- .format(FLINT_DATASOURCE)
- .options(flintSparkConf.properties)
- .addSinkOptions(options)
- .start(indexName)
- Some(job.id.toString)
-
- // Otherwise, fall back to foreachBatch + batch refresh
- case AUTO =>
- logInfo("Start refreshing index in foreach streaming style")
- val job = spark.readStream
- .options(options.extraSourceOptions(tableName))
- .table(quotedTableName(tableName))
- .writeStream
- .queryName(indexName)
- .addSinkOptions(options)
- .foreachBatch { (batchDF: DataFrame, _: Long) =>
- batchRefresh(Some(batchDF))
- }
- .start()
- Some(job.id.toString)
- }
-
- logInfo("Refresh index complete")
- jobId
- }
-
private def stopRefreshingJob(indexName: String): Unit = {
logInfo(s"Terminating refreshing job $indexName")
val job = spark.streams.active.find(_.name == indexName)
@@ -403,48 +341,4 @@ class FlintSpark(val spark: SparkSession) extends Logging {
logWarning("Refreshing job not found")
}
}
-
- // Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API
- private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) {
-
- def addSinkOptions(options: FlintSparkIndexOptions): DataStreamWriter[Row] = {
- dataStream
- .addCheckpointLocation(options.checkpointLocation())
- .addRefreshInterval(options.refreshInterval())
- .addOutputMode(options.outputMode())
- .options(options.extraSinkOptions())
- }
-
- def addCheckpointLocation(checkpointLocation: Option[String]): DataStreamWriter[Row] = {
- checkpointLocation match {
- case Some(location) => dataStream.option("checkpointLocation", location)
- case None if flintSparkConf.isCheckpointMandatory =>
- throw new IllegalStateException(
- s"Checkpoint location is mandatory for incremental refresh if ${CHECKPOINT_MANDATORY.key} enabled")
- case _ => dataStream
- }
- }
-
- def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = {
- refreshInterval
- .map(interval => dataStream.trigger(Trigger.ProcessingTime(interval)))
- .getOrElse(dataStream)
- }
-
- def addOutputMode(outputMode: Option[String]): DataStreamWriter[Row] = {
- outputMode.map(dataStream.outputMode).getOrElse(dataStream)
- }
- }
-}
-
-object FlintSpark {
-
- /**
- * Index refresh mode: FULL: refresh on current source data in batch style at one shot
- * INCREMENTAL: auto refresh on new data in continuous streaming style
- */
- object RefreshMode extends Enumeration {
- type RefreshMode = Value
- val MANUAL, AUTO = Value
- }
}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala
index ffb479b54..9107a8a66 100644
--- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala
@@ -8,7 +8,7 @@ package org.opensearch.flint.spark
import org.json4s.{Formats, NoTypeHints}
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
-import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, WATERMARK_DELAY}
+import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, WATERMARK_DELAY}
import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames
/**
@@ -39,6 +39,15 @@ case class FlintSparkIndexOptions(options: Map[String, String]) {
*/
def refreshInterval(): Option[String] = getOptionValue(REFRESH_INTERVAL)
+ /**
+ * Is refresh incremental or full. This only applies to manual refresh.
+ *
+ * @return
+ * incremental option value
+ */
+ def incrementalRefresh(): Boolean =
+ getOptionValue(INCREMENTAL_REFRESH).getOrElse("false").toBoolean
+
/**
* The checkpoint location which maybe required by Flint index's refresh.
*
@@ -103,6 +112,9 @@ case class FlintSparkIndexOptions(options: Map[String, String]) {
if (!options.contains(AUTO_REFRESH.toString)) {
map += (AUTO_REFRESH.toString -> autoRefresh().toString)
}
+ if (!options.contains(INCREMENTAL_REFRESH.toString)) {
+ map += (INCREMENTAL_REFRESH.toString -> incrementalRefresh().toString)
+ }
map.result()
}
@@ -131,6 +143,7 @@ object FlintSparkIndexOptions {
type OptionName = Value
val AUTO_REFRESH: OptionName.Value = Value("auto_refresh")
val REFRESH_INTERVAL: OptionName.Value = Value("refresh_interval")
+ val INCREMENTAL_REFRESH: OptionName.Value = Value("incremental_refresh")
val CHECKPOINT_LOCATION: OptionName.Value = Value("checkpoint_location")
val WATERMARK_DELAY: OptionName.Value = Value("watermark_delay")
val OUTPUT_MODE: OptionName.Value = Value("output_mode")
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala
new file mode 100644
index 000000000..09428f80d
--- /dev/null
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala
@@ -0,0 +1,111 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.spark.refresh
+
+import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions}
+import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh}
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode}
+
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
+import org.apache.spark.sql.flint.config.FlintSparkConf
+import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY
+import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}
+
+/**
+ * Index refresh that auto refreshes the index by index options provided.
+ *
+ * @param indexName
+ * Flint index name
+ * @param index
+ * Flint index
+ */
+class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintSparkIndexRefresh {
+
+ override def refreshMode: RefreshMode = AUTO
+
+ override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
+ val options = index.options
+ val tableName = index.metadata().source
+ index match {
+ // Flint index has specialized logic and capability for incremental refresh
+ case refresh: StreamingRefresh =>
+ logInfo("Start refreshing index in streaming style")
+ val job =
+ refresh
+ .buildStream(spark)
+ .writeStream
+ .queryName(indexName)
+ .format(FLINT_DATASOURCE)
+ .options(flintSparkConf.properties)
+ .addSinkOptions(options, flintSparkConf)
+ .start(indexName)
+ Some(job.id.toString)
+
+ // Otherwise, fall back to foreachBatch + batch refresh
+ case _ =>
+ logInfo("Start refreshing index in foreach streaming style")
+ val job = spark.readStream
+ .options(options.extraSourceOptions(tableName))
+ .table(quotedTableName(tableName))
+ .writeStream
+ .queryName(indexName)
+ .addSinkOptions(options, flintSparkConf)
+ .foreachBatch { (batchDF: DataFrame, _: Long) =>
+ new FullIndexRefresh(indexName, index, Some(batchDF))
+ .start(spark, flintSparkConf)
+ () // discard return value above and return unit to use right overridden method
+ }
+ .start()
+ Some(job.id.toString)
+ }
+ }
+
+ // Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API
+ private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) {
+
+ def addSinkOptions(
+ options: FlintSparkIndexOptions,
+ flintSparkConf: FlintSparkConf): DataStreamWriter[Row] = {
+ dataStream
+ .addCheckpointLocation(options.checkpointLocation(), flintSparkConf.isCheckpointMandatory)
+ .addRefreshInterval(options.refreshInterval())
+ .addAvailableNowTrigger(options.incrementalRefresh())
+ .addOutputMode(options.outputMode())
+ .options(options.extraSinkOptions())
+ }
+
+ def addCheckpointLocation(
+ checkpointLocation: Option[String],
+ isCheckpointMandatory: Boolean): DataStreamWriter[Row] = {
+ checkpointLocation match {
+ case Some(location) => dataStream.option("checkpointLocation", location)
+ case None if isCheckpointMandatory =>
+ throw new IllegalStateException(
+ s"Checkpoint location is mandatory for incremental refresh if ${CHECKPOINT_MANDATORY.key} enabled")
+ case _ => dataStream
+ }
+ }
+
+ def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = {
+ refreshInterval
+ .map(interval => dataStream.trigger(Trigger.ProcessingTime(interval)))
+ .getOrElse(dataStream)
+ }
+
+ def addAvailableNowTrigger(incrementalRefresh: Boolean): DataStreamWriter[Row] = {
+ if (incrementalRefresh) {
+ dataStream.trigger(Trigger.AvailableNow())
+ } else {
+ dataStream
+ }
+ }
+
+ def addOutputMode(outputMode: Option[String]): DataStreamWriter[Row] = {
+ outputMode.map(dataStream.outputMode).getOrElse(dataStream)
+ }
+ }
+}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala
new file mode 100644
index 000000000..3c929d8e3
--- /dev/null
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala
@@ -0,0 +1,68 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.spark.refresh
+
+import org.opensearch.flint.spark.FlintSparkIndex
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.RefreshMode
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.flint.config.FlintSparkConf
+
+/**
+ * Flint Spark index refresh that sync index data with source in style defined by concrete
+ * implementation class.
+ */
+trait FlintSparkIndexRefresh extends Logging {
+
+ /**
+ * @return
+ * refresh mode
+ */
+ def refreshMode: RefreshMode
+
+ /**
+ * Start refreshing the index.
+ *
+ * @param spark
+ * Spark session to submit job
+ * @param flintSparkConf
+ * Flint Spark configuration
+ * @return
+ * optional Spark job ID
+ */
+ def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String]
+}
+
+object FlintSparkIndexRefresh {
+
+ /** Index refresh mode */
+ object RefreshMode extends Enumeration {
+ type RefreshMode = Value
+ val AUTO, FULL, INCREMENTAL = Value
+ }
+
+ /**
+ * Create concrete index refresh implementation for the given index.
+ *
+ * @param indexName
+ * Flint index name
+ * @param index
+ * Flint index
+ * @return
+ * index refresh
+ */
+ def create(indexName: String, index: FlintSparkIndex): FlintSparkIndexRefresh = {
+ val options = index.options
+ if (options.autoRefresh()) {
+ new AutoIndexRefresh(indexName, index)
+ } else if (options.incrementalRefresh()) {
+ new IncrementalIndexRefresh(indexName, index)
+ } else {
+ new FullIndexRefresh(indexName, index)
+ }
+ }
+}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala
new file mode 100644
index 000000000..be09c2c36
--- /dev/null
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala
@@ -0,0 +1,45 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.spark.refresh
+
+import org.opensearch.flint.spark.FlintSparkIndex
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{FULL, RefreshMode}
+
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.SaveMode.Overwrite
+import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
+import org.apache.spark.sql.flint.config.FlintSparkConf
+
+/**
+ * Index refresh that fully refreshes the index from the given source data frame.
+ *
+ * @param indexName
+ * Flint index name
+ * @param index
+ * Flint index
+ * @param source
+ * refresh from this data frame representing a micro batch or from the beginning
+ */
+class FullIndexRefresh(
+ indexName: String,
+ index: FlintSparkIndex,
+ source: Option[DataFrame] = None)
+ extends FlintSparkIndexRefresh {
+
+ override def refreshMode: RefreshMode = FULL
+
+ override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
+ logInfo(s"Start refreshing index $indexName in full mode")
+ index
+ .build(spark, source)
+ .write
+ .format(FLINT_DATASOURCE)
+ .options(flintSparkConf.properties)
+ .mode(Overwrite)
+ .save(indexName)
+ None
+ }
+}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala
new file mode 100644
index 000000000..418ada902
--- /dev/null
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala
@@ -0,0 +1,45 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.spark.refresh
+
+import org.opensearch.flint.spark.FlintSparkIndex
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode}
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.flint.config.FlintSparkConf
+
+/**
+ * Index refresh that incrementally refreshes the index from the last checkpoint.
+ *
+ * @param indexName
+ * Flint index name
+ * @param index
+ * Flint index
+ */
+class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)
+ extends FlintSparkIndexRefresh {
+
+ override def refreshMode: RefreshMode = INCREMENTAL
+
+ override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
+ logInfo(s"Start refreshing index $indexName in incremental mode")
+
+ // TODO: move this to validation method together in future
+ if (index.options.checkpointLocation().isEmpty) {
+ throw new IllegalStateException("Checkpoint location is required by incremental refresh")
+ }
+
+ // Reuse auto refresh which uses AvailableNow trigger and will stop once complete
+ val jobId =
+ new AutoIndexRefresh(indexName, index)
+ .start(spark, flintSparkConf)
+
+ spark.streams
+ .get(jobId.get)
+ .awaitTermination()
+ None
+ }
+}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala
index eae401a69..14fa21240 100644
--- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala
@@ -7,7 +7,6 @@ package org.opensearch.flint.spark.sql.covering
import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
-import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala
index a67803a18..5b31890bb 100644
--- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala
@@ -9,7 +9,6 @@ import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
-import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala
index 73bff5cba..9b638f36f 100644
--- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala
+++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala
@@ -9,7 +9,6 @@ import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
-import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala
index b678096ca..212d91e13 100644
--- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala
+++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala
@@ -15,6 +15,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
test("should return lowercase name as option name") {
AUTO_REFRESH.toString shouldBe "auto_refresh"
REFRESH_INTERVAL.toString shouldBe "refresh_interval"
+ INCREMENTAL_REFRESH.toString shouldBe "incremental_refresh"
CHECKPOINT_LOCATION.toString shouldBe "checkpoint_location"
WATERMARK_DELAY.toString shouldBe "watermark_delay"
OUTPUT_MODE.toString shouldBe "output_mode"
@@ -27,6 +28,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
Map(
"auto_refresh" -> "true",
"refresh_interval" -> "1 Minute",
+ "incremental_refresh" -> "true",
"checkpoint_location" -> "s3://test/",
"watermark_delay" -> "30 Seconds",
"output_mode" -> "complete",
@@ -44,6 +46,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.autoRefresh() shouldBe true
options.refreshInterval() shouldBe Some("1 Minute")
+ options.incrementalRefresh() shouldBe true
options.checkpointLocation() shouldBe Some("s3://test/")
options.watermarkDelay() shouldBe Some("30 Seconds")
options.outputMode() shouldBe Some("complete")
@@ -85,6 +88,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.optionsWithDefault shouldBe Map(
"auto_refresh" -> "false",
+ "incremental_refresh" -> "false",
"refresh_interval" -> "1 Minute")
}
diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexRefreshSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexRefreshSuite.scala
new file mode 100644
index 000000000..e9226e1c8
--- /dev/null
+++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexRefreshSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.flint.spark
+
+import org.mockito.Mockito.{when, RETURNS_DEEP_STUBS}
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh
+import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode._
+import org.scalatest.matchers.should.Matchers
+import org.scalatestplus.mockito.MockitoSugar.mock
+
+import org.apache.spark.FlintSuite
+
+class FlintSparkIndexRefreshSuite extends FlintSuite with Matchers {
+
+ /** Test index name */
+ val indexName: String = "test"
+
+ /** Mock Flint index */
+ var index: FlintSparkIndex = _
+
+ override def beforeEach(): Unit = {
+ index = mock[FlintSparkIndex](RETURNS_DEEP_STUBS)
+ }
+
+ test("should auto refresh if auto refresh option enabled") {
+ when(index.options.autoRefresh()).thenReturn(true)
+
+ val refresh = FlintSparkIndexRefresh.create(indexName, index)
+ refresh.refreshMode shouldBe AUTO
+ }
+
+ test("should full refresh if both auto and incremental refresh option disabled") {
+ when(index.options.autoRefresh()).thenReturn(false)
+ when(index.options.incrementalRefresh()).thenReturn(false)
+
+ val refresh = FlintSparkIndexRefresh.create(indexName, index)
+ refresh.refreshMode shouldBe FULL
+ }
+
+ test(
+ "should incremental refresh if auto refresh disabled but incremental refresh option enabled") {
+ when(index.options.autoRefresh()).thenReturn(false)
+ when(index.options.incrementalRefresh()).thenReturn(true)
+
+ val refresh = FlintSparkIndexRefresh.create(indexName, index)
+ refresh.refreshMode shouldBe INCREMENTAL
+ }
+}
diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala
index b7746d44a..c1df42883 100644
--- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala
+++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala
@@ -5,12 +5,12 @@
package org.opensearch.flint.spark.mv
-import scala.collection.JavaConverters.mapAsJavaMapConverter
+import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter}
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan}
-import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}
+import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, the}
import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.spark.FlintSuite
@@ -77,9 +77,8 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
Map("test_col" -> "integer"),
indexOptions)
- mv.metadata().options shouldBe Map(
- "auto_refresh" -> "true",
- "index_settings" -> indexSettings).asJava
+ mv.metadata().options.asScala should contain allOf ("auto_refresh" -> "true",
+ "index_settings" -> indexSettings)
mv.metadata().indexSettings shouldBe Some(indexSettings)
}
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala
index d1996359f..a77d261cd 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala
@@ -59,7 +59,10 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
| "columnType": "int"
| }],
| "source": "spark_catalog.default.ci_test",
- | "options": { "auto_refresh": "false" },
+ | "options": {
+ | "auto_refresh": "false",
+ | "incremental_refresh": "false"
+ | },
| "properties": {
| "filterCondition": "age > 30"
| }
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala
index 450da14c9..3c9e06257 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala
@@ -136,7 +136,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
}
}
- test("create covering index with manual refresh") {
+ test("create covering index with full refresh") {
sql(s"""
| CREATE INDEX $testIndex ON $testTable
| (name, age)
@@ -151,6 +151,35 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
indexData.count() shouldBe 2
}
+ test("create covering index with incremental refresh") {
+ withTempDir { checkpointDir =>
+ sql(s"""
+ | CREATE INDEX $testIndex ON $testTable
+ | (name, age)
+ | WITH (
+ | incremental_refresh = true,
+ | checkpoint_location = '${checkpointDir.getAbsolutePath}'
+ | )
+ | """.stripMargin)
+
+ // Refresh all present source data as of now
+ sql(s"REFRESH INDEX $testIndex ON $testTable")
+ flint.queryIndex(testFlintIndex).count() shouldBe 2
+
+ // New data won't be refreshed until refresh statement triggered
+ sql(s"""
+ | INSERT INTO $testTable
+ | PARTITION (year=2023, month=5)
+ | VALUES ('Hello', 50, 'Vancouver')
+ |""".stripMargin)
+ flint.queryIndex(testFlintIndex).count() shouldBe 2
+
+ // New data is refreshed incrementally
+ sql(s"REFRESH INDEX $testIndex ON $testTable")
+ flint.queryIndex(testFlintIndex).count() shouldBe 3
+ }
+ }
+
test("create covering index on table without database name") {
sql(s"CREATE INDEX $testIndex ON covering_sql_test (name)")
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala
index 4df6dc55b..586b4e877 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala
@@ -73,6 +73,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
| }],
| "options": {
| "auto_refresh": "true",
+ | "incremental_refresh": "false",
| "checkpoint_location": "s3://test/",
| "watermark_delay": "30 Seconds"
| },
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala
index ed702c7a1..20b7f3d55 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala
@@ -129,7 +129,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
(settings \ "index.number_of_replicas").extract[String] shouldBe "2"
}
- test("create materialized view with manual refresh") {
+ test("create materialized view with full refresh") {
sql(s"""
| CREATE MATERIALIZED VIEW $testMvName
| AS $testQuery
@@ -146,6 +146,35 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
indexData.count() shouldBe 4
}
+ test("create materialized view with incremental refresh") {
+ withTempDir { checkpointDir =>
+ sql(s"""
+ | CREATE MATERIALIZED VIEW $testMvName
+ | AS $testQuery
+ | WITH (
+ | incremental_refresh = true,
+ | checkpoint_location = '${checkpointDir.getAbsolutePath}',
+ | watermark_delay = '1 Second'
+ | )
+ | """.stripMargin)
+
+ // Refresh all present source data as of now
+ sql(s"REFRESH MATERIALIZED VIEW $testMvName")
+ flint.queryIndex(testFlintIndex).count() shouldBe 3
+
+ // New data won't be refreshed until refresh statement triggered
+ sql(s"""
+ | INSERT INTO $testTable VALUES
+ | (TIMESTAMP '2023-10-01 04:00:00', 'F', 25, 'Vancouver')
+ | """.stripMargin)
+ flint.queryIndex(testFlintIndex).count() shouldBe 3
+
+ // New data is refreshed incrementally
+ sql(s"REFRESH MATERIALIZED VIEW $testMvName")
+ flint.queryIndex(testFlintIndex).count() shouldBe 4
+ }
+ }
+
test("create materialized view if not exists") {
sql(s"CREATE MATERIALIZED VIEW IF NOT EXISTS $testMvName AS $testQuery")
flint.describeIndex(testFlintIndex) shouldBe defined
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala
index 789b07c0c..99c4b9a42 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala
@@ -7,11 +7,14 @@ package org.opensearch.flint.spark
import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson
import org.json4s.native.JsonMethods._
+import org.opensearch.client.RequestOptions
import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.bloom_filter_might_contain
+import org.opensearch.index.query.QueryBuilders
+import org.opensearch.index.reindex.DeleteByQueryRequest
import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.must.Matchers._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
@@ -28,9 +31,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
private val testTable = "spark_catalog.default.test"
private val testIndex = getSkippingIndexName(testTable)
- override def beforeAll(): Unit = {
- super.beforeAll()
-
+ override def beforeEach(): Unit = {
+ super.beforeEach()
createPartitionedMultiRowTable(testTable)
}
@@ -39,6 +41,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
// Delete all test indices
deleteTestIndex(testIndex)
+ sql(s"DROP TABLE $testTable")
}
test("create skipping index with metadata successfully") {
@@ -93,7 +96,10 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
| "columnType": "string"
| }],
| "source": "spark_catalog.default.test",
- | "options": { "auto_refresh": "false" },
+ | "options": {
+ | "auto_refresh": "false",
+ | "incremental_refresh": "false"
+ | },
| "properties": {}
| },
| "properties": {
@@ -122,7 +128,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
| }
|""".stripMargin)
- index.get.options shouldBe FlintSparkIndexOptions(Map("auto_refresh" -> "false"))
+ index.get.options shouldBe FlintSparkIndexOptions(
+ Map("auto_refresh" -> "false", "incremental_refresh" -> "false"))
}
test("create skipping index with index options successfully") {
@@ -143,6 +150,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
optionJson should matchJson("""
| {
| "auto_refresh": "true",
+ | "incremental_refresh": "false",
| "refresh_interval": "1 Minute",
| "checkpoint_location": "s3a://test/",
| "index_settings": "{\"number_of_shards\": 3,\"number_of_replicas\": 2}"
@@ -185,6 +193,51 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
}
test("incremental refresh skipping index successfully") {
+ withTempDir { checkpointDir =>
+ flint
+ .skippingIndex()
+ .onTable(testTable)
+ .addPartitions("year", "month")
+ .options(
+ FlintSparkIndexOptions(
+ Map(
+ "incremental_refresh" -> "true",
+ "checkpoint_location" -> checkpointDir.getAbsolutePath)))
+ .create()
+
+ flint.refreshIndex(testIndex) shouldBe empty
+ flint.queryIndex(testIndex).collect().toSet should have size 2
+
+ // Delete all index data intentionally and generate a new source file
+ openSearchClient.deleteByQuery(
+ new DeleteByQueryRequest(testIndex).setQuery(QueryBuilders.matchAllQuery()),
+ RequestOptions.DEFAULT)
+ sql(s"""
+ | INSERT INTO $testTable
+ | PARTITION (year=2023, month=4)
+ | VALUES ('Hello', 35, 'Vancouver')
+ | """.stripMargin)
+
+ // Expect to only refresh the new file
+ flint.refreshIndex(testIndex) shouldBe empty
+ flint.queryIndex(testIndex).collect().toSet should have size 1
+ }
+ }
+
+ test("should fail if incremental refresh without checkpoint location") {
+ flint
+ .skippingIndex()
+ .onTable(testTable)
+ .addPartitions("year", "month")
+ .options(FlintSparkIndexOptions(Map("incremental_refresh" -> "true")))
+ .create()
+
+ assertThrows[IllegalStateException] {
+ flint.refreshIndex(testIndex)
+ }
+ }
+
+ test("auto refresh skipping index successfully") {
// Create Flint index and wait for complete
flint
.skippingIndex()
@@ -581,7 +634,10 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
| "columnType": "struct"
| }],
| "source": "$testTable",
- | "options": { "auto_refresh": "false" },
+ | "options": {
+ | "auto_refresh": "false",
+ | "incremental_refresh": "false"
+ | },
| "properties": {}
| },
| "properties": {
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala
index 3f94762a5..ca14a555c 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala
@@ -27,7 +27,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
private val testTable = "spark_catalog.default.skipping_sql_test"
private val testIndex = getSkippingIndexName(testTable)
- override def beforeAll(): Unit = {
+ override def beforeEach(): Unit = {
super.beforeAll()
createPartitionedMultiRowTable(testTable)
@@ -37,6 +37,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
super.afterEach()
deleteTestIndex(testIndex)
+ sql(s"DROP TABLE $testTable")
}
test("create skipping index with auto refresh") {
@@ -142,7 +143,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
}
}
- test("create skipping index with manual refresh") {
+ test("create skipping index with full refresh") {
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
@@ -161,6 +162,34 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
indexData.count() shouldBe 2
}
+ test("create skipping index with incremental refresh") {
+ withTempDir { checkpointDir =>
+ sql(s"""
+ | CREATE SKIPPING INDEX ON $testTable
+ | ( year PARTITION )
+ | WITH (
+ | incremental_refresh = true,
+ | checkpoint_location = '${checkpointDir.getAbsolutePath}'
+ | )
+ | """.stripMargin)
+
+ // Refresh all present source data as of now
+ sql(s"REFRESH SKIPPING INDEX ON $testTable")
+ flint.queryIndex(testIndex).count() shouldBe 2
+
+ // New data won't be refreshed until refresh statement triggered
+ sql(s"""
+ | INSERT INTO $testTable
+ | PARTITION (year=2023, month=5)
+ | VALUES ('Hello', 50, 'Vancouver')
+ |""".stripMargin)
+ flint.queryIndex(testIndex).count() shouldBe 2
+
+ sql(s"REFRESH SKIPPING INDEX ON $testTable")
+ flint.queryIndex(testIndex).count() shouldBe 3
+ }
+ }
+
test("should fail if refresh an auto refresh skipping index") {
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala
index a2b93648e..b27275539 100644
--- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala
+++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala
@@ -64,7 +64,7 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
(parse(mapping) \ "_meta" \ "latestId").extract[String] shouldBe testLatestId
}
- test("manual refresh index") {
+ test("full refresh index") {
flint
.skippingIndex()
.onTable(testTable)
@@ -78,6 +78,26 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match
}
test("incremental refresh index") {
+ withTempDir { checkpointDir =>
+ flint
+ .skippingIndex()
+ .onTable(testTable)
+ .addPartitions("year", "month")
+ .options(
+ FlintSparkIndexOptions(
+ Map(
+ "incremental_refresh" -> "true",
+ "checkpoint_location" -> checkpointDir.getAbsolutePath)))
+ .create()
+ flint.refreshIndex(testFlintIndex)
+
+ val latest = latestLogEntry(testLatestId)
+ latest should contain("state" -> "active")
+ latest("jobStartTime").asInstanceOf[Number].longValue() should be > 0L
+ }
+ }
+
+ test("auto refresh index") {
flint
.skippingIndex()
.onTable(testTable)