Skip to content

Commit

Permalink
Merge changes from fork to feature branch (#138)
Browse files Browse the repository at this point in the history
* add java and scala size and ordering functions

* add scala unit test for ordering and size function

* update comments and add example

* add java test cases

* fix comments

* add expr function for java and scala

* add formatting functions scala

* remove format_string func

---------

Co-authored-by: sfc-gh-mrojas <[email protected]>
  • Loading branch information
sfc-gh-gmahadevan and sfc-gh-mrojas authored Aug 7, 2024
1 parent 71c3d1b commit 9fdd791
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 6 deletions.
100 changes: 100 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF;

import com.snowflake.snowpark.functions;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.DataType;
import com.snowflake.snowpark_java.udf.*;
Expand Down Expand Up @@ -3880,6 +3881,105 @@ public static Column listagg(Column col) {
return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn()));
}

/**
* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
* df.sort(Functions.desc("a")).show();
* -------
* |"A" |
* -------
* |3 |
* |2 |
* |1 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in descending manner.
*/
public static Column desc(String name) {
return new Column(functions.desc(name));
}

/**
* Returns a Column expression with values sorted in ascending order.
*
* <p>Example: order column values in ascending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
* df.sort(Functions.asc("a")).show();
* -------
* |"A" |
* -------
* |1 |
* |2 |
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in ascending manner.
*/
public static Column asc(String name) {
return new Column(functions.asc(name));
}

/**
* Returns the size of the input ARRAY.
*
* <p>If the specified column contains a VARIANT value that contains an ARRAY, the size of the
* ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* <p>Example: calculate size of the array in a column
*
* <pre>{@code
* DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
* df.select(Functions.size(Functions.col("arr"))).show();
* -------------------------
* |"ARRAY_SIZE(""ARR"")" |
* -------------------------
* |3 |
* -------------------------
* }</pre>
*
* @since 1.14.0
* @param col The input column name
* @return size of the input ARRAY.
*/
public static Column size(Column col) {
return array_size(col);
}

/**
* Creates a Column expression from row SQL text.
*
* <p>Note that the function does not interpret or check the SQL text.
*
* <pre>{@code
* DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
* df.filter(Functions.expr("a > 2")).show();
* -------
* |"A" |
* -------
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param s The input column name
* @return column expression from input statement.
*/
public static Column expr(String s) {
return sqlExpr(s);
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
183 changes: 177 additions & 6 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ package com.snowflake.snowpark

import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.internal.ScalaFunctions._
import com.snowflake.snowpark.internal.{
ErrorMessage,
OpenTelemetry,
UDXRegistrationHandler,
Utils
}
import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils}
import com.snowflake.snowpark.types.TimestampType

import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
Expand Down Expand Up @@ -3140,6 +3136,181 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
* df.sort(desc("id")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* |2 |
* |1 |
* --------
* }}}
*
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in a descending manner.
*/
def desc(colName: String): Column = col(colName).desc

/**
* Returns a Column expression with values sorted in ascending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id")
* df.sort(asc("id")).show()
*
* --------
* |"ID" |
* --------
* |1 |
* |2 |
* |3 |
* --------
* }}}
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in an ascending manner.
*/
def asc(colName: String): Column = col(colName).asc

/**
* Returns the size of the input ARRAY.
*
* If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY
* is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.select(size(col("id"))).show()
*
* ------------------------
* |"ARRAY_SIZE(""ID"")" |
* ------------------------
* |3 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param c Column to get the size.
* @return Size of array column.
*/
def size(c: Column): Column = array_size(c)

/**
* Creates a [[Column]] expression from raw SQL text.
*
* Note that the function does not interpret or check the SQL text.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.filter(expr("id > 2")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* --------
* }}}
*
* @since 1.14.0
* @param s SQL Expression as text.
* @return Converted SQL Expression.
*/
def expr(s: String): Column = sqlExpr(s)

/**
* Wrapper for Snowflake built-in array function. Create array from columns.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id")
* df.select(array(col("a"), col("b")).as("id")).show()
*
* --------
* |"ID" |
* --------
* |[ |
* | 1, |
* | 2 |
* |] |
* |[ |
* | 4, |
* | 5 |
* |] |
* --------
* }}}
*
* @since 1.14.0
* @param c Columns to build the array.
* @return The array.
*/
def array(c: Column*): Column = array_construct(c: _*)

/**
* Wrapper for Snowflake built-in date_format function.
* Converts a date into the specified format.
* Example:
* {{{
* val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date")
* df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show()
*
* --------------------
* |"FORMATTED_DATE" |
* --------------------
* |2023/10/10 |
* |2022/05/15 |
* |NULL |
* --------------------
*
* }}}
*
* @since 1.14.0
* @param c Column to format to date.
* @param s Date format.
* @return Column object.
*/
def date_format(c: Column, s: String): Column =
builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi"))

/**
* Wrapper for Snowflake built-in last function.
* Gets the last value of a column according to its grouping.
* Functional difference with windows, In Snowpark is needed the order by.
* SQL doesn't guarantee the order.
* Example
* {{{
* val df = session.createDataFrame(Seq((5, "a", 10),
* (5, "b", 20),
* (3, "d", 15),
* (3, "e", 40))).toDF("grade", "name", "score")
* val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
* df.select(last(col("name")).over(window)).show()
*
* ---------------------
* |"LAST_SCORE_NAME" |
* ---------------------
* |a |
* |a |
* |d |
* |d |
* ---------------------
* }}}
*
* @since 1.14.0
* @param c Column to obtain last value.
* @return Column object.
*/
def last(c: Column): Column =
builtin("LAST_VALUE")(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
34 changes: 34 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2764,4 +2764,38 @@ public void any_value() {
assert result.length == 1;
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}

@Test
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};

checkAnswer(df.sort(Functions.asc("a")), expected, false);
}

@Test
public void test_desc() {
DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)");
Row[] expected = {Row.create(3), Row.create(2), Row.create(1)};

checkAnswer(df.sort(Functions.desc("a")), expected, false);
}

@Test
public void test_size() {
DataFrame df = getSession()
.sql(
"select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
Row[] expected = {Row.create(3)};

checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
}

@Test
public void test_expr() {
DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
Row[] expected = {Row.create(3)};
checkAnswer(df.filter(Functions.expr("a > 2")), expected, false);
}

}
Loading

0 comments on commit 9fdd791

Please sign in to comment.