From 182689c5f5d3dda78fe1cd25af3ada06c6df545a Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 7 Nov 2024 10:47:34 -0800 Subject: [PATCH] Add validation for time column in tumble function (#858) * Validate tumble function argument and add IT Signed-off-by: Chen Dai * Add IT for verifying correctness of subquery workaround Signed-off-by: Chen Dai * Modify error message wording Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- .../spark/mv/FlintSparkMaterializedView.scala | 10 ++- ...FlintSparkMaterializedViewSqlITSuite.scala | 75 +++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index aecfc99df..e2a64d183 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -133,8 +133,14 @@ case class FlintSparkMaterializedView( // Assume first aggregate item must be time column val winFunc = winFuncs.head - val timeCol = winFunc.arguments.head.asInstanceOf[Attribute] - Some(agg, timeCol) + val timeCol = winFunc.arguments.head + timeCol match { + case attr: Attribute => + Some(agg, attr) + case _ => + throw new IllegalArgumentException( + s"Tumble function only supports simple timestamp column, but found: $timeCol") + } } private def isWindowingFunction(func: UnresolvedFunction): Boolean = { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 9e75078d2..ae2e53090 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -448,5 +448,80 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("tumble function should raise error for non-simple time column") { + val httpLogs = s"$catalogName.default.mv_test_tumble" + withTable(httpLogs) { + createTableHttpLog(httpLogs) + + withTempDir { checkpointDir => + val ex = the[IllegalStateException] thrownBy { + sql(s""" + | CREATE MATERIALIZED VIEW `$catalogName`.`default`.`mv_test_metrics` + | AS + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $httpLogs + | GROUP BY + | TUMBLE(CAST(timestamp AS TIMESTAMP), '10 Minute') + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second' + | ) + |""".stripMargin) + } + ex.getCause should have message + "Tumble function only supports simple timestamp column, but found: cast('timestamp as timestamp)" + } + } + } + + test("tumble function should succeed with casted time column within subquery") { + val httpLogs = s"$catalogName.default.mv_test_tumble" + withTable(httpLogs) { + createTableHttpLog(httpLogs) + + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW `$catalogName`.`default`.`mv_test_metrics` + | AS + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM ( + | SELECT CAST(timestamp AS TIMESTAMP) AS time + | FROM $httpLogs + | ) + | GROUP BY + | TUMBLE(time, '10 Minute') + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second' + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + checkAnswer( + flint.queryIndex(testFlintIndex).select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 10:00:00"), 2), + Row(timestamp("2023-10-01 10:10:00"), 2) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 10:20:00"), 2) + */ + )) + } + } + } + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) }