From a9e9aaf8523d3f72e19b1a8c3200afaea286400f Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 13 Nov 2024 16:21:42 -0800 Subject: [PATCH] Address review comments Signed-off-by: Andy Kwok --- .../flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala | 9 +++++++++ ...LLogicalPlanTrendlineCommandTranslatorTestSuite.scala | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index 589cad33b..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -7,10 +7,12 @@ package org.opensearch.flint.spark.ppl import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -483,6 +485,13 @@ class FlintSparkPPLTrendlineITSuite } + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + private def getNthValueAggregation( dataField: String, sortField: String, diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index baf472a08..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.utils.SortUtils @@ -251,6 +252,14 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + private def getNthValueAggregation( dataField: String, sortField: String,