From d26eb08a2105e9ba051589c7535dce896227c818 Mon Sep 17 00:00:00 2001 From: Ganesh Mahadevan Date: Wed, 7 Aug 2024 12:15:08 -0500 Subject: [PATCH] =?UTF-8?q?SNOW-802269=20Add=20ordering=20and=20size=20fun?= =?UTF-8?q?ction=20for=20scala=20and=20java=20modules=E2=80=A6=20(#137)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SNOW-802269 Add ordering and size function for scala and java modules (#133) * add java and scala size and ordering functions * add scala unit test for ordering and size function * update comments and add example * add java test cases * fix comments --------- Co-authored-by: sfc-gh-mrojas --- .../snowflake/snowpark_java/Functions.java | 77 +++++++++++++++++++ .../com/snowflake/snowpark/functions.scala | 67 ++++++++++++++++ .../snowpark_test/JavaFunctionSuite.java | 26 +++++++ .../snowpark_test/FunctionSuite.scala | 29 +++++++ 4 files changed, 199 insertions(+) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 56d8d08b..dbadd87b 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2,6 +2,7 @@ import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF; +import com.snowflake.snowpark.functions; import com.snowflake.snowpark.internal.JavaUtils; import com.snowflake.snowpark_java.types.DataType; import com.snowflake.snowpark_java.udf.*; @@ -3880,6 +3881,82 @@ public static Column listagg(Column col) { return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn())); } + /** + * Returns a Column expression with values sorted in descending order. + * + *

Example: order column values in descending + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
+   * df.sort(Functions.desc("a")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |3    |
+   * |2    |
+   * |1    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in descending manner. + */ + public static Column desc(String name) { + return new Column(functions.desc(name)); + } + + /** + * Returns a Column expression with values sorted in ascending order. + * + *

Example: order column values in ascending + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
+   * df.sort(Functions.asc("a")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |1    |
+   * |2    |
+   * |3    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in ascending manner. + */ + public static Column asc(String name) { + return new Column(functions.asc(name)); + } + + /** + * Returns the size of the input ARRAY. + * + *

If the specified column contains a VARIANT value that contains an ARRAY, the size of the + * ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + *

Example: calculate size of the array in a column + * + *

{@code
+   * DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
+   * df.select(Functions.size(Functions.col("arr"))).show();
+   * -------------------------
+   * |"ARRAY_SIZE(""ARR"")"  |
+   * -------------------------
+   * |3                      |
+   * -------------------------
+   * }
+ * + * @since 1.14.0 + * @param col The input column name + * @return size of the input ARRAY. + */ + public static Column size(Column col) { + return array_size(col); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index a7fd9ff0..1cd3eff0 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -3140,6 +3140,73 @@ object functions { */ def listagg(col: Column): Column = listagg(col, "", isDistinct = false) + /** + * Returns a Column expression with values sorted in descending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") + * df.sort(desc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * |2 | + * |1 | + * -------- + * }}} + * + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in a descending manner. + */ + def desc(colName: String): Column = col(colName).desc + + /** + * Returns a Column expression with values sorted in ascending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id") + * df.sort(asc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |1 | + * |2 | + * |3 | + * -------- + * }}} + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in an ascending manner. + */ + def asc(colName: String): Column = col(colName).asc + + /** + * Returns the size of the input ARRAY. + * + * If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY + * is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.select(size(col("id"))).show() + * + * ------------------------ + * |"ARRAY_SIZE(""ID"")" | + * ------------------------ + * |3 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @param c Column to get the size. + * @return Size of array column. + */ + def size(c: Column): Column = array_size(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 6ee298d3..2b3b4fc9 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2764,4 +2764,30 @@ public void any_value() { assert result.length == 1; assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3; } + + @Test + public void test_asc() { + DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); + Row[] expected = {Row.create(1), Row.create(2), Row.create(3)}; + + checkAnswer(df.sort(Functions.asc("a")), expected, false); + } + + @Test + public void test_desc() { + DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)"); + Row[] expected = {Row.create(3), Row.create(2), Row.create(1)}; + + checkAnswer(df.sort(Functions.desc("a")), expected, false); + } + + @Test + public void test_size() { + DataFrame df = getSession() + .sql( + "select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)"); + Row[] expected = {Row.create(3)}; + + checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index e473de12..770e7c7d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2178,6 +2178,35 @@ trait FunctionSuite extends TestData { sort = false) } + test("desc column order") { + val input = Seq(1, 2, 3).toDF("data") + val expected = Seq(3, 2, 1).toDF("data") + + val inputStr = Seq("a", "b", "c").toDF("dataStr") + val expectedStr = Seq("c", "b", "a").toDF("dataStr") + + checkAnswer(input.sort(desc("data")), expected, sort = false) + checkAnswer(inputStr.sort(desc("dataStr")), expectedStr, sort = false) + } + + test("asc column order") { + val input = Seq(3, 2, 1).toDF("data") + val expected = Seq(1, 2, 3).toDF("data") + + val inputStr = Seq("c", "b", "a").toDF("dataStr") + val expectedStr = Seq("a", "b", "c").toDF("dataStr") + + checkAnswer(input.sort(asc("data")), expected, sort = false) + checkAnswer(inputStr.sort(asc("dataStr")), expectedStr, sort = false) + } + + test("column array size") { + + val input = Seq(Array(1, 2, 3)).toDF("size") + val expected = Seq((3)).toDF("size") + checkAnswer(input.select(size(col("size"))), expected, sort = false) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession