diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala index 55fda3c10..b33e7cc45 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.spark +import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.sql.FlintSparkSqlParser import org.apache.spark.sql.SparkSessionExtensions @@ -18,6 +19,9 @@ class FlintSparkExtensions extends (SparkSessionExtensions => Unit) { extensions.injectParser { (spark, parser) => new FlintSparkSqlParser(parser) } + + extensions.injectFunction(TumbleFunction.description) + extensions.injectOptimizerRule { spark => new FlintSparkOptimizer(spark) } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/TumbleFunction.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/TumbleFunction.scala new file mode 100644 index 000000000..8ab27c9ec --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/function/TumbleFunction.scala @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.function + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.functions.window + +/** + * Tumble windowing function that groups row into fixed interval window without overlap. + */ +object TumbleFunction { + + val identifier: FunctionIdentifier = FunctionIdentifier("tumble") + + val exprInfo: ExpressionInfo = new ExpressionInfo(classOf[Column].getCanonicalName, "window") + + val functionBuilder: Seq[Expression] => Expression = + (children: Seq[Expression]) => { + // Delegate actual implementation to window() function + val timeColumn = children.head + val windowDuration = children(1) + window(new Column(timeColumn), windowDuration.toString()).expr + } + + /** + * Function description to register current function to Spark extension. + */ + val description: (FunctionIdentifier, ExpressionInfo, FunctionBuilder) = + (identifier, exprInfo, functionBuilder) +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/TumbleFunctionSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/TumbleFunctionSuite.scala new file mode 100644 index 000000000..a6c8fe8be --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/function/TumbleFunctionSuite.scala @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.function + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite + +class TumbleFunctionSuite extends FlintSuite with Matchers { + + test("should require both column name and window expression as arguments") { + // TumbleFunction.functionBuilder(AttributeReference()) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala new file mode 100644 index 000000000..8a32b8112 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkWindowingFunctionITSuite.scala @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.sql.Timestamp + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkWindowingFunctionITSuite extends QueryTest with FlintSuite { + + test("tumble windowing function") { + val inputDF = spark + .createDataFrame( + Seq( + (1L, "2023-01-01 00:00:00"), + (2L, "2023-01-01 00:09:00"), + (3L, "2023-01-01 00:15:00"))) + .toDF("id", "timestamp") + + val resultDF = inputDF.selectExpr("TUMBLE(timestamp, '10 minutes') AS window") + val expectedData = Seq( + Row(Row(timestamp("2023-01-01 00:00:00"), timestamp("2023-01-01 00:10:00"))), + Row(Row(timestamp("2023-01-01 00:00:00"), timestamp("2023-01-01 00:10:00"))), + Row(Row(timestamp("2023-01-01 00:10:00"), timestamp("2023-01-01 00:20:00")))) + + checkAnswer(resultDF, expectedData) + } + + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) +}