diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 3efca3205..409b128c9 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -60,6 +60,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` +- `source = table | trendline sma(2, temperature) as temp_trend` ```sql source = table | eval status_category = @@ -122,6 +123,15 @@ Assumptions: `a`, `b`, `c`, `d`, `e` are existing fields in `table` - `source = table | fillnull using a = 101, b = 102` - `source = table | fillnull using a = concat(b, c), d = 2 * pi() * e` +### Flatten +[See additional command details](ppl-flatten-command.md) +Assumptions: `bridges`, `coor` are existing fields in `table`, and the field's types are `struct` or `array>` +- `source = table | flatten bridges` +- `source = table | flatten coor` +- `source = table | flatten bridges | flatten coor` +- `source = table | fields bridges | flatten bridges` +- `source = table | fields country, bridges | flatten bridges | fields country, length | stats avg(length) as avg by country` + ```sql source = table | eval e = eval status_category = case(a >= 200 AND a < 300, 'Success', diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 8d9b86eda..6ba49b031 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -31,6 +31,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`describe command`](PPL-Example-Commands.md/#describe) - [`fillnull command`](ppl-fillnull-command.md) + + - [`flatten command`](ppl-flatten-command.md) - [`eval command`](ppl-eval-command.md) @@ -67,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`subquery commands`](ppl-subquery-command.md) - [`correlation commands`](ppl-correlation-command.md) - + + - [`trendline commands`](ppl-trendline-command.md) * **Functions** diff --git a/docs/ppl-lang/ppl-flatten-command.md b/docs/ppl-lang/ppl-flatten-command.md new file mode 100644 index 000000000..4c1ae5d0d --- /dev/null +++ b/docs/ppl-lang/ppl-flatten-command.md @@ -0,0 +1,90 @@ +## PPL `flatten` command + +### Description +Using `flatten` command to flatten a field of type: +- `struct` +- `array>` + + +### Syntax +`flatten ` + +* field: to be flattened. The field must be of supported type. + +### Test table +#### Schema +| col\_name | data\_type | +|-----------|-------------------------------------------------| +| \_time | string | +| bridges | array\\> | +| city | string | +| coor | struct\ | +| country | string | +#### Data +| \_time | bridges | city | coor | country | +|---------------------|----------------------------------------------|---------|------------------------|---------------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | {35, 51.5074, -0.1278} | England | +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | {35, 48.8566, 2.3522} | France | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | {2, 45.4408, 12.3155} | Italy | +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | {200, 50.0755, 14.4378}| Czech Republic| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| {96, 47.4979, 19.0402} | Hungary | +| 1990-09-13T12:00:00 | NULL | Warsaw | NULL | Poland | + + + +### Example 1: flatten struct +This example shows how to flatten a struct field. +PPL query: + - `source=table | flatten coor` + +| \_time | bridges | city | country | alt | lat | long | +|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| +| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | + + + +### Example 2: flatten array + +The example shows how to flatten an array of struct fields. + +PPL query: + - `source=table | flatten bridges` + +| \_time | city | coor | country | length | name | +|---------------------|---------|------------------------|---------------|--------|-------------------| +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 801 | Tower Bridge | +| 2024-09-13T12:00:00 | London | {35, 51.5074, -0.1278} | England | 928 | London Bridge | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 232 | Pont Neuf | +| 2024-09-13T12:00:00 | Paris | {35, 48.8566, 2.3522} | France | 160 | Pont Alexandre III| +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 48 | Rialto Bridge | +| 2024-09-13T12:00:00 | Venice | {2, 45.4408, 12.3155} | Italy | 11 | Bridge of Sighs | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 516 | Charles Bridge | +| 2024-09-13T12:00:00 | Prague | {200, 50.0755, 14.4378}| Czech Republic| 343 | Legion Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 375 | Chain Bridge | +| 2024-09-13T12:00:00 | Budapest| {96, 47.4979, 19.0402} | Hungary | 333 | Liberty Bridge | +| 1990-09-13T12:00:00 | Warsaw | NULL | Poland | NULL | NULL | + + +### Example 3: flatten array and struct +This example shows how to flatten multiple fields. +PPL query: + - `source=table | flatten bridges | flatten coor` + +| \_time | city | country | length | name | alt | lat | long | +|---------------------|---------|---------------|--------|-------------------|------|--------|--------| +| 2024-09-13T12:00:00 | London | England | 801 | Tower Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | London | England | 928 | London Bridge | 35 | 51.5074| -0.1278| +| 2024-09-13T12:00:00 | Paris | France | 232 | Pont Neuf | 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Paris | France | 160 | Pont Alexandre III| 35 | 48.8566| 2.3522 | +| 2024-09-13T12:00:00 | Venice | Italy | 48 | Rialto Bridge | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Venice | Italy | 11 | Bridge of Sighs | 2 | 45.4408| 12.3155| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 516 | Charles Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Prague | Czech Republic| 343 | Legion Bridge | 200 | 50.0755| 14.4378| +| 2024-09-13T12:00:00 | Budapest| Hungary | 375 | Chain Bridge | 96 | 47.4979| 19.0402| +| 2024-09-13T12:00:00 | Budapest| Hungary | 333 | Liberty Bridge | 96 | 47.4979| 19.0402| +| 1990-09-13T12:00:00 | Warsaw | Poland | NULL | NULL | NULL | NULL | NULL | \ No newline at end of file diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md new file mode 100644 index 000000000..393a9dd59 --- /dev/null +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -0,0 +1,60 @@ +## PPL trendline Command + +**Description** +Using ``trendline`` command to calculate moving averages of fields. + + +### Syntax +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +### Example 1: Calculate simple moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.5| + | 14| 257|2023-04-06 17:07:...| 13.5| + | 15| 258|2023-04-06 17:07:...| 14.5| + +-----------+---------+--------------------+----------+ + +### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.5| NULL| + | 13| 256|2023-04-06 17:07:...| 13.5| 14.0| + | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| + | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| + +-----------+---------+--------------------+------------+------------------+ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index b4412a3d4..68d2409ee 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -60,7 +60,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w private val flintMetadataCacheWriter = FlintMetadataCacheWriterBuilder.build(flintSparkConf) private val flintAsyncQueryScheduler: AsyncQueryScheduler = { - AsyncQuerySchedulerBuilder.build(flintSparkConf.flintOptions()) + AsyncQuerySchedulerBuilder.build(spark, flintSparkConf.flintOptions()) } override protected val flintMetadataLogService: FlintMetadataLogService = { @@ -183,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w attachLatestLogEntry(indexName, metadata) } .toList - .flatMap(FlintSparkIndexFactory.create) + .flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata)) } else { Seq.empty } @@ -202,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w if (flintClient.exists(indexName)) { val metadata = flintIndexMetadataService.getIndexMetadata(indexName) val metadataWithEntry = attachLatestLogEntry(indexName, metadata) - FlintSparkIndexFactory.create(metadataWithEntry) + FlintSparkIndexFactory.create(spark, metadataWithEntry) } else { Option.empty } @@ -327,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w val index = describeIndex(indexName) if (index.exists(_.options.autoRefresh())) { - val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get + val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get FlintSparkIndexRefresh .create(updatedIndex.name(), updatedIndex) .validate(spark) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index 0391741cf..2ff2883a9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get) + validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get) } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 78636d992..ca659550d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession /** * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index from generic Flint metadata. * + * @param spark + * Spark session * @param metadata * Flint metadata * @return * Flint index instance, or None if any error during creation */ - def create(metadata: FlintMetadata): Option[FlintSparkIndex] = { + def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = { try { - Some(doCreate(metadata)) + Some(doCreate(spark, metadata)) } catch { case e: Exception => logWarning(s"Failed to create Flint index from metadata $metadata", e) @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging { /** * Creates Flint index with default options. * + * @param spark + * Spark session * @param index * Flint index - * @param metadata - * Flint metadata * @return * Flint index with default options */ - def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = { + def createWithDefaultOptions( + spark: SparkSession, + index: FlintSparkIndex): Option[FlintSparkIndex] = { val originalOptions = index.options val updatedOptions = FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions) val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - this.create(updatedMetadata) + this.create(spark, updatedMetadata) } - private def doCreate(metadata: FlintMetadata): FlintSparkIndex = { + private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = { val indexOptions = FlintSparkIndexOptions( metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) val latestLogEntry = metadata.latestLogEntry @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging { FlintSparkMaterializedView( metadata.name, metadata.source, + getMvSourceTables(spark, metadata), metadata.indexedColumns.map { colInfo => getString(colInfo, "columnName") -> getString(colInfo, "columnType") }.toMap, @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging { .toMap } + private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = { + val sourceTables = getArrayString(metadata.properties, "sourceTables") + if (sourceTables.isEmpty) { + FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source) + } else { + sourceTables + } + } + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { map.get(key).asInstanceOf[String] } @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging { Some(value.asInstanceOf[String]) } } + + private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = { + map.get(key) match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala index 1aaa85075..7e9922655 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.flint.{loadTable, parseTableName} /** * Flint Spark validation helper. @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging { * true if all non Hive, otherwise false */ def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = { - // Extract source table name (possibly more than one for MV query) val tableNames = index match { case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName) case covering: FlintSparkCoveringIndex => Seq(covering.tableName) - case mv: FlintSparkMaterializedView => - spark.sessionState.sqlParser - .parsePlan(mv.query) - .collect { case relation: UnresolvedRelation => - qualifyTableName(spark, relation.tableName) - } + case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq } // Validate if any source table is not supported (currently Hive only) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala index e8a91e1be..e1c0f318c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCache.scala @@ -10,6 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser /** @@ -46,9 +47,7 @@ case class FlintMetadataCache( object FlintMetadataCache { - // TODO: constant for version - val mockTableName = - "dataSourceName.default.logGroups(logGroupIdentifier:['arn:aws:logs:us-east-1:123456:test-llt-xa', 'arn:aws:logs:us-east-1:123456:sample-lg-1'])" + val metadataCacheVersion = "1.0" def apply(metadata: FlintMetadata): FlintMetadataCache = { val indexOptions = FlintSparkIndexOptions( @@ -61,6 +60,15 @@ object FlintMetadataCache { } else { None } + val sourceTables = metadata.kind match { + case MV_INDEX_TYPE => + metadata.properties.get("sourceTables") match { + case list: java.util.ArrayList[_] => + list.toArray.map(_.toString) + case _ => Array.empty[String] + } + case _ => Array(metadata.source) + } val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry => entry.lastRefreshCompleteTime match { case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None @@ -68,7 +76,6 @@ object FlintMetadataCache { } } - // TODO: get source tables from metadata - FlintMetadataCache("1.0", refreshInterval, Array(mockTableName), lastRefreshTime) + FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index caa75be75..aecfc99df 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * MV name * @param query * source query that generates MV data + * @param sourceTables + * source table names * @param outputSchema * output schema * @param options @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class FlintSparkMaterializedView( mvName: String, query: String, + sourceTables: Array[String], outputSchema: Map[String, String], override val options: FlintSparkIndexOptions = empty, override val latestLogEntry: Option[FlintMetadataLogEntry] = None) @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView( metadataBuilder(this) .name(mvName) .source(query) + .addProperty("sourceTables", sourceTables) .indexedColumns(indexColumnMaps) .schema(schema) .build() @@ -165,10 +169,30 @@ object FlintSparkMaterializedView { flintIndexNamePrefix(mvName) } + /** + * Extract source table names (possibly more than one) from the query. + * + * @param spark + * Spark session + * @param query + * source query that generates MV data + * @return + * source table names + */ + def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = { + spark.sessionState.sqlParser + .parsePlan(query) + .collect { case relation: UnresolvedRelation => + qualifyTableName(spark, relation.tableName) + } + .toArray + } + /** Builder class for MV build */ class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { private var mvName: String = "" private var query: String = "" + private var sourceTables: Array[String] = Array.empty[String] /** * Set MV name. @@ -193,6 +217,7 @@ object FlintSparkMaterializedView { */ def query(query: String): Builder = { this.query = query + this.sourceTables = extractSourceTableNames(flint.spark, query) this } @@ -221,7 +246,7 @@ object FlintSparkMaterializedView { field.name -> field.dataType.simpleString } .toMap - FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions) } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java index 3620608b0..330b38f02 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/AsyncQuerySchedulerBuilder.java @@ -7,9 +7,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.flint.config.FlintSparkConf; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.core.FlintOptions; +import java.io.IOException; import java.lang.reflect.Constructor; /** @@ -28,11 +31,27 @@ public enum AsyncQuerySchedulerAction { REMOVE } - public static AsyncQueryScheduler build(FlintOptions options) { + public static AsyncQueryScheduler build(SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilder().doBuild(sparkSession, options); + } + + /** + * Builds an AsyncQueryScheduler based on the provided options. + * + * @param sparkSession The SparkSession to be used. + * @param options The FlintOptions containing configuration details. + * @return An instance of AsyncQueryScheduler. + */ + protected AsyncQueryScheduler doBuild(SparkSession sparkSession, FlintOptions options) throws IOException { String className = options.getCustomAsyncQuerySchedulerClass(); if (className.isEmpty()) { - return new OpenSearchAsyncQueryScheduler(options); + OpenSearchAsyncQueryScheduler scheduler = createOpenSearchAsyncQueryScheduler(options); + // Check if the scheduler has access to the required index. Disable the external scheduler otherwise. + if (!hasAccessToSchedulerIndex(scheduler)){ + setExternalSchedulerEnabled(sparkSession, false); + } + return scheduler; } // Attempts to instantiate AsyncQueryScheduler using reflection @@ -45,4 +64,16 @@ public static AsyncQueryScheduler build(FlintOptions options) { throw new RuntimeException("Failed to instantiate AsyncQueryScheduler: " + className, e); } } + + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return new OpenSearchAsyncQueryScheduler(options); + } + + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return scheduler.hasAccessToSchedulerIndex(); + } + + protected void setExternalSchedulerEnabled(SparkSession sparkSession, boolean enabled) { + sparkSession.sqlContext().setConf(FlintSparkConf.EXTERNAL_SCHEDULER_ENABLED().key(), String.valueOf(enabled)); + } } \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java index 19532254b..a1ef45825 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -37,6 +38,7 @@ import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.rest.RestStatus; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -55,6 +57,11 @@ public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler { private static final ObjectMapper mapper = new ObjectMapper(); private final FlintOptions flintOptions; + @VisibleForTesting + public OpenSearchAsyncQueryScheduler() { + this.flintOptions = new FlintOptions(ImmutableMap.of()); + } + public OpenSearchAsyncQueryScheduler(FlintOptions options) { this.flintOptions = options; } @@ -124,6 +131,28 @@ void createAsyncQuerySchedulerIndex(IRestHighLevelClient client) { } } + /** + * Checks if the current setup has access to the scheduler index. + * + * This method attempts to create a client and ensure that the scheduler index exists. + * If these operations succeed, it indicates that the user has the necessary permissions + * to access and potentially modify the scheduler index. + * + * @see #createClient() + * @see #ensureIndexExists(IRestHighLevelClient) + */ + public boolean hasAccessToSchedulerIndex() throws IOException { + IRestHighLevelClient client = createClient(); + try { + ensureIndexExists(client); + return true; + } catch (Throwable e) { + LOG.error("Failed to ensure index exists", e); + return false; + } finally { + client.close(); + } + } private void ensureIndexExists(IRestHighLevelClient client) { try { if (!client.doesIndexExist(new GetIndexRequest(SCHEDULER_INDEX_NAME), RequestOptions.DEFAULT)) { diff --git a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java index 67b5afee5..3c65a96a5 100644 --- a/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java +++ b/flint-spark-integration/src/test/java/org/opensearch/flint/core/scheduler/AsyncQuerySchedulerBuilderTest.java @@ -5,43 +5,80 @@ package org.opensearch.flint.core.scheduler; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.SQLContext; +import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.flint.common.scheduler.AsyncQueryScheduler; import org.opensearch.flint.common.scheduler.model.AsyncQuerySchedulerRequest; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.spark.scheduler.AsyncQuerySchedulerBuilder; import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler; +import java.io.IOException; + import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AsyncQuerySchedulerBuilderTest { + @Mock + private SparkSession sparkSession; + + @Mock + private SQLContext sqlContext; + + private AsyncQuerySchedulerBuilderForLocalTest testBuilder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + when(sparkSession.sqlContext()).thenReturn(sqlContext); + } + + @Test + public void testBuildWithEmptyClassNameAndAccessibleIndex() throws IOException { + FlintOptions options = mock(FlintOptions.class); + when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); + + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, true, sparkSession, options); + assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext, never()).setConf(anyString(), anyString()); + } @Test - public void testBuildWithEmptyClassName() { + public void testBuildWithEmptyClassNameAndInaccessibleIndex() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn(""); + OpenSearchAsyncQueryScheduler mockScheduler = mock(OpenSearchAsyncQueryScheduler.class); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = testBuilder.build(mockScheduler, false, sparkSession, options); assertTrue(scheduler instanceof OpenSearchAsyncQueryScheduler); + verify(sqlContext).setConf("spark.flint.job.externalScheduler.enabled", "false"); } @Test - public void testBuildWithCustomClassName() { + public void testBuildWithCustomClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); - when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); + when(options.getCustomAsyncQuerySchedulerClass()) + .thenReturn("org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest"); - AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(options); + AsyncQueryScheduler scheduler = AsyncQuerySchedulerBuilder.build(sparkSession, options); assertTrue(scheduler instanceof AsyncQuerySchedulerForLocalTest); } @Test(expected = RuntimeException.class) - public void testBuildWithInvalidClassName() { + public void testBuildWithInvalidClassName() throws IOException { FlintOptions options = mock(FlintOptions.class); when(options.getCustomAsyncQuerySchedulerClass()).thenReturn("invalid.ClassName"); - AsyncQuerySchedulerBuilder.build(options); + AsyncQuerySchedulerBuilder.build(sparkSession, options); } public static class AsyncQuerySchedulerForLocalTest implements AsyncQueryScheduler { @@ -65,4 +102,35 @@ public void removeJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { // Custom implementation } } + + public static class OpenSearchAsyncQuerySchedulerForLocalTest extends OpenSearchAsyncQueryScheduler { + @Override + public boolean hasAccessToSchedulerIndex() { + return true; + } + } + + public static class AsyncQuerySchedulerBuilderForLocalTest extends AsyncQuerySchedulerBuilder { + private OpenSearchAsyncQueryScheduler mockScheduler; + private Boolean mockHasAccess; + + public AsyncQuerySchedulerBuilderForLocalTest(OpenSearchAsyncQueryScheduler mockScheduler, Boolean mockHasAccess) { + this.mockScheduler = mockScheduler; + this.mockHasAccess = mockHasAccess; + } + + @Override + protected OpenSearchAsyncQueryScheduler createOpenSearchAsyncQueryScheduler(FlintOptions options) { + return mockScheduler != null ? mockScheduler : super.createOpenSearchAsyncQueryScheduler(options); + } + + @Override + protected boolean hasAccessToSchedulerIndex(OpenSearchAsyncQueryScheduler scheduler) throws IOException { + return mockHasAccess != null ? mockHasAccess : super.hasAccessToSchedulerIndex(scheduler); + } + + public static AsyncQueryScheduler build(OpenSearchAsyncQueryScheduler asyncQueryScheduler, Boolean hasAccess, SparkSession sparkSession, FlintOptions options) throws IOException { + return new AsyncQuerySchedulerBuilderForLocalTest(asyncQueryScheduler, hasAccess).doBuild(sparkSession, options); + } + } } \ No newline at end of file diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index e43b0c52c..b675265b7 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.FlintSparkExtensions import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.flint.config.FlintConfigEntry +import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,6 +26,10 @@ trait FlintSuite extends SharedSparkSession { // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) + // Override scheduler class for unit testing + .set( + FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key, + "org.opensearch.flint.core.scheduler.AsyncQuerySchedulerBuilderTest$AsyncQuerySchedulerForLocalTest") conf } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala new file mode 100644 index 000000000..07720ff24 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexFactorySuite.scala @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.scalatest.matchers.should.Matchers._ + +import org.apache.spark.FlintSuite + +class FlintSparkIndexFactorySuite extends FlintSuite { + + test("create mv should generate source tables if missing in metadata") { + val testTable = "spark_catalog.default.mv_build_test" + val testMvName = "spark_catalog.default.mv" + val testQuery = s"SELECT * FROM $testTable" + + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "indexedColumns": [ + | { + | "columnType": "int", + | "columnName": "age" + | } + | ], + | "name": "$testMvName", + | "source": "$testQuery" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + + val metadata = FlintOpenSearchIndexMetadataService.deserialize(content) + val index = FlintSparkIndexFactory.create(spark, metadata) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala index c6d2cf12a..6ec6cf696 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/metadatacache/FlintMetadataCacheSuite.scala @@ -7,6 +7,9 @@ package org.opensearch.flint.spark.metadatacache import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -21,11 +24,12 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { "", Map.empty[String, Any]) - it should "construct from FlintMetadata" in { + it should "construct from skipping index FlintMetadata" in { val content = - """ { + s""" { | "_meta": { - | "kind": "test_kind", + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", | "options": { | "auto_refresh": "true", | "refresh_interval": "10 Minutes" @@ -43,18 +47,85 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { .copy(latestLogEntry = Some(flintMetadataLogEntry)) val metadataCache = FlintMetadataCache(metadata) - metadataCache.metadataCacheVersion shouldBe "1.0" + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion metadataCache.refreshInterval.get shouldBe 600 - metadataCache.sourceTables shouldBe Array(FlintMetadataCache.mockTableName) + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from covering index FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$COVERING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") + metadataCache.lastRefreshTime.get shouldBe 1234567890123L + } + + it should "construct from materialized view FlintMetadata" in { + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "source": "spark_catalog.default.wrong_table", + | "options": { + | "auto_refresh": "true", + | "refresh_interval": "10 Minutes" + | }, + | "properties": { + | "sourceTables": [ + | "spark_catalog.default.test_table", + | "spark_catalog.default.another_table" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + + val metadataCache = FlintMetadataCache(metadata) + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion + metadataCache.refreshInterval.get shouldBe 600 + metadataCache.sourceTables shouldBe Array( + "spark_catalog.default.test_table", + "spark_catalog.default.another_table") metadataCache.lastRefreshTime.get shouldBe 1234567890123L } it should "construct from FlintMetadata excluding invalid fields" in { // Set auto_refresh = false and lastRefreshCompleteTime = 0 val content = - """ { + s""" { | "_meta": { - | "kind": "test_kind", + | "kind": "$SKIPPING_INDEX_TYPE", + | "source": "spark_catalog.default.test_table", | "options": { | "refresh_interval": "10 Minutes" | } @@ -71,9 +142,9 @@ class FlintMetadataCacheSuite extends AnyFlatSpec with Matchers { .copy(latestLogEntry = Some(flintMetadataLogEntry.copy(lastRefreshCompleteTime = 0L))) val metadataCache = FlintMetadataCache(metadata) - metadataCache.metadataCacheVersion shouldBe "1.0" + metadataCache.metadataCacheVersion shouldBe FlintMetadataCache.metadataCacheVersion metadataCache.refreshInterval shouldBe empty - metadataCache.sourceTables shouldBe Array(FlintMetadataCache.mockTableName) + metadataCache.sourceTables shouldBe Array("spark_catalog.default.test_table") metadataCache.lastRefreshTime shouldBe empty } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index c1df42883..1c9a9e83c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} -import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, the} +import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite @@ -37,31 +37,34 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = "SELECT 1" test("get mv name") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" } test("get mv name with dots") { val testMvNameDots = "spark_catalog.default.mv.2023.10" - val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvNameDots, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv.2023.10" } test("should fail if get name with unqualified MV name") { the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("mv", testQuery, Array.empty, Map.empty).name() the[IllegalArgumentException] thrownBy - FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + FlintSparkMaterializedView("default.mv", testQuery, Array.empty, Map.empty).name() } test("get metadata") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + val mv = + FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map("test_col" -> "integer")) val metadata = mv.metadata() metadata.name shouldBe mv.mvName metadata.kind shouldBe MV_INDEX_TYPE metadata.source shouldBe "SELECT 1" + metadata.properties should contain key "sourceTables" + metadata.properties.get("sourceTables").asInstanceOf[Array[String]] should have size 0 metadata.indexedColumns shouldBe Array( Map("columnName" -> "test_col", "columnType" -> "integer").asJava) metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava @@ -74,6 +77,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, testQuery, + Array.empty, Map("test_col" -> "integer"), indexOptions) @@ -83,12 +87,12 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build batch data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.build(spark, None).collect() shouldBe Array(Row(1)) } test("should fail if build given other source data frame") { - val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) } @@ -103,7 +107,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -128,7 +132,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { |""".stripMargin val options = Map("watermark_delay" -> "30 Seconds") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -144,7 +148,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { test("build stream with non-aggregate query") { val testQuery = s"SELECT name, age FROM $testTable WHERE age > 30" - withAggregateMaterializedView(testQuery, Map.empty) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), Map.empty) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable) @@ -158,7 +162,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testQuery = s"SELECT name, age FROM $testTable" val options = Map("extra_options" -> s"""{"$testTable": {"maxFilesPerTrigger": "1"}}""") - withAggregateMaterializedView(testQuery, options) { actualPlan => + withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( actualPlan, streamingRelation(testTable, Map("maxFilesPerTrigger" -> "1")) @@ -175,6 +179,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView( testMvName, s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), Map.empty) the[IllegalStateException] thrownBy @@ -182,14 +187,20 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } - private def withAggregateMaterializedView(query: String, options: Map[String, String])( - codeBlock: LogicalPlan => Unit): Unit = { + private def withAggregateMaterializedView( + query: String, + sourceTables: Array[String], + options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView(testMvName, query, Map.empty, FlintSparkIndexOptions(options)) + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) val actualPlan = mv.buildStream(spark).queryExecution.logical codeBlock(actualPlan) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index 14d41c2bb..fc77faaea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -17,9 +17,10 @@ import org.opensearch.flint.common.FlintVersion.current import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils} import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName -import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName} import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler -import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.{DataFrame, Row} @@ -51,6 +52,29 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { deleteTestIndex(testFlintIndex) } + test("extract source table names from materialized view source query successfully") { + val testComplexQuery = s""" + | SELECT * + | FROM ( + | SELECT 1 + | FROM table1 + | LEFT JOIN `table2` + | ) + | UNION ALL + | SELECT 1 + | FROM spark_catalog.default.`table/3` + | INNER JOIN spark_catalog.default.`table.4` + |""".stripMargin + extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs + Array( + "spark_catalog.default.table1", + "spark_catalog.default.table2", + "spark_catalog.default.`table/3`", + "spark_catalog.default.`table.4`") + + extractSourceTableNames(flint.spark, "SELECT 1") should have size 0 + } + test("create materialized view with metadata successfully") { withTempDir { checkpointDir => val indexOptions = @@ -91,7 +115,9 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | "scheduler_mode":"internal" | }, | "latestId": "$testLatestId", - | "properties": {} + | "properties": { + | "sourceTables": ["$testTable"] + | } | }, | "properties": { | "startTime": { @@ -107,6 +133,22 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { } } + test("create materialized view should parse source tables successfully") { + val indexOptions = FlintSparkIndexOptions(Map.empty) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions, testFlintIndex) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get + .asInstanceOf[FlintSparkMaterializedView] + .sourceTables should contain theSameElementsAs Array(testTable) + } + test("create materialized view with default checkpoint location successfully") { withTempDir { checkpointDir => setFlintSparkConf( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c8c902294..c53eee548 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} @@ -23,6 +23,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.{FlintSuite, SparkConf} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -49,6 +50,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit override def beforeAll(): Unit = { super.beforeAll() + // Revoke override in FlintSuite on IT + conf.unsetConf(FlintSparkConf.CUSTOM_FLINT_SCHEDULER_CLASS.key) // Replace executor to avoid impact on IT. // TODO: Currently no IT test scheduler so no need to restore it back. @@ -534,6 +537,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiValueStructTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_value Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)) ), + | ( 2, array(STRUCT("2_Monday", 2), null) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 4, null ) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( @@ -695,4 +720,100 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (9, '2001:db8::ff00:12:', true, false) | """.stripMargin) } + + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { + val json = + """ + |[ + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Tower Bridge", "length": 801}, + | {"name": "London Bridge", "length": 928} + | ], + | "city": "London", + | "country": "England", + | "coor": { + | "lat": 51.5074, + | "long": -0.1278, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Pont Neuf", "length": 232}, + | {"name": "Pont Alexandre III", "length": 160} + | ], + | "city": "Paris", + | "country": "France", + | "coor": { + | "lat": 48.8566, + | "long": 2.3522, + | "alt": 35 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Rialto Bridge", "length": 48}, + | {"name": "Bridge of Sighs", "length": 11} + | ], + | "city": "Venice", + | "country": "Italy", + | "coor": { + | "lat": 45.4408, + | "long": 12.3155, + | "alt": 2 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Charles Bridge", "length": 516}, + | {"name": "Legion Bridge", "length": 343} + | ], + | "city": "Prague", + | "country": "Czech Republic", + | "coor": { + | "lat": 50.0755, + | "long": 14.4378, + | "alt": 200 + | } + | }, + | { + | "_time": "2024-09-13T12:00:00", + | "bridges": [ + | {"name": "Chain Bridge", "length": 375}, + | {"name": "Liberty Bridge", "length": 333} + | ], + | "city": "Budapest", + | "country": "Hungary", + | "coor": { + | "lat": 47.4979, + | "long": 19.0402, + | "alt": 96 + | } + | }, + | { + | "_time": "1990-09-13T12:00:00", + | "bridges": null, + | "city": "Warsaw", + | "country": "Poland", + | "coor": null + | } + |] + |""".stripMargin + val tempFile = Files.createTempFile("jsonTestData", ".json") + val absolutPath = tempFile.toAbsolutePath.toString; + Files.write(tempFile, json.getBytes) + sql(s""" + | CREATE TEMPORARY VIEW $testTable + | USING org.apache.spark.sql.json + | OPTIONS ( + | path "$absolutPath", + | multiLine true + | ); + |""".stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index a6f7e0ed0..c9f6c47f7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -9,7 +9,6 @@ import scala.jdk.CollectionConverters.mapAsJavaMapConverter import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.native.JsonMethods._ -import org.opensearch.OpenSearchException import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions @@ -207,13 +206,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexInitial = flint.describeIndex(testIndex).get indexInitial.options.refreshInterval() shouldBe Some("4 Minute") - the[OpenSearchException] thrownBy { - val client = - OpenSearchClientUtils.createClient(new FlintOptions(openSearchOptions.asJava)) - client.get( - new GetRequest(OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, testIndex), - RequestOptions.DEFAULT) - } + indexInitial.options.isExternalSchedulerEnabled() shouldBe false // Update Flint index to change refresh interval val updatedIndex = flint @@ -228,6 +221,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { val indexFinal = flint.describeIndex(testIndex).get indexFinal.options.autoRefresh() shouldBe true indexFinal.options.refreshInterval() shouldBe Some("5 Minutes") + indexFinal.options.isExternalSchedulerEnabled() shouldBe true indexFinal.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) // Verify scheduler index is updated diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala index c04209f06..c0d253fd3 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/metadatacache/FlintOpenSearchMetadataCacheWriterITSuite.scala @@ -17,7 +17,9 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintOpenSearchIndexMetadataService} import org.opensearch.flint.spark.{FlintSparkIndexOptions, FlintSparkSuite} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} import org.scalatest.Entry import org.scalatest.matchers.should.Matchers @@ -78,9 +80,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat | { | "_meta": { | "version": "${current()}", - | "name": "${testFlintIndex}", + | "name": "$testFlintIndex", | "kind": "test_kind", - | "source": "test_source_table", + | "source": "$testTable", | "indexedColumns": [ | { | "test_field": "spark_type" @@ -90,12 +92,12 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat | "refresh_interval": "10 Minutes" | }, | "properties": { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"], - | "lastRefreshTime": ${testLastRefreshCompleteTime} + | "sourceTables": ["$testTable"], + | "lastRefreshTime": $testLastRefreshCompleteTime | }, - | "latestId": "${testLatestId}" + | "latestId": "$testLatestId" | }, | "properties": { | "test_field": { @@ -107,7 +109,7 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val builder = new FlintMetadata.Builder builder.name(testFlintIndex) builder.kind("test_kind") - builder.source("test_source_table") + builder.source(testTable) builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava) builder.options( Map("auto_refresh" -> "true", "refresh_interval" -> "10 Minutes") @@ -129,12 +131,71 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) + } + + Seq(SKIPPING_INDEX_TYPE, COVERING_INDEX_TYPE).foreach { case kind => + test(s"write metadata cache to $kind index mappings with source tables") { + val content = + s""" { + | "_meta": { + | "kind": "$kind", + | "source": "$testTable" + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties + properties + .get("sourceTables") + .asInstanceOf[List[String]] + .toArray should contain theSameElementsAs Array(testTable) + } + } + + test(s"write metadata cache to materialized view index mappings with source tables") { + val testTable2 = "spark_catalog.default.metadatacache_test2" + val content = + s""" { + | "_meta": { + | "kind": "$MV_INDEX_TYPE", + | "properties": { + | "sourceTables": [ + | "$testTable", "$testTable2" + | ] + | } + | }, + | "properties": { + | "age": { + | "type": "integer" + | } + | } + | } + |""".stripMargin + val metadata = FlintOpenSearchIndexMetadataService + .deserialize(content) + .copy(latestLogEntry = Some(flintMetadataLogEntry)) + flintClient.createIndex(testFlintIndex, metadata) + flintMetadataCacheWriter.updateMetadataCache(testFlintIndex, metadata) + + val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties .get("sourceTables") .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) + .toArray should contain theSameElementsAs Array(testTable, testTable2) } test("write metadata cache to index mappings with refresh interval") { @@ -162,13 +223,11 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 4 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("refreshInterval", 600), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } test("exclude refresh interval in metadata cache when auto refresh is false") { @@ -195,12 +254,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } test("exclude last refresh time in metadata cache when index has not been refreshed") { @@ -212,11 +269,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat val properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 2 - properties should contain(Entry("metadataCacheVersion", "1.0")) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) + properties should contain( + Entry("metadataCacheVersion", FlintMetadataCache.metadataCacheVersion)) } test("write metadata cache to index mappings and preserve other index metadata") { @@ -246,12 +300,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 var properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) val newContent = """ { @@ -278,12 +330,10 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintIndexMetadataService.getIndexMetadata(testFlintIndex).schema should have size 1 properties = flintIndexMetadataService.getIndexMetadata(testFlintIndex).properties properties should have size 3 - properties should contain allOf (Entry("metadataCacheVersion", "1.0"), + properties should contain allOf (Entry( + "metadataCacheVersion", + FlintMetadataCache.metadataCacheVersion), Entry("lastRefreshTime", testLastRefreshCompleteTime)) - properties - .get("sourceTables") - .asInstanceOf[List[String]] - .toArray should contain theSameElementsAs Array(FlintMetadataCache.mockTableName) } Seq( @@ -296,9 +346,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat "checkpoint_location" -> "s3a://test/"), s""" | { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "sourceTables": ["$testTable"] | } |""".stripMargin), ( @@ -306,8 +356,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat Map.empty[String, String], s""" | { - | "metadataCacheVersion": "1.0", - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] | } |""".stripMargin), ( @@ -315,8 +365,8 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat Map("incremental_refresh" -> "true", "checkpoint_location" -> "s3a://test/"), s""" | { - | "metadataCacheVersion": "1.0", - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", + | "sourceTables": ["$testTable"] | } |""".stripMargin)).foreach { case (refreshMode, optionsMap, expectedJson) => test(s"write metadata cache for $refreshMode") { @@ -389,9 +439,9 @@ class FlintOpenSearchMetadataCacheWriterITSuite extends FlintSparkSuite with Mat flintMetadataCacheWriter.serialize(index.get.metadata())) \ "_meta" \ "properties")) propertiesJson should matchJson(s""" | { - | "metadataCacheVersion": "1.0", + | "metadataCacheVersion": "${FlintMetadataCache.metadataCacheVersion}", | "refreshInterval": 600, - | "sourceTables": ["${FlintMetadataCache.mockTableName}"] + | "sourceTables": ["$testTable"] | } |""".stripMargin) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala new file mode 100644 index 000000000..e714a5f7e --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFlattenITSuite.scala @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFlattenITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("flatten for structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | fields coor + | | flatten coor + | """.stripMargin) + + assert(frame.columns.sameElements(Array("alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(35, 51.5074, -0.1278), Row(null, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter) + val flattenCoor = flattenPlanFor("coor", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + private def flattenPlanFor(flattenedColumn: String, parentPlan: LogicalPlan): LogicalPlan = { + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute(flattenedColumn)) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), parentPlan) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute(flattenedColumn)), generate) + dropSourceColumn + } + + test("flatten for arrays") { + val frame = sql(s""" + | source = $testTable + | | fields bridges + | | flatten bridges + | """.stripMargin) + + assert(frame.columns.sameElements(Array("length", "name"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, null), + Row(11L, "Bridge of Sighs"), + Row(48L, "Rialto Bridge"), + Row(160L, "Pont Alexandre III"), + Row(232L, "Pont Neuf"), + Row(801L, "Tower Bridge"), + Row(928L, "London Bridge"), + Row(343L, "Legion Bridge"), + Row(516L, "Charles Bridge"), + Row(333L, "Liberty Bridge"), + Row(375L, "Chain Bridge")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCoor) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenBridges) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten for structs and arrays") { + val frame = sql(s""" + | source = $testTable | flatten bridges | flatten coor + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("_time", "city", "country", "length", "name", "alt", "lat", "long"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("1990-09-13T12:00:00", "Warsaw", "Poland", null, null, null, null, null), + Row( + "2024-09-13T12:00:00", + "Venice", + "Italy", + 11L, + "Bridge of Sighs", + 2, + 45.4408, + 12.3155), + Row("2024-09-13T12:00:00", "Venice", "Italy", 48L, "Rialto Bridge", 2, 45.4408, 12.3155), + Row( + "2024-09-13T12:00:00", + "Paris", + "France", + 160L, + "Pont Alexandre III", + 35, + 48.8566, + 2.3522), + Row("2024-09-13T12:00:00", "Paris", "France", 232L, "Pont Neuf", 35, 48.8566, 2.3522), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 801L, + "Tower Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "London", + "England", + 928L, + "London Bridge", + 35, + 51.5074, + -0.1278), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 343L, + "Legion Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Prague", + "Czech Republic", + 516L, + "Charles Bridge", + 200, + 50.0755, + 14.4378), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 333L, + "Liberty Bridge", + 96, + 47.4979, + 19.0402), + Row( + "2024-09-13T12:00:00", + "Budapest", + "Hungary", + 375L, + "Chain Bridge", + 96, + 47.4979, + 19.0402)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val flattenBridges = flattenPlanFor("bridges", table) + val flattenCoor = flattenPlanFor("coor", flattenBridges) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenCoor) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val frame = sql(s""" + | source = $testTable + | | fields country, bridges + | | flatten bridges + | | fields country, length + | | stats avg(length) as avg by country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(null, "Poland"), + Row(196d, "France"), + Row(429.5, "Czech Republic"), + Row(864.5, "England"), + Row(29.5, "Italy"), + Row(354.0, "Hungary")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val projectCountryBridges = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table) + val flattenBridges = flattenPlanFor("bridges", projectCountryBridges) + val projectCountryLength = + Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")), flattenBridges) + val average = Alias( + UnresolvedFunction( + seq("AVG"), + seq(UnresolvedAttribute("length")), + isDistinct = false, + None, + ignoreNulls = false), + "avg")() + val country = Alias(UnresolvedAttribute("country"), "country")() + val grouping = Alias(UnresolvedAttribute("country"), "country")() + val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct table") { + val frame = sql(s""" + | source = $structTable + | | flatten struct_col + | | flatten field1 + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten struct nested table") { + val frame = sql(s""" + | source = $structNestedTable + | | flatten struct_col + | | flatten field1 + | | flatten struct_col2 + | | flatten field1 + | """.stripMargin) + + assert( + frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(30, 123, "value1", 23, "valueA"), + Row(40, 123, "value5", 33, "valueB"), + Row(30, 823, "value4", 83, "valueC"), + Row(40, 456, "value2", 46, "valueD"), + Row(50, 789, "value3", 89, "valueE")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) + val flattenStructCol = flattenPlanFor("struct_col", table) + val flattenField1 = flattenPlanFor("field1", flattenStructCol) + val flattenStructCol2 = flattenPlanFor("struct_col2", flattenField1) + val flattenField1Again = flattenPlanFor("field1", flattenStructCol2) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenField1Again) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("flatten multi value nullable") { + val frame = sql(s""" + | source = $multiValueTable + | | flatten multi_value + | """.stripMargin) + + assert(frame.columns.sameElements(Array("int_col", "name", "value"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "1_one", 1), + Row(1, null, 11), + Row(1, "1_three", null), + Row(2, "2_Monday", 2), + Row(2, null, null), + Row(3, "3_third", 3), + Row(3, "3_4th", 4), + Row(4, null, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val flattenMultiValue = flattenPlanFor("multi_value", table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), flattenMultiValue) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..bc4463537 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test trendline sma command without fields command and without alias") { + val frame = sql(s""" + | source = $testTable | sort - age | trendline sma(2, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 50.0), + Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command with fields command") { + val frame = sql(s""" + | source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, null), + Row("Hello", 30, null), + Row("John", 25, 41.666666666666664), + Row("Jane", 20, 25)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageSmaField = UnresolvedAttribute("age_sma") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline sma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, null, null), + Row("John", 25, 22.5, null), + Row("Hello", 30, 27.5, 25.0), + Row("Jake", 70, 50.0, 41.666666666666664)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "three_points_sma")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 45.0), + Row("Hello", 60, 55.0), + Row("Jake", 140, 100.0)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val doubledAgeField = UnresolvedAttribute("doubled_age") + val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma") + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false), + "doubled_age")()), + table) + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val doubleAgeSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val caseWhen = + CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")()) + val expectedPlan = Project( + Seq(nameField, doubledAgeField, doubledAgeSmaField), + Project(trendlineProjectList, sort)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline sma command chaining") { + val frame = sql(s""" + | source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "age_1", + "age_2", + "age_1_trendline", + "age_2_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index bf6989b7c..38bd1f9d2 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,9 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +FLATTEN: 'FLATTEN'; +TRENDLINE: 'TRENDLINE'; +EXPAND: 'EXPAND'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -89,6 +92,10 @@ FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; NULLS: 'NULLS'; +//TRENDLINE KEYWORDS +SMA: 'SMA'; +WMA: 'WMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; @@ -396,6 +403,9 @@ ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; CIDRMATCH: 'CIDRMATCH'; +// Geo Loction +GEOIP: 'GEOIP'; + // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; NULLIF: 'NULLIF'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index aaf807a7b..a55d4fe14 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -53,6 +53,9 @@ commands | renameCommand | fillnullCommand | fieldsummaryCommand + | flattenCommand + | expandCommand + | trendlineCommand ; commandName @@ -80,8 +83,11 @@ commandName | PATTERNS | LOOKUP | RENAME + | EXPAND | FILLNULL | FIELDSUMMARY + | FLATTEN + | TRENDLINE ; searchCommand @@ -89,7 +95,7 @@ searchCommand | (SEARCH)? fromClause logicalExpression # searchFromFilter | (SEARCH)? logicalExpression fromClause # searchFilterFrom ; - + fieldsummaryCommand : FIELDSUMMARY (fieldsummaryParameter)* ; @@ -246,6 +252,26 @@ fillnullCommand : expression ; +flattenCommand + : FLATTEN fieldExpression + ; + +expandCommand + : EXPAND fieldExpression + ; + +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + | WMA + ; kmeansCommand : KMEANS (kmeansParameter)* @@ -422,6 +448,7 @@ valueExpression | positionFunction # positionFunctionCall | caseFunction # caseExpr | timestampFunction # timestampFunctionCall + | geoipFunction # geoFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr ; @@ -516,6 +543,11 @@ dataTypeFunctionCall : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS ; +// geoip function +geoipFunction + : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS + ; + // boolean functions booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS @@ -549,6 +581,7 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName + | geoipFunctionName ; functionArgs @@ -856,6 +889,10 @@ collectionFunctionName : ARRAY ; +geoipFunctionName + : GEOIP + ; + positionFunctionName : POSITION ; @@ -1125,4 +1162,5 @@ keywordsCanBeId | ANTI | BETWEEN | CIDRMATCH + | trendlineType ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 03c40fcd2..525a0954c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -111,6 +111,10 @@ public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } @@ -326,4 +330,8 @@ public T visitWindow(Window node, C context) { public T visitCidr(Cidr node, C context) { return visitChildren(node, context); } + + public T visitFlatten(Flatten flatten, C context) { + return visitChildren(flatten, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java new file mode 100644 index 000000000..e31fbb6e3 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -0,0 +1,34 @@ +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +@RequiredArgsConstructor +public class Flatten extends UnresolvedPlan { + + private UnresolvedPlan child; + + @Getter + private final Field fieldToBeFlattened; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFlatten(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..9fa1ae81d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, Trendline.TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + } + + public enum TrendlineType { + SMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java index 571905f8a..69a89b83a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java @@ -9,18 +9,24 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.types.DataTypes; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -32,6 +38,7 @@ import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; @@ -54,7 +61,9 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTransformer; import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 5d2fe986b..669459fba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,63 +6,47 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; -import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; -import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GeneratorOuter; import org.apache.spark.sql.catalyst.expressions.In$; import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.MakeInterval$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Generate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.flint.spark.FlattenGenerator; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Between; -import org.opensearch.sql.ast.expression.BinaryExpression; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -74,6 +58,7 @@ import org.opensearch.sql.ast.tree.FieldSummary; import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -87,19 +72,16 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; -import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.expression.function.SerializableUdf; -import org.opensearch.sql.ppl.utils.AggregatorTransformer; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; -import scala.Tuple2; import scala.collection.IterableLike; import scala.collection.Seq; @@ -107,16 +89,11 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Stack; -import java.util.function.BiFunction; import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.List.of; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; -import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; @@ -128,8 +105,6 @@ import static org.opensearch.sql.ppl.utils.LookupTransformer.buildOutputProjectList; import static org.opensearch.sql.ppl.utils.LookupTransformer.buildProjectListFromFields; import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; import static scala.collection.JavaConverters.seqAsJavaList; /** @@ -251,6 +226,30 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + + node.getSortByField() + .ifPresent(sortField -> { + Expression sortFieldExpression = visitExpression(sortField, context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(sortField))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortOrder, true, p)); + }); + + List trendlineProjectExpressions = new ArrayList<>(); + + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + } + + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); @@ -453,6 +452,20 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + @Override + public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { + flatten.getChild().get(0).accept(this, context); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + Expression field = visitExpression(flatten.getFieldToBeFlattened(), context); + context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + FlattenGenerator flattenGenerator = new FlattenGenerator(field); + context.apply(p -> new Generate(new GeneratorOuter(flattenGenerator), seq(), true, (Option) None$.MODULE$, seq(), p)); + return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); + } + private void visitFieldList(List fieldList, CatalystPlanContext context) { fieldList.forEach(field -> visitExpression(field, context)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index c69e9541e..4e6b1f131 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -386,6 +386,30 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(this::toTrendlineComputation) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { + int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); + String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { @@ -562,6 +586,12 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo } } + @Override + public UnresolvedPlan visitFlattenCommand(OpenSearchPPLParser.FlattenCommandContext ctx) { + Field unresolvedExpression = (Field) internalVisitExpression(ctx.fieldExpression()); + return new Flatten(unresolvedExpression); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index b6dfd0447..5e0f0775d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -43,6 +43,8 @@ import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 62eef90ed..e4defad52 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -14,16 +14,14 @@ import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.NullType$; import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.opensearch.sql.ast.expression.SpanUnit.DAY; import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; @@ -67,6 +65,8 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; + case UNDEFINED: + return NullType$.MODULE$; default: return StringType$.MODULE$; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java index 83603b031..803daea8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -38,7 +38,7 @@ static SortOrder getSortDirection(Sort node, NamedExpression expression) { .findAny(); return field.map(value -> sortOrder((Expression) expression, - (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + isSortedAscending(value))) .orElse(null); } @@ -51,4 +51,8 @@ static SortOrder sortOrder(Expression expression, boolean ascending) { seq(new ArrayList()) ); } + + static boolean isSortedAscending(Field field) { + return (Boolean) field.getFieldArgs().get(0).getValue().getValue(); + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java new file mode 100644 index 000000000..67603ccc7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.*; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.ppl.CatalystExpressionVisitor; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; +import scala.Tuple2; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface TrendlineCatalystUtils { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + return computations.stream() + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .collect(Collectors.toList()); + } + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //window lower boundary + expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + + //window definition + WindowSpecDefinition windowDefinition = new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList())); + } else { + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + } + } + + private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpressionVisitor expressionVisitor, WindowExpression trendlineWindow, Trendline.TrendlineComputation node, CatalystPlanContext context) { + //required number of data points + expressionVisitor.visitLiteral(new Literal(node.getNumberOfDataPoints(), DataType.INTEGER), context); + Expression requiredNumberOfDataPoints = context.popNamedParseExpressions().get(); + + //count data points function + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.COUNT.name(), new Literal(1, DataType.INTEGER)), context); + Expression countDataPointsFunction = context.popNamedParseExpressions().get(); + //count data points window + WindowExpression countDataPointsWindow = new WindowExpression( + countDataPointsFunction, + trendlineWindow.windowSpec()); + + expressionVisitor.visitLiteral(new Literal(null, DataType.NULL), context); + Expression nullLiteral = context.popNamedParseExpressions().get(); + Tuple2 nullWhenNumberOfDataPointsLessThenRequired = new Tuple2<>( + new LessThan(countDataPointsWindow, requiredNumberOfDataPoints), + nullLiteral + ); + return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala new file mode 100644 index 000000000..23b545826 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlattenGenerator.scala @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, CreateArray, Expression, GenericInternalRow, Inline, UnaryExpression} +import org.apache.spark.sql.types.{ArrayType, StructType} + +class FlattenGenerator(override val child: Expression) + extends Inline(child) + with CollectionGenerator { + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case st: StructType => TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + override def elementSchema: StructType = child.dataType match { + case st: StructType => st + case _ => super.elementSchema + } + + override protected def withNewChildInternal(newChild: Expression): FlattenGenerator = { + newChild.dataType match { + case ArrayType(st: StructType, _) => new FlattenGenerator(newChild) + case st: StructType => withNewChildInternal(CreateArray(Seq(newChild), false)) + case _ => + throw new IllegalArgumentException(s"Unexpected input type ${newChild.dataType}") + } + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..58a6c04b3 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFlattenCommandTranslatorTestSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GeneratorOuter, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanFlattenCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test flatten only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | flatten field_with_array"), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("field_with_array")) + val outerGenerator = GeneratorOuter(flattenGenerator) + val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and stats") { + val context = new CatalystPlanContext + val query = + "source = relation | fields state, company, employee | flatten employee | fields state, company, salary | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val projectStateCompanyEmployee = + Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("employee")), + table) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + projectStateCompanyEmployee) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val projectStateCompanySalary = Project( + Seq( + UnresolvedAttribute("state"), + UnresolvedAttribute("company"), + UnresolvedAttribute("salary")), + dropSourceColumn) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + projectStateCompanySalary) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and eval") { + val context = new CatalystPlanContext + val query = "source = relation | flatten employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("relation")) + val generate = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test flatten and parse and flatten") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | flatten employee | parse description '(?.+@.+)' | flatten roles"), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("employee"))), + seq(), + true, + None, + seq(), + table) + val dropSourceColumnEmployee = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generateEmployee) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + dropSourceColumnEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..d22750ee0 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanTrendlineCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test trendline") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sma(3, age)"), context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, table)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age sma(3, age)"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val ageField = UnresolvedAttribute("age") + val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table) + val countWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val smaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")()) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline with multiple trendline sma commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort + age sma(2, age) as two_points_sma sma(3, age) | fields name, age, two_points_sma, age_trendline"), + context) + + val table = UnresolvedRelation(Seq("relation")) + val nameField = UnresolvedAttribute("name") + val ageField = UnresolvedAttribute("age") + val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma") + val ageTrendlineField = UnresolvedAttribute("age_trendline") + val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table) + val twoPointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val twoPointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow))) + val threePointsCountWindow = new WindowExpression( + UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val threePointsSmaWindow = WindowExpression( + UnresolvedFunction("AVG", Seq(ageField), isDistinct = false), + WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow))) + val twoPointsCaseWhen = CaseWhen( + Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))), + twoPointsSmaWindow) + val threePointsCaseWhen = CaseWhen( + Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))), + threePointsSmaWindow) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsCaseWhen, "two_points_sma")(), + Alias(threePointsCaseWhen, "age_trendline")()) + val expectedPlan = Project( + Seq(nameField, ageField, ageTwoPointsSmaField, ageTrendlineField), + Project(trendlineProjectList, sort)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } +}