Skip to content

Commit

Permalink
SNOW-802269 - Add missing scala and java functions (#139)
Browse files Browse the repository at this point in the history
* Merge changes from fork to feature branch (#138)

* add java and scala size and ordering functions

* add scala unit test for ordering and size function

* update comments and add example

* add java test cases

* fix comments

* add expr function for java and scala

* add formatting functions scala

* remove format_string func

---------

Co-authored-by: sfc-gh-mrojas <[email protected]>

* add java function and test case

* fix test case

* fix test file import

* fix test file import

* fix docs

---------

Co-authored-by: sfc-gh-mrojas <[email protected]>
  • Loading branch information
sfc-gh-gmahadevan and sfc-gh-mrojas authored Aug 8, 2024
1 parent d26eb08 commit f7647e4
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 6 deletions.
95 changes: 95 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -3957,6 +3957,101 @@ public static Column size(Column col) {
return array_size(col);
}

/**
* Creates a Column expression from row SQL text.
*
* <p>Note that the function does not interpret or check the SQL text.
*
* <pre>{@code
* DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
* df.filter(Functions.expr("a > 2")).show();
* -------
* |"A" |
* -------
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param s The SQL text
* @return column expression from input statement.
*/
public static Column expr(String s) {
return sqlExpr(s);
}

/**
* Returns an ARRAY constructed from zero, one, or more inputs.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)");
* df.select(Functions.array(df.col("a"), df.col("b"), df.col("c")).as("array")).show();
*-----------
* |"ARRAY" |
* -----------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* -----------
* }</pre>
*
* @since 1.14.0
* @param cols The input column names
* @return Column object as array.
*/
public static Column array(Column... cols) { return array_construct(cols); }

/**
*
* Converts an input expression into the corresponding date in the specified date format.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)");
* df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD").as("formatted_date")).show();
* --------------------
* |"FORMATTED_DATE" |
* --------------------
* |2023/10/10 |
* |2022/05/15 |
* --------------------
* }</pre>
*
* @since 1.14.0
* @param col The input date column name
* @param s string format
* @return formatted column object.
*/
public static Column date_format(Column col, String s) {
return new Column(functions.date_format(col.toScalaColumn(), s));
}

/**
* Returns the last value of the column in a group.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" +
* " (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");
* df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))).show();
* ----------------
* |"LAST_VALUE" |
* ----------------
* |a |
* |a |
* |d |
* |d |
* ----------------
* }</pre>
*
* @since 1.14.0
* @param col The input column to get last value
* @return column object from last function.
*/
public static Column last(Column col) {
return new Column(functions.last(col.toScalaColumn()));
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
112 changes: 106 additions & 6 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ package com.snowflake.snowpark

import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.internal.ScalaFunctions._
import com.snowflake.snowpark.internal.{
ErrorMessage,
OpenTelemetry,
UDXRegistrationHandler,
Utils
}
import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils}
import com.snowflake.snowpark.types.TimestampType

import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
Expand Down Expand Up @@ -3207,6 +3203,110 @@ object functions {
*/
def size(c: Column): Column = array_size(c)

/**
* Creates a [[Column]] expression from raw SQL text.
*
* Note that the function does not interpret or check the SQL text.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.filter(expr("id > 2")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* --------
* }}}
*
* @since 1.14.0
* @param s SQL Expression as text.
* @return Converted SQL Expression.
*/
def expr(s: String): Column = sqlExpr(s)

/**
* Returns an ARRAY constructed from zero, one, or more inputs.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id")
* df.select(array(col("a"), col("b")).as("id")).show()
*
* --------
* |"ID" |
* --------
* |[ |
* | 1, |
* | 2 |
* |] |
* |[ |
* | 4, |
* | 5 |
* |] |
* --------
* }}}
*
* @since 1.14.0
* @param c Columns to build the array.
* @return The array.
*/
def array(c: Column*): Column = array_construct(c: _*)

/**
* Converts an input expression into the corresponding date in the specified date format.
* Example:
* {{{
* val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date")
* df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show()
*
* --------------------
* |"FORMATTED_DATE" |
* --------------------
* |2023/10/10 |
* |2022/05/15 |
* |NULL |
* --------------------
*
* }}}
*
* @since 1.14.0
* @param c Column to format to date.
* @param s Date format.
* @return Column object.
*/
def date_format(c: Column, s: String): Column =
builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi"))

/**
* Returns the last value of the column in a group.
* Example
* {{{
* val df = session.createDataFrame(Seq((5, "a", 10),
* (5, "b", 20),
* (3, "d", 15),
* (3, "e", 40))).toDF("grade", "name", "score")
* val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
* df.select(last(col("name")).over(window)).show()
*
* ---------------------
* |"LAST_SCORE_NAME" |
* ---------------------
* |a |
* |a |
* |d |
* |d |
* ---------------------
* }}}
*
* @since 1.14.0
* @param c Column to obtain last value.
* @return Column object.
*/
def last(c: Column): Column =
builtin("LAST_VALUE")(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
31 changes: 31 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2790,4 +2790,35 @@ public void test_size() {

checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
}

@Test
public void test_expr() {
DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)");
Row[] expected = {Row.create(3)};
checkAnswer(df.filter(Functions.expr("a > 2")), expected, false);
}

@Test
public void test_array() {
DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)");
Row[] expected = {Row.create("[\n 1,\n 2,\n 3\n]")};
checkAnswer(df.select(Functions.array(df.col("a"), df.col("b"), df.col("c"))), expected, false);
}

@Test
public void date_format() {
DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)");
Row[] expected = {Row.create("2023/10/10"), Row.create("2022/05/15")};

checkAnswer(df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD")), expected, false);
}

@Test
public void last() {
DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" +
" (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");

Row[] expected = {Row.create("a"), Row.create("a"), Row.create("d"), Row.create("d")};
checkAnswer(df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))), expected, false);
}
}
34 changes: 34 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,40 @@ trait FunctionSuite extends TestData {
val expected = Seq((3)).toDF("size")
checkAnswer(input.select(size(col("size"))), expected, sort = false)
}

test("expr function") {

val input = Seq(1, 2, 3).toDF("id")
val expected = Seq((3)).toDF("id")
checkAnswer(input.filter(expr("id > 2")), expected, sort = false)
}

test("array function") {

val input = Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c")
val expected = Seq(Array(1, 2), Array(4, 5)).toDF("id")
checkAnswer(input.select(array(col("a"), col("b")).as("id")), expected, sort = false)
}

test("date format function") {

val input = Seq("2023-10-10", "2022-05-15").toDF("date")
val expected = Seq("2023/10/10", "2022/05/15").toDF("formatted_date")

checkAnswer(input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")),
expected, sort = false)
}

test("last function") {

val input = Seq((5, "a", 10), (5, "b", 20),
(3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score")
val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
val expected = Seq("a", "a", "d", "d").toDF("last_score_name")

checkAnswer(input.select(last(col("name")).over(window).as("last_score_name")),
expected, sort = false)
}

}

Expand Down

0 comments on commit f7647e4

Please sign in to comment.