Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge changes from fork to feature branch #138

Merged
100 changes: 100 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF;

import com.snowflake.snowpark.functions;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.DataType;
import com.snowflake.snowpark_java.udf.*;
Expand Down Expand Up @@ -3880,6 +3881,105 @@ public static Column listagg(Column col) {
return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn()));
}

/**
* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
* df.sort(Functions.desc("a")).show();
* -------
* |"A" |
* -------
* |3 |
* |2 |
* |1 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in descending manner.
*/
public static Column desc(String name) {
return new Column(functions.desc(name));
}

/**
* Returns a Column expression with values sorted in ascending order.
*
* <p>Example: order column values in ascending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
* df.sort(Functions.asc("a")).show();
* -------
* |"A" |
* -------
* |1 |
* |2 |
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in ascending manner.
*/
public static Column asc(String name) {
return new Column(functions.asc(name));
}

/**
* Returns the size of the input ARRAY.
*
* <p>If the specified column contains a VARIANT value that contains an ARRAY, the size of the
* ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* <p>Example: calculate size of the array in a column
*
* <pre>{@code
* DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
* df.select(Functions.size(Functions.col("arr"))).show();
* -------------------------
* |"ARRAY_SIZE(""ARR"")" |
* -------------------------
* |3 |
* -------------------------
* }</pre>
*
* @since 1.14.0
* @param col The input column name
* @return size of the input ARRAY.
*/
public static Column size(Column col) {
return array_size(col);
}

/**
* Creates a Column expression from row SQL text.
*
* <p>Note that the function does not interpret or check the SQL text.
*
* <pre>{@code
* DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
* df.filter(Functions.expr("a > 2")).show();
* -------
* |"A" |
* -------
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param s The input column name
* @return column expression from input statement.
*/
public static Column expr(String s) {
return sqlExpr(s);
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
183 changes: 177 additions & 6 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ package com.snowflake.snowpark

import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.internal.ScalaFunctions._
import com.snowflake.snowpark.internal.{
ErrorMessage,
OpenTelemetry,
UDXRegistrationHandler,
Utils
}
import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils}
import com.snowflake.snowpark.types.TimestampType

import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
Expand Down Expand Up @@ -3140,6 +3136,181 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
* df.sort(desc("id")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* |2 |
* |1 |
* --------
* }}}
*
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in a descending manner.
*/
def desc(colName: String): Column = col(colName).desc

/**
* Returns a Column expression with values sorted in ascending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id")
* df.sort(asc("id")).show()
*
* --------
* |"ID" |
* --------
* |1 |
* |2 |
* |3 |
* --------
* }}}
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in an ascending manner.
*/
def asc(colName: String): Column = col(colName).asc

/**
* Returns the size of the input ARRAY.
*
* If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY
* is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.select(size(col("id"))).show()
*
* ------------------------
* |"ARRAY_SIZE(""ID"")" |
* ------------------------
* |3 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param c Column to get the size.
* @return Size of array column.
*/
def size(c: Column): Column = array_size(c)

/**
* Creates a [[Column]] expression from raw SQL text.
*
* Note that the function does not interpret or check the SQL text.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.filter(expr("id > 2")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* --------
* }}}
*
* @since 1.14.0
* @param s SQL Expression as text.
* @return Converted SQL Expression.
*/
def expr(s: String): Column = sqlExpr(s)

/**
* Wrapper for Snowflake built-in array function. Create array from columns.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id")
* df.select(array(col("a"), col("b")).as("id")).show()
*
* --------
* |"ID" |
* --------
* |[ |
* | 1, |
* | 2 |
* |] |
* |[ |
* | 4, |
* | 5 |
* |] |
* --------
* }}}
*
* @since 1.14.0
* @param c Columns to build the array.
* @return The array.
*/
def array(c: Column*): Column = array_construct(c: _*)

/**
* Wrapper for Snowflake built-in date_format function.
* Converts a date into the specified format.
* Example:
* {{{
* val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date")
* df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show()
*
* --------------------
* |"FORMATTED_DATE" |
* --------------------
* |2023/10/10 |
* |2022/05/15 |
* |NULL |
* --------------------
*
* }}}
*
* @since 1.14.0
* @param c Column to format to date.
* @param s Date format.
* @return Column object.
*/
def date_format(c: Column, s: String): Column =
builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi"))

/**
* Wrapper for Snowflake built-in last function.
* Gets the last value of a column according to its grouping.
* Functional difference with windows, In Snowpark is needed the order by.
* SQL doesn't guarantee the order.
* Example
* {{{
* val df = session.createDataFrame(Seq((5, "a", 10),
* (5, "b", 20),
* (3, "d", 15),
* (3, "e", 40))).toDF("grade", "name", "score")
* val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
* df.select(last(col("name")).over(window)).show()
*
* ---------------------
* |"LAST_SCORE_NAME" |
* ---------------------
* |a |
* |a |
* |d |
* |d |
* ---------------------
* }}}
*
* @since 1.14.0
* @param c Column to obtain last value.
* @return Column object.
*/
def last(c: Column): Column =
builtin("LAST_VALUE")(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
34 changes: 34 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2764,4 +2764,38 @@ public void any_value() {
assert result.length == 1;
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}

@Test
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};

checkAnswer(df.sort(Functions.asc("a")), expected, false);
}

@Test
public void test_desc() {
DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)");
Row[] expected = {Row.create(3), Row.create(2), Row.create(1)};

checkAnswer(df.sort(Functions.desc("a")), expected, false);
}

@Test
public void test_size() {
DataFrame df = getSession()
.sql(
"select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
Row[] expected = {Row.create(3)};

checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
}

@Test
public void test_expr() {
DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
Row[] expected = {Row.create(3)};
checkAnswer(df.filter(Functions.expr("a > 2")), expected, false);
}

}
Loading
Loading