diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 56d8d08b..19a90f03 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2,6 +2,7 @@ import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF; +import com.snowflake.snowpark.functions; import com.snowflake.snowpark.internal.JavaUtils; import com.snowflake.snowpark_java.types.DataType; import com.snowflake.snowpark_java.udf.*; @@ -3880,6 +3881,105 @@ public static Column listagg(Column col) { return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn())); } + /** + * Returns a Column expression with values sorted in descending order. + * + *
Example: order column values in descending + * + *
{@code + * DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)"); + * df.sort(Functions.desc("a")).show(); + * ------- + * |"A" | + * ------- + * |3 | + * |2 | + * |1 | + * ------- + * }+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in descending manner. + */ + public static Column desc(String name) { + return new Column(functions.desc(name)); + } + + /** + * Returns a Column expression with values sorted in ascending order. + * + *
Example: order column values in ascending + * + *
{@code + * DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); + * df.sort(Functions.asc("a")).show(); + * ------- + * |"A" | + * ------- + * |1 | + * |2 | + * |3 | + * ------- + * }+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in ascending manner. + */ + public static Column asc(String name) { + return new Column(functions.asc(name)); + } + + /** + * Returns the size of the input ARRAY. + * + *
If the specified column contains a VARIANT value that contains an ARRAY, the size of the + * ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + *
Example: calculate size of the array in a column + * + *
{@code + * DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)"); + * df.select(Functions.size(Functions.col("arr"))).show(); + * ------------------------- + * |"ARRAY_SIZE(""ARR"")" | + * ------------------------- + * |3 | + * ------------------------- + * }+ * + * @since 1.14.0 + * @param col The input column name + * @return size of the input ARRAY. + */ + public static Column size(Column col) { + return array_size(col); + } + + /** + * Creates a Column expression from row SQL text. + * + *
Note that the function does not interpret or check the SQL text. + * + *
{@code + * DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)"); + * df.filter(Functions.expr("a > 2")).show(); + * ------- + * |"A" | + * ------- + * |3 | + * ------- + * }+ * + * @since 1.14.0 + * @param s The input column name + * @return column expression from input statement. + */ + public static Column expr(String s) { + return sqlExpr(s); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index a7fd9ff0..a13b9450 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2,12 +2,8 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.internal.ScalaFunctions._ -import com.snowflake.snowpark.internal.{ - ErrorMessage, - OpenTelemetry, - UDXRegistrationHandler, - Utils -} +import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils} +import com.snowflake.snowpark.types.TimestampType import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -3140,6 +3136,181 @@ object functions { */ def listagg(col: Column): Column = listagg(col, "", isDistinct = false) + /** + * Returns a Column expression with values sorted in descending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") + * df.sort(desc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * |2 | + * |1 | + * -------- + * }}} + * + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in a descending manner. + */ + def desc(colName: String): Column = col(colName).desc + + /** + * Returns a Column expression with values sorted in ascending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id") + * df.sort(asc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |1 | + * |2 | + * |3 | + * -------- + * }}} + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in an ascending manner. + */ + def asc(colName: String): Column = col(colName).asc + + /** + * Returns the size of the input ARRAY. + * + * If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY + * is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.select(size(col("id"))).show() + * + * ------------------------ + * |"ARRAY_SIZE(""ID"")" | + * ------------------------ + * |3 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @param c Column to get the size. + * @return Size of array column. + */ + def size(c: Column): Column = array_size(c) + + /** + * Creates a [[Column]] expression from raw SQL text. + * + * Note that the function does not interpret or check the SQL text. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.filter(expr("id > 2")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * -------- + * }}} + * + * @since 1.14.0 + * @param s SQL Expression as text. + * @return Converted SQL Expression. + */ + def expr(s: String): Column = sqlExpr(s) + + /** + * Wrapper for Snowflake built-in array function. Create array from columns. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id") + * df.select(array(col("a"), col("b")).as("id")).show() + * + * -------- + * |"ID" | + * -------- + * |[ | + * | 1, | + * | 2 | + * |] | + * |[ | + * | 4, | + * | 5 | + * |] | + * -------- + * }}} + * + * @since 1.14.0 + * @param c Columns to build the array. + * @return The array. + */ + def array(c: Column*): Column = array_construct(c: _*) + + /** + * Wrapper for Snowflake built-in date_format function. + * Converts a date into the specified format. + * Example: + * {{{ + * val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date") + * df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show() + * + * -------------------- + * |"FORMATTED_DATE" | + * -------------------- + * |2023/10/10 | + * |2022/05/15 | + * |NULL | + * -------------------- + * + * }}} + * + * @since 1.14.0 + * @param c Column to format to date. + * @param s Date format. + * @return Column object. + */ + def date_format(c: Column, s: String): Column = + builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi")) + + /** + * Wrapper for Snowflake built-in last function. + * Gets the last value of a column according to its grouping. + * Functional difference with windows, In Snowpark is needed the order by. + * SQL doesn't guarantee the order. + * Example + * {{{ + * val df = session.createDataFrame(Seq((5, "a", 10), + * (5, "b", 20), + * (3, "d", 15), + * (3, "e", 40))).toDF("grade", "name", "score") + * val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + * df.select(last(col("name")).over(window)).show() + * + * --------------------- + * |"LAST_SCORE_NAME" | + * --------------------- + * |a | + * |a | + * |d | + * |d | + * --------------------- + * }}} + * + * @since 1.14.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(c: Column): Column = + builtin("LAST_VALUE")(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 6ee298d3..3fbe1ad1 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2764,4 +2764,38 @@ public void any_value() { assert result.length == 1; assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3; } + + @Test + public void test_asc() { + DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); + Row[] expected = {Row.create(1), Row.create(2), Row.create(3)}; + + checkAnswer(df.sort(Functions.asc("a")), expected, false); + } + + @Test + public void test_desc() { + DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)"); + Row[] expected = {Row.create(3), Row.create(2), Row.create(1)}; + + checkAnswer(df.sort(Functions.desc("a")), expected, false); + } + + @Test + public void test_size() { + DataFrame df = getSession() + .sql( + "select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)"); + Row[] expected = {Row.create(3)}; + + checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false); + } + + @Test + public void test_expr() { + DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)"); + Row[] expected = {Row.create(3)}; + checkAnswer(df.filter(Functions.expr("a > 2")), expected, false); + } + } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index e473de12..841413d6 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2178,6 +2178,70 @@ trait FunctionSuite extends TestData { sort = false) } + test("desc column order") { + val input = Seq(1, 2, 3).toDF("data") + val expected = Seq(3, 2, 1).toDF("data") + + val inputStr = Seq("a", "b", "c").toDF("dataStr") + val expectedStr = Seq("c", "b", "a").toDF("dataStr") + + checkAnswer(input.sort(desc("data")), expected, sort = false) + checkAnswer(inputStr.sort(desc("dataStr")), expectedStr, sort = false) + } + + test("asc column order") { + + val input = Seq(3, 2, 1).toDF("data") + val expected = Seq(1, 2, 3).toDF("data") + + val inputStr = Seq("c", "b", "a").toDF("dataStr") + val expectedStr = Seq("a", "b", "c").toDF("dataStr") + + checkAnswer(input.sort(asc("data")), expected, sort = false) + checkAnswer(inputStr.sort(asc("dataStr")), expectedStr, sort = false) + } + + test("column array size") { + + val input = Seq(Array(1, 2, 3)).toDF("size") + val expected = Seq((3)).toDF("size") + checkAnswer(input.select(size(col("size"))), expected, sort = false) + } + + test("expr function") { + + val input = Seq(1, 2, 3).toDF("id") + val expected = Seq((3)).toDF("id") + checkAnswer(input.filter(expr("id > 2")), expected, sort = false) + } + + test("array function") { + + val input = Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c") + val expected = Seq(Array(1, 2), Array(4, 5)).toDF("id") + checkAnswer(input.select(array(col("a"), col("b")).as("id")), expected, sort = false) + } + + test("date-format function") { + + val input = Seq("2023-10-10", "2022-05-15").toDF("date") + val expected = Seq("2023/10/10", "2022/05/15").toDF("formatted_date") + + checkAnswer(input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")), + expected, sort = false) + } + + test("last value function") { + + val input = Seq((5, "a", 10), (5, "b", 20), + (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score") + val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + val expected = Seq("a", "a", "d", "d").toDF("last_score_name") + + checkAnswer(input.select(last(col("name")).over(window).as("last_score_name")), + expected, sort = false) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession