Skip to content

Commit

Permalink
SNOW-802269 Add ordering and size function for scala and java modules (
Browse files Browse the repository at this point in the history
…#133)

* add java and scala size and ordering functions

* add scala unit test for ordering and size function

* update comments and add example

* add java test cases

* fix comments

---------

Co-authored-by: sfc-gh-mrojas <[email protected]>
  • Loading branch information
sfc-gh-gmahadevan and sfc-gh-mrojas authored Aug 7, 2024
1 parent 71c3d1b commit d2339f4
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 0 deletions.
77 changes: 77 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF;

import com.snowflake.snowpark.functions;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.DataType;
import com.snowflake.snowpark_java.udf.*;
Expand Down Expand Up @@ -3880,6 +3881,82 @@ public static Column listagg(Column col) {
return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn()));
}

/**
* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
* df.sort(Functions.desc("a")).show();
* -------
* |"A" |
* -------
* |3 |
* |2 |
* |1 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in descending manner.
*/
public static Column desc(String name) {
return new Column(functions.desc(name));
}

/**
* Returns a Column expression with values sorted in ascending order.
*
* <p>Example: order column values in ascending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
* df.sort(Functions.asc("a")).show();
* -------
* |"A" |
* -------
* |1 |
* |2 |
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in ascending manner.
*/
public static Column asc(String name) {
return new Column(functions.asc(name));
}

/**
* Returns the size of the input ARRAY.
*
* <p>If the specified column contains a VARIANT value that contains an ARRAY, the size of the
* ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* <p>Example: calculate size of the array in a column
*
* <pre>{@code
* DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
* df.select(Functions.size(Functions.col("arr"))).show();
* -------------------------
* |"ARRAY_SIZE(""ARR"")" |
* -------------------------
* |3 |
* -------------------------
* }</pre>
*
* @since 1.14.0
* @param col The input column name
* @return size of the input ARRAY.
*/
public static Column size(Column col) {
return array_size(col);
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
67 changes: 67 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,73 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
* df.sort(desc("id")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* |2 |
* |1 |
* --------
* }}}
*
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in a descending manner.
*/
def desc(colName: String): Column = col(colName).desc

/**
* Returns a Column expression with values sorted in ascending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id")
* df.sort(asc("id")).show()
*
* --------
* |"ID" |
* --------
* |1 |
* |2 |
* |3 |
* --------
* }}}
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in an ascending manner.
*/
def asc(colName: String): Column = col(colName).asc

/**
* Returns the size of the input ARRAY.
*
* If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY
* is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.select(size(col("id"))).show()
*
* ------------------------
* |"ARRAY_SIZE(""ID"")" |
* ------------------------
* |3 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param c Column to get the size.
* @return Size of array column.
*/
def size(c: Column): Column = array_size(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
26 changes: 26 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2764,4 +2764,30 @@ public void any_value() {
assert result.length == 1;
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}

@Test
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};

checkAnswer(df.sort(Functions.asc("a")), expected, false);
}

@Test
public void test_desc() {
DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)");
Row[] expected = {Row.create(3), Row.create(2), Row.create(1)};

checkAnswer(df.sort(Functions.desc("a")), expected, false);
}

@Test
public void test_size() {
DataFrame df = getSession()
.sql(
"select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
Row[] expected = {Row.create(3)};

checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
}
}
29 changes: 29 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2178,6 +2178,35 @@ trait FunctionSuite extends TestData {
sort = false)
}

test("desc column order") {
val input = Seq(1, 2, 3).toDF("data")
val expected = Seq(3, 2, 1).toDF("data")

val inputStr = Seq("a", "b", "c").toDF("dataStr")
val expectedStr = Seq("c", "b", "a").toDF("dataStr")

checkAnswer(input.sort(desc("data")), expected, sort = false)
checkAnswer(inputStr.sort(desc("dataStr")), expectedStr, sort = false)
}

test("asc column order") {
val input = Seq(3, 2, 1).toDF("data")
val expected = Seq(1, 2, 3).toDF("data")

val inputStr = Seq("c", "b", "a").toDF("dataStr")
val expectedStr = Seq("a", "b", "c").toDF("dataStr")

checkAnswer(input.sort(asc("data")), expected, sort = false)
checkAnswer(inputStr.sort(asc("dataStr")), expectedStr, sort = false)
}

test("column array size") {

val input = Seq(Array(1, 2, 3)).toDF("size")
val expected = Seq((3)).toDF("size")
checkAnswer(input.select(size(col("size"))), expected, sort = false)
}

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down

0 comments on commit d2339f4

Please sign in to comment.