diff --git a/build.sbt b/build.sbt index 73fb481a6..f7653c50c 100644 --- a/build.sbt +++ b/build.sbt @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind"), "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), - "software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided", + "software.amazon.awssdk" % "auth-crt" % "2.28.10", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", @@ -117,7 +117,7 @@ lazy val flintCommons = (project in file("flint-commons")) "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", - "org.projectlombok" % "lombok" % "1.18.30", + "org.projectlombok" % "lombok" % "1.18.30" % "provided", ), libraryDependencies ++= deps(sparkVersion), publish / skip := true, diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index c553d483f..8e6cbaae9 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -2,6 +2,11 @@ #### **Describe** - `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command +- `describe schema.table` +- `` describe schema.`table` `` +- `describe catalog.schema.table` +- `` describe catalog.schema.`table` `` +- `` describe `catalog`.`schema`.`table` `` #### **Explain** - `explain simple | source = table | where a = 1 | fields a,b,c` @@ -268,7 +273,7 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ **SQL Migration examples with IN-Subquery PPL:** -1. tpch q4 (in-subquery with aggregation) +tpch q4 (in-subquery with aggregation) ```sql select o_orderpriority, @@ -304,52 +309,21 @@ source = orders | fields o_orderpriority, order_count ``` -2.tpch q20 (nested in-subquery) -```sql -select - s_name, - s_address -from - supplier, - nation -where - s_suppkey in ( - select - ps_suppkey - from - partsupp - where - ps_partkey in ( - select - p_partkey - from - part - where - p_name like 'forest%' - ) - ) - and s_nationkey = n_nationkey - and n_name = 'CANADA' -order by - s_name -``` +#### **ExistsSubquery** +[See additional command details](ppl-subquery-command.md) + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 +- `source = outer | where exists [ source = inner | where a = c ]` +- `source = outer | where not exists [ source = inner | where a = c ]` +- `source = outer | where exists [ source = inner | where a = c and b = d ]` +- `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where not exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where exists [ source = inner ] | eval l = "Bala" | fields l` (special uncorrelated exists) + -Rewritten by PPL InSubquery query: -```sql -source = supplier -| where s_suppkey IN [ - source = partsupp - | where ps_partkey IN [ - source = part - | where like(p_name, "forest%") - | fields p_partkey - ] - | fields ps_suppkey - ] -| inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' - nation -| sort s_name -``` #### **ScalarSubquery** [See additional command details](ppl-subquery-command.md) diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index 1762306d2..ac0f98fe8 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -112,6 +112,58 @@ source = supplier | sort s_name ``` +**ExistsSubquery usage** + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 + +- `source = outer | where exists [ source = inner | where a = c ]` +- `source = outer | where not exists [ source = inner | where a = c ]` +- `source = outer | where exists [ source = inner | where a = c and b = d ]` +- `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where not exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where exists [ source = inner ] | eval l = "nonEmpty" | fields l` (special uncorrelated exists) + +**_SQL Migration examples with Exists-Subquery PPL:_** + +tpch q4 (exists subquery with aggregation) +```sql +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + l_orderkey + from + lineitem + where l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +``` +Rewritten by PPL ExistsSubquery query: +```sql +source = orders +| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" + and exists [ + source = lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count(1) as order_count by o_orderpriority +| sort o_orderpriority +| fields o_orderpriority, order_count +``` + **ScalarSubquery usage** Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested @@ -191,14 +243,14 @@ source = spark_catalog.default.outer ### **Additional Context** -The most cases in the description is to request a `InSubquery` expression. +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expression. The common usage of subquery expression is in `where` clause: The `where` command syntax is: ``` | where ``` -So the subquery in description is part of boolean expression, such as +So the subquery is part of boolean expression, such as ```sql | where orders.order_id in (subquery source=returns | where return_reason="damaged" | return order_id) @@ -217,10 +269,11 @@ In issue description is a `ScalarSubquery`: ```sql source=employees | join source=sales on employees.employee_id = sales.employee_id -| where sales.sale_amount > (subquery source=targets | where target_met="true" | return target_value) +| where sales.sale_amount > [ source=targets | where target_met="true" | fields target_value ] ``` -Recall the join command doc: https://github.com/opensearch-project/opensearch-spark/blob/main/docs/PPL-Join-command.md#more-examples, the example is a subquery/subsearch **plan**, rather than a **expression**. +But `RelationSubquery` is not a subquery expression, it is a subquery plan. +[Recall the join command doc](ppl-join-command.md), the example is a subquery/subsearch **plan**, rather than a **expression**. ```sql SEARCH source=customer @@ -245,7 +298,32 @@ SEARCH Apply the syntax here and simply into ```sql -search | left join on (subquery search ...) +search | left join on [ search ... ] ``` -The `(subquery search ...)` is not a `expression`, it's `plan`, similar to the `relation` plan \ No newline at end of file +The `[ search ...]` is not a `expression`, it's `plan`, similar to the `relation` plan + +**Uncorrelated Subquery** + +An uncorrelated subquery is independent of the outer query. It is executed once, and the result is used by the outer query. +It's **less common** when using `ExistsSubquery` because `ExistsSubquery` typically checks for the presence of rows that are dependent on the outer query’s row. + +There is a very special exists subquery which highlight by `(special uncorrelated exists)`: +```sql +SELECT 'nonEmpty' +FROM outer + WHERE EXISTS ( + SELECT * + FROM inner + ); +``` +Rewritten by PPL ExistsSubquery query: +```sql +source = outer +| where exists [ + source = inner + ] +| eval l = "nonEmpty" +| fields l +``` +This query just print "nonEmpty" if the inner table is not empty. \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java index 8e63992f5..81a482d5e 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricsUtil.java @@ -11,6 +11,7 @@ import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; +import org.apache.spark.metrics.source.FlintIndexMetricSource; import org.apache.spark.metrics.source.Source; import scala.collection.Seq; @@ -33,10 +34,20 @@ private MetricsUtil() { * If the counter does not exist, it is created before being incremented. * * @param metricName The name of the metric for which the counter is incremented. - * This name is used to retrieve or create the counter. */ public static void incrementCounter(String metricName) { - Counter counter = getOrCreateCounter(metricName); + incrementCounter(metricName, false); + } + + /** + * Increments the Counter metric associated with the given metric name. + * If the counter does not exist, it is created before being incremented. + * + * @param metricName The name of the metric for which the counter is incremented. + * @param isIndexMetric Whether this metric is an index-specific metric. + */ + public static void incrementCounter(String metricName, boolean isIndexMetric) { + Counter counter = getOrCreateCounter(metricName, isIndexMetric); if (counter != null) { counter.inc(); } @@ -48,7 +59,17 @@ public static void incrementCounter(String metricName) { * @param metricName The name of the metric counter to be decremented. */ public static void decrementCounter(String metricName) { - Counter counter = getOrCreateCounter(metricName); + decrementCounter(metricName, false); + } + + /** + * Decrements the value of the specified metric counter by one, if the counter exists and its current count is greater than zero. + * + * @param metricName The name of the metric counter to be decremented. + * @param isIndexMetric Whether this metric is an index-specific metric. + */ + public static void decrementCounter(String metricName, boolean isIndexMetric) { + Counter counter = getOrCreateCounter(metricName, isIndexMetric); if (counter != null && counter.getCount() > 0) { counter.dec(); } @@ -56,21 +77,30 @@ public static void decrementCounter(String metricName) { /** * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. - * This context can be used to measure the duration of a particular operation or event. * * @param metricName The name of the metric timer to retrieve the context for. * @return A {@link Timer.Context} instance for timing operations, or {@code null} if the timer could not be created or retrieved. */ public static Timer.Context getTimerContext(String metricName) { - Timer timer = getOrCreateTimer(metricName); + return getTimerContext(metricName, false); + } + + /** + * Retrieves a {@link Timer.Context} for the specified metric name, creating a new timer if one does not already exist. + * + * @param metricName The name of the metric timer to retrieve the context for. + * @param isIndexMetric Whether this metric is an index-specific metric. + * @return A {@link Timer.Context} instance for timing operations, or {@code null} if the timer could not be created or retrieved. + */ + public static Timer.Context getTimerContext(String metricName, boolean isIndexMetric) { + Timer timer = getOrCreateTimer(metricName, isIndexMetric); return timer != null ? timer.time() : null; } /** - * Stops the timer associated with the given {@link Timer.Context}, effectively recording the elapsed time since the timer was started - * and returning the duration. If the context is {@code null}, this method does nothing and returns {@code null}. + * Stops the timer associated with the given {@link Timer.Context}. * - * @param context The {@link Timer.Context} to stop. May be {@code null}, in which case this method has no effect and returns {@code null}. + * @param context The {@link Timer.Context} to stop. May be {@code null}. * @return The elapsed time in nanoseconds since the timer was started, or {@code null} if the context was {@code null}. */ public static Long stopTimer(Timer.Context context) { @@ -79,13 +109,23 @@ public static Long stopTimer(Timer.Context context) { /** * Registers a gauge metric with the provided name and value. - * The gauge will reflect the current value of the AtomicInteger provided. * * @param metricName The name of the gauge metric to register. - * @param value The AtomicInteger whose current value should be reflected by the gauge. + * @param value The AtomicInteger whose current value should be reflected by the gauge. */ public static void registerGauge(String metricName, final AtomicInteger value) { - MetricRegistry metricRegistry = getMetricRegistry(); + registerGauge(metricName, value, false); + } + + /** + * Registers a gauge metric with the provided name and value. + * + * @param metricName The name of the gauge metric to register. + * @param value The AtomicInteger whose current value should be reflected by the gauge. + * @param isIndexMetric Whether this metric is an index-specific metric. + */ + public static void registerGauge(String metricName, final AtomicInteger value, boolean isIndexMetric) { + MetricRegistry metricRegistry = getMetricRegistry(isIndexMetric); if (metricRegistry == null) { LOG.warning("MetricRegistry not available, cannot register gauge: " + metricName); return; @@ -93,39 +133,37 @@ public static void registerGauge(String metricName, final AtomicInteger value) { metricRegistry.register(metricName, (Gauge) value::get); } - // Retrieves or creates a new counter for the given metric name - private static Counter getOrCreateCounter(String metricName) { - MetricRegistry metricRegistry = getMetricRegistry(); + private static Counter getOrCreateCounter(String metricName, boolean isIndexMetric) { + MetricRegistry metricRegistry = getMetricRegistry(isIndexMetric); return metricRegistry != null ? metricRegistry.counter(metricName) : null; } - // Retrieves or creates a new Timer for the given metric name - private static Timer getOrCreateTimer(String metricName) { - MetricRegistry metricRegistry = getMetricRegistry(); + private static Timer getOrCreateTimer(String metricName, boolean isIndexMetric) { + MetricRegistry metricRegistry = getMetricRegistry(isIndexMetric); return metricRegistry != null ? metricRegistry.timer(metricName) : null; } - // Retrieves the MetricRegistry from the current Spark environment. - private static MetricRegistry getMetricRegistry() { + private static MetricRegistry getMetricRegistry(boolean isIndexMetric) { SparkEnv sparkEnv = SparkEnv.get(); if (sparkEnv == null) { LOG.warning("Spark environment not available, cannot access MetricRegistry."); return null; } - FlintMetricSource flintMetricSource = getOrInitFlintMetricSource(sparkEnv); - return flintMetricSource.metricRegistry(); + Source metricSource = isIndexMetric ? + getOrInitMetricSource(sparkEnv, FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME(), FlintIndexMetricSource::new) : + getOrInitMetricSource(sparkEnv, FlintMetricSource.FLINT_METRIC_SOURCE_NAME(), FlintMetricSource::new); + return metricSource.metricRegistry(); } - // Gets or initializes the FlintMetricSource - private static FlintMetricSource getOrInitFlintMetricSource(SparkEnv sparkEnv) { - Seq metricSourceSeq = sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()); + private static Source getOrInitMetricSource(SparkEnv sparkEnv, String sourceName, java.util.function.Supplier sourceSupplier) { + Seq metricSourceSeq = sparkEnv.metricsSystem().getSourcesByName(sourceName); if (metricSourceSeq == null || metricSourceSeq.isEmpty()) { - FlintMetricSource metricSource = new FlintMetricSource(); + Source metricSource = sourceSupplier.get(); sparkEnv.metricsSystem().registerSource(metricSource); return metricSource; } - return (FlintMetricSource) metricSourceSeq.head(); + return metricSourceSeq.head(); } } diff --git a/flint-core/src/main/scala/apache/spark/metrics/source/FlintMetricSource.scala b/flint-core/src/main/scala/apache/spark/metrics/source/FlintMetricSource.scala index d5f241572..7bdfa11e6 100644 --- a/flint-core/src/main/scala/apache/spark/metrics/source/FlintMetricSource.scala +++ b/flint-core/src/main/scala/apache/spark/metrics/source/FlintMetricSource.scala @@ -7,13 +7,25 @@ package org.apache.spark.metrics.source import com.codahale.metrics.MetricRegistry -class FlintMetricSource() extends Source { +/** + * Metric source for general Flint metrics. + */ +class FlintMetricSource extends Source { // Implementing the Source trait override val sourceName: String = FlintMetricSource.FLINT_METRIC_SOURCE_NAME override val metricRegistry: MetricRegistry = new MetricRegistry } +/** + * Metric source for Flint index-specific metrics. + */ +class FlintIndexMetricSource extends Source { + override val sourceName: String = FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME + override val metricRegistry: MetricRegistry = new MetricRegistry +} + object FlintMetricSource { val FLINT_METRIC_SOURCE_NAME = "Flint" // Default source name + val FLINT_INDEX_METRIC_SOURCE_NAME = "FlintIndex" // Index specific source name } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index e505cf45d..6ddc6ae9c 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -105,8 +105,11 @@ public class FlintOptions implements Serializable { public static final String DEFAULT_SUPPORT_SHARD = "true"; + private static final String UNKNOWN = "UNKNOWN"; + public static final String BULK_REQUEST_RATE_LIMIT_PER_NODE = "bulkRequestRateLimitPerNode"; public static final String DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE = "0"; + public static final String DEFAULT_EXTERNAL_SCHEDULER_INTERVAL = "5 minutes"; public FlintOptions(Map options) { this.options = options; @@ -185,9 +188,9 @@ public String getDataSourceName() { * @return the AWS accountId */ public String getAWSAccountId() { - String clusterName = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", ""); + String clusterName = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN); String[] parts = clusterName.split(":"); - return parts.length == 2 ? parts[0] : ""; + return parts.length == 2 ? parts[0] : UNKNOWN; } public String getSystemIndexName() { diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java index 3b8940536..b54269ce0 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/MetricsUtilTest.java @@ -5,6 +5,8 @@ import com.codahale.metrics.Timer; import org.apache.spark.SparkEnv; import org.apache.spark.metrics.source.FlintMetricSource; +import org.apache.spark.metrics.source.FlintIndexMetricSource; +import org.apache.spark.metrics.source.Source; import org.junit.Test; import org.junit.jupiter.api.Assertions; import org.mockito.MockedStatic; @@ -26,55 +28,73 @@ public class MetricsUtilTest { @Test public void testIncrementDecrementCounter() { + testIncrementDecrementCounterHelper(false); + } + + @Test + public void testIncrementDecrementCounterForIndexMetrics() { + testIncrementDecrementCounterHelper(true); + } + + private void testIncrementDecrementCounterHelper(boolean isIndexMetric) { try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { // Mock SparkEnv SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); - // Mock FlintMetricSource - FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); - when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) - .thenReturn(flintMetricSource); + // Mock appropriate MetricSource + String sourceName = isIndexMetric ? FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME() : FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = isIndexMetric ? Mockito.spy(new FlintIndexMetricSource()) : Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); // Test the methods String testMetric = "testPrefix.2xx.count"; - MetricsUtil.incrementCounter(testMetric); - MetricsUtil.incrementCounter(testMetric); - MetricsUtil.decrementCounter(testMetric); + MetricsUtil.incrementCounter(testMetric, isIndexMetric); + MetricsUtil.incrementCounter(testMetric, isIndexMetric); + MetricsUtil.decrementCounter(testMetric, isIndexMetric); // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(3)).metricRegistry(); - Counter counter = flintMetricSource.metricRegistry().getCounters().get(testMetric); + verify(metricSource, times(3)).metricRegistry(); + Counter counter = metricSource.metricRegistry().getCounters().get(testMetric); Assertions.assertNotNull(counter); - Assertions.assertEquals(counter.getCount(), 1); + Assertions.assertEquals(1, counter.getCount()); } } @Test public void testStartStopTimer() { + testStartStopTimerHelper(false); + } + + @Test + public void testStartStopTimerForIndexMetrics() { + testStartStopTimerHelper(true); + } + + private void testStartStopTimerHelper(boolean isIndexMetric) { try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { // Mock SparkEnv SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); - // Mock FlintMetricSource - FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); - when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) - .thenReturn(flintMetricSource); + // Mock appropriate MetricSource + String sourceName = isIndexMetric ? FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME() : FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = isIndexMetric ? Mockito.spy(new FlintIndexMetricSource()) : Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); // Test the methods String testMetric = "testPrefix.processingTime"; - Timer.Context context = MetricsUtil.getTimerContext(testMetric); + Timer.Context context = MetricsUtil.getTimerContext(testMetric, isIndexMetric); TimeUnit.MILLISECONDS.sleep(500); MetricsUtil.stopTimer(context); // Verify interactions verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(1)).metricRegistry(); - Timer timer = flintMetricSource.metricRegistry().getTimers().get(testMetric); + verify(metricSource, times(1)).metricRegistry(); + Timer timer = metricSource.metricRegistry().getTimers().get(testMetric); Assertions.assertNotNull(timer); - Assertions.assertEquals(timer.getCount(), 1L); + Assertions.assertEquals(1L, timer.getCount()); assertEquals(1.9, timer.getMeanRate(), 0.1); } catch (InterruptedException e) { throw new RuntimeException(e); @@ -82,33 +102,71 @@ public void testStartStopTimer() { } @Test - public void testRegisterGaugeWhenMetricRegistryIsAvailable() { + public void testRegisterGauge() { + testRegisterGaugeHelper(false); + } + + @Test + public void testRegisterGaugeForIndexMetrics() { + testRegisterGaugeHelper(true); + } + + private void testRegisterGaugeHelper(boolean isIndexMetric) { try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { // Mock SparkEnv SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); - // Mock FlintMetricSource - FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); - when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) - .thenReturn(flintMetricSource); + // Mock appropriate MetricSource + String sourceName = isIndexMetric ? FlintMetricSource.FLINT_INDEX_METRIC_SOURCE_NAME() : FlintMetricSource.FLINT_METRIC_SOURCE_NAME(); + Source metricSource = isIndexMetric ? Mockito.spy(new FlintIndexMetricSource()) : Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(sourceName).head()).thenReturn(metricSource); // Setup gauge AtomicInteger testValue = new AtomicInteger(1); String gaugeName = "test.gauge"; - MetricsUtil.registerGauge(gaugeName, testValue); + MetricsUtil.registerGauge(gaugeName, testValue, isIndexMetric); verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); - verify(flintMetricSource, times(1)).metricRegistry(); + verify(metricSource, times(1)).metricRegistry(); - Gauge gauge = flintMetricSource.metricRegistry().getGauges().get(gaugeName); + Gauge gauge = metricSource.metricRegistry().getGauges().get(gaugeName); Assertions.assertNotNull(gauge); - Assertions.assertEquals(gauge.getValue(), 1); + Assertions.assertEquals(1, gauge.getValue()); testValue.incrementAndGet(); testValue.incrementAndGet(); testValue.decrementAndGet(); - Assertions.assertEquals(gauge.getValue(), 2); + Assertions.assertEquals(2, gauge.getValue()); + } + } + + @Test + public void testDefaultBehavior() { + try (MockedStatic sparkEnvMock = mockStatic(SparkEnv.class)) { + // Mock SparkEnv + SparkEnv sparkEnv = mock(SparkEnv.class, RETURNS_DEEP_STUBS); + sparkEnvMock.when(SparkEnv::get).thenReturn(sparkEnv); + + // Mock FlintMetricSource + FlintMetricSource flintMetricSource = Mockito.spy(new FlintMetricSource()); + when(sparkEnv.metricsSystem().getSourcesByName(FlintMetricSource.FLINT_METRIC_SOURCE_NAME()).head()) + .thenReturn(flintMetricSource); + + // Test default behavior (non-index metrics) + String testCountMetric = "testDefault.count"; + String testTimerMetric = "testDefault.time"; + String testGaugeMetric = "testDefault.gauge"; + MetricsUtil.incrementCounter(testCountMetric); + MetricsUtil.getTimerContext(testTimerMetric); + MetricsUtil.registerGauge(testGaugeMetric, new AtomicInteger(0), false); + + // Verify interactions + verify(sparkEnv.metricsSystem(), times(0)).registerSource(any()); + verify(flintMetricSource, times(3)).metricRegistry(); + Assertions.assertNotNull(flintMetricSource.metricRegistry().getCounters().get(testCountMetric)); + Assertions.assertNotNull(flintMetricSource.metricRegistry().getTimers().get(testTimerMetric)); + Assertions.assertNotNull(flintMetricSource.metricRegistry().getGauges().get(testGaugeMetric)); } } } \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 43dc43ad0..68721d235 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -174,7 +174,7 @@ object FlintSparkConf { val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD = FlintConfig("spark.flint.job.externalScheduler.interval") .doc("Interval threshold in minutes for external scheduler to trigger index refresh") - .createWithDefault("5 minutes") + .createWithDefault(FlintOptions.DEFAULT_EXTERNAL_SCHEDULER_INTERVAL) val CHECKPOINT_LOCATION_ROOT_DIR = FlintConfig("spark.flint.index.checkpointLocation.rootDir") .doc("Root directory of a user specified checkpoint location for index refresh") @@ -294,8 +294,10 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable def isExternalSchedulerEnabled: Boolean = EXTERNAL_SCHEDULER_ENABLED.readFrom(reader).toBoolean - def externalSchedulerIntervalThreshold(): String = - EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD.readFrom(reader) + def externalSchedulerIntervalThreshold(): String = { + val value = EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD.readFrom(reader) + if (value.trim.isEmpty) FlintOptions.DEFAULT_EXTERNAL_SCHEDULER_INTERVAL else value + } def checkpointLocationRootDir: Option[String] = CHECKPOINT_LOCATION_ROOT_DIR.readFrom(reader) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 72039bddf..779b7e013 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -23,7 +23,7 @@ import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode._ -import org.opensearch.flint.spark.scheduler.{AsyncQuerySchedulerBuilder, FlintSparkJobSchedulingService} +import org.opensearch.flint.spark.scheduler.{AsyncQuerySchedulerBuilder, FlintSparkJobExternalSchedulingService, FlintSparkJobInternalSchedulingService, FlintSparkJobSchedulingService} import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer @@ -225,17 +225,22 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w */ def updateIndex(index: FlintSparkIndex): Option[String] = { val indexName = index.name() - validateUpdateAllowed( - describeIndex(indexName) - .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) - .options, - index.options) + val originalOptions = describeIndex(indexName) + .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) + .options + validateUpdateAllowed(originalOptions, index.options) + val isSchedulerModeChanged = + index.options.isExternalSchedulerEnabled() != originalOptions.isExternalSchedulerEnabled() withTransaction[Option[String]](indexName, "Update Flint index") { tx => - // Relies on validation to forbid auto-to-auto and manual-to-manual updates - index.options.autoRefresh() match { - case true => updateIndexManualToAuto(index, tx) - case false => updateIndexAutoToManual(index, tx) + // Relies on validation to prevent: + // 1. auto-to-auto updates besides scheduler_mode + // 2. any manual-to-manual updates + // 3. both refresh_mode and scheduler_mode updated + (index.options.autoRefresh(), isSchedulerModeChanged) match { + case (true, true) => updateSchedulerMode(index, tx) + case (true, false) => updateIndexManualToAuto(index, tx) + case (false, false) => updateIndexAutoToManual(index, tx) } } } @@ -325,19 +330,29 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w def recoverIndex(indexName: String): Boolean = withTransaction[Boolean](indexName, "Recover Flint index") { tx => val index = describeIndex(indexName) + if (index.exists(_.options.autoRefresh())) { + val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get + FlintSparkIndexRefresh + .create(updatedIndex.name(), updatedIndex) + .validate(spark) + val jobSchedulingService = FlintSparkJobSchedulingService.create( + updatedIndex, + spark, + flintAsyncQueryScheduler, + flintSparkConf, + flintIndexMonitor) tx .initialLog(latest => Set(ACTIVE, REFRESHING, FAILED).contains(latest.state)) .transientLog(latest => latest.copy(state = RECOVERING, createTime = System.currentTimeMillis())) .finalLog(latest => { - flintIndexMonitor.startMonitor(indexName) - latest.copy(state = REFRESHING) + latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUpdate) }) .commit(_ => { - FlintSparkIndexRefresh - .create(indexName, index.get) - .start(spark, flintSparkConf) + flintIndexMetadataService.updateIndexMetadata(indexName, updatedIndex.metadata()) + logInfo("Update index options complete") + jobSchedulingService.handleJob(updatedIndex, AsyncQuerySchedulerAction.UPDATE) true }) } else { @@ -430,37 +445,78 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w private def validateUpdateAllowed( originalOptions: FlintSparkIndexOptions, updatedOptions: FlintSparkIndexOptions): Unit = { - // auto_refresh must change - if (updatedOptions.autoRefresh() == originalOptions.autoRefresh()) { - throw new IllegalArgumentException("auto_refresh option must be updated") - } + val isAutoRefreshChanged = updatedOptions.autoRefresh() != originalOptions.autoRefresh() - val refreshMode = (updatedOptions.autoRefresh(), updatedOptions.incrementalRefresh()) match { - case (true, false) => AUTO - case (false, false) => FULL - case (false, true) => INCREMENTAL + val changedOptions = updatedOptions.options.filterNot { case (k, v) => + originalOptions.options.get(k).contains(v) + }.keySet + + if (changedOptions.isEmpty) { + throw new IllegalArgumentException("No index option updated") } - // validate allowed options depending on refresh mode - val allowedOptionNames = refreshMode match { - case FULL => Set(AUTO_REFRESH, INCREMENTAL_REFRESH) - case AUTO | INCREMENTAL => - Set( + // Validate based on auto_refresh state and changes + (isAutoRefreshChanged, updatedOptions.autoRefresh()) match { + case (true, true) => + // Changing from manual to auto refresh + if (updatedOptions.incrementalRefresh()) { + throw new IllegalArgumentException( + "Altering index to auto refresh while incremental refresh remains true") + } + + val allowedOptions = Set( AUTO_REFRESH, INCREMENTAL_REFRESH, SCHEDULER_MODE, REFRESH_INTERVAL, CHECKPOINT_LOCATION, WATERMARK_DELAY) + validateChangedOptions(changedOptions, allowedOptions, s"Altering index to auto refresh") + case (true, false) => + val allowedOptions = if (updatedOptions.incrementalRefresh()) { + // Changing from auto refresh to incremental refresh + Set( + AUTO_REFRESH, + INCREMENTAL_REFRESH, + REFRESH_INTERVAL, + CHECKPOINT_LOCATION, + WATERMARK_DELAY) + } else { + // Changing from auto refresh to full refresh + Set(AUTO_REFRESH) + } + validateChangedOptions( + changedOptions, + allowedOptions, + "Altering index to full/incremental refresh") + + case (false, true) => + // original refresh_mode is auto, only allow changing scheduler_mode + validateChangedOptions( + changedOptions, + Set(SCHEDULER_MODE), + "Altering index when auto_refresh remains true") + + case (false, false) => + // original refresh_mode is full/incremental, not allowed to change any options + if (changedOptions.nonEmpty) { + throw new IllegalArgumentException( + "No options can be updated when auto_refresh remains false") + } } + } - // Get the changed option names - val updateOptionNames = updatedOptions.options.filterNot { case (k, v) => - originalOptions.options.get(k).contains(v) - }.keys - if (!updateOptionNames.forall(allowedOptionNames.map(_.toString).contains)) { + private def validateChangedOptions( + changedOptions: Set[String], + allowedOptions: Set[OptionName], + context: String): Unit = { + + val allowedOptionStrings = allowedOptions.map(_.toString) + + if (!changedOptions.subsetOf(allowedOptionStrings)) { + val invalidOptions = changedOptions -- allowedOptionStrings throw new IllegalArgumentException( - s"Altering index to ${refreshMode} refresh only allows options: ${allowedOptionNames}") + s"$context only allows changing: $allowedOptions. Invalid options: $invalidOptions") } } @@ -477,9 +533,12 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w flintIndexMonitor) tx .initialLog(latest => - latest.state == REFRESHING && latest.entryVersion == indexLogEntry.entryVersion) + // Index in external scheduler mode should be in active or refreshing state + Set(jobSchedulingService.stateTransitions.initialStateForUnschedule).contains( + latest.state) && latest.entryVersion == indexLogEntry.entryVersion) .transientLog(latest => latest.copy(state = UPDATING)) - .finalLog(latest => latest.copy(state = ACTIVE)) + .finalLog(latest => + latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUnschedule)) .commit(_ => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) logInfo("Update index options complete") @@ -501,13 +560,11 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w flintIndexMonitor) tx .initialLog(latest => - latest.state == ACTIVE && latest.entryVersion == indexLogEntry.entryVersion) + latest.state == jobSchedulingService.stateTransitions.initialStateForUpdate && latest.entryVersion == indexLogEntry.entryVersion) .transientLog(latest => latest.copy(state = UPDATING, createTime = System.currentTimeMillis())) .finalLog(latest => { - logInfo("Scheduling index state monitor") - flintIndexMonitor.startMonitor(indexName) - latest.copy(state = REFRESHING) + latest.copy(state = jobSchedulingService.stateTransitions.finalStateForUpdate) }) .commit(_ => { flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) @@ -515,4 +572,36 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w jobSchedulingService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) }) } + + private def updateSchedulerMode( + index: FlintSparkIndex, + tx: OptimisticTransaction[Option[String]]): Option[String] = { + val indexName = index.name + val indexLogEntry = index.latestLogEntry.get + val internalSchedulingService = + new FlintSparkJobInternalSchedulingService(spark, flintIndexMonitor) + val externalSchedulingService = + new FlintSparkJobExternalSchedulingService(flintAsyncQueryScheduler, flintSparkConf) + + val isExternal = index.options.isExternalSchedulerEnabled() + val (initialState, finalState, oldService, newService) = + if (isExternal) { + (REFRESHING, ACTIVE, internalSchedulingService, externalSchedulingService) + } else { + (ACTIVE, REFRESHING, externalSchedulingService, internalSchedulingService) + } + + tx + .initialLog(latest => + latest.state == initialState && latest.entryVersion == indexLogEntry.entryVersion) + .transientLog(latest => latest.copy(state = UPDATING)) + .finalLog(latest => latest.copy(state = finalState)) + .commit(_ => { + flintIndexMetadataService.updateIndexMetadata(indexName, index.metadata) + logInfo("Update index options complete") + oldService.handleJob(index, AsyncQuerySchedulerAction.UNSCHEDULE) + logInfo(s"Unscheduled ${if (isExternal) "internal" else "external"} jobs") + newService.handleJob(index, AsyncQuerySchedulerAction.UPDATE) + }) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index afd53724e..0391741cf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -5,20 +5,14 @@ package org.opensearch.flint.spark -import java.util.Collections - import scala.collection.JavaConverters.mapAsJavaMapConverter -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{CHECKPOINT_LOCATION, REFRESH_INTERVAL, SCHEDULER_MODE} -import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.FlintSparkIndexOptions.{empty, updateOptionsWithDefaults} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode -import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser import org.apache.spark.sql.catalog.Column import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.flint.{findField, loadTable, parseTableName, qualifyTableName} -import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{StructField, StructType} /** @@ -156,60 +150,4 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { isPartition = false, // useless for now so just set to false isBucket = false) } - - /** - * Updates the options with a default values for Create and Alter index. - * - * @param indexName - * The index name string - * @param options - * The original FlintSparkIndexOptions - * @return - * Updated FlintSparkIndexOptions - */ - private def updateOptionsWithDefaults( - indexName: String, - options: FlintSparkIndexOptions): FlintSparkIndexOptions = { - val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String]) - - val updatedOptions = - new scala.collection.mutable.HashMap[String, String]() ++= options.options - - // Add checkpoint location if not present - options.checkpointLocation(indexName, flintSparkConf).foreach { location => - updatedOptions += (CHECKPOINT_LOCATION.toString -> location) - } - - // Update scheduler mode and refresh interval only if auto refresh is enabled - if (!options.autoRefresh()) { - return FlintSparkIndexOptions(updatedOptions.toMap) - } - - val externalSchedulerEnabled = flintSparkConf.isExternalSchedulerEnabled - val thresholdInterval = - IntervalSchedulerParser.parse(flintSparkConf.externalSchedulerIntervalThreshold()) - val currentInterval = options.refreshInterval().map(IntervalSchedulerParser.parse) - - ( - externalSchedulerEnabled, - currentInterval, - updatedOptions.get(SCHEDULER_MODE.toString)) match { - case (true, Some(interval), _) if interval.getInterval >= thresholdInterval.getInterval => - updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.EXTERNAL.toString) - case (true, None, Some("external")) => - updatedOptions += (REFRESH_INTERVAL.toString -> flintSparkConf - .externalSchedulerIntervalThreshold()) - case (true, None, None) => - updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.EXTERNAL.toString) - updatedOptions += (REFRESH_INTERVAL.toString -> flintSparkConf - .externalSchedulerIntervalThreshold()) - case (false, _, Some("external")) => - throw new IllegalArgumentException( - "External scheduler mode spark conf is not enabled but refresh interval is set to external scheduler mode") - case _ => - updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.INTERNAL.toString) - } - - FlintSparkIndexOptions(updatedOptions.toMap) - } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 6c34e00e1..78636d992 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark import java.util.Collections +import scala.collection.JavaConverters.mapAsJavaMapConverter import scala.collection.JavaConverters.mapAsScalaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata @@ -49,6 +50,26 @@ object FlintSparkIndexFactory extends Logging { } } + /** + * Creates Flint index with default options. + * + * @param index + * Flint index + * @param metadata + * Flint metadata + * @return + * Flint index with default options + */ + def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = { + val originalOptions = index.options + val updatedOptions = + FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions) + val updatedMetadata = index + .metadata() + .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) + this.create(updatedMetadata) + } + private def doCreate(metadata: FlintMetadata): FlintSparkIndex = { val indexOptions = FlintSparkIndexOptions( metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index e73e07d79..4bfc50c55 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -5,14 +5,16 @@ package org.opensearch.flint.spark -import java.util.UUID +import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization +import org.opensearch.flint.core.logging.CustomLogging.logInfo import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode +import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser import org.apache.spark.sql.flint.config.FlintSparkConf @@ -201,4 +203,64 @@ object FlintSparkIndexOptions { require(invalidOptions.isEmpty, s"option name ${invalidOptions.mkString(",")} is invalid") } + + /** + * Updates the options with default values. + * + * @param indexName + * The index name string + * @param options + * The original FlintSparkIndexOptions + * @return + * Updated FlintSparkIndexOptions + */ + def updateOptionsWithDefaults( + indexName: String, + options: FlintSparkIndexOptions): FlintSparkIndexOptions = { + val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String]) + + val updatedOptions = + new scala.collection.mutable.HashMap[String, String]() ++= options.options + + // Add checkpoint location if not present + options.checkpointLocation(indexName, flintSparkConf).foreach { location => + updatedOptions += (CHECKPOINT_LOCATION.toString -> location) + } + + // Update scheduler mode and refresh interval only if auto refresh is enabled + if (!options.autoRefresh()) { + return FlintSparkIndexOptions(updatedOptions.toMap) + } + + val externalSchedulerEnabled = flintSparkConf.isExternalSchedulerEnabled + val thresholdInterval = + IntervalSchedulerParser.parse(flintSparkConf.externalSchedulerIntervalThreshold()) + val currentInterval = options.refreshInterval().map(IntervalSchedulerParser.parse) + ( + externalSchedulerEnabled, + currentInterval.isDefined, + updatedOptions.get(SCHEDULER_MODE.toString)) match { + case (true, true, None | Some("external")) + if currentInterval.get.getInterval >= thresholdInterval.getInterval => + updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.EXTERNAL.toString) + case (true, true, Some("external")) + if currentInterval.get.getInterval < thresholdInterval.getInterval => + throw new IllegalArgumentException( + s"Input refresh_interval is ${options.refreshInterval().get}, required above the interval threshold of external scheduler: ${flintSparkConf + .externalSchedulerIntervalThreshold()}") + case (true, false, Some("external")) => + updatedOptions += (REFRESH_INTERVAL.toString -> flintSparkConf + .externalSchedulerIntervalThreshold()) + case (true, false, None) => + updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.EXTERNAL.toString) + updatedOptions += (REFRESH_INTERVAL.toString -> flintSparkConf + .externalSchedulerIntervalThreshold()) + case (false, _, Some("external")) => + throw new IllegalArgumentException( + "spark.flint.job.externalScheduler.enabled is false but refresh interval is set to external scheduler mode") + case _ => + updatedOptions += (SCHEDULER_MODE.toString -> SchedulerMode.INTERNAL.toString) + } + FlintSparkIndexOptions(updatedOptions.toMap) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala index 98f0d838f..f675df75a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.refresh import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkValidationHelper} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode} +import org.opensearch.flint.spark.refresh.util.RefreshMetricsAspect import org.apache.spark.sql.SparkSession import org.apache.spark.sql.flint.config.FlintSparkConf @@ -19,9 +20,10 @@ import org.apache.spark.sql.flint.config.FlintSparkConf * @param index * Flint index */ -class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex) +class IncrementalIndexRefresh(val indexName: String, index: FlintSparkIndex) extends FlintSparkIndexRefresh - with FlintSparkValidationHelper { + with FlintSparkValidationHelper + with RefreshMetricsAspect { override def refreshMode: RefreshMode = INCREMENTAL @@ -43,15 +45,21 @@ class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex) override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = { logInfo(s"Start refreshing index $indexName in incremental mode") - // Reuse auto refresh which uses AvailableNow trigger and will stop once complete - val jobId = - new AutoIndexRefresh(indexName, index) - .start(spark, flintSparkConf) + val clientId = flintSparkConf.flintOptions().getAWSAccountId() + val dataSource = flintSparkConf.flintOptions().getDataSourceName() - // Blocks the calling thread until the streaming query finishes - spark.streams - .get(jobId.get) - .awaitTermination() - None + withMetrics(clientId, dataSource, indexName, "incrementalRefresh") { + // Reuse auto refresh which uses AvailableNow trigger and will stop once complete + val jobId = + new AutoIndexRefresh(indexName, index) + .start(spark, flintSparkConf) + + // Blocks the calling thread until the streaming query finishes + spark.streams + .get(jobId.get) + .awaitTermination() + + None + } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsAspect.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsAspect.scala new file mode 100644 index 000000000..c5832e01c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsAspect.scala @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.refresh.util + +import org.opensearch.flint.core.metrics.MetricsUtil + +/** + * A trait that provides aspect-oriented metrics functionality for refresh operations. + * + * This trait can be mixed into classes that need to track metrics for various operations, + * particularly those related to index refreshing. It provides a method to wrap operations with + * metric collection, including timing and success/failure counting. + */ +trait RefreshMetricsAspect { + + /** + * Wraps an operation with metric collection. + * + * @param clientId + * The ID of the client performing the operation + * @param dataSource + * The name of the data source being used + * @param indexName + * The name of the index being operated on + * @param metricPrefix + * The prefix for the metrics (e.g., "incrementalRefresh") + * @param block + * The operation to be performed and measured + * @return + * The result of the operation + * + * This method will: + * 1. Start a timer for the operation 2. Execute the provided operation 3. Increment a success + * or failure counter based on the outcome 4. Stop the timer 5. Return the result of the + * operation or throw any exception that occurred + */ + def withMetrics(clientId: String, dataSource: String, indexName: String, metricPrefix: String)( + block: => Option[String]): Option[String] = { + val refreshMetricsHelper = new RefreshMetricsHelper(clientId, dataSource, indexName) + + val processingTimeMetric = s"$metricPrefix.processingTime" + val successMetric = s"$metricPrefix.success.count" + val failedMetric = s"$metricPrefix.failed.count" + + val timerContext = refreshMetricsHelper.getTimerContext(processingTimeMetric) + + try { + val result = block + refreshMetricsHelper.incrementCounter(successMetric) + result + } catch { + case e: Exception => + refreshMetricsHelper.incrementCounter(failedMetric) + throw e + } finally { + MetricsUtil.stopTimer(timerContext) + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsHelper.scala new file mode 100644 index 000000000..4b91b0be2 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/util/RefreshMetricsHelper.scala @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.refresh.util + +import com.codahale.metrics.Timer +import org.opensearch.flint.core.metrics.MetricsUtil + +/** + * Helper class for constructing dimensioned metric names used in refresh operations. + */ +class RefreshMetricsHelper(clientId: String, dataSource: String, indexName: String) { + private val isIndexMetric = true + + /** + * Increments a counter metric with the specified dimensioned name. + * + * @param metricName + * The name of the metric to increment + */ + def incrementCounter(metricName: String): Unit = { + MetricsUtil.incrementCounter( + RefreshMetricsHelper.constructDimensionedMetricName( + metricName, + clientId, + dataSource, + indexName), + isIndexMetric) + } + + /** + * Gets a timer context for the specified metric name. + * + * @param metricName + * The name of the metric + * @return + * A Timer.Context object + */ + def getTimerContext(metricName: String): Timer.Context = { + MetricsUtil.getTimerContext( + RefreshMetricsHelper.constructDimensionedMetricName( + metricName, + clientId, + dataSource, + indexName), + isIndexMetric) + } +} + +object RefreshMetricsHelper { + + /** + * Constructs a dimensioned metric name for external scheduler request count. + * + * @param metricName + * The name of the metric + * @param clientId + * The ID of the client making the request + * @param dataSource + * The data source being used + * @param indexName + * The name of the index being refreshed + * @return + * A formatted string representing the dimensioned metric name + */ + private def constructDimensionedMetricName( + metricName: String, + clientId: String, + dataSource: String, + indexName: String): String = { + s"${metricName}[clientId##${clientId},dataSource##${dataSource},indexName##${indexName}]" + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java index 9865081c8..3620608b0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java @@ -30,13 +30,13 @@ public enum AsyncQuerySchedulerAction { public static AsyncQueryScheduler build(FlintOptions options) { String className = options.getCustomAsyncQuerySchedulerClass(); - logger.info("Attempting to instantiate AsyncQueryScheduler with class name: {}", className); if (className.isEmpty()) { return new OpenSearchAsyncQueryScheduler(options); } // Attempts to instantiate AsyncQueryScheduler using reflection + logger.info("Attempting to instantiate AsyncQueryScheduler with class name: {}", className); try { Class asyncQuerySchedulerClass = Class.forName(className); Constructor constructor = asyncQuerySchedulerClass.getConstructor(); diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala index 87226d99f..d043746c0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobExternalSchedulingService.scala @@ -7,9 +7,12 @@ package org.opensearch.flint.spark.scheduler import java.time.Instant +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.common.scheduler.AsyncQueryScheduler import org.opensearch.flint.common.scheduler.model.{AsyncQuerySchedulerRequest, LangType} +import org.opensearch.flint.core.storage.OpenSearchClientUtils import org.opensearch.flint.spark.FlintSparkIndex +import org.opensearch.flint.spark.refresh.util.RefreshMetricsAspect import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction import org.opensearch.flint.spark.scheduler.util.RefreshQueryGenerator @@ -32,45 +35,56 @@ class FlintSparkJobExternalSchedulingService( flintAsyncQueryScheduler: AsyncQueryScheduler, flintSparkConf: FlintSparkConf) extends FlintSparkJobSchedulingService + with RefreshMetricsAspect with Logging { + override val stateTransitions: StateTransitions = StateTransitions( + initialStateForUpdate = IndexState.ACTIVE, + finalStateForUpdate = IndexState.ACTIVE, + initialStateForUnschedule = IndexState.ACTIVE, + finalStateForUnschedule = IndexState.ACTIVE) + override def handleJob( index: FlintSparkIndex, action: AsyncQuerySchedulerAction): Option[String] = { val dataSource = flintSparkConf.flintOptions().getDataSourceName() val clientId = flintSparkConf.flintOptions().getAWSAccountId() - val indexName = index.name() + // This is to make sure jobId is consistent with the index name + val indexName = OpenSearchClientUtils.sanitizeIndexName(index.name()) logInfo(s"handleAsyncQueryScheduler invoked: $action") - val baseRequest = AsyncQuerySchedulerRequest - .builder() - .accountId(clientId) - .jobId(indexName) - .dataSource(dataSource) + withMetrics(clientId, dataSource, indexName, "externalScheduler") { + val baseRequest = AsyncQuerySchedulerRequest + .builder() + .accountId(clientId) + .jobId(indexName) + .dataSource(dataSource) - val request = action match { - case AsyncQuerySchedulerAction.SCHEDULE | AsyncQuerySchedulerAction.UPDATE => - val currentTime = Instant.now() - baseRequest - .scheduledQuery(RefreshQueryGenerator.generateRefreshQuery(index)) - .queryLang(LangType.SQL) - .interval(index.options.refreshInterval().get) - .enabled(true) - .enabledTime(currentTime) - .lastUpdateTime(currentTime) - .build() - case _ => baseRequest.build() - } + val request = action match { + case AsyncQuerySchedulerAction.SCHEDULE | AsyncQuerySchedulerAction.UPDATE => + val currentTime = Instant.now() + baseRequest + .scheduledQuery(RefreshQueryGenerator.generateRefreshQuery(index)) + .queryLang(LangType.SQL) + .interval(index.options.refreshInterval().get) + .enabled(true) + .enabledTime(currentTime) + .lastUpdateTime(currentTime) + .build() + case _ => baseRequest.build() + } - action match { - case AsyncQuerySchedulerAction.SCHEDULE => flintAsyncQueryScheduler.scheduleJob(request) - case AsyncQuerySchedulerAction.UPDATE => flintAsyncQueryScheduler.updateJob(request) - case AsyncQuerySchedulerAction.UNSCHEDULE => flintAsyncQueryScheduler.unscheduleJob(request) - case AsyncQuerySchedulerAction.REMOVE => flintAsyncQueryScheduler.removeJob(request) - case _ => throw new IllegalArgumentException(s"Unsupported action: $action") - } + action match { + case AsyncQuerySchedulerAction.SCHEDULE => flintAsyncQueryScheduler.scheduleJob(request) + case AsyncQuerySchedulerAction.UPDATE => flintAsyncQueryScheduler.updateJob(request) + case AsyncQuerySchedulerAction.UNSCHEDULE => + flintAsyncQueryScheduler.unscheduleJob(request) + case AsyncQuerySchedulerAction.REMOVE => flintAsyncQueryScheduler.removeJob(request) + case _ => throw new IllegalArgumentException(s"Unsupported action: $action") + } - None // Return None for all cases + None // Return None for all cases + } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala index ab22941bb..d22eff2c9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobInternalSchedulingService.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.scheduler import scala.collection.JavaConverters.mapAsJavaMapConverter +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexMonitor} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction @@ -33,6 +34,12 @@ class FlintSparkJobInternalSchedulingService( extends FlintSparkJobSchedulingService with Logging { + override val stateTransitions: StateTransitions = StateTransitions( + initialStateForUpdate = IndexState.ACTIVE, + finalStateForUpdate = IndexState.REFRESHING, + initialStateForUnschedule = IndexState.REFRESHING, + finalStateForUnschedule = IndexState.ACTIVE) + /** * Handles job-related actions for a given Flint Spark index. * @@ -52,7 +59,7 @@ class FlintSparkJobInternalSchedulingService( action match { case AsyncQuerySchedulerAction.SCHEDULE => None // No-op case AsyncQuerySchedulerAction.UPDATE => - logInfo("Updating index state monitor") + logInfo("Scheduling index state monitor") flintIndexMonitor.startMonitor(indexName) startRefreshingJob(index) case AsyncQuerySchedulerAction.UNSCHEDULE => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala index 40ef9fcbe..6e25d8a8c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/FlintSparkJobSchedulingService.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.spark.scheduler +import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.IndexState import org.opensearch.flint.common.scheduler.AsyncQueryScheduler import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexMonitor} import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder.AsyncQuerySchedulerAction @@ -17,6 +18,14 @@ import org.apache.spark.sql.flint.config.FlintSparkConf */ trait FlintSparkJobSchedulingService { + case class StateTransitions( + initialStateForUpdate: IndexState, + finalStateForUpdate: IndexState, + initialStateForUnschedule: IndexState, + finalStateForUnschedule: IndexState) + + val stateTransitions: StateTransitions + /** * Handles a job action for a given Flint Spark index. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGenerator.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGenerator.scala index 510e0b9d5..86363b252 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGenerator.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGenerator.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.scheduler.util import org.opensearch.flint.spark.FlintSparkIndex +import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex @@ -25,11 +26,11 @@ object RefreshQueryGenerator { def generateRefreshQuery(index: FlintSparkIndex): String = { index match { case skippingIndex: FlintSparkSkippingIndex => - s"REFRESH SKIPPING INDEX ON ${skippingIndex.tableName}" + s"REFRESH SKIPPING INDEX ON ${quotedTableName(skippingIndex.tableName)}" case coveringIndex: FlintSparkCoveringIndex => - s"REFRESH INDEX ${coveringIndex.indexName} ON ${coveringIndex.tableName}" + s"REFRESH INDEX ${coveringIndex.indexName} ON ${quotedTableName(coveringIndex.tableName)}" case materializedView: FlintSparkMaterializedView => - s"REFRESH MATERIALIZED VIEW ${materializedView.mvName}" + s"REFRESH MATERIALIZED VIEW ${quotedTableName(materializedView.mvName)}" case _ => throw new IllegalArgumentException( s"Unsupported index type: ${index.getClass.getSimpleName}") diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala index ede5379a1..8f3aa9917 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala @@ -10,7 +10,6 @@ import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` import org.antlr.v4.runtime.tree.RuleNode import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.mv.FlintSparkMaterializedView -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText, IndexBelongsTo} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index 1a164a9f2..0cde6ab0f 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -114,6 +114,19 @@ class FlintSparkConfSuite extends FlintSuite { } } + test("externalSchedulerIntervalThreshold should return default value when empty") { + val options = FlintSparkConf(Map("spark.flint.job.externalScheduler.interval" -> "").asJava) + assert(options + .externalSchedulerIntervalThreshold() === FlintOptions.DEFAULT_EXTERNAL_SCHEDULER_INTERVAL) + } + + test("externalSchedulerIntervalThreshold should return configured value when set") { + val configuredValue = "30" + val options = + FlintSparkConf(Map("spark.flint.job.externalScheduler.interval" -> configuredValue).asJava) + assert(options.externalSchedulerIntervalThreshold() === configuredValue) + } + /** * Delete index `indexNames` after calling `f`. */ diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala index a2ec85df9..063c32074 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -7,13 +7,20 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{CHECKPOINT_LOCATION, REFRESH_INTERVAL, SCHEDULER_MODE} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode +import org.scalatest.Inspectors.forAll +import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatest.matchers.should.Matchers.not.include +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.wordspec.AnyWordSpec import org.apache.spark.FlintSuite import org.apache.spark.sql.flint.config.FlintSparkConf -class FlintSparkIndexBuilderSuite extends FlintSuite { +class FlintSparkIndexBuilderSuite + extends FlintSuite + with Matchers + with TableDrivenPropertyChecks { val indexName: String = "test_index" val testCheckpointLocation = "/test/checkpoints/" @@ -143,71 +150,148 @@ class FlintSparkIndexBuilderSuite extends FlintSuite { } } - test( - "updateOptionsWithDefaults should set internal scheduler mode when auto refresh is false") { - val options = FlintSparkIndexOptions(Map("auto_refresh" -> "false")) - val builder = new FakeFlintSparkIndexBuilder - - val updatedOptions = builder.options(options, indexName).testOptions - updatedOptions.options.get(SCHEDULER_MODE.toString) shouldBe None - } - - test( - "updateOptionsWithDefaults should set internal scheduler mode when external scheduler is disabled") { - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, false) - val options = FlintSparkIndexOptions(Map("auto_refresh" -> "true")) - val builder = new FakeFlintSparkIndexBuilder - - val updatedOptions = builder.options(options, indexName).testOptions - updatedOptions.options(SCHEDULER_MODE.toString) shouldBe SchedulerMode.INTERNAL.toString - } - - test( - "updateOptionsWithDefaults should set external scheduler mode when interval is above threshold") { - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, true) - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD, "5 minutes") - val options = - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "refresh_interval" -> "10 minutes")) - val builder = new FakeFlintSparkIndexBuilder - - val updatedOptions = builder.options(options, indexName).testOptions - updatedOptions.options(SCHEDULER_MODE.toString) shouldBe SchedulerMode.EXTERNAL.toString - } - - test( - "updateOptionsWithDefaults should set external scheduler mode and default interval when no interval is provided") { - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, true) - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD, "5 minutes") - val options = FlintSparkIndexOptions(Map("auto_refresh" -> "true")) - val builder = new FakeFlintSparkIndexBuilder - - val updatedOptions = builder.options(options, indexName).testOptions - updatedOptions.options(SCHEDULER_MODE.toString) shouldBe SchedulerMode.EXTERNAL.toString - updatedOptions.options(REFRESH_INTERVAL.toString) shouldBe "5 minutes" - } - - test("updateOptionsWithDefaults should set external scheduler mode when explicitly specified") { - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, true) - val options = - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "scheduler_mode" -> "external")) - val builder = new FakeFlintSparkIndexBuilder - - val updatedOptions = builder.options(options, indexName).testOptions - updatedOptions.options(SCHEDULER_MODE.toString) shouldBe SchedulerMode.EXTERNAL.toString - } - - test( - "updateOptionsWithDefaults should throw exception when external scheduler is disabled but mode is external") { - setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, false) - val options = - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "scheduler_mode" -> "external")) - val builder = new FakeFlintSparkIndexBuilder - - val exception = intercept[IllegalArgumentException] { - builder.options(options, indexName) + test("updateOptionsWithDefaults scenarios") { + val scenarios = Table( + ( + "testName", + "externalSchedulerEnabled", + "thresholdInterval", + "inputOptions", + "expectedMode", + "expectedInterval", + "expectedException"), + ( + "set internal mode when auto refresh is false", + false, + "5 minutes", + Map("auto_refresh" -> "false"), + None, + None, + None), + ( + "set internal mode when external scheduler is disabled", + false, + "5 minutes", + Map("auto_refresh" -> "true"), + Some(SchedulerMode.INTERNAL.toString), + None, + None), + ( + "set external mode when interval is above threshold", + true, + "5 minutes", + Map("auto_refresh" -> "true", "refresh_interval" -> "10 minutes"), + Some(SchedulerMode.EXTERNAL.toString), + Some("10 minutes"), + None), + ( + "set external mode and default interval when no interval provided", + true, + "5 minutes", + Map("auto_refresh" -> "true"), + Some(SchedulerMode.EXTERNAL.toString), + Some("5 minutes"), + None), + ( + "set external mode when explicitly specified", + true, + "5 minutes", + Map("auto_refresh" -> "true", "scheduler_mode" -> "external"), + Some(SchedulerMode.EXTERNAL.toString), + None, + None), + ( + "throw exception when external scheduler disabled but mode is external", + false, + "5 minutes", + Map("auto_refresh" -> "true", "scheduler_mode" -> "external"), + None, + None, + Some( + "spark.flint.job.externalScheduler.enabled is false but refresh interval is set to external scheduler mode")), + ( + "set external mode when interval above threshold and no mode specified", + true, + "5 minutes", + Map("auto_refresh" -> "true", "refresh_interval" -> "10 minutes"), + Some(SchedulerMode.EXTERNAL.toString), + Some("10 minutes"), + None), + ( + "throw exception when interval below threshold but mode is external", + true, + "5 minutes", + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "1 minute", + "scheduler_mode" -> "external"), + None, + None, + Some("Input refresh_interval is 1 minute, required above the interval threshold")), + ( + "set external mode when interval above threshold and mode specified", + true, + "5 minutes", + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "10 minute", + "scheduler_mode" -> "external"), + Some(SchedulerMode.EXTERNAL.toString), + None, + None), + ( + "set default interval when mode is external but no interval provided", + true, + "5 minutes", + Map("auto_refresh" -> "true", "scheduler_mode" -> "external"), + Some(SchedulerMode.EXTERNAL.toString), + Some("5 minutes"), + None), + ( + "set external mode when external scheduler enabled but no mode or interval specified", + true, + "5 minutes", + Map("auto_refresh" -> "true"), + Some(SchedulerMode.EXTERNAL.toString), + None, + None)) + + forAll(scenarios) { + ( + testName, + externalSchedulerEnabled, + thresholdInterval, + inputOptions, + expectedMode, + expectedInterval, + expectedException) => + withClue(s"Scenario: $testName - ") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, externalSchedulerEnabled) + setFlintSparkConf( + FlintSparkConf.EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD, + thresholdInterval) + + val options = FlintSparkIndexOptions(inputOptions) + val builder = new FakeFlintSparkIndexBuilder + + expectedException match { + case Some(exceptionMessage) => + val exception = intercept[IllegalArgumentException] { + builder.options(options, indexName) + } + exception.getMessage should include(exceptionMessage) + + case None => + val updatedOptions = builder.options(options, indexName).testOptions + expectedMode.foreach { mode => + updatedOptions.options(SCHEDULER_MODE.toString) shouldBe mode + } + expectedInterval.foreach { interval => + updatedOptions.options(REFRESH_INTERVAL.toString) shouldBe interval + } + } + } } - exception.getMessage should include( - "External scheduler mode is not enabled in the configuration") } override def afterEach(): Unit = { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGeneratorSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGeneratorSuite.scala index 1cd83c38d..0d154e407 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGeneratorSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/scheduler/util/RefreshQueryGeneratorSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.scheduler.util; import org.mockito.Mockito._ +import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSparkIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView @@ -16,33 +17,42 @@ import org.apache.spark.SparkFunSuite class RefreshQueryGeneratorTest extends SparkFunSuite with Matchers { + val testTable = "dummy.default.testTable" + val expectedTableName = "dummy.default.`testTable`" + + val mockMetadata = mock(classOf[FlintMetadata]) + test("generateRefreshQuery should return correct query for FlintSparkSkippingIndex") { val mockIndex = mock(classOf[FlintSparkSkippingIndex]) - when(mockIndex.tableName).thenReturn("testTable") + when(mockIndex.metadata()).thenReturn(mockMetadata) + when(mockIndex.tableName).thenReturn(testTable) val result = RefreshQueryGenerator.generateRefreshQuery(mockIndex) - result shouldBe "REFRESH SKIPPING INDEX ON testTable" + result shouldBe s"REFRESH SKIPPING INDEX ON ${expectedTableName}" } test("generateRefreshQuery should return correct query for FlintSparkCoveringIndex") { val mockIndex = mock(classOf[FlintSparkCoveringIndex]) when(mockIndex.indexName).thenReturn("testIndex") - when(mockIndex.tableName).thenReturn("testTable") + when(mockIndex.tableName).thenReturn(testTable) val result = RefreshQueryGenerator.generateRefreshQuery(mockIndex) - result shouldBe "REFRESH INDEX testIndex ON testTable" + result shouldBe s"REFRESH INDEX testIndex ON ${expectedTableName}" } test("generateRefreshQuery should return correct query for FlintSparkMaterializedView") { val mockIndex = mock(classOf[FlintSparkMaterializedView]) - when(mockIndex.mvName).thenReturn("testMV") + when(mockIndex.metadata()).thenReturn(mockMetadata) + when(mockIndex.mvName).thenReturn(testTable) val result = RefreshQueryGenerator.generateRefreshQuery(mockIndex) - result shouldBe "REFRESH MATERIALIZED VIEW testMV" + result shouldBe s"REFRESH MATERIALIZED VIEW ${expectedTableName}" } test("generateRefreshQuery should throw IllegalArgumentException for unsupported index type") { val mockIndex = mock(classOf[FlintSparkIndex]) + when(mockIndex.metadata()).thenReturn(mockMetadata) + when(mockIndex.metadata().source).thenReturn(testTable) val exception = intercept[IllegalArgumentException] { RefreshQueryGenerator.generateRefreshQuery(mockIndex) diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 57277440e..11bc7271c 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.matchers.must.Matchers.{contain, defined} import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf._ import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -209,6 +210,50 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { } } + test("create skipping index with invalid refresh interval") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true, refresh_interval = '2 minutes', scheduler_mode = 'external') + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080t" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "FAILED", s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got ${result.error}") + + // Check for the specific error message + assert( + result.error.contains( + "Input refresh_interval is 2 minutes, required above the interval threshold of external scheduler: 5 minutes"), + s"Expected error message about invalid refresh interval, but got: ${result.error}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + + // Ensure no streaming job was started + assert(spark.streams.active.isEmpty, "No streaming job should have been started") + conf.unsetConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED.key) + } + test("create skipping index with auto refresh and streaming job early exit") { // Custom listener to force streaming job to fail at the beginning val listener = new StreamingQueryListener { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index 54d3ba6dc..378131eb0 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -15,6 +15,7 @@ import org.opensearch.client.RequestOptions import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} +import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.scalatest.matchers.must.Matchers.{contain, defined} @@ -194,7 +195,8 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { val sourceMap = response.getSourceAsMap sourceMap.get("jobId") shouldBe testFlintIndex - sourceMap.get("scheduledQuery") shouldBe s"REFRESH INDEX $testIndex ON $testTable" + sourceMap + .get("scheduledQuery") shouldBe s"REFRESH INDEX $testIndex ON ${quotedTableName(testTable)}" sourceMap.get("enabled") shouldBe true sourceMap.get("queryLang") shouldBe "sql" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index c00e982e0..c2f0f9101 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -16,6 +16,7 @@ import org.opensearch.client.RequestOptions import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} +import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.CHECKPOINT_LOCATION import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler @@ -365,7 +366,8 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { val sourceMap = response.getSourceAsMap sourceMap.get("jobId") shouldBe testFlintIndex - sourceMap.get("scheduledQuery") shouldBe s"REFRESH MATERIALIZED VIEW $testMvName" + sourceMap + .get("scheduledQuery") shouldBe s"REFRESH MATERIALIZED VIEW ${quotedTableName(testMvName)}" sourceMap.get("enabled") shouldBe true sourceMap.get("queryLang") shouldBe "sql" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index f569bf123..9e75078d2 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -47,6 +47,8 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { super.afterEach() deleteTestIndex(testFlintIndex) sql(s"DROP TABLE $testTable") + conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) + conf.unsetConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED.key) } test("create materialized view with auto refresh") { @@ -119,8 +121,6 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { // Drop index with test scheduler sql(s"DROP MATERIALIZED VIEW $testMvName") - conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) - conf.unsetConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED.key) } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index b535173e3..a2a7c9799 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -16,7 +16,7 @@ import org.opensearch.client.RequestOptions import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, ID_COLUMN} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName @@ -338,7 +338,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { val sourceMap = response.getSourceAsMap sourceMap.get("jobId") shouldBe testIndex - sourceMap.get("scheduledQuery") shouldBe s"REFRESH SKIPPING INDEX ON $testTable" + sourceMap + .get("scheduledQuery") shouldBe s"REFRESH SKIPPING INDEX ON ${quotedTableName(testTable)}" sourceMap.get("enabled") shouldBe true sourceMap.get("queryLang") shouldBe "sql" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index 7bbf24567..53889045f 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -15,6 +15,8 @@ import org.opensearch.index.reindex.DeleteByQueryRequest import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.apache.spark.sql.flint.config.FlintSparkConf + class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { /** Test table and index name */ @@ -32,6 +34,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { // Delete all test indices deleteTestIndex(testIndex) sql(s"DROP TABLE $testTable") + conf.unsetConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED.key) } test("update index with index options successfully") { @@ -177,6 +180,121 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { } } + // Test update options validation failure with external scheduler + Seq( + ( + "update index without changing index option", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("auto_refresh" -> "true")), + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("checkpoint_location" -> "s3a://test/")), + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"))), + "No index option updated"), + ( + "update index option when auto_refresh is false", + Seq( + ( + Map.empty[String, String], + Map("auto_refresh" -> "false", "checkpoint_location" -> "s3a://test/")), + ( + Map.empty[String, String], + Map("incremental_refresh" -> "true", "checkpoint_location" -> "s3a://test/")), + (Map.empty[String, String], Map("checkpoint_location" -> "s3a://test/"))), + "No options can be updated when auto_refresh remains false"), + ( + "update other index option besides scheduler_mode when auto_refresh is true", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("watermark_delay" -> "1 Minute"))), + "Altering index when auto_refresh remains true only allows changing: Set(scheduler_mode). Invalid options"), + ( + "convert to full refresh with disallowed options", + Seq( + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("auto_refresh" -> "false", "scheduler_mode" -> "internal")), + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("auto_refresh" -> "false", "refresh_interval" -> "5 Minute")), + ( + Map("auto_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), + Map("auto_refresh" -> "false", "watermark_delay" -> "1 Minute"))), + "Altering index to full/incremental refresh only allows changing"), + ( + "convert to auto refresh with disallowed options", + Seq( + ( + Map.empty[String, String], + Map( + "auto_refresh" -> "true", + "output_mode" -> "complete", + "checkpoint_location" -> "s3a://test/"))), + "Altering index to auto refresh only allows changing: Set(auto_refresh, watermark_delay, scheduler_mode, " + + "refresh_interval, incremental_refresh, checkpoint_location). Invalid options: Set(output_mode)"), + ( + "convert to invalid refresh mode", + Seq( + ( + Map.empty[String, String], + Map( + "auto_refresh" -> "true", + "incremental_refresh" -> "true", + "checkpoint_location" -> "s3a://test/"))), + "Altering index to auto refresh while incremental refresh remains true")) + .foreach { case (testName, testCases, expectedErrorMessage) => + test(s"should fail if $testName and external scheduler enabled") { + setFlintSparkConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED, "true") + testCases.foreach { case (initialOptionsMap, updateOptionsMap) => + logInfo(s"initialOptionsMap: ${initialOptionsMap}") + logInfo(s"updateOptionsMap: ${updateOptionsMap}") + + withTempDir { checkpointDir => + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options( + FlintSparkIndexOptions( + initialOptionsMap + .get("checkpoint_location") + .map(_ => + initialOptionsMap + .updated("checkpoint_location", checkpointDir.getAbsolutePath)) + .getOrElse(initialOptionsMap)), + testIndex) + .create() + flint.refreshIndex(testIndex) + + val index = flint.describeIndex(testIndex).get + val exception = the[IllegalArgumentException] thrownBy { + val updatedIndex = flint + .skippingIndex() + .copyWithUpdate( + index, + FlintSparkIndexOptions( + updateOptionsMap + .get("checkpoint_location") + .map(_ => + updateOptionsMap + .updated("checkpoint_location", checkpointDir.getAbsolutePath)) + .getOrElse(updateOptionsMap))) + flint.updateIndex(updatedIndex) + } + + exception.getMessage should include(expectedErrorMessage) + + deleteTestIndex(testIndex) + } + } + } + } + // Test update options validation success Seq( ( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 4c38e1471..cbc4308b0 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder} @@ -22,12 +22,20 @@ class FlintSparkPPLBasicITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`" + private val t2 = "`spark_catalog`.default.`flint_ppl_test2`" + private val t3 = "spark_catalog.`default`.`flint_ppl_test3`" + private val t4 = "`spark_catalog`.`default`.flint_ppl_test4" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createPartitionedStateCountryTable(t1) + createPartitionedStateCountryTable(t2) + createPartitionedStateCountryTable(t3) + createPartitionedStateCountryTable(t4) } protected override def afterEach(): Unit = { @@ -516,4 +524,77 @@ class FlintSparkPPLBasicITSuite // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | source = $table| head 2 + | """.stripMargin) + assert(frame.collect().length == 2) + } + // test read table which is unable to create + val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`" + val t6 = "spark_catalog.default.`flint_ppl_test6.log`" + Seq(t5, t6).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | source = $table| head 2 + | """.stripMargin)) + assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) + } + val t7 = "spark_catalog.default.flint_ppl_test7.log" + val ex = intercept[IllegalArgumentException](sql(s""" + | source = $t7| head 2 + | """.stripMargin)) + assert(ex.getMessage().contains("Invalid table name")) + } + + test("test describe backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | describe $table + | """.stripMargin) + assert(frame.collect().length > 0) + } + // test read table which is unable to create + val t5 = "`spark_catalog`.default.`flint/ppl/test5.log`" + val t6 = "spark_catalog.default.`flint_ppl_test6.log`" + Seq(t5, t6).foreach { table => + val ex = intercept[AnalysisException](sql(s""" + | describe $table + | """.stripMargin)) + assert(ex.getMessage().contains("TABLE_OR_VIEW_NOT_FOUND")) + } + val t7 = "spark_catalog.default.flint_ppl_test7.log" + val ex = intercept[IllegalArgumentException](sql(s""" + | describe $t7 + | """.stripMargin)) + assert(ex.getMessage().contains("Invalid table name")) + } + + test("test explain backtick table names and name contains '.'") { + Seq(t1, t2, t3, t4).foreach { table => + val frame = sql(s""" + | explain extended | source = $table + | """.stripMargin) + assert(frame.collect().length > 0) + } + // test read table which is unable to create + val table = "`spark_catalog`.default.`flint/ppl/test4.log`" + val frame = sql(s""" + | explain extended | source = $table + | """.stripMargin) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint/ppl/test4.log")) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedStar(None)), relation), + ExplainMode.fromString("extended")) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + val t7 = "spark_catalog.default.flint_ppl_test7.log" + val ex = intercept[IllegalArgumentException](sql(s""" + | explain extended | source = $t7 + | """.stripMargin)) + assert(ex.getMessage().contains("Invalid table name")) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala new file mode 100644 index 000000000..81bdd99df --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala @@ -0,0 +1,373 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, Exists, GreaterThan, InSubquery, ListQuery, Literal, Not, Or, ScalarSubquery, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExistsSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test simple exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test not exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where not exists [ + | source = $innerTable | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Not( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test empty exists subquery") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where uid = 0000 AND id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + var results: Array[Row] = frame.collect() + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + var expectedResults: Array[Row] = Array() + assert(results.sorted.sameElements(expectedResults.sorted)) + + frame = sql(s""" + source = $outerTable + | | where not exists [ + | source = $innerTable | where uid = 0000 AND id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + results = frame.collect() + expectedResults = Array( + Row(1000, "Jake", 100000), + Row(1001, "Hello", 70000), + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1004, "David", 0), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test uncorrelated exists subquery") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where like(name, 'J%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + assert(results.length == 7) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Exists( + Filter( + UnresolvedFunction( + "like", + Seq(UnresolvedAttribute("name"), Literal("J%")), + isDistinct = false), + inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + frame = sql(s""" + | source = $outerTable + | | where not exists [ + | source = $innerTable | where like(name, 'J%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + assert(frame.collect().length == 0) + + frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where like(name, 'X%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + assert(frame.collect().length == 0) + } + + test("uncorrelated exists subquery check the return content of inner table is empty or not") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | ] + | | eval constant = "Bala" + | | fields constant + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala")) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter(Exists(inner), outer) + val evalProject = + Project(Seq(UnresolvedStar(None), Alias(Literal("Bala"), "constant")()), existsSubquery) + val expectedPlan = Project(Seq(UnresolvedAttribute("constant")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where uid = 999 + | ] + | | eval constant = "Bala" + | | fields constant + | """.stripMargin) + frame.show + assert(frame.collect().length == 0) + } + + test("test nested exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | | where exists [ + | source = $nestedInnerTable + | | where $nestedInnerTable.occupation = $innerTable.occupation + | ] + | | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val existsSubqueryForOccupation = + Filter( + Exists( + Filter( + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test3.occupation"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.occupation")), + inner2)), + inner1) + val existsSubqueryForId = + Filter( + Exists( + Filter( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + existsSubqueryForOccupation)), + outer) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubqueryForId) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test exists subquery with conjunction of conditions") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | | where id = uid AND + | $outerTable.name = $innerTable.name AND + | $outerTable.occupation = $innerTable.occupation + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1000, "Jake", 100000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists( + Filter( + And( + And( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test1.name"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.name"))), + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test1.occupation"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.occupation"))), + inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala index ee08e692a..9d8c2c12d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -305,8 +305,6 @@ class FlintSparkPPLInSubqueryITSuite | | sort - salary | | fields id, name, salary | """.stripMargin) - frame.show() - frame.explain(true) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) @@ -358,7 +356,6 @@ class FlintSparkPPLInSubqueryITSuite | $innerTable | | fields a.id, a.name, a.salary | """.stripMargin) - frame.explain(true) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index a2b84d960..83bdb185d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -120,6 +120,7 @@ APPEND: 'APPEND'; CASE: 'CASE'; ELSE: 'ELSE'; IN: 'IN'; +EXISTS: 'EXISTS'; // LOGICAL KEYWORDS NOT: 'NOT'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 266a8a709..a29c68f87 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -362,24 +362,21 @@ percentileAggFunction // expressions expression : logicalExpression - | comparisonExpression | valueExpression ; logicalExpression - : comparisonExpression # comparsion - | NOT logicalExpression # logicalNot - | left = logicalExpression OR right = logicalExpression # logicalOr + : NOT logicalExpression # logicalNot + | comparisonExpression # comparsion | left = logicalExpression (AND)? right = logicalExpression # logicalAnd + | left = logicalExpression OR right = logicalExpression # logicalOr | left = logicalExpression XOR right = logicalExpression # logicalXor | booleanExpression # booleanExpr - | isEmptyExpression # isEmptyExpr ; comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr | valueExpression IN valueList # inExpr - | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr ; valueExpressionList @@ -408,7 +405,10 @@ positionFunction ; booleanExpression - : booleanFunctionCall + : booleanFunctionCall # booleanFunctionCallExpr + | isEmptyExpression # isEmptyExpr + | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr ; isEmptyExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index eb415876b..5ac54127b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -21,7 +21,8 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -30,7 +31,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedAttribute; @@ -313,4 +314,8 @@ public T visitFieldSummary(FieldSummary fieldSummary, C context) { public T visitScalarSubquery(ScalarSubquery node, C context) { return visitChildren(node, context); } + + public T visitExistsSubquery(ExistsSubquery node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java new file mode 100644 index 000000000..bdd1683ee --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression.subquery; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class ExistsSubquery extends UnresolvedExpression { + private final UnresolvedPlan query; + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExistsSubquery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java similarity index 87% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java index ed40e4b45..4a15453e5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java @@ -3,16 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.ast.expression; +package org.opensearch.sql.ast.expression.subquery; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import java.util.Arrays; import java.util.List; @Getter diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java similarity index 84% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java index cccadb717..7c3721ffb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java @@ -3,13 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.ast.expression; +package org.opensearch.sql.ast.expression.subquery; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.UnresolvedPlan; @Getter diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index cb9bbd64d..e1732f75f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -49,6 +49,10 @@ public List getTableName() { return tableName.stream().map(Object::toString).collect(Collectors.toList()); } + public List getQualifiedNames() { + return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); + } + /** * Return alias. * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 4ac7bb01c..76a7a0c79 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -13,6 +13,7 @@ import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.ListQuery$; @@ -40,7 +41,8 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -49,7 +51,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; @@ -88,7 +90,6 @@ import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; -import scala.Option$; import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; @@ -113,6 +114,7 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildLookupRelationProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; +import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; @@ -152,22 +154,7 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { if (node instanceof DescribeRelation) { - TableIdentifier identifier; - if (node.getTableQualifiedName().getParts().size() == 1) { - identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0)); - } else if (node.getTableQualifiedName().getParts().size() == 2) { - identifier = new TableIdentifier( - node.getTableQualifiedName().getParts().get(1), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0))); - } else if (node.getTableQualifiedName().getParts().size() == 3) { - identifier = new TableIdentifier( - node.getTableQualifiedName().getParts().get(2), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1))); - } else { - throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName() - + " Syntax: [ database_name. ] table_name"); - } + TableIdentifier identifier = getTableIdentifier(node.getTableQualifiedName()); return context.with( new DescribeTableCommand( identifier, @@ -176,9 +163,9 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { DescribeRelation$.MODULE$.getOutputAttrs())); } //regular sql algebraic relations - node.getTableName().forEach(t -> + node.getQualifiedNames().forEach(q -> // Resolving the qualifiedName which is composed of a datasource.schema.table - context.withRelation(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) + context.withRelation(new UnresolvedRelation(getTableIdentifier(q).nameParts(), CaseInsensitiveStringMap.empty(), false)) ); return context.getPlan(); } @@ -327,7 +314,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex seq(new ArrayList()))); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); } - //visit TopAggregation results limit + //visit TopAggregation results limit if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) { context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( ((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p)); @@ -836,5 +823,19 @@ public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext c Option.empty()); return context.getNamedParseExpressions().push(scalarSubQuery); } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 432a0092c..47220174f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -23,7 +23,8 @@ import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldList; import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; @@ -33,7 +34,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; @@ -421,6 +422,11 @@ public UnresolvedExpression visitScalarSubqueryExpr(OpenSearchPPLParser.ScalarSu return new ScalarSubquery(astBuilder.visitSubSearch(ctx.subSearch())); } + @Override + public UnresolvedExpression visitExistsSubqueryExpr(OpenSearchPPLParser.ExistsSubqueryExprContext ctx) { + return new ExistsSubquery(astBuilder.visitSubSearch(ctx.subSearch())); + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 33cb5611d..7be7f1f45 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -1,8 +1,10 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.ast.expression.QualifiedName; +import scala.Option$; import java.util.List; import java.util.Optional; @@ -15,7 +17,7 @@ public interface RelationUtils { * * @param relations * @param node - * @param contextRelations + * @param tables * @return */ static Optional resolveField(List relations, QualifiedName node, List tables) { @@ -29,4 +31,26 @@ static Optional resolveField(List relations, .findFirst() .map(rel -> node); } + + static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) { + TableIdentifier identifier; + if (qualifiedName.getParts().isEmpty()) { + throw new IllegalArgumentException("Empty table name is invalid"); + } else if (qualifiedName.getParts().size() == 1) { + identifier = new TableIdentifier(qualifiedName.getParts().get(0)); + } else if (qualifiedName.getParts().size() == 2) { + identifier = new TableIdentifier( + qualifiedName.getParts().get(1), + Option$.MODULE$.apply(qualifiedName.getParts().get(0))); + } else if (qualifiedName.getParts().size() == 3) { + identifier = new TableIdentifier( + qualifiedName.getParts().get(2), + Option$.MODULE$.apply(qualifiedName.getParts().get(1)), + Option$.MODULE$.apply(qualifiedName.getParts().get(0))); + } else { + throw new IllegalArgumentException("Invalid table name: " + qualifiedName + + " Syntax: [ database_name. ] table_name"); + } + return identifier; + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index cc87e8853..96176982e 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -37,13 +37,26 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite thrown.getMessage === "Invalid table name: t.b.c.d Syntax: [ database_name. ] table_name") } + test("test describe with backticks") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "describe t.b.`c.d`"), context) + + val expectedPlan = DescribeTableCommand( + TableIdentifier("c.d", Option("b"), Option("t")), + Map.empty[String, String].empty, + isExtended = true, + output = DescribeRelation.getOutputAttrs) + comparePlans(expectedPlan, logPlan, false) + } + test("test describe FQN table clause") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "describe schema.default.http_logs"), context) + planTransformer.visit(plan(pplParser, "describe catalog.schema.http_logs"), context) val expectedPlan = DescribeTableCommand( - TableIdentifier("http_logs", Option("schema"), Option("default")), + TableIdentifier("http_logs", Option("schema"), Option("catalog")), Map.empty[String, String].empty, isExtended = true, output = DescribeRelation.getOutputAttrs) @@ -64,10 +77,10 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite test("test FQN table describe table clause") { val context = new CatalystPlanContext - val logPlan = planTransformer.visit(plan(pplParser, "describe catalog.t"), context) + val logPlan = planTransformer.visit(plan(pplParser, "describe schema.t"), context) val expectedPlan = DescribeTableCommand( - TableIdentifier("t", Option("catalog")), + TableIdentifier("t", Option("schema")), Map.empty[String, String].empty, isExtended = true, output = DescribeRelation.getOutputAttrs) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala new file mode 100644 index 000000000..02dfe1096 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala @@ -0,0 +1,315 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, Exists, GreaterThanOrEqual, LessThan, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanExistsSubqueryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + // Assume outer table contains fields [a, b] + // and inner table contains fields [c, d] + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test where exists (select * from inner where a = c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner | where a = c + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val subquery = + Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), inner)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where exists (select * from inner where a = c and b = d)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner | where a = c AND b = d + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val existsSubquery = + Filter( + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d"))), + inner)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, existsSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where not exists (select * from inner where a = c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where not exists [ + | source = spark_catalog.default.inner | where a = c + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val subquery = + Filter( + Not(Exists(Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where not exists (select * from inner where a = c and b = d)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where not exists [ + | source = spark_catalog.default.inner | where a = c AND b = d + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val existsSubquery = + Filter( + Not( + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d"))), + inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, existsSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + // Assume outer table contains fields [a, b] + // and inner1 table contains fields [c, d] + // and inner2 table contains fields [e, f] + test("test nested exists subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner1 + | | where exists [ + | source = spark_catalog.default.inner2 + | | where c = e + | ] + | | where a = c + | ] + | | sort - a + | | fields a, b + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "inner1")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "inner2")) + val subqueryOuter = + Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("c"), UnresolvedAttribute("e")), inner2)), + inner1) + val subqueryInner = + Filter( + Exists( + Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), subqueryOuter)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subqueryInner) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test tpch q4: exists subquery with aggregation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = orders + | | where o_orderdate >= "1993-07-01" AND o_orderdate < "1993-10-01" + | AND exists [ + | source = lineitem + | | where l_orderkey = o_orderkey + | AND l_commitdate < l_receiptdate + | ] + | | stats count(1) as order_count by o_orderpriority + | | sort o_orderpriority + | | fields o_orderpriority, order_count + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("orders")) + val inner = UnresolvedRelation(Seq("lineitem")) + val inSubquery = + Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("o_orderdate"), Literal("1993-07-01")), + LessThan(UnresolvedAttribute("o_orderdate"), Literal("1993-10-01"))), + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("l_orderkey"), UnresolvedAttribute("o_orderkey")), + LessThan( + UnresolvedAttribute("l_commitdate"), + UnresolvedAttribute("l_receiptdate"))), + inner))), + outer) + val o_orderpriorityAlias = Alias(UnresolvedAttribute("o_orderpriority"), "o_orderpriority")() + val groupByAttributes = Seq(o_orderpriorityAlias) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), + "order_count")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, o_orderpriorityAlias), inSubquery) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("o_orderpriority"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project( + Seq(UnresolvedAttribute("o_orderpriority"), UnresolvedAttribute("order_count")), + sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + // We can support q21 when the table alias is supported + ignore("test tpch q21 (partial): multiple exists subquery") { + // select + // s_name, + // count(*) as numwait + // from + // supplier, + // lineitem l1, + // where + // s_suppkey = l1.l_suppkey + // and l1.l_receiptdate > l1.l_commitdate + // and exists ( + // select + // * + // from + // lineitem l2 + // where + // l2.l_orderkey = l1.l_orderkey + // and l2.l_suppkey <> l1.l_suppkey + // ) + // and not exists ( + // select + // * + // from + // lineitem l3 + // where + // l3.l_orderkey = l1.l_orderkey + // and l3.l_suppkey <> l1.l_suppkey + // and l3.l_receiptdate > l3.l_commitdate + // ) + // group by + // s_name + // order by + // numwait desc, + // s_name + // limit 100 + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = supplier + | | join left=s right=l1 on s_suppkey = l1.l_suppkey + | lineitem as l1 + | | where l1.l_receiptdate > l1.l_commitdate + | | where exists [ + | source = lineitem as l2 + | | where l2.l_orderkey = l1.l_orderkey and + | l2.l_suppkey <> l1.l_suppkey + | ] + | | where not exists [ + | source = lineitem as l3 + | | where l3.l_orderkey = l1.l_orderkey and + | l3.l_suppkey <> l1.l_suppkey and + | l3.l_receiptdate > l3.l_commitdate + | ] + | | stats count(1) as numwait by s_name + | | sort - numwait, s_name + | | fields s_name, numwait + | | limit 100 + | """.stripMargin), + context) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index 407b3df84..20809db95 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -5,23 +5,15 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals -import org.mockito.Mockito.when import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, Or, SortOrder, UnixTimestamp} -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class PPLLogicalPlanFiltersTranslatorTestSuite extends SparkFunSuite @@ -219,4 +211,26 @@ class PPLLogicalPlanFiltersTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("test order of evaluation of predicate expression") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | where department = 'HR' OR job_title = 'Manager' AND salary > 50000"), + context) + + val table = UnresolvedRelation(Seq("employees")) + val filter = + Filter( + Or( + EqualTo(UnresolvedAttribute("department"), Literal("HR")), + And( + EqualTo(UnresolvedAttribute("job_title"), Literal("Manager")), + GreaterThan(UnresolvedAttribute("salary"), Literal(50000)))), + table) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), filter) + comparePlans(expectedPlan, logPlan, false) + } }