diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 1d18a91c..78c64418 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3882,7 +3882,127 @@ public static Column listagg(Column col) { } /** - * Returns a Column expression with values sorted in descending order. + * Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp: + * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the + * specified string column. If the regex did not match, or the specified group did not match, an + * empty string is returned. Example: + * + *
{@code
+   * from snowflake.snowpark.functions import regexp_extract
+   * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
+   * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
+   *    ---------
+   *     |"RES"  |
+   *     ---------
+   *     |20     |
+   *     |40     |
+   *     ---------
+   * }
+ * + * @since 1.14.0 + * @param col Column. + * @param exp String + * @param position Integer. + * @param Occurences Integer. + * @param grpIdx Integer. + * @return Column object. + */ + public static Column regexp_extract( + Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) { + return new Column( + com.snowflake.snowpark.functions.regexp_extract( + col.toScalaColumn(), exp, position, Occurences, grpIdx)); + } + + /** + * Returns the sign of its argument: + * + *

- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0. + * + *

Args: col: The column to evaluate its sign Example:: * + * + *

{@code df =
+   * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+   * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+   * sign("c").alias("c_sign")).show()
+   *   ----------------------------------
+   *     |"A_SIGN"  |"B_SIGN"  |"C_SIGN"  |
+   *     ----------------------------------
+   *     |-1        |1         |0         |
+   *     ----------------------------------
+   * }
+ * + * @since 1.14.0 + * @param col Column to calculate the sign. + * @return Column object. + */ + public static Column signum(Column col) { + return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn())); + } + + /** + * Returns the sign of its argument: + * + *

- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0. + * + *

Args: col: The column to evaluate its sign Example:: + * + *

{@code df =
+   * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+   * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+   * sign("c").alias("c_sign")).show()
+   *   ----------------------------------
+   *     |"A_SIGN"  |"B_SIGN"  |"C_SIGN"  |
+   *     ----------------------------------
+   *     |-1        |1         |0         |
+   *     ----------------------------------
+   * }
+ * + * @since 1.14.0 + * @param col Column to calculate the sign. + * @return Column object. + */ + public static Column sign(Column col) { + return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn())); + } + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. If count + * is positive, everything the left of the final delimiter (counting from left) is returned. If + * count is negative, every to the right of the final delimiter (counting from the right) is + * returned. substring_index performs a case-sensitive match when searching for delim. + * + * @param col String. + * @param delim String + * @param count Integer. + * @return Column object. + * @since 1.14.0 + */ + public static Column substring_index(String col, String delim, Integer count) { + return new Column(com.snowflake.snowpark.functions.substring_index(col, delim, count)); + } + + /** + * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is + * returned. + * + *

Example:: + * + *

{@code
+   * df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+   * df.select(array_agg("a", True).alias("result")).show()
+   * "RESULT" [ 1, 2, 3 ]
+   * }
+ * + * @since 1.14.0 + * @param c Column to be collect. + * @return The array. + */ + public static Column collect_list(Column c) { + return new Column(com.snowflake.snowpark.functions.collect_list(c.toScalaColumn())); + } + + /* Returns a Column expression with values sorted in descending order. * *

Example: order column values in descending * @@ -4180,6 +4300,131 @@ public static Column unbase64(Column c) { return new Column(functions.unbase64(c.toScalaColumn())); } + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values ('scala', 'java scala python'), \n " +
+   *             "('b', 'abcd') as T(a,b)");
+   * df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")).show();
+   * ------------
+   * |"LOCATE"  |
+   * ------------
+   * |6         |
+   * |2         |
+   * ------------
+   * }
+ * + * @since 1.14.0 + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return returns the position of the first occurrence. + */ + public static Column locate(Column substr, Column str, int pos) { + return new Column(functions.locate(substr.toScalaColumn(), str.toScalaColumn(), pos)); + } + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * default to 1. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
+   * df.select(Functions.locate("b", Functions.col("s")).as("locate")).show();
+   * ------------
+   * |"LOCATE"  |
+   * ------------
+   * |2         |
+   * ------------
+   * }
+ * + * @since 1.14.0 + * @param substr string to search + * @param str value where string will be searched + * @return returns the position of the first occurrence. + */ + public static Column locate(String substr, Column str) { + return new Column(functions.locate(substr, str.toScalaColumn(), 1)); + } + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + *

This is equivalent to the NTILE function in SQL. + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
+   * df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y"))).as("ntile")).show();
+   * -----------
+   * |"NTILE"  |
+   * -----------
+   * |1        |
+   * |2        |
+   * |3        |
+   * |1        |
+   * |2        |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @param n number of groups + * @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition. + */ + public static Column ntile(int n) { + return new Column(functions.ntile(n)); + } + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake + * returns integers of 17-19 digits. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+   * df.withColumn("randn",Functions.randn()).select("randn").show();
+   * ------------------------
+   * |"RANDN"               |
+   * ------------------------
+   * |6799378361097866000   |
+   * |-7280487148628086605  |
+   * |775606662514393461    |
+   * ------------------------
+   * }
+ * + * @since 1.14.0 + * @return Random number. + */ + public static Column randn() { + return new Column(functions.randn()); + } + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake + * returns integers of 17-19 digits. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+   * df.withColumn("randn_with_seed",Functions.randn(123l)).select("randn_with_seed").show();
+   * ------------------------
+   * |"RANDN_WITH_SEED"     |
+   * ------------------------
+   * |5777523539921853504   |
+   * |-8190739547906189845  |
+   * |-1138438814981368515  |
+   * ------------------------
+   * }
+ * + * @since 1.14.0 + * @return Random number. + */ + public static Column randn(long seed) { + return new Column(functions.randn(seed)); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index c20475d7..0090fb96 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2853,26 +2853,183 @@ object functions { */ def listagg(col: Column): Column = listagg(col, "", isDistinct = false) - /** Returns a Column expression with values sorted in descending order. Example: - * {{{ - * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") - * df.sort(desc("id")).show() + /** Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp: + * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the + * specified string column. If the regex did not match, or the specified group did not match, an + * empty string is returned. Example: from snowflake.snowpark.functions import regexp_extract + * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"]) + * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() --------- + * \|"RES" | --------- + * | 20 | + * |:---| + * | 40 | + * --------- Note: non-greedy tokens such as are not supported + * @since 1.14.0 + * @return + * Column object. + */ + def regexp_extract( + colName: Column, + exp: String, + position: Int, + Occurences: Int, + grpIdx: Int): Column = { + when(colName.is_null, lit(null)) + .otherwise( + coalesce( + builtin("REGEXP_SUBSTR")( + colName, + lit(exp), + lit(position), + lit(Occurences), + lit("ce"), + lit(grpIdx)), + lit(""))) + } + + /** Returns the sign of its argument as mentioned : * - * -------- - * |"ID" | - * -------- - * |3 | - * |2 | - * |1 | - * -------- - * }}} + * - -1 if the argument is negative. + * - 1 if it is positive. + * - 0 if it is 0. * + * Args: col: The column to evaluate its sign Example:: >>> df = + * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>> + * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), + * sign("c").alias("c_sign")).show() ---------------------------------- + * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ---------------------------------- + * \|-1 |1 |0 | ---------------------------------- * @since 1.14.0 - * @param colName - * Column name. + * @param e + * Column to calculate the sign. + * @return + * Column object. + */ + def sign(colName: Column): Column = { + builtin("SIGN")(colName) + } + + /** Returns the sign of its argument: + * + * - -1 if the argument is negative. + * - 1 if it is positive. + * - 0 if it is 0. + * + * Args: col: The column to evaluate its sign Example:: >>> df = + * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>> + * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), + * sign("c").alias("c_sign")).show() ---------------------------------- + * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ---------------------------------- + * \|-1 |1 |0 | ---------------------------------- + * @since 1.14.0 + * @param e + * Column to calculate the sign. + * @return + * Column object. + */ + def signum(colName: Column): Column = { + builtin("SIGN")(colName) + } + + /** Returns the sign of the given column. Returns either 1 for positive, 0 for 0 or NaN, -1 for + * negative and null for null. NOTE: if string values are provided snowflake will attempts to + * cast. If it casts correctly, returns the calculation, if not an error will be thrown + * @since 1.14.0 + * @param columnName + * Name of the column to calculate the sign. * @return - * Column object ordered in a descending manner. + * Column object. + */ + def signum(columnName: String): Column = { + signum(col(columnName)) + } + + /** Returns the substring from string str before count occurrences of the delimiter delim. If + * count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + * @since 1.14.0 */ + def substring_index(str: String, delim: String, count: Int): Column = { + when( + lit(count) < lit(0), + callBuiltin( + "substring", + lit(str), + callBuiltin( + "regexp_instr", + sqlExpr(s"reverse('${str}')"), + lit(delim), + 1, + abs(lit(count)), + lit(0)))) + .otherwise( + callBuiltin( + "substring", + lit(str), + 1, + callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1))) + } + + /** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is + * returned. Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) + * >>> df.select(array_agg("a", True).alias("result")).show() ------------ + * \|"RESULT" | ------------ + * | [ | + * |:---| + * | 1, | + * | 2, | + * | 3 | + * | ] | + * ------------ + * @since 1.14.0 + * @param c + * Column to be collect. + * @return + * The array. + */ + def collect_list(c: Column): Column = array_agg(c) + + /** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is + * returned. + * + * Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) >>> + * df.select(array_agg("a", True).alias("result")).show() ------------ + * \|"RESULT" | ------------ + * | [ | + * |:---| + * | 1, | + * | 2, | + * | 3 | + * | ] | + * ------------ + * @since 1.14.0 + * @param s + * Column name to be collected. + * @return + * The array. + */ + def collect_list(s: String): Column = array_agg(col(s)) + + /* Returns a Column expression with values sorted in descending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") + * df.sort(desc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * |2 | + * |1 | + * -------- + * }}} + * + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in a descending manner. + */ def desc(colName: String): Column = col(colName).desc /** Returns a Column expression with values sorted in ascending order. Example: @@ -3150,6 +3307,135 @@ object functions { */ def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col) + /** Locate the position of the first occurrence of substr in a string column, after position pos. + * + * @note + * The position is not zero based, but 1 based index. returns 0 if substr could not be found in + * str. This function is just leverages the SF POSITION builtin Example + * {{{ + * val df = session.createDataFrame(Seq(("b", "abcd"))).toDF("a", "b") + * df.select(locate(col("a"), col("b"), 1).as("locate")).show() + * ------------ + * |"LOCATE" | + * ------------ + * |2 | + * ------------ + * + * }}} + * @since 1.14.0 + * @param substr + * string to search + * @param str + * value where string will be searched + * @param pos + * index for starting the search + * @return + * returns the position of the first occurrence. + */ + def locate(substr: Column, str: Column, pos: Int): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos) + + /** Locate the position of the first occurrence of substr in a string column, after position pos. + * + * @note + * The position is not zero based, but 1 based index. returns 0 if substr could not be found in + * str. This function is just leverages the SF POSITION builtin Example + * {{{ + * val df = session.createDataFrame(Seq("java scala python")).toDF("a") + * df.select(locate("scala", col("a")).as("locate")).show() + * ------------ + * |"LOCATE" | + * ------------ + * |6 | + * ------------ + * + * }}} + * @since 1.14.0 + * @param substr + * string to search + * @param str + * value where string will be searched + * @param pos + * index for starting the search. default to 1. + * @return + * Returns the position of the first occurrence + */ + def locate(substr: String, str: Column, pos: Int = 1): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos)) + + /** Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the + * second quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. Example + * {{{ + * val df = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score") + * val window = Window.partitionBy(col("grade")).orderBy(col("score")) + * df.select(ntile(2).over(window).as("ntile")).show() + * ----------- + * |"NTILE" | + * ----------- + * |1 | + * |1 | + * |2 | + * |2 | + * ----------- + * }}} + * + * @since 1.14.0 + * @param n + * number of groups + * @return + * returns the ntile group id (from 1 to n inclusive) in an ordered window partition. + */ + def ntile(n: Int): Column = callBuiltin("ntile", lit(n)) + + /** Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake + * returns integers of 17-19 digits. Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.withColumn("randn", randn()).select("randn").show() + * ------------------------ + * |"RANDN" | + * ------------------------ + * |-2093909082984812541 | + * |-1379817492278593383 | + * |-1231198046297539927 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @return + * Random number. + */ + def randn(): Column = + builtin("RANDOM")() + + /** Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. Calls to the Snowflake RANDOM function. NOTE: Snowflake returns + * integers of 17-19 digits. Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show() + * ------------------------ + * |"RANDN_WITH_SEED" | + * ------------------------ + * |5777523539921853504 | + * |-8190739547906189845 | + * |-1138438814981368515 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @param seed + * Seed to use in the random function. + * @return + * Random number. + */ + def randn(seed: Long): Column = + builtin("RANDOM")(seed) + /** Invokes a built-in snowflake function with the specified name and arguments. Arguments can be * of two types * diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 05e38211..bdd14337 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2766,6 +2766,59 @@ public void any_value() { } @Test + public void regexp_extract() { + DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)"); + Row[] expected = {Row.create("MAN")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false); + Row[] expected2 = {Row.create("PLAN")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false); + Row[] expected3 = {Row.create("CANAL")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected3, false); + } + + @Test + public void signum() { + DataFrame df = getSession().sql("select * from values(1) as T(a)"); + checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false); + DataFrame df1 = getSession().sql("select * from values(-2) as T(a)"); + checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false); + DataFrame df2 = getSession().sql("select * from values(0) as T(a)"); + checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false); + } + + @Test + public void sign() { + DataFrame df = getSession().sql("select * from values(1) as T(a)"); + checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false); + DataFrame df1 = getSession().sql("select * from values(-2) as T(a)"); + checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false); + DataFrame df2 = getSession().sql("select * from values(0) as T(a)"); + checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false); + } + + @Test + public void collect_list() { + DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)"); + df.select(Functions.collect_list(df.col("a"))).show(); + } + + @Test + public void substring_index() { + DataFrame df = + getSession() + .sql( + "select * from values ('It was the best of times,it was the worst of times') as T(a)"); + checkAnswer( + df.select( + Functions.substring_index( + "It was the best of times,it was the worst of times", "was", 1)), + new Row[] {Row.create("It was ")}, + false); + } + public void test_asc() { DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); Row[] expected = {Row.create(1), Row.create(2), Row.create(3)}; @@ -2874,4 +2927,58 @@ public void unbase64() { Row[] expected = {Row.create("test")}; checkAnswer(df.select(Functions.unbase64(Functions.col("a"))), expected, false); } + + @Test + public void locate_int() { + DataFrame df = + getSession() + .sql( + "select * from values ('scala', 'java scala python'), \n " + + "('b', 'abcd') as T(a,b)"); + Row[] expected = {Row.create(6), Row.create(2)}; + checkAnswer( + df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")), + expected, + false); + } + + @Test + public void locate() { + DataFrame df = getSession().sql("select * from values ('abcd') as T(s)"); + Row[] expected = {Row.create(2)}; + checkAnswer(df.select(Functions.locate("b", Functions.col("s")).as("locate")), expected, false); + } + + @Test + public void ntile_int() { + DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)"); + Row[] expected = {Row.create(1), Row.create(2), Row.create(3), Row.create(1), Row.create(2)}; + + checkAnswer( + df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y")))), + expected, + false); + } + + @Test + public void randn() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)"); + + assert (df.withColumn("randn", Functions.randn()).select("randn").first() != null); + } + + @Test + public void randn_seed() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)"); + Row[] expected = { + Row.create(5777523539921853504L), + Row.create(-8190739547906189845L), + Row.create(-1138438814981368515L) + }; + + checkAnswer( + df.withColumn("randn_with_seed", Functions.randn(123l)).select("randn_with_seed"), + expected, + false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 47dc225d..1260ef0e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode import com.snowflake.snowpark._ import com.snowflake.snowpark.functions.{repeat, _} import com.snowflake.snowpark.types._ +import com.snowflake.snowpark_java.types.LongType import net.snowflake.client.jdbc.SnowflakeSQLException import java.sql.{Date, Time, Timestamp} @@ -1245,8 +1246,7 @@ trait FunctionSuite extends TestData { ) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") } test("timestamp_ltz_from_parts") { @@ -2183,8 +2183,8 @@ trait FunctionSuite extends TestData { "4.000000000000000e+00,\n 1.000000000000000e+00,\n 5.000000000000000e+00,\n " + "1.000000000000000e+00,\n 6.000000000000000e+00,\n 1.000000000000000e+00,\n " + "7.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " + - "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n " + - "\"type\": \"tdigest\",\n \"version\": 1\n}" + "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00\n ]," + + "\n \"type\": \"tdigest\",\n \"version\": 1\n}" ) ), sort = false @@ -2222,8 +2222,8 @@ trait FunctionSuite extends TestData { "1.000000000000000e+00,\n 7.000000000000000e+00,\n 1.000000000000000e+00,\n " + "8.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " + "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00,\n " + - "9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n \"type\": \"tdigest\",\n " + - "\"version\": 1\n}" + "9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n \"type\": \"tdigest\"," + + "\n \"version\": 1\n}" ) ), sort = false @@ -2495,6 +2495,49 @@ trait FunctionSuite extends TestData { sort = false ) } + test("regexp_extract") { + val data = Seq("A MAN A PLAN A CANAL").toDF("a") + var expected = Seq(Row("MAN")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)), + expected, + sort = false) + expected = Seq(Row("PLAN")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)), + expected, + sort = false) + expected = Seq(Row("CANAL")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)), + expected, + sort = false) + + } + test("signum") { + val df = Seq(1).toDF("a") + checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false) + val df1 = Seq(-2).toDF("a") + checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false) + val df2 = Seq(0).toDF("a") + checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false) + } + test("sign") { + val df = Seq(1).toDF("a") + checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false) + val df1 = Seq(-2).toDF("a") + checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false) + val df2 = Seq(0).toDF("a") + checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false) + } + + test("substring_index") { + val df = Seq("It was the best of times, it was the worst of times").toDF("a") + checkAnswer( + df.select(substring_index("It was the best of times, it was the worst of times", "was", 1)), + Seq(Row("It was ")), + sort = false) + } test("desc column order") { val input = Seq(1, 2, 3).toDF("data") @@ -2600,6 +2643,42 @@ trait FunctionSuite extends TestData { checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false) } + test("locate Column function") { + val input = + session.createDataFrame(Seq(("scala", "java scala python"), ("b", "abcd"))).toDF("a", "b") + val expected = Seq((6), (2)).toDF("locate") + checkAnswer(input.select(locate(col("a"), col("b"), 1).as("locate")), expected, sort = false) + } + + test("locate String function") { + + val input = session.createDataFrame(Seq("java scala python")).toDF("a") + val expected = Seq(6).toDF("locate") + checkAnswer(input.select(locate("scala", col("a")).as("locate")), expected, sort = false) + } + + test("ntile function") { + val input = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score") + val window = Window.partitionBy(col("grade")).orderBy(col("score")) + val expected = Seq((1), (1), (2), (2)).toDF("ntile") + checkAnswer(input.select(ntile(2).over(window).as("ntile")), expected, sort = false) + } + + test("randn seed function") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + val expected = Seq((5777523539921853504L), (-8190739547906189845L), (-1138438814981368515L)) + .toDF("randn_with_seed") + val df = input.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed") + + checkAnswer(df, expected, sort = false) + } + + test("randn function") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + + assert(input.withColumn("randn", randn()).select("randn").first() != null) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession