From 3c8a49041a69334f712c53b223a794d48acab084 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 25 Jul 2024 17:01:04 -0700 Subject: [PATCH 1/8] [Bugfix] Insights on query execution error (#475) * BugFix: Add error logs Signed-off-by: Louis Chu * Add IT Signed-off-by: Louis Chu * Fix IT Signed-off-by: Louis Chu * Log stacktrace Signed-off-by: Louis Chu * Use full msg instead of prefix Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../apache/spark/sql/FlintREPLITSuite.scala | 63 +++++++++++++++++++ .../apache/spark/sql/FlintJobExecutor.scala | 22 ++++--- 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala index 1c0b27674..921db792a 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -422,6 +422,69 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } } + test("create table with dummy location should fail with excepted error message") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL()) + + val dummyLocation = "s3://path/to/dummy/location" + val testQueryId = "110" + val createTableStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | LOCATION '$dummyLocation' + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + val createTableStatementId = + submitQuery(s"${makeJsonCompliant(createTableStatement)}", testQueryId) + + val createTableStatementValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + failureValidation(result) + true + } + pollForResultAndAssert(createTableStatementValidation, testQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.error match { + case Some(error) + if error == """{"Message":"Fail to run query. Cause: No FileSystem for scheme \"s3\""}""" => + // Assertion passed + case _ => + fail(s"Statement error is: ${statement.error}") + } + statement.state == "failed" + }, + createTableStatementId), + s"Fail to verify for $createTableStatementId.") + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + } + } + /** * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or * removed when inside a JSON string. The same goes for tab characters, which should be diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index f38a27ef4..00f023694 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -436,18 +436,22 @@ trait FlintJobExecutor { private def handleQueryException( e: Exception, - message: String, + messagePrefix: String, errorSource: Option[String] = None, statusCode: Option[Int] = None): String = { - - val errorDetails = Map("Message" -> s"$message: ${e.getMessage}") ++ + val errorMessage = s"$messagePrefix: ${e.getMessage}" + val errorDetails = Map("Message" -> errorMessage) ++ errorSource.map("ErrorSource" -> _) ++ statusCode.map(code => "StatusCode" -> code.toString) val errorJson = mapper.writeValueAsString(errorDetails) - statusCode.foreach { code => - CustomLogging.logError(new OperationMessage("", code), e) + // CustomLogging will call log4j logger.error() underneath + statusCode match { + case Some(code) => + CustomLogging.logError(new OperationMessage(errorMessage, code), e) + case None => + CustomLogging.logError(errorMessage, e) } errorJson @@ -491,16 +495,14 @@ trait FlintJobExecutor { case r: SparkException => handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix) case r: Exception => - val rootCauseClassName = ex.getClass.getName - val errMsg = ex.getMessage - logDebug(s"Root cause class name: $rootCauseClassName") - logDebug(s"Root cause error message: $errMsg") + val rootCauseClassName = r.getClass.getName + val errMsg = r.getMessage if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" && errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) { val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage) handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix) } else { - handleQueryException(ex, ExceptionMessages.QueryRunErrorPrefix) + handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix) } } } From 98bd79a6ea1d9587d67ae761d417b0db4242a815 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 25 Jul 2024 18:08:25 -0700 Subject: [PATCH 2/8] Disable unsupported PPL function expressions (#478) Signed-off-by: Tomoyuki Morita --- ppl-spark-integration/README.md | 2 +- .../src/main/antlr4/OpenSearchPPLParser.g4 | 9 +------- .../sql/ppl/parser/AstExpressionBuilder.java | 23 ------------------- 3 files changed, 2 insertions(+), 32 deletions(-) diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index ecd043acd..61ef5b670 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -262,7 +262,7 @@ The next samples of PPL queries are currently supported: - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) - - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, MAX, MIN and SUM aggregation functions) + - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, DISTINCT_COUNT, MAX, MIN and SUM aggregation functions) - `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) - `correlation` - [See details](../docs/PPL-Correlation-command.md) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index f6cd0d4ee..086413ca4 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -33,7 +33,6 @@ commands : whereCommand | correlateCommand | fieldsCommand - | renameCommand | statsCommand | sortCommand | headCommand @@ -224,8 +223,6 @@ statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall - | percentileAggFunction # percentileAggFunctionCall - | takeAggFunction # takeAggFunctionCall ; statsFunctionName @@ -257,8 +254,6 @@ logicalExpression | left = logicalExpression OR right = logicalExpression # logicalOr | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression XOR right = logicalExpression # logicalXor - | booleanExpression # booleanExpr - | relevanceExpression # relevanceExpr ; comparisonExpression @@ -266,9 +261,7 @@ comparisonExpression ; valueExpression - : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic - | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic - | primaryExpression # valueExpressionDefault + : primaryExpression # valueExpressionDefault | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 047d9af44..6f1129b04 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -112,15 +112,6 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); } - /** - * Value Expression. - */ - @Override - public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { - return new Function( - ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); - } - @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around @@ -172,20 +163,6 @@ public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.Perce Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value)))); } - @Override - public UnresolvedExpression visitTakeAggFunctionCall( - OpenSearchPPLParser.TakeAggFunctionCallContext ctx) { - ImmutableList.Builder builder = ImmutableList.builder(); - builder.add( - new UnresolvedArgument( - "size", - ctx.takeAggFunction().size != null - ? visit(ctx.takeAggFunction().size) - : new Literal(DEFAULT_TAKE_FUNCTION_SIZE_VALUE, DataType.INTEGER))); - return new AggregateFunction( - "take", visit(ctx.takeAggFunction().fieldExpression()), builder.build()); - } - /** * Eval function. */ From 3e4df0a577c9e365ce82a02b195a8ae797a08ae1 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 26 Jul 2024 14:32:27 -0700 Subject: [PATCH 3/8] Add error output column to show Flint index statement (#436) * Add extended in show flint index and refactor AST builder Signed-off-by: Chen Dai * Update user manual Signed-off-by: Chen Dai * Add IT Signed-off-by: Chen Dai * Fix scalafmt issue for global execution context Signed-off-by: Chen Dai * Split ANTLR grammar rule Signed-off-by: Chen Dai * Refactor comments and test code Signed-off-by: Chen Dai * Remove thread pool in IT Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- docs/index.md | 13 +- .../main/antlr4/FlintSparkSqlExtensions.g4 | 5 +- .../src/main/antlr4/SparkSqlBase.g4 | 1 + .../sql/index/FlintSparkIndexAstBuilder.scala | 166 +++++++++++++----- .../spark/FlintSparkIndexSqlITSuite.scala | 46 ++++- 5 files changed, 181 insertions(+), 50 deletions(-) diff --git a/docs/index.md b/docs/index.md index af6e54a3e..249e7a770 100644 --- a/docs/index.md +++ b/docs/index.md @@ -328,9 +328,11 @@ VACUUM MATERIALIZED VIEW alb_logs_metrics - index_name: user defined name for covering index and materialized view - auto_refresh: auto refresh option of the index (true / false) - status: status of the index +- **Extended Usage**: Display additional information, including the following output columns: + - error: error message if the index is in failed status ```sql -SHOW FLINT [INDEX|INDEXES] IN catalog[.database] +SHOW FLINT [INDEX|INDEXES] [EXTENDED] IN catalog[.database] ``` Example: @@ -344,6 +346,15 @@ fetched rows / total rows = 3/3 | flint_spark_catalog_default_http_logs_skipping_index | skipping | default | http_logs | NULL | true | refreshing | | flint_spark_catalog_default_http_logs_status_clientip_index | covering | default | http_logs | status_clientip | false | active | +-------------------------------------------------------------+----------+----------+-----------+-----------------+--------------+------------+ + +sql> SHOW FLINT INDEXES EXTENDED IN spark_catalog.default; +fetched rows / total rows = 2/2 ++-------------------------------------------------------------+----------+----------+-----------+-----------------+--------------+------------+-------------------------------+ +| flint_index_name | kind | database | table | index_name | auto_refresh | status | error | +|-------------------------------------------------------------+----------+----------+-----------+-----------------+--------------+------------+-------------------------------| +| flint_spark_catalog_default_http_count_view | mv | default | NULL | http_count_view | false | active | NULL | +| flint_spark_catalog_default_http_logs_skipping_index | skipping | default | http_logs | NULL | true | failed | failure in bulk execution:... | ++-------------------------------------------------------------+----------+----------+-----------+-----------------+--------------+------------+-------------------------------+ ``` - **Analyze Skipping Index**: Provides recommendation for creating skipping index. It outputs the following columns: diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index 2e8d634da..46e814e9f 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -156,7 +156,10 @@ indexManagementStatement ; showFlintIndexStatement - : SHOW FLINT (INDEX | INDEXES) IN catalogDb=multipartIdentifier + : SHOW FLINT (INDEX | INDEXES) + IN catalogDb=multipartIdentifier #showFlintIndex + | SHOW FLINT (INDEX | INDEXES) EXTENDED + IN catalogDb=multipartIdentifier #showFlintIndexExtended ; indexJobManagementStatement diff --git a/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 b/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 index 283981e47..c53c61adf 100644 --- a/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 +++ b/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 @@ -163,6 +163,7 @@ DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; DROP: 'DROP'; EXISTS: 'EXISTS'; +EXTENDED: 'EXTENDED'; FALSE: 'FALSE'; FLINT: 'FLINT'; IF: 'IF'; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/index/FlintSparkIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/index/FlintSparkIndexAstBuilder.scala index 62c98b023..e6cccbc4a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/index/FlintSparkIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/index/FlintSparkIndexAstBuilder.scala @@ -7,12 +7,13 @@ package org.opensearch.flint.spark.sql.index import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import org.opensearch.flint.spark.FlintSparkIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.IndexBelongsTo -import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.ShowFlintIndexStatementContext +import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{MultipartIdentifierContext, ShowFlintIndexContext, ShowFlintIndexExtendedContext} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,52 +26,123 @@ import org.apache.spark.sql.types.{BooleanType, StringType} trait FlintSparkIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[AnyRef] { self: SparkSqlAstBuilder => - override def visitShowFlintIndexStatement(ctx: ShowFlintIndexStatementContext): Command = { - val outputSchema = Seq( - AttributeReference("flint_index_name", StringType, nullable = false)(), - AttributeReference("kind", StringType, nullable = false)(), - AttributeReference("database", StringType, nullable = false)(), - AttributeReference("table", StringType, nullable = true)(), - AttributeReference("index_name", StringType, nullable = true)(), - AttributeReference("auto_refresh", BooleanType, nullable = false)(), - AttributeReference("status", StringType, nullable = false)()) - - FlintSparkSqlCommand(outputSchema) { flint => - val catalogDbName = - ctx.catalogDb.parts - .map(part => part.getText) - .mkString("_") - val indexNamePattern = s"flint_${catalogDbName}_*" - flint - .describeIndexes(indexNamePattern) - .filter(index => index belongsTo ctx.catalogDb) - .map { index => - val (databaseName, tableName, indexName) = index match { - case skipping: FlintSparkSkippingIndex => - val parts = skipping.tableName.split('.') - (parts(1), parts.drop(2).mkString("."), null) - case covering: FlintSparkCoveringIndex => - val parts = covering.tableName.split('.') - (parts(1), parts.drop(2).mkString("."), covering.indexName) - case mv: FlintSparkMaterializedView => - val parts = mv.mvName.split('.') - (parts(1), null, parts.drop(2).mkString(".")) - } - - val status = index.latestLogEntry match { - case Some(entry) => entry.state.toString - case None => "unavailable" - } - - Row( - index.name, - index.kind, - databaseName, - tableName, - indexName, - index.options.autoRefresh(), - status) - } + /** + * Represents the basic output schema for the FlintSparkSqlCommand. This schema includes + * essential information about each index. + */ + private val baseOutputSchema = Seq( + AttributeReference("flint_index_name", StringType, nullable = false)(), + AttributeReference("kind", StringType, nullable = false)(), + AttributeReference("database", StringType, nullable = false)(), + AttributeReference("table", StringType, nullable = true)(), + AttributeReference("index_name", StringType, nullable = true)(), + AttributeReference("auto_refresh", BooleanType, nullable = false)(), + AttributeReference("status", StringType, nullable = false)()) + + /** + * Extends the base output schema with additional information. This schema is used when the + * EXTENDED keyword is present. + */ + private val extendedOutputSchema = Seq( + AttributeReference("error", StringType, nullable = true)()) + + override def visitShowFlintIndex(ctx: ShowFlintIndexContext): Command = { + new ShowFlintIndexCommandBuilder() + .withSchema(baseOutputSchema) + .forCatalog(ctx.catalogDb) + .constructRows(baseRowData) + .build() + } + + override def visitShowFlintIndexExtended(ctx: ShowFlintIndexExtendedContext): Command = { + new ShowFlintIndexCommandBuilder() + .withSchema(baseOutputSchema ++ extendedOutputSchema) + .forCatalog(ctx.catalogDb) + .constructRows(index => baseRowData(index) ++ extendedRowData(index)) + .build() + } + + /** + * Builder class for constructing FlintSparkSqlCommand objects. + */ + private class ShowFlintIndexCommandBuilder { + private var schema: Seq[AttributeReference] = _ + private var catalogDb: MultipartIdentifierContext = _ + private var rowDataBuilder: FlintSparkIndex => Seq[Any] = _ + + /** Specify the output schema for the command. */ + def withSchema(schema: Seq[AttributeReference]): ShowFlintIndexCommandBuilder = { + this.schema = schema + this + } + + /** Specify the catalog database context for the command. */ + def forCatalog(catalogDb: MultipartIdentifierContext): ShowFlintIndexCommandBuilder = { + this.catalogDb = catalogDb + this + } + + /** Configures a function to construct row data for each index. */ + def constructRows( + rowDataBuilder: FlintSparkIndex => Seq[Any]): ShowFlintIndexCommandBuilder = { + this.rowDataBuilder = rowDataBuilder + this + } + + /** Builds the command using the configured parameters. */ + def build(): FlintSparkSqlCommand = { + require(schema != null, "Schema must be set before building the command") + require(catalogDb != null, "Catalog database must be set before building the command") + require(rowDataBuilder != null, "Row data builder must be set before building the command") + + FlintSparkSqlCommand(schema) { flint => + val catalogDbName = + catalogDb.parts + .map(part => part.getText) + .mkString("_") + val indexNamePattern = s"flint_${catalogDbName}_*" + + flint + .describeIndexes(indexNamePattern) + .filter(index => index belongsTo catalogDb) + .map { index => Row.fromSeq(rowDataBuilder(index)) } + } + } + } + + private def baseRowData(index: FlintSparkIndex): Seq[Any] = { + val (databaseName, tableName, indexName) = index match { + case skipping: FlintSparkSkippingIndex => + val parts = skipping.tableName.split('.') + (parts(1), parts.drop(2).mkString("."), null) + case covering: FlintSparkCoveringIndex => + val parts = covering.tableName.split('.') + (parts(1), parts.drop(2).mkString("."), covering.indexName) + case mv: FlintSparkMaterializedView => + val parts = mv.mvName.split('.') + (parts(1), null, parts.drop(2).mkString(".")) + } + + val status = index.latestLogEntry match { + case Some(entry) => entry.state.toString + case None => "unavailable" + } + + Seq( + index.name, + index.kind, + databaseName, + tableName, + indexName, + index.options.autoRefresh(), + status) + } + + private def extendedRowData(index: FlintSparkIndex): Seq[Any] = { + val error = index.latestLogEntry match { + case Some(entry) => entry.error + case None => null } + Seq(error) } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala index e312ba6de..a5744271f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSqlITSuite.scala @@ -5,6 +5,9 @@ package org.opensearch.flint.spark +import scala.collection.JavaConverters._ + +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.client.RequestOptions import org.opensearch.client.indices.CreateIndexRequest import org.opensearch.common.xcontent.XContentType @@ -12,10 +15,11 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.AUTO_REFRESH import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.scalatest.matchers.should.Matchers import org.apache.spark.sql.Row -class FlintSparkIndexSqlITSuite extends FlintSparkSuite { +class FlintSparkIndexSqlITSuite extends FlintSparkSuite with Matchers { private val testTableName = "index_test" private val testTableQualifiedName = s"spark_catalog.default.$testTableName" @@ -99,6 +103,46 @@ class FlintSparkIndexSqlITSuite extends FlintSparkSuite { FlintSparkMaterializedView.getFlintIndexName("spark_catalog.other.mv2")) } + test("show flint indexes with extended information") { + // Create and refresh with all existing data + flint + .skippingIndex() + .onTable(testTableQualifiedName) + .addValueSet("name") + .options(FlintSparkIndexOptions(Map(AUTO_REFRESH.toString -> "true"))) + .create() + flint.refreshIndex(testSkippingFlintIndex) + val activeJob = spark.streams.active.find(_.name == testSkippingFlintIndex) + awaitStreamingComplete(activeJob.get.id.toString) + + // Assert output contains empty error message + def outputError: String = { + val df = sql("SHOW FLINT INDEX EXTENDED IN spark_catalog") + df.columns should contain("error") + df.collect().head.getAs[String]("error") + } + outputError shouldBe empty + + // Trigger next micro batch after 5 seconds with index readonly + new Thread(() => { + Thread.sleep(5000) + openSearchClient + .indices() + .putSettings( + new UpdateSettingsRequest(testSkippingFlintIndex).settings( + Map("index.blocks.write" -> true).asJava), + RequestOptions.DEFAULT) + sql( + s"INSERT INTO $testTableQualifiedName VALUES (TIMESTAMP '2023-10-01 04:00:00', 'F', 25, 'Vancouver')") + }).start() + + // Await to store exception and verify if it's as expected + flint.flintIndexMonitor.awaitMonitor(Some(testSkippingFlintIndex)) + outputError should include("OpenSearchException") + + deleteTestIndex(testSkippingFlintIndex) + } + test("should return empty when show flint index in empty database") { checkAnswer(sql(s"SHOW FLINT INDEX IN spark_catalog.default"), Seq.empty) } From c6ab291f45f0443b7a06a503fcb7aeacf68c33cb Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Mon, 29 Jul 2024 23:54:37 +0800 Subject: [PATCH 4/8] [DOC] Checklist to fix cound not find Docker environment on macOS (#477) Signed-off-by: Lantao Jin --- DEVELOPER_GUIDE.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 619a33e24..58aa5df79 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -12,6 +12,12 @@ can do so by running the following command: ``` sbt integtest/test ``` +If you get integration test failures with error message "Previous attempts to find a Docker environment failed" in macOS, fix the issue by following the checklist: +1. Check you've installed Docker in your dev host. If not, install Docker first. +2. Check if the file /var/run/docker.sock exists. If not, go to `3`. +3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` +4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. + ### AWS Integration Test The integration folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` From 1f81ddf88ac96cd0130c9bd8f32abc4454e90699 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 30 Jul 2024 23:41:51 +0800 Subject: [PATCH 5/8] Translate PPL-builtin functions to Spark-builtin functions (#448) * Add string functions and math functions Signed-off-by: Lantao Jin * add from_unixtime and unix_timestamp test Signed-off-by: Lantao Jin * Add IT Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- .../FlintSparkPPLBuiltinFunctionITSuite.scala | 554 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLParser.g4 | 10 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 17 +- .../sql/ppl/parser/AstExpressionBuilder.java | 9 + .../ppl/utils/BuiltinFunctionTranslator.java | 30 + .../sql/ppl/utils/DataTypeTransformer.java | 19 +- .../spark/ppl/LogicalPlanTestUtils.scala | 1 - ...PlanMathFunctionsTranslatorTestSuite.scala | 163 ++++++ ...anStringFunctionsTranslatorTestSuite.scala | 224 +++++++ ...PlanTimeFunctionsTranslatorTestSuite.scala | 56 ++ 10 files changed, 1077 insertions(+), 6 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala new file mode 100644 index 000000000..127b29295 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -0,0 +1,554 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.DoubleType + +class FlintSparkPPLBuiltinFunctionITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test string functions - concat") { + val frame = sql(s""" + | source = $testTable name=concat('He', 'llo') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction("concat", seq(Literal("He"), Literal("llo")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - concat with field") { + val frame = sql(s""" + | source = $testTable name=concat('Hello', state) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array.empty + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("name"), + UnresolvedFunction( + "concat", + seq(Literal("Hello"), UnresolvedAttribute("state")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - length") { + val frame = sql(s""" + | source = $testTable |where length(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test function name should be insensitive") { + val frame = sql(s""" + | source = $testTable |where leNgTh(name) = 5 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal(5)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - lower") { + val frame = sql(s""" + | source = $testTable |where lower(name) = "hello" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("lower", seq(UnresolvedAttribute("name")), isDistinct = false), + Literal("hello")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - upper") { + val frame = sql(s""" + | source = $testTable |where upper(name) = upper("hello") | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + UnresolvedFunction("upper", seq(Literal("hello")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - substring") { + val frame = sql(s""" + | source = $testTable |where substring(name, 2, 2) = "el" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "substring", + seq(UnresolvedAttribute("name"), Literal(2), Literal(2)), + isDistinct = false), + Literal("el")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - like") { + val frame = sql(s""" + | source = $testTable | where like(name, '_ello%') | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val likeFunction = UnresolvedFunction( + "like", + seq(UnresolvedAttribute("name"), Literal("_ello%")), + isDistinct = false) + + val filterPlan = Filter(likeFunction, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace") { + val frame = sql(s""" + | source = $testTable |where replace(name, 'o', ' ') = "Hell " | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false), + Literal("Hell ")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test string functions - replace and trim") { + val frame = sql(s""" + | source = $testTable |where trim(replace(name, 'o', ' ')) = "Hell" | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "trim", + seq( + UnresolvedFunction( + "replace", + seq(UnresolvedAttribute("name"), Literal("o"), Literal(" ")), + isDistinct = false)), + isDistinct = false), + Literal("Hell")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs") { + val frame = sql(s""" + | source = $testTable |where age = abs(-30) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("abs", seq(Literal(-30)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - abs with field") { + val frame = sql(s""" + | source = $testTable |where abs(age) = 30 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("abs", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(30)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ceil") { + val frame = sql(s""" + | source = $testTable |where age = ceil(29.7) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("ceil", seq(Literal(29.7d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - floor") { + val frame = sql(s""" + | source = $testTable |where age = floor(30.4) | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedAttribute("age"), + UnresolvedFunction("floor", seq(Literal(30.4d, DoubleType)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - ln") { + val frame = sql(s""" + | source = $testTable |where ln(age) > 4 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), + Literal(4)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - mod") { + val frame = sql(s""" + | source = $testTable |where mod(age, 10) = 0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction("mod", seq(UnresolvedAttribute("age"), Literal(10)), isDistinct = false), + Literal(0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test math functions - pow and sqrt") { + val frame = sql(s""" + | source = $testTable |where sqrt(pow(age, 2)) = 30.0 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = EqualTo( + UnresolvedFunction( + "sqrt", + seq( + UnresolvedFunction( + "pow", + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false)), + isDistinct = false), + Literal(30.0)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test time functions - from_unixtime and unix_timestamp") { + val frame = sql(s""" + | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val filterExpr = GreaterThan( + UnresolvedFunction( + "unix_timestamp", + seq(UnresolvedFunction("from_unixtime", seq(Literal(1700000001)), isDistinct = false)), + isDistinct = false), + Literal(1700000000)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 086413ca4..aac3c3f36 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -254,19 +254,25 @@ logicalExpression | left = logicalExpression OR right = logicalExpression # logicalOr | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression XOR right = logicalExpression # logicalXor + | booleanExpression # booleanExpr ; comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr + | valueExpression IN valueList # inExpr ; valueExpression - : primaryExpression # valueExpressionDefault + : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic + | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic + | primaryExpression # valueExpressionDefault + | positionFunction # positionFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr ; primaryExpression - : fieldExpression + : evalFunctionCall + | fieldExpression | literalValue ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6d14db328..04f4320c1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -54,6 +54,7 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ppl.utils.AggregatorTranslator; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; @@ -397,7 +398,21 @@ public Expression visitEval(Eval node, CatalystPlanContext context) { @Override public Expression visitFunction(Function node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Function"); + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6f1129b04..92e9dd458 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -112,6 +112,15 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); } + /** + * Value Expression. + */ + @Override + public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { + return new Function( + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + } + @Override public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { return visit(ctx.valueExpression()); // Discard parenthesis around diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java new file mode 100644 index 000000000..0d57fea20 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +public interface BuiltinFunctionTranslator { + + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { + if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { + // TODO change it when UDF is supported + // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions + throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); + } else { + String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); + return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index 0c7269a07..4345b0897 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -6,10 +6,15 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.types.BooleanType$; import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.ShortType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; import org.opensearch.sql.ast.expression.SpanUnit; @@ -33,8 +38,8 @@ * translate the PPL ast expressions data-types into catalyst data-types */ public interface DataTypeTransformer { - static Seq seq(T element) { - return seq(List.of(element)); + static Seq seq(T... elements) { + return seq(List.of(elements)); } static Seq seq(List list) { return asScalaBufferConverter(list).asScala().seq(); @@ -46,6 +51,16 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return DateType$.MODULE$; case INTEGER: return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case BOOLEAN: + return BooleanType$.MODULE$; + case SHORT: + return ShortType$.MODULE$; case BYTE: return ByteType$.MODULE$; default: diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala index a36b34ef4..0c116a728 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala @@ -52,5 +52,4 @@ trait LogicalPlanTestUtils { // Return the string representation of the transformed plan transformedPlan.toString } - } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..24336b098 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanMathFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test abs") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = abs(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("abs", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ceil") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ceil(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ceil", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test floor") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = floor(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("floor", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ln") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ln(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ln", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test mod") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = mod(10, 4)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("mod", seq(Literal(10), Literal(4)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test pow") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = pow(2, 3)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("pow", seq(Literal(2), Literal(3)), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test sqrt") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sqrt(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("sqrt", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test arithmetic: + - * / %") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | where sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - sqrt(pow(a, 2)) * 1 = sqrt(pow(a, 2)) % 1", + false), + context) + val table = UnresolvedRelation(Seq("t")) + // sqrt(pow(a, 2)) + val sqrtPow = + UnresolvedFunction( + "sqrt", + seq( + UnresolvedFunction( + "pow", + seq(UnresolvedAttribute("a"), Literal(2)), + isDistinct = false)), + isDistinct = false) + // sqrt(pow(a, 2)) / 1 + val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) * 1 + val sqrtPowMultiply = + UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) % 1 + val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false) + // sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 + val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false) + val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false) + val filterExpr = EqualTo(sub, sqrtPowMod) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..36a31862b --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala @@ -0,0 +1,224 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanStringFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test unknown function") { + val context = new CatalystPlanContext + intercept[SyntaxCheckException] { + planTransformer.visit(plan(pplParser, "source=t a = unknown(b)", false), context) + } + } + + test("test concat") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a = CONCAT('hello', 'world')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("concat", seq(Literal("hello"), Literal("world")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test concat with field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = CONCAT('hello', b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "concat", + seq(Literal("hello"), UnresolvedAttribute("b")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test length") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LENGTH(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("length", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test lower") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = LOWER(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("lower", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test upper - case insensitive") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = uPPer(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("upper", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test trim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = trim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("trim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test ltrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ltrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("ltrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test rtrim") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = rtrim(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("rtrim", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test substring") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = substring(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("substring", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test like") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a=like(b, 'Hatti_')", false), context) + + val table = UnresolvedRelation(Seq("t")) + val likeExpr = new Like(UnresolvedAttribute("a"), Literal("Hatti_")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "like", + seq(UnresolvedAttribute("b"), Literal("Hatti_")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test position") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=position('world' IN 'helloworld')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "position", + seq(Literal("world"), Literal("helloworld")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test replace") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t a=replace('hello', 'l', 'x')", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction( + "replace", + seq(Literal("hello"), Literal("l"), Literal("x")), + isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala new file mode 100644 index 000000000..7cfcc33d5 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} + +class PPLLogicalPlanTimeFunctionsTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test from_unixtime") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = from_unixtime(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("from_unixtime", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test unix_timestamp") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t a = unix_timestamp(b)", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("unix_timestamp", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } +} From 24d3b81356a5dbbbfb3a22e43375c8dbc0aaf428 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 1 Aug 2024 02:27:02 +0800 Subject: [PATCH 6/8] Translate Eval Command (#499) * Support Eval Command Signed-off-by: Lantao Jin * add more ITs and documentation Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- .../spark/ppl/FlintSparkPPLEvalITSuite.scala | 528 ++++++++++++++++++ ppl-spark-integration/README.md | 18 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 + .../sql/ast/expression/Argument.java | 7 +- .../opensearch/sql/ast/expression/Field.java | 10 +- .../opensearch/sql/ast/tree/Correlation.java | 11 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 19 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 2 + .../sql/ppl/parser/AstExpressionBuilder.java | 2 +- .../sql/common/utils/StringUtilsTest.java | 5 + ...PLLogicalPlanEvalTranslatorTestSuite.scala | 217 +++++++ 11 files changed, 803 insertions(+), 17 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala new file mode 100644 index 000000000..407c2cb3b --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -0,0 +1,528 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLEvalITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test single eval expression with new field") { + val frame = sql(s""" + | source = $testTable | eval col = 1 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val evalProjectList = Seq(UnresolvedStar(None), Alias(Literal(1), "col")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval expressions with new fields") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = 2 | fields name, age + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(Literal(1), "col1")(), Alias(Literal(2), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eval expressions in fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = 2 | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, 1, 2), + Row("Hello", 30, 1, 2), + Row("John", 25, 1, 2), + Row("Jane", 20, 1, 2)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(Literal(1), "col1")(), Alias(Literal(2), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eval expression without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = "New Field1", col2 = "New Field2" + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "New Field1", "New Field2"), + Row("Hello", 30, "New York", "USA", 2023, 4, "New Field1", "New Field2"), + Row("John", 25, "Ontario", "Canada", 2023, 4, "New Field1", "New Field2"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "New Field1", "New Field2")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val projectList = Seq( + UnresolvedStar(None), + Alias(Literal("New Field1"), "col1")(), + Alias(Literal("New Field2"), "col2")()) + val expectedPlan = Project(seq(UnresolvedStar(None)), Project(projectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test reusing existing fields in eval expressions") { + val frame = sql(s""" + | source = $testTable | eval col1 = state, col2 = country | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA"), + Row("Hello", 30, "New York", "USA"), + Row("John", 25, "Ontario", "Canada"), + Row("Jane", 20, "Quebec", "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("state"), "col1")(), + Alias(UnresolvedAttribute("country"), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test( + "test overriding existing fields: throw exception when specify the new field in fields command") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = 40 | eval name = upper(name) | sort name | fields name, age, state + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'name' is ambiguous")) + } + + test("test overriding existing fields: throw exception when specify the new field in where") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = abs(age) | where age < 50 + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + } + + test( + "test overriding existing fields: throw exception when specify the new field in aggregate expression") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval age = abs(age) | stats avg(age) + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'age' is ambiguous")) + } + + test( + "test overriding existing fields: throw exception when specify the new field in grouping list") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | eval country = upper(country) | stats avg(age) by country + | """.stripMargin)) + assert(ex.getMessage().contains("Reference 'country' is ambiguous")) + } + + test("test override existing fields: the eval field doesn't appear in fields command") { + val frame = sql(s""" + | source = $testTable | eval age = 40, name = upper(name) | sort name | fields state, country + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("New York", "USA"), + Row("California", "USA"), + Row("Quebec", "Canada"), + Row("Ontario", "Canada")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val projectList = Seq( + UnresolvedStar(None), + Alias(Literal(40), "age")(), + Alias( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + "name")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + Project(projectList, table)) + val expectedPlan = + Project(seq(UnresolvedAttribute("state"), UnresolvedAttribute("country")), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test override existing fields: the new fields not appear in fields command") { + val frame = sql(s""" + | source = $testTable | eval age = 40 | eval name = upper(name) | sort name + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + // In Spark, `name` in eval (as an alias) will be treated as a new column (exprIds are different). + // So if `name` appears in fields command, it will throw ambiguous reference exception. + val expectedResults: Array[Row] = Array( + Row("Hello", 30, "New York", "USA", 2023, 4, 40, "HELLO"), + Row("Jake", 70, "California", "USA", 2023, 4, 40, "JAKE"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 40, "JANE"), + Row("John", 25, "Ontario", "Canada", 2023, 4, 40, "JOHN")) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(40), "age")()) + val evalProjectList2 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), + "name")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + Project(evalProjectList2, Project(evalProjectList1, table))) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval commands in fields list") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | eval col2 = 2 | fields name, age, col1, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, 1, 2), + Row("Hello", 30, 1, 2), + Row("John", 25, 1, 2), + Row("Jane", 20, 1, 2)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("col1"), + UnresolvedAttribute("col2")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(1), "col1")()) + val evalProjectList2 = Seq(UnresolvedStar(None), Alias(Literal(2), "col2")()) + val expectedPlan = + Project(fieldsProjectList, Project(evalProjectList2, Project(evalProjectList1, table))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eval commands without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = ln(age) | eval col2 = unix_timestamp('2020-09-16 17:30:00') | sort - col1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, 4.248495242049359, 1600302600), + Row("Hello", 30, "New York", "USA", 2023, 4, 3.4011973816621555, 1600302600), + Row("John", 25, "Ontario", "Canada", 2023, 4, 3.2188758248682006, 1600302600), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 2.995732273553991, 1600302600)) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val evalProjectList1 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), + "col1")()) + val evalProjectList2 = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "unix_timestamp", + seq(Literal("2020-09-16 17:30:00")), + isDistinct = false), + "col2")()) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("col1"), Descending)), + global = true, + Project(evalProjectList2, Project(evalProjectList1, table))) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortedPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex eval commands - case 1") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort age | head 2 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row("Jane", 20, 2), Row("John", 25, 2)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 2") { + val frame = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | sort + col2 | head 2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 3") { + val frame = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | fields name, age | sort + col2 | head 2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Hello", 30)) + assert(results.sameElements(expectedResults)) + } + + test("test complex eval commands - case 4: execute 1, 2 and 3 together") { + val frame1 = sql(s""" + | source = $testTable | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort age | head 2 | fields name, age, col2 + | """.stripMargin) + val results1: Array[Row] = frame1.collect() + // results1.foreach(println(_)) + val expectedResults1: Array[Row] = Array(Row("Jane", 20, 2), Row("John", 25, 2)) + assert(results1.sameElements(expectedResults1)) + + val frame2 = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | sort + col2 | head 2 + | """.stripMargin) + val results2: Array[Row] = frame2.collect() + // results2.foreach(println(_)) + val expectedResults2: Array[Row] = Array( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30)) + assert(results2.sameElements(expectedResults2)) + + val frame3 = sql(s""" + | source = $testTable | eval col1 = age | sort - col1 | head 3 | eval col2 = age | fields name, age | sort + col2 | head 2 + | """.stripMargin) + val results3: Array[Row] = frame3.collect() + // results3.foreach(println(_)) + val expectedResults3: Array[Row] = Array(Row("John", 25), Row("Hello", 30)) + assert(results3.sameElements(expectedResults3)) + } + + test("test eval expression used in aggregation") { + val frame = sql(s""" + | source = $testTable | eval col1 = age, col2 = country | stats avg(col1) by col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("age"), "col1")(), + Alias(UnresolvedAttribute("country"), "col2")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("col1")), isDistinct = false), + "avg(col1)")(), + Alias(UnresolvedAttribute("col2"), "col2")()) + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedAttribute("col2"), "col2")()), + aggregateExpressions, + Project(evalProjectList, table)) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test complex eval expressions with fields command") { + val frame = sql(s""" + | source = $testTable | eval new_name = upper(name) | eval compound_field = concat('Hello ', if(like(new_name, 'HEL%'), 'World', name)) | fields new_name, compound_field + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("JAKE", "Hello Jake"), + Row("HELLO", "Hello World"), + Row("JOHN", "Hello John"), + Row("JANE", "Hello Jane")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test complex eval expressions without fields command") { + val frame = sql(s""" + | source = $testTable | eval col1 = "New Field" | eval col2 = upper(lower(col1)) + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "New Field", "NEW FIELD"), + Row("Hello", 30, "New York", "USA", 2023, 4, "New Field", "NEW FIELD"), + Row("John", 25, "Ontario", "Canada", 2023, 4, "New Field", "NEW FIELD"), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "New Field", "NEW FIELD")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test depended eval expressions in individual eval command") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1 | eval col2 = col1 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70, 1), Row("Hello", 30, 1), Row("John", 25, 1), Row("Jane", 20, 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"), UnresolvedAttribute("col2")) + val evalProjectList1 = Seq(UnresolvedStar(None), Alias(Literal(1), "col1")()) + val evalProjectList2 = Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("col1"), "col2")()) + val expectedPlan = + Project(fieldsProjectList, Project(evalProjectList2, Project(evalProjectList1, table))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // +--------------------------------+ + // | Below tests are not supported | + // +--------------------------------+ + // Todo: Upgrading spark version to 3.4.0 and above could fix this test. + // https://issues.apache.org/jira/browse/SPARK-27561 + ignore("test lateral eval expressions references - SPARK-27561 required") { + val frame = sql(s""" + | source = $testTable | eval col1 = 1, col2 = col1 | fields name, age, col2 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = + Array(Row("Jake", 70, 1), Row("Hello", 30, 1), Row("John", 25, 1), Row("Jane", 20, 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age"), UnresolvedAttribute("col2")) + val evalProjectList = Seq( + UnresolvedStar(None), + Alias(Literal(1), "col1")(), + Alias(UnresolvedAttribute("col1"), "col2")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // Todo excluded fields not support yet + ignore("test single eval expression with excluded fields") { + val frame = sql(s""" + | source = $testTable | eval new_field = "New Field" | fields - age + | """.stripMargin) + + val results: Array[Row] = frame.collect() + // results.foreach(println(_)) + val expectedResults: Array[Row] = Array( + Row("Jake", "California", "USA", 2023, 4, "New Field"), + Row("Hello", "New York", "USA", 2023, 4, "New Field"), + Row("John", "Ontario", "Canada", 2023, 4, "New Field"), + Row("Jane", "Quebec", "Canada", 2023, 4, "New Field")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 61ef5b670..1538f43be 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -240,6 +240,23 @@ The next samples of PPL queries are currently supported: - `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` - `source = table | where c = 'test' NOT a > 1 | fields a,b,c` + +**Eval** + +Assumptions: `a`, `b`, `c` are existing fields in `table` + - `source = table | eval f = 1 | fields a,b,c,f` + - `source = table | eval f = 1` (output a,b,c,f fields) + - `source = table | eval n = now() | eval t = unix_timestamp(a) | fields n,t` + - `source = table | eval f = a | where f > 1 | sort f | fields a,b,c | head 5` + - `source = table | eval f = a * 2 | eval h = f * 2 | fields a,f,h` + - `source = table | eval f = a * 2, h = f * 2 | fields a,f,h` (Spark 3.4.0+ required) + - `source = table | eval f = a * 2, h = b | stats avg(f) by h` + +Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" + - `source = table | eval a = 10 | fields a,b,c` + - `source = table | eval a = a * 2 | stats avg(a)` + - `source = table | eval a = abs(a) | where a > 0` + **Aggregations** - `source = table | stats avg(a) ` - `source = table | where a < 50 | stats avg(c) ` @@ -261,6 +278,7 @@ The next samples of PPL queries are currently supported: - `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst) - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) + - `eval` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/eval.rst) - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, DISTINCT_COUNT, MAX, MIN and SUM aggregation functions) - `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index aac3c3f36..2d0986890 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -36,6 +36,7 @@ commands | statsCommand | sortCommand | headCommand + | evalCommand ; searchCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java index 3f51b595e..35ded8d8b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java @@ -13,8 +13,7 @@ /** Argument. */ public class Argument extends UnresolvedExpression { private final String name; - private String argName; - private Literal value; + private final Literal value; public Argument(String name, Literal value) { this.name = name; @@ -27,8 +26,8 @@ public List getChild() { return Arrays.asList(value); } - public String getArgName() { - return argName; + public String getName() { + return name; } public Literal getValue() { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java index 39b42dfe4..a8ec28d0e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java @@ -8,25 +8,25 @@ import com.google.common.collect.ImmutableList; import org.opensearch.sql.ast.AbstractNodeVisitor; -import java.util.ArrayList; import java.util.Collections; import java.util.List; + public class Field extends UnresolvedExpression { - private final UnresolvedExpression field; + private final QualifiedName field; private final List fieldArgs; /** Constructor of Field. */ - public Field(UnresolvedExpression field) { + public Field(QualifiedName field) { this(field, Collections.emptyList()); } /** Constructor of Field. */ - public Field(UnresolvedExpression field, List fieldArgs) { + public Field(QualifiedName field, List fieldArgs) { this.field = field; this.fieldArgs = fieldArgs; } - public UnresolvedExpression getField() { + public QualifiedName getField() { return field; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java index 6cc2b66ff..0a49bbb6c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java @@ -4,21 +4,20 @@ import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Scope; -import org.opensearch.sql.ast.expression.SpanUnit; -import org.opensearch.sql.ast.expression.UnresolvedExpression; import java.util.List; /** Logical plan node of correlation , the interface for building the searching sources. */ public class Correlation extends UnresolvedPlan { - private final CorrelationType correlationType; - private final List fieldsList; + private final CorrelationType correlationType; + private final List fieldsList; private final Scope scope; private final FieldsMapping mappingListContext; private UnresolvedPlan child ; - public Correlation(String correlationType, List fieldsList, Scope scope, FieldsMapping mappingListContext) { + public Correlation(String correlationType, List fieldsList, Scope scope, FieldsMapping mappingListContext) { this.correlationType = CorrelationType.valueOf(correlationType); this.fieldsList = fieldsList; this.scope = scope; @@ -45,7 +44,7 @@ public CorrelationType getCorrelationType() { return correlationType; } - public List getFieldsList() { + public List getFieldsList() { return fieldsList; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 04f4320c1..fd8d81e5c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -31,6 +31,7 @@ import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; @@ -60,6 +61,7 @@ import scala.Option; import scala.collection.Seq; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -224,7 +226,22 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); + LogicalPlan child = node.getChild().get(0).accept(this, context); + List aliases = new ArrayList<>(); + List letExpressions = node.getExpressionList(); + for(Let let : letExpressions) { + Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); + aliases.add(alias); + } + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + List expressionList = visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ce9eea769..9973f4676 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -110,7 +110,9 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), ctx.fieldList().fieldExpression().stream() + .map(OpenSearchPPLParser.FieldExpressionContext::qualifiedName) .map(this::internalVisitExpression) + .map(u -> (QualifiedName) u) .collect(Collectors.toList()), Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), expressionBuilder.visit(ctx.scopeClause().value), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 92e9dd458..71abb329f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -141,7 +141,7 @@ public UnresolvedExpression visitWcFieldExpression(OpenSearchPPLParser.WcFieldEx @Override public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ctx) { - return new Field( + return new Field((QualifiedName) visit(ctx.sortFieldExpression().fieldExpression().qualifiedName()), ArgumentFactory.getArgumentList(ctx)); } diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java index 4a942d067..5a20de2d4 100644 --- a/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/common/utils/StringUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.common.utils; import static org.junit.Assert.assertEquals; diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala new file mode 100644 index 000000000..772eb050a --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -0,0 +1,217 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} + +class PPLLogicalPlanEvalTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test eval expressions not included in fields expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 1 | fields c", false), + context) + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions included in fields expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, c = 1 | fields a, b, c", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "c")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions without fields command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=t | eval a = 1, b = 1", false), context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = + Project(seq(UnresolvedStar(None)), Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 1 | sort - a | fields b", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val sortOrder = SortOrder(UnresolvedAttribute("a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eval expressions with multiple recursive sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, a = a | sort - a | fields b", false), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(UnresolvedAttribute("a"), "a")()) + val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t"))) + val sortOrder = SortOrder(UnresolvedAttribute("a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple eval expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = 1, b = 'hello' | eval b = a | sort - b | fields b", + false), + context) + + val evalProjectList1: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal("hello"), "b")()) + val evalProjectList2: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("a"), "b")(exprId = ExprId(2), qualifier = Seq.empty)) + val evalProject1 = Project(evalProjectList1, UnresolvedRelation(Seq("t"))) + val evalProject2 = Project(evalProjectList2, evalProject1) + val sortOrder = SortOrder(UnresolvedAttribute("b"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject2) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - date function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = TIMESTAMP('2020-09-16 17:30:00') | fields a", false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction("timestamp", seq(Literal("2020-09-16 17:30:00")), isDistinct = false), + "a")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - math function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = RAND() | fields a", false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")( + exprId = ExprId(0), + qualifier = Seq.empty)) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test complex eval expressions - compound function") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = if(like(b, '%Hello%'), 'World', 'Hi') | fields a", + false), + context) + + val evalProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "if", + seq( + UnresolvedFunction( + "like", + seq(UnresolvedAttribute("b"), Literal("%Hello%")), + isDistinct = false), + Literal("World"), + Literal("Hi")), + isDistinct = false), + "a")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("a")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // Todo fields-excluded command not supported + ignore("test eval expressions with fields-excluded command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 2 | fields - b", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // Todo fields-included command not supported + ignore("test eval expressions with fields-included command") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = 1, b = 2 | fields + b", false), + context) + + val projectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(2), "b")()) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("t"))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} From 06ff8be989805267472eb0228d74f02fe2322b11 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 1 Aug 2024 08:49:23 -0700 Subject: [PATCH 7/8] Fix SigV4 signature when connecting to OpenSearchServerless (#473) --------- Signed-off-by: Tomoyuki Morita --- .../opensearch/flint/core/FlintOptions.java | 18 ++- .../AWSRequestSigningApacheInterceptor.java | 22 ++- ...sedAWSRequestSigningApacheInterceptor.java | 5 +- .../core/storage/OpenSearchClientUtils.java | 4 +- ...WSRequestSigningApacheInterceptorTest.java | 126 ++++++++++++++++++ .../sql/flint/config/FlintSparkConf.scala | 7 + 6 files changed, 166 insertions(+), 16 deletions(-) create mode 100644 flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index c49247f37..9be01737c 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -33,16 +33,20 @@ public class FlintOptions implements Serializable { public static final String SCHEME = "scheme"; - public static final String AUTH = "auth"; + /** + * Service name used for SigV4 signature. + * `es`: Amazon OpenSearch Service + * `aoss`: Amazon OpenSearch Serverless + */ + public static final String SERVICE_NAME = "auth.servicename"; + public static final String SERVICE_NAME_ES = "es"; + public static final String SERVICE_NAME_AOSS = "aoss"; + public static final String AUTH = "auth"; public static final String NONE_AUTH = "noauth"; - public static final String SIGV4_AUTH = "sigv4"; - public static final String BASIC_AUTH = "basic"; - public static final String USERNAME = "auth.username"; - public static final String PASSWORD = "auth.password"; public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider"; @@ -131,6 +135,10 @@ public String getAuth() { return options.getOrDefault(AUTH, NONE_AUTH); } + public String getServiceName() { + return options.getOrDefault(SERVICE_NAME, SERVICE_NAME_ES); + } + public String getCustomAwsCredentialsProvider() { return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java index a3925999e..172ac5ceb 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java @@ -5,7 +5,9 @@ package org.opensearch.flint.core.auth; +import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_CONTENT_SHA256; import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_AOSS; import com.amazonaws.DefaultRequest; import com.amazonaws.auth.AWSCredentialsProvider; @@ -31,6 +33,7 @@ import org.apache.http.entity.BasicHttpEntity; import org.apache.http.message.BasicHeader; import org.apache.http.protocol.HttpContext; +import org.opensearch.flint.core.storage.OpenSearchClientUtils; /** * From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java @@ -74,13 +77,6 @@ public AWSRequestSigningApacheInterceptor(final String service, @Override public void process(final HttpRequest request, final HttpContext context) throws HttpException, IOException { - URIBuilder uriBuilder; - try { - uriBuilder = new URIBuilder(request.getRequestLine().getUri()); - } catch (URISyntaxException e) { - throw new IOException("Invalid URI" , e); - } - // Copy Apache HttpRequest to AWS DefaultRequest DefaultRequest signableRequest = new DefaultRequest<>(service); @@ -91,7 +87,10 @@ public void process(final HttpRequest request, final HttpContext context) final HttpMethodName httpMethod = HttpMethodName.fromValue(request.getRequestLine().getMethod()); signableRequest.setHttpMethod(httpMethod); + + URIBuilder uriBuilder; try { + uriBuilder = new URIBuilder(request.getRequestLine().getUri()); signableRequest.setResourcePath(uriBuilder.build().getRawPath()); } catch (URISyntaxException e) { throw new IOException("Invalid URI" , e); @@ -110,6 +109,10 @@ public void process(final HttpRequest request, final HttpContext context) signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams())); signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders())); + if (SERVICE_NAME_AOSS.equals(service)) { + enableContentBodySignature(signableRequest); + } + // Sign it signer.sign(signableRequest, awsCredentialsProvider.getCredentials()); @@ -126,6 +129,11 @@ public void process(final HttpRequest request, final HttpContext context) } } + private void enableContentBodySignature(DefaultRequest signableRequest) { + // AWS4Signer will add `x-amz-content-sha256` header when this header is set + signableRequest.addHeader(X_AMZ_CONTENT_SHA256, "required"); + } + /** * * @param params list of HTTP query params as NameValuePairs diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java index b69343730..05b83d658 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -5,6 +5,8 @@ package org.opensearch.flint.core.auth; +import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_ES; + import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; import org.apache.http.HttpException; @@ -14,6 +16,7 @@ import org.apache.http.protocol.HttpContext; import org.jetbrains.annotations.TestOnly; import org.opensearch.common.Strings; +import org.opensearch.flint.core.storage.OpenSearchClientUtils; import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner; import java.io.IOException; @@ -84,7 +87,7 @@ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, @Override public void process(HttpRequest request, HttpContext context) throws HttpException, IOException { String resourcePath = parseUriToPath(request); - if ("es".equals(this.service) && isMetadataAccess(resourcePath)) { + if (SERVICE_NAME_ES.equals(this.service) && isMetadataAccess(resourcePath)) { metadataAccessInterceptor.process(request, context); } else { primaryInterceptor.process(request, context); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 9277a17df..21241d7ab 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -30,8 +30,6 @@ */ public class OpenSearchClientUtils { - private static final String SERVICE_NAME = "es"; - /** * Metadata log index name prefix */ @@ -90,7 +88,7 @@ private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClient restClientBuilder.setHttpClientConfigCallback(builder -> { HttpAsyncClientBuilder delegate = builder.addInterceptorLast( new ResourceBasedAWSRequestSigningApacheInterceptor( - SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); + options.getServiceName(), options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); return RetryableHttpAsyncClient.builder(delegate, options); } ); diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java new file mode 100644 index 000000000..ae8fdfa9a --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java @@ -0,0 +1,126 @@ +package org.opensearch.flint.core.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.auth.signer.internal.SignerConstant.X_AMZ_CONTENT_SHA256; + +import com.amazonaws.DefaultRequest; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import com.amazonaws.http.HttpMethodName; +import com.amazonaws.util.IOUtils; +import java.io.IOException; +import java.net.URI; +import org.apache.http.HttpHost; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHttpEntityEnclosingRequest; +import org.apache.http.protocol.BasicHttpContext; +import org.apache.http.protocol.HttpCoreContext; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.utils.StringInputStream; + +@ExtendWith(MockitoExtension.class) +class AWSRequestSigningApacheInterceptorTest { + + @Mock + AWSCredentialsProvider awsCredentialsProvider; + @Mock Signer signer; + @Mock + AWSCredentials awsCredentials; + + @Captor + ArgumentCaptor> signableRequestCaptor; + + @Test + public void testProcessWithServiceIsEs() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequestWithEntity(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals(new URI("http://hello.world"), signableRequest.getEndpoint()); + assertEquals(HttpMethodName.POST, signableRequest.getHttpMethod()); + assertEquals("/path", signableRequest.getResourcePath()); + assertEquals("ENTITY", IOUtils.toString(signableRequest.getContent())); + assertEquals("HeaderValue", signableRequest.getHeaders().get("Test-Header")); + assertEquals("value0", signableRequest.getParameters().get("param0").get(0)); + } + + @Test + public void testProcessWithoutEntity() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequest(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals("", IOUtils.toString(signableRequest.getContent())); + } + + @NotNull + private static BasicHttpContext getContext() { + BasicHttpContext context = new BasicHttpContext(); + context.setAttribute(HttpCoreContext.HTTP_TARGET_HOST, new HttpHost("hello.world")); + return context; + } + + @Test + public void testProcessWithServiceIsAoss() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequest(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals("required", signableRequest.getHeaders().get(X_AMZ_CONTENT_SHA256)); + } + + @Test + public void testInvalidURI() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "::INVALID_URI::"); + final BasicHttpContext context = getContext(); + + assertThrows(IOException.class, () -> { + awsRequestSigningApacheInterceptor.process(request, context); + }); + } + + @NotNull + private static BasicHttpEntityEnclosingRequest getRequestWithEntity() { + BasicHttpEntityEnclosingRequest request = getRequest(); + BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); + basicHttpEntity.setContent(new StringInputStream("ENTITY")); + request.setEntity(basicHttpEntity); + request.setHeader("content-length", "6"); + return request; + } + + @NotNull + private static BasicHttpEntityEnclosingRequest getRequest() { + BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "https://hello.world/path?param0=value0"); + request.setHeader("Test-Header", "HeaderValue"); + request.setHeader("content-length", "0"); + return request; + } +} \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index f2f680281..7ea284959 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -57,6 +57,12 @@ object FlintSparkConf { "noauth(no auth), sigv4(sigv4 auth), basic(basic auth)") .createWithDefault(FlintOptions.NONE_AUTH) + val SERVICE_NAME = FlintConfig("spark.datasource.flint.auth.servicename") + .datasourceOption() + .doc("service name used for SigV4 signature. " + + "es (AWS OpenSearch Service), aoss (Amazon OpenSearch Serverless)") + .createWithDefault(FlintOptions.SERVICE_NAME_ES) + val USERNAME = FlintConfig("spark.datasource.flint.auth.username") .datasourceOption() .doc("basic auth username") @@ -267,6 +273,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable RETRYABLE_HTTP_STATUS_CODES, REGION, CUSTOM_AWS_CREDENTIALS_PROVIDER, + SERVICE_NAME, USERNAME, PASSWORD, SOCKET_TIMEOUT_MILLIS, From d3e54d48c578680ccaa38ddc96a74372bf64eb84 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Sat, 3 Aug 2024 03:48:05 +0800 Subject: [PATCH 8/8] Support more PPL builtin functions by adding a name mapping (#504) * Support more builtin functions by adding a name mapping Signed-off-by: Lantao Jin * shorten the map declaration Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- .../flint/spark/FlintSparkSuite.scala | 22 ++ .../FlintSparkPPLBuiltinFunctionITSuite.scala | 290 +++++++++++------- .../ppl/utils/BuiltinFunctionTranslator.java | 57 +++- ...ggregationQueriesTranslatorTestSuite.scala | 31 +- ...lPlanBasicQueriesTranslatorTestSuite.scala | 27 +- ...ogicalPlanFiltersTranslatorTestSuite.scala | 29 +- ...PlanMathFunctionsTranslatorTestSuite.scala | 65 +++- ...anStringFunctionsTranslatorTestSuite.scala | 29 +- ...PlanTimeFunctionsTranslatorTestSuite.scala | 177 ++++++++++- 9 files changed, 540 insertions(+), 187 deletions(-) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index fbb2f89bd..7e0b68376 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -188,6 +188,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createNullableStateCountryTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada'), + | (null, 10, null, 'Canada') + | """.stripMargin) + } + protected def createOccupationTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala index 127b29295..c9bf8a926 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBuiltinFunctionITSuite.scala @@ -5,9 +5,11 @@ package org.opensearch.flint.spark.ppl +import java.sql.{Date, Time, Timestamp} + import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{EqualTo, GreaterThan, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -22,12 +24,14 @@ class FlintSparkPPLBuiltinFunctionITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val testNullTable = "spark_catalog.default.flint_ppl_test_null" override def beforeAll(): Unit = { super.beforeAll() // Create test table createPartitionedStateCountryTable(testTable) + createNullableStateCountryTable(testNullTable) } protected override def afterEach(): Unit = { @@ -44,17 +48,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable name=concat('He', 'llo') | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("name"), @@ -62,7 +61,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -71,15 +69,11 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable name=concat('Hello', state) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array.empty assert(results.sameElements(expectedResults)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("name"), @@ -90,7 +84,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -99,17 +92,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where length(name) = 5 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -117,7 +105,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -126,17 +113,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where leNgTh(name) = 5 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("length", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -144,7 +126,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -153,17 +134,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where lower(name) = "hello" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("lower", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -171,7 +147,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -180,17 +155,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where upper(name) = upper("hello") | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("upper", seq(UnresolvedAttribute("name")), isDistinct = false), @@ -198,7 +168,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -207,17 +176,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where substring(name, 2, 2) = "el" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -228,7 +192,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -237,17 +200,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable | where like(name, '_ello%') | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val likeFunction = UnresolvedFunction( "like", @@ -257,7 +215,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(likeFunction, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -266,17 +223,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where replace(name, 'o', ' ') = "Hell " | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -287,7 +239,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -296,17 +247,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where trim(replace(name, 'o', ' ')) = "Hell" | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -321,7 +267,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -330,17 +275,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = abs(-30) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -348,7 +288,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -357,17 +296,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where abs(age) = 30 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("abs", seq(UnresolvedAttribute("age")), isDistinct = false), @@ -375,7 +309,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -384,17 +317,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = ceil(29.7) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -402,7 +330,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -411,17 +338,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where age = floor(30.4) | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedAttribute("age"), @@ -429,7 +351,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -438,17 +359,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where ln(age) > 4 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = GreaterThan( UnresolvedFunction("ln", seq(UnresolvedAttribute("age")), isDistinct = false), @@ -456,7 +372,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -465,17 +380,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where mod(age, 10) = 0 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("Jane", 20)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction("mod", seq(UnresolvedAttribute("age"), Literal(10)), isDistinct = false), @@ -483,7 +393,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -492,17 +401,12 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where sqrt(pow(age, 2)) = 30.0 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Hello", 30)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = EqualTo( UnresolvedFunction( @@ -517,7 +421,6 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -526,18 +429,13 @@ class FlintSparkPPLBuiltinFunctionITSuite | source = $testTable |where unix_timestamp(from_unixtime(1700000001)) > 1700000000 | fields name, age | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() - // Define the expected results val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) val filterExpr = GreaterThan( UnresolvedFunction( @@ -548,7 +446,183 @@ class FlintSparkPPLBuiltinFunctionITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test arithmetic operators (+ - * / %)") { + val frame = sql(s""" + | source = $testTable | where (sqrt(pow(age, 2)) + sqrt(pow(age, 2)) / 1 - sqrt(pow(age, 2)) * 1) % 25.0 = 0 | fields name, age + | """.stripMargin) // equals age + age - age + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("John", 25)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test boolean operators (= != < <= > >=)") { + val frame = sql(s""" + | source = $testTable | eval a = age = 30, b = age != 70, c = 30 < age, d = 30 <= age, e = 30 > age, f = 30 >= age | fields age, a, b, c, d, e, f + | """.stripMargin) // equals age + age - age + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(70, false, false, true, true, false, false), + Row(30, true, true, false, true, false, true), + Row(25, false, true, false, false, true, true), + Row(20, false, true, false, false, true, true)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test boolean condition functions - isnull isnotnull ifnull nullif") { + val frameIsNull = sql(s""" + | source = $testNullTable | where isnull(name) | fields age + | """.stripMargin) + + val results1: Array[Row] = frameIsNull.collect() + val expectedResults1: Array[Row] = Array(Row(10)) + assert(results1.sameElements(expectedResults1)) + + val frameIsNotNull = sql(s""" + | source = $testNullTable | where isnotnull(name) | fields name + | """.stripMargin) + + val results2: Array[Row] = frameIsNotNull.collect() + val expectedResults2: Array[Row] = Array(Row("John"), Row("Jane"), Row("Jake"), Row("Hello")) + assert(results2.sameElements(expectedResults2)) + + val frameIfNull = sql(s""" + | source = $testNullTable | eval new_name = ifnull(name, "Unknown") | fields new_name, age + | """.stripMargin) + + val results3: Array[Row] = frameIfNull.collect() + val expectedResults3: Array[Row] = Array( + Row("John", 25), + Row("Jane", 20), + Row("Unknown", 10), + Row("Jake", 70), + Row("Hello", 30)) + assert(results3.sameElements(expectedResults3)) + + val frameNullIf = sql(s""" + | source = $testNullTable | eval new_age = nullif(age, 20) | fields name, new_age + | """.stripMargin) + + val results4: Array[Row] = frameNullIf.collect() + val expectedResults4: Array[Row] = + Array(Row("John", 25), Row("Jane", null), Row(null, 10), Row("Jake", 70), Row("Hello", 30)) + assert(results4.sameElements(expectedResults4)) + } + + test("test typeof function") { + val frame = sql(s""" + | source = $testNullTable | eval tdate = typeof(DATE('2008-04-14')), tint = typeof(1), tnow = typeof(now()), tcol = typeof(age) | fields tdate, tint, tnow, tcol | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("date", "int", "timestamp", "int")) + assert(results.sameElements(expectedResults)) + } + + test("test the builtin functions which required additional name mapping") { + val frame = sql(s""" + | source = $testNullTable + | | eval a = DAY_OF_WEEK(DATE('2020-08-26')) + | | eval b = DAY_OF_MONTH(DATE('2020-08-26')) + | | eval c = DAY_OF_YEAR(DATE('2020-08-26')) + | | eval d = WEEK_OF_YEAR(DATE('2020-08-26')) + | | eval e = WEEK(DATE('2020-08-26')) + | | eval f = MONTH_OF_YEAR(DATE('2020-08-26')) + | | eval g = HOUR_OF_DAY(DATE('2020-08-26')) + | | eval h = MINUTE_OF_HOUR(DATE('2020-08-26')) + | | eval i = SECOND_OF_MINUTE(DATE('2020-08-26')) + | | eval j = SUBDATE(DATE('2020-08-26'), 1) + | | eval k = ADDDATE(DATE('2020-08-26'), 1) + | | eval l = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')) + | | eval m = DATEDIFF(ADDDATE(LOCALTIME(), 1), LOCALTIME()) + | | fields a, b, c, d, e, f, g, h, i, j, k, l, m + | | head 1 + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = { + Array( + Row( + 4, + 26, + 239, + 35, + 35, + 8, + 0, + 0, + 0, + Date.valueOf("2020-08-25"), + Date.valueOf("2020-08-27"), + 1, + 1)) + } + assert(results.sameElements(expectedResults)) + } + + test("not all arguments could work in builtin functions") { + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = WEEK(DATE('2008-02-20'), 1) + | """.stripMargin)) + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = SUBDATE(DATE('2020-08-26'), INTERVAL 31 DAY) + | """.stripMargin)) + intercept[AnalysisException](sql(s""" + | source = $testTable | eval a = ADDDATE(DATE('2020-08-26'), INTERVAL 1 HOUR) + | """.stripMargin)) + } + + // Todo + // +---------------------------------------+ + // | Below tests are not supported (cast) | + // +---------------------------------------+ + ignore("test cast to string") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(true as string), cint = CAST(1 as string), cdate = CAST(CAST('2012-08-07' as date) as string) | fields cbool, cint, cdate + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(true, 1, "2012-08-07")) + assert(results.sameElements(expectedResults)) + } + + ignore("test cast to number") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(true as int), cstring = CAST('1' as int) | fields cbool, cstring + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1, 1)) + assert(results.sameElements(expectedResults)) + } + + ignore("test cast to date") { + val frame = sql(s""" + | source = $testNullTable | eval cdate = CAST('2012-08-07' as date), ctime = CAST('01:01:01' as time), ctimestamp = CAST('2012-08-07 01:01:01' as timestamp) | fields cdate, ctime, ctimestamp + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row( + Date.valueOf("2012-08-07"), + Time.valueOf("01:01:01"), + Timestamp.valueOf("2012-08-07 01:01:01"))) + assert(results.sameElements(expectedResults)) + } + + ignore("test can be chained") { + val frame = sql(s""" + | source = $testNullTable | eval cbool = CAST(CAST(true as string) as boolean) | fields cbool + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(true)) + assert(results.sameElements(expectedResults)) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java index 0d57fea20..53c6673a8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/BuiltinFunctionTranslator.java @@ -5,25 +5,78 @@ package org.opensearch.sql.ppl.utils; +import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.Expression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import java.util.List; -import java.util.Locale; +import java.util.Map; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MODULUS; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_MONTH; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MONTH_OF_YEAR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.HOUR_OF_DAY; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SECOND_OF_MINUTE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DATEDIFF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOCALTIME; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; public interface BuiltinFunctionTranslator { + /** + * The name mapping between PPL builtin functions to Spark builtin functions. + */ + static final Map SPARK_BUILTIN_FUNCTION_NAME_MAPPING + = new ImmutableMap.Builder() + // arithmetic operators + .put(ADD, "+") + .put(SUBTRACT, "-") + .put(MULTIPLY, "*") + .put(DIVIDE, "/") + .put(MODULUS, "%") + // time functions + .put(DAY_OF_WEEK, "dayofweek") + .put(DAY_OF_MONTH, "dayofmonth") + .put(DAY_OF_YEAR, "dayofyear") + .put(WEEK_OF_YEAR, "weekofyear") + .put(WEEK, "weekofyear") + .put(MONTH_OF_YEAR, "month") + .put(HOUR_OF_DAY, "hour") + .put(MINUTE_OF_HOUR, "minute") + .put(SECOND_OF_MINUTE, "second") + .put(SUBDATE, "date_sub") // only maps subdate(date, days) + .put(ADDDATE, "date_add") // only maps adddate(date, days) + .put(DATEDIFF, "datediff") + .put(LOCALTIME, "localtimestamp") + //condition functions + .put(IS_NULL, "isnull") + .put(IS_NOT_NULL, "isnotnull") + .build(); + static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List args) { if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) { // TODO change it when UDF is supported // TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL"); } else { - String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT); + BuiltinFunctionName builtin = BuiltinFunctionName.of(function.getFuncName()).get(); + String name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING + .getOrDefault(builtin, builtin.getName().getFunctionName()); return new UnresolvedFunction(seq(name), seq(args), false, empty(),false); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 1fdd20c74..ba634cc1c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -13,17 +12,19 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { private val planTransformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - test("test average price ") { + test("test average price") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = @@ -38,10 +39,10 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } - ignore("test average price with Alias") { + test("test average price with Alias") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -57,7 +58,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product ") { @@ -81,7 +82,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product and filter") { @@ -109,7 +110,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) val expectedPlan = Project(star, aggregatePlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("test average price group by product and filter sorted") { @@ -144,7 +145,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite global = true, aggregatePlan) val expectedPlan = Project(star, sortedPlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years query test ") { val context = new CatalystPlanContext @@ -164,7 +165,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) val expectedPlan = Project(star, aggregatePlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years query with sort test ") { @@ -190,7 +191,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) val expectedPlan = Project(star, sortedPlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl simple avg age by span of interval of 10 years by country query test ") { @@ -219,7 +220,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite tableRelation) val expectedPlan = Project(star, aggregatePlan) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count sales by weeks window and productId with sorting test") { val context = new CatalystPlanContext @@ -257,7 +258,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val expectedPlan = Project(star, sortedPlan) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count sales by days window and productId with sorting test") { @@ -296,7 +297,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite aggregatePlan) val expectedPlan = Project(star, sortedPlan) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test("create ppl query count status amount by day window and group by status test") { val context = new CatalystPlanContext @@ -331,7 +332,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) val expectedPlan = Project(star, planWithLimit) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( "create ppl query count only error (status >= 400) status amount by day window and group by status test") { @@ -368,7 +369,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) val expectedPlan = Project(star, planWithLimit) // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index bc31691d0..5b94ca092 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -13,11 +12,12 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.IntegerType class PPLLogicalPlanBasicQueriesTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -31,7 +31,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with escaped table name") { @@ -41,7 +41,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with schema.table and no explicit fields (defaults to all fields)") { @@ -51,7 +51,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } @@ -62,7 +62,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with one field projected") { @@ -72,7 +72,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two fields projected") { @@ -82,7 +82,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) val expectedPlan = Project(projectList, table) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with one table with two fields projected sorted by one field") { @@ -97,7 +97,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val sorted = Sort(sortOrder, true, table) val expectedPlan = Project(projectList, sorted) - assert(compareByString(expectedPlan) === compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( @@ -111,7 +111,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table))) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -129,8 +129,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB)) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } test( @@ -152,7 +151,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("Search multiple tables - translated into union call with fields") { @@ -172,6 +171,6 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala index 27dd972fc..fd7957106 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.apache.hadoop.conf.Configuration import org.junit.Assert.assertEquals import org.mockito.Mockito.when import org.opensearch.flint.spark.ppl.PlaneUtils.plan @@ -20,11 +19,13 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, Or, SortOrder, UnixTimestamp} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class PPLLogicalPlanFiltersTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -40,7 +41,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'and' filtered ") { @@ -54,7 +55,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(And(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'or' filtered ") { @@ -68,7 +69,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test simple search with only one table with two field with 'not' filtered ") { @@ -82,7 +83,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -96,7 +97,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -111,7 +112,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -127,7 +128,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -141,7 +142,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -155,7 +156,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -169,7 +170,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -183,7 +184,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -197,7 +198,7 @@ class PPLLogicalPlanFiltersTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test( @@ -218,6 +219,6 @@ class PPLLogicalPlanFiltersTranslatorTestSuite Project(projectList, filterPlan)) val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) - assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala index 24336b098..73fa2a999 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -5,20 +5,20 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanMathFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -36,7 +36,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ceil") { @@ -50,7 +50,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test floor") { @@ -64,7 +64,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ln") { @@ -78,7 +78,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test mod") { @@ -93,7 +93,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test pow") { @@ -107,7 +107,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test sqrt") { @@ -121,7 +121,7 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test arithmetic: + - * / %") { @@ -145,19 +145,52 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite isDistinct = false)), isDistinct = false) // sqrt(pow(a, 2)) / 1 - val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false) + val sqrtPowDivide = UnresolvedFunction("/", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) * 1 val sqrtPowMultiply = - UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false) + UnresolvedFunction("*", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) % 1 - val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false) + val sqrtPowMod = UnresolvedFunction("%", seq(sqrtPow, Literal(1)), isDistinct = false) // sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false) - val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false) + val add = UnresolvedFunction("+", seq(sqrtPow, sqrtPowDivide), isDistinct = false) + val sub = UnresolvedFunction("-", seq(add, sqrtPowMultiply), isDistinct = false) val filterExpr = EqualTo(sub, sqrtPowMod) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test boolean operators: = != < <= > >=") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = age = 30, b = age != 70, c = 30 < age, d = 30 <= age, e = 30 > age, f = 30 >= age | fields age, a, b, c, d, e, f", + false), + context) + + val table = UnresolvedRelation(Seq("t")) + val evalProject = Project( + Seq( + UnresolvedStar(None), + Alias(EqualTo(UnresolvedAttribute("age"), Literal(30)), "a")(), + Alias(Not(EqualTo(UnresolvedAttribute("age"), Literal(70))), "b")(), + Alias(LessThan(Literal(30), UnresolvedAttribute("age")), "c")(), + Alias(LessThanOrEqual(Literal(30), UnresolvedAttribute("age")), "d")(), + Alias(GreaterThan(Literal(30), UnresolvedAttribute("age")), "e")(), + Alias(GreaterThanOrEqual(Literal(30), UnresolvedAttribute("age")), "f")()), + table) + val projectList = Seq( + UnresolvedAttribute("age"), + UnresolvedAttribute("a"), + UnresolvedAttribute("b"), + UnresolvedAttribute("c"), + UnresolvedAttribute("d"), + UnresolvedAttribute("e"), + UnresolvedAttribute("f")) + val expectedPlan = Project(projectList, evalProject) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala index 36a31862b..0d3c12b82 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanStringFunctionsTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} @@ -14,11 +13,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Like, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanStringFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -45,7 +46,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test concat with field") { @@ -63,7 +64,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test length") { @@ -77,7 +78,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test lower") { @@ -91,7 +92,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test upper - case insensitive") { @@ -105,7 +106,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test trim") { @@ -119,7 +120,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test ltrim") { @@ -133,7 +134,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test rtrim") { @@ -147,7 +148,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test substring") { @@ -162,7 +163,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test like") { @@ -181,7 +182,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test position") { @@ -200,7 +201,7 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test replace") { @@ -219,6 +220,6 @@ class PPLLogicalPlanStringFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala index 7cfcc33d5..cd857fc08 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTimeFunctionsTranslatorTestSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark.ppl -import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq @@ -13,11 +12,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} class PPLLogicalPlanTimeFunctionsTranslatorTestSuite extends SparkFunSuite + with PlanTest with LogicalPlanTestUtils with Matchers { @@ -36,7 +37,7 @@ class PPLLogicalPlanTimeFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) } test("test unix_timestamp") { @@ -51,6 +52,174 @@ class PPLLogicalPlanTimeFunctionsTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(expectedPlan, logPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test builtin time functions with name mapping") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = t + | | eval a = DAY_OF_WEEK(DATE('2020-08-26')) + | | eval b = DAY_OF_MONTH(DATE('2020-08-26')) + | | eval c = DAY_OF_YEAR(DATE('2020-08-26')) + | | eval d = WEEK_OF_YEAR(DATE('2020-08-26')) + | | eval e = WEEK(DATE('2020-08-26')) + | | eval f = MONTH_OF_YEAR(DATE('2020-08-26')) + | | eval g = HOUR_OF_DAY(DATE('2020-08-26')) + | | eval h = MINUTE_OF_HOUR(DATE('2020-08-26')) + | | eval i = SECOND_OF_MINUTE(DATE('2020-08-26')) + | | eval j = SUBDATE(DATE('2020-08-26'), 1) + | | eval k = ADDDATE(DATE('2020-08-26'), 1) + | | eval l = DATEDIFF(TIMESTAMP('2000-01-02 00:00:00'), TIMESTAMP('2000-01-01 23:59:59')) + | | eval m = LOCALTIME() + | """.stripMargin, + false), + context) + + val table = UnresolvedRelation(Seq("t")) + val projectA = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofweek", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "a")()), + table) + val projectB = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofmonth", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "b")()), + projectA) + val projectC = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "dayofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "c")()), + projectB) + val projectD = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "weekofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "d")()), + projectC) + val projectE = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "weekofyear", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "e")()), + projectD) + val projectF = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "month", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "f")()), + projectE) + val projectG = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "hour", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "g")()), + projectF) + val projectH = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "minute", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "h")()), + projectG) + val projectI = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "second", + Seq(UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false)), + isDistinct = false), + "i")()), + projectH) + val projectJ = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "date_sub", + Seq( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + Literal(1)), + isDistinct = false), + "j")()), + projectI) + val projectK = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "date_add", + Seq( + UnresolvedFunction("date", Seq(Literal("2020-08-26")), isDistinct = false), + Literal(1)), + isDistinct = false), + "k")()), + projectJ) + val projectL = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "datediff", + Seq( + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-02 00:00:00")), + isDistinct = false), + UnresolvedFunction( + "timestamp", + Seq(Literal("2000-01-01 23:59:59")), + isDistinct = false)), + isDistinct = false), + "l")()), + projectK) + val projectM = Project( + Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("localtimestamp", Seq.empty, isDistinct = false), "m")()), + projectL) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, projectM) + comparePlans(expectedPlan, logPlan, false) } }