From f7647e427315ae5829348b5daaa201db52195e16 Mon Sep 17 00:00:00 2001 From: Ganesh Mahadevan Date: Thu, 8 Aug 2024 12:09:57 -0500 Subject: [PATCH] SNOW-802269 - Add missing scala and java functions (#139) * Merge changes from fork to feature branch (#138) * add java and scala size and ordering functions * add scala unit test for ordering and size function * update comments and add example * add java test cases * fix comments * add expr function for java and scala * add formatting functions scala * remove format_string func --------- Co-authored-by: sfc-gh-mrojas * add java function and test case * fix test case * fix test file import * fix test file import * fix docs --------- Co-authored-by: sfc-gh-mrojas --- .../snowflake/snowpark_java/Functions.java | 95 +++++++++++++++ .../com/snowflake/snowpark/functions.scala | 112 +++++++++++++++++- .../snowpark_test/JavaFunctionSuite.java | 31 +++++ .../snowpark_test/FunctionSuite.scala | 34 ++++++ 4 files changed, 266 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index dbadd87b..8daaf9fc 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3957,6 +3957,101 @@ public static Column size(Column col) { return array_size(col); } + /** + * Creates a Column expression from row SQL text. + * + *

Note that the function does not interpret or check the SQL text. + * + *

{@code
+   * DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
+   * df.filter(Functions.expr("a > 2")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |3    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param s The SQL text + * @return column expression from input statement. + */ + public static Column expr(String s) { + return sqlExpr(s); + } + + /** + * Returns an ARRAY constructed from zero, one, or more inputs. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)");
+   * df.select(Functions.array(df.col("a"), df.col("b"), df.col("c")).as("array")).show();
+   *-----------
+   * |"ARRAY"  |
+   * -----------
+   * |[        |
+   * |  1,     |
+   * |  2,     |
+   * |  3      |
+   * |]        |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @param cols The input column names + * @return Column object as array. + */ + public static Column array(Column... cols) { return array_construct(cols); } + + /** + * + * Converts an input expression into the corresponding date in the specified date format. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)");
+   * df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD").as("formatted_date")).show();
+   * --------------------
+   * |"FORMATTED_DATE"  |
+   * --------------------
+   * |2023/10/10        |
+   * |2022/05/15        |
+   * --------------------
+   * }
+ * + * @since 1.14.0 + * @param col The input date column name + * @param s string format + * @return formatted column object. + */ + public static Column date_format(Column col, String s) { + return new Column(functions.date_format(col.toScalaColumn(), s)); + } + + /** + * Returns the last value of the column in a group. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" +
+   *             "    (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");
+   * df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))).show();
+   * ----------------
+   * |"LAST_VALUE"  |
+   * ----------------
+   * |a             |
+   * |a             |
+   * |d             |
+   * |d             |
+   * ----------------
+   * }
+ * + * @since 1.14.0 + * @param col The input column to get last value + * @return column object from last function. + */ + public static Column last(Column col) { + return new Column(functions.last(col.toScalaColumn())); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 1cd3eff0..662a00c4 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2,12 +2,8 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.internal.ScalaFunctions._ -import com.snowflake.snowpark.internal.{ - ErrorMessage, - OpenTelemetry, - UDXRegistrationHandler, - Utils -} +import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils} +import com.snowflake.snowpark.types.TimestampType import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -3207,6 +3203,110 @@ object functions { */ def size(c: Column): Column = array_size(c) + /** + * Creates a [[Column]] expression from raw SQL text. + * + * Note that the function does not interpret or check the SQL text. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.filter(expr("id > 2")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * -------- + * }}} + * + * @since 1.14.0 + * @param s SQL Expression as text. + * @return Converted SQL Expression. + */ + def expr(s: String): Column = sqlExpr(s) + + /** + * Returns an ARRAY constructed from zero, one, or more inputs. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id") + * df.select(array(col("a"), col("b")).as("id")).show() + * + * -------- + * |"ID" | + * -------- + * |[ | + * | 1, | + * | 2 | + * |] | + * |[ | + * | 4, | + * | 5 | + * |] | + * -------- + * }}} + * + * @since 1.14.0 + * @param c Columns to build the array. + * @return The array. + */ + def array(c: Column*): Column = array_construct(c: _*) + + /** + * Converts an input expression into the corresponding date in the specified date format. + * Example: + * {{{ + * val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date") + * df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show() + * + * -------------------- + * |"FORMATTED_DATE" | + * -------------------- + * |2023/10/10 | + * |2022/05/15 | + * |NULL | + * -------------------- + * + * }}} + * + * @since 1.14.0 + * @param c Column to format to date. + * @param s Date format. + * @return Column object. + */ + def date_format(c: Column, s: String): Column = + builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi")) + + /** + * Returns the last value of the column in a group. + * Example + * {{{ + * val df = session.createDataFrame(Seq((5, "a", 10), + * (5, "b", 20), + * (3, "d", 15), + * (3, "e", 40))).toDF("grade", "name", "score") + * val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + * df.select(last(col("name")).over(window)).show() + * + * --------------------- + * |"LAST_SCORE_NAME" | + * --------------------- + * |a | + * |a | + * |d | + * |d | + * --------------------- + * }}} + * + * @since 1.14.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(c: Column): Column = + builtin("LAST_VALUE")(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 2b3b4fc9..edddec2e 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2790,4 +2790,35 @@ public void test_size() { checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false); } + + @Test + public void test_expr() { + DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)"); + Row[] expected = {Row.create(3)}; + checkAnswer(df.filter(Functions.expr("a > 2")), expected, false); + } + + @Test + public void test_array() { + DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)"); + Row[] expected = {Row.create("[\n 1,\n 2,\n 3\n]")}; + checkAnswer(df.select(Functions.array(df.col("a"), df.col("b"), df.col("c"))), expected, false); + } + + @Test + public void date_format() { + DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)"); + Row[] expected = {Row.create("2023/10/10"), Row.create("2022/05/15")}; + + checkAnswer(df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD")), expected, false); + } + + @Test + public void last() { + DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" + + " (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)"); + + Row[] expected = {Row.create("a"), Row.create("a"), Row.create("d"), Row.create("d")}; + checkAnswer(df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 770e7c7d..806a6ff8 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2206,6 +2206,40 @@ trait FunctionSuite extends TestData { val expected = Seq((3)).toDF("size") checkAnswer(input.select(size(col("size"))), expected, sort = false) } + + test("expr function") { + + val input = Seq(1, 2, 3).toDF("id") + val expected = Seq((3)).toDF("id") + checkAnswer(input.filter(expr("id > 2")), expected, sort = false) + } + + test("array function") { + + val input = Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c") + val expected = Seq(Array(1, 2), Array(4, 5)).toDF("id") + checkAnswer(input.select(array(col("a"), col("b")).as("id")), expected, sort = false) + } + + test("date format function") { + + val input = Seq("2023-10-10", "2022-05-15").toDF("date") + val expected = Seq("2023/10/10", "2022/05/15").toDF("formatted_date") + + checkAnswer(input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")), + expected, sort = false) + } + + test("last function") { + + val input = Seq((5, "a", 10), (5, "b", 20), + (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score") + val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + val expected = Seq("a", "a", "d", "d").toDF("last_score_name") + + checkAnswer(input.select(last(col("name")).over(window).as("last_score_name")), + expected, sort = false) + } }