diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java
index 1d18a91c..78c64418 100644
--- a/src/main/java/com/snowflake/snowpark_java/Functions.java
+++ b/src/main/java/com/snowflake/snowpark_java/Functions.java
@@ -3882,7 +3882,127 @@ public static Column listagg(Column col) {
}
/**
- * Returns a Column expression with values sorted in descending order.
+ * Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
+ * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
+ * specified string column. If the regex did not match, or the specified group did not match, an
+ * empty string is returned. Example:
+ *
+ *
{@code
+ * from snowflake.snowpark.functions import regexp_extract
+ * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
+ * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
+ * ---------
+ * |"RES" |
+ * ---------
+ * |20 |
+ * |40 |
+ * ---------
+ * }
+ *
+ * @since 1.14.0
+ * @param col Column.
+ * @param exp String
+ * @param position Integer.
+ * @param Occurences Integer.
+ * @param grpIdx Integer.
+ * @return Column object.
+ */
+ public static Column regexp_extract(
+ Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
+ return new Column(
+ com.snowflake.snowpark.functions.regexp_extract(
+ col.toScalaColumn(), exp, position, Occurences, grpIdx));
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
+ *
+ *
Args: col: The column to evaluate its sign Example:: *
+ *
+ *
{@code df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ * }
+ *
+ * @since 1.14.0
+ * @param col Column to calculate the sign.
+ * @return Column object.
+ */
+ public static Column signum(Column col) {
+ return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
+ *
+ *
Args: col: The column to evaluate its sign Example::
+ *
+ *
{@code df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ * }
+ *
+ * @since 1.14.0
+ * @param col Column to calculate the sign.
+ * @return Column object.
+ */
+ public static Column sign(Column col) {
+ return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
+ }
+
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim. If count
+ * is positive, everything the left of the final delimiter (counting from left) is returned. If
+ * count is negative, every to the right of the final delimiter (counting from the right) is
+ * returned. substring_index performs a case-sensitive match when searching for delim.
+ *
+ * @param col String.
+ * @param delim String
+ * @param count Integer.
+ * @return Column object.
+ * @since 1.14.0
+ */
+ public static Column substring_index(String col, String delim, Integer count) {
+ return new Column(com.snowflake.snowpark.functions.substring_index(col, delim, count));
+ }
+
+ /**
+ * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
+ * returned.
+ *
+ * Example::
+ *
+ *
{@code
+ * df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+ * df.select(array_agg("a", True).alias("result")).show()
+ * "RESULT" [ 1, 2, 3 ]
+ * }
+ *
+ * @since 1.14.0
+ * @param c Column to be collect.
+ * @return The array.
+ */
+ public static Column collect_list(Column c) {
+ return new Column(com.snowflake.snowpark.functions.collect_list(c.toScalaColumn()));
+ }
+
+ /* Returns a Column expression with values sorted in descending order.
*
* Example: order column values in descending
*
@@ -4180,6 +4300,131 @@ public static Column unbase64(Column c) {
return new Column(functions.unbase64(c.toScalaColumn()));
}
+ /**
+ * Locate the position of the first occurrence of substr in a string column, after position pos.
+ *
+ *
{@code
+ * DataFrame df = getSession().sql("select * from values ('scala', 'java scala python'), \n " +
+ * "('b', 'abcd') as T(a,b)");
+ * df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")).show();
+ * ------------
+ * |"LOCATE" |
+ * ------------
+ * |6 |
+ * |2 |
+ * ------------
+ * }
+ *
+ * @since 1.14.0
+ * @param substr string to search
+ * @param str value where string will be searched
+ * @param pos index for starting the search
+ * @return returns the position of the first occurrence.
+ */
+ public static Column locate(Column substr, Column str, int pos) {
+ return new Column(functions.locate(substr.toScalaColumn(), str.toScalaColumn(), pos));
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr in a string column, after position pos.
+ * default to 1.
+ *
+ * {@code
+ * DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
+ * df.select(Functions.locate("b", Functions.col("s")).as("locate")).show();
+ * ------------
+ * |"LOCATE" |
+ * ------------
+ * |2 |
+ * ------------
+ * }
+ *
+ * @since 1.14.0
+ * @param substr string to search
+ * @param str value where string will be searched
+ * @return returns the position of the first occurrence.
+ */
+ public static Column locate(String substr, Column str) {
+ return new Column(functions.locate(substr, str.toScalaColumn(), 1));
+ }
+
+ /**
+ * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
+ * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
+ * quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
+ *
+ * This is equivalent to the NTILE function in SQL.
+ *
+ *
{@code
+ * DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
+ * df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y"))).as("ntile")).show();
+ * -----------
+ * |"NTILE" |
+ * -----------
+ * |1 |
+ * |2 |
+ * |3 |
+ * |1 |
+ * |2 |
+ * -----------
+ * }
+ *
+ * @since 1.14.0
+ * @param n number of groups
+ * @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
+ */
+ public static Column ntile(int n) {
+ return new Column(functions.ntile(n));
+ }
+
+ /**
+ * Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
+ * returns integers of 17-19 digits.
+ *
+ * {@code
+ * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+ * df.withColumn("randn",Functions.randn()).select("randn").show();
+ * ------------------------
+ * |"RANDN" |
+ * ------------------------
+ * |6799378361097866000 |
+ * |-7280487148628086605 |
+ * |775606662514393461 |
+ * ------------------------
+ * }
+ *
+ * @since 1.14.0
+ * @return Random number.
+ */
+ public static Column randn() {
+ return new Column(functions.randn());
+ }
+
+ /**
+ * Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
+ * returns integers of 17-19 digits.
+ *
+ * {@code
+ * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+ * df.withColumn("randn_with_seed",Functions.randn(123l)).select("randn_with_seed").show();
+ * ------------------------
+ * |"RANDN_WITH_SEED" |
+ * ------------------------
+ * |5777523539921853504 |
+ * |-8190739547906189845 |
+ * |-1138438814981368515 |
+ * ------------------------
+ * }
+ *
+ * @since 1.14.0
+ * @return Random number.
+ */
+ public static Column randn(long seed) {
+ return new Column(functions.randn(seed));
+ }
+
/**
* Calls a user-defined function (UDF) by name.
*
diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala
index c20475d7..0090fb96 100644
--- a/src/main/scala/com/snowflake/snowpark/functions.scala
+++ b/src/main/scala/com/snowflake/snowpark/functions.scala
@@ -2853,26 +2853,183 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)
- /** Returns a Column expression with values sorted in descending order. Example:
- * {{{
- * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
- * df.sort(desc("id")).show()
+ /** Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
+ * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
+ * specified string column. If the regex did not match, or the specified group did not match, an
+ * empty string is returned. Example: from snowflake.snowpark.functions import regexp_extract
+ * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
+ * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() ---------
+ * \|"RES" | ---------
+ * | 20 |
+ * |:---|
+ * | 40 |
+ * --------- Note: non-greedy tokens such as are not supported
+ * @since 1.14.0
+ * @return
+ * Column object.
+ */
+ def regexp_extract(
+ colName: Column,
+ exp: String,
+ position: Int,
+ Occurences: Int,
+ grpIdx: Int): Column = {
+ when(colName.is_null, lit(null))
+ .otherwise(
+ coalesce(
+ builtin("REGEXP_SUBSTR")(
+ colName,
+ lit(exp),
+ lit(position),
+ lit(Occurences),
+ lit("ce"),
+ lit(grpIdx)),
+ lit("")))
+ }
+
+ /** Returns the sign of its argument as mentioned :
*
- * --------
- * |"ID" |
- * --------
- * |3 |
- * |2 |
- * |1 |
- * --------
- * }}}
+ * - -1 if the argument is negative.
+ * - 1 if it is positive.
+ * - 0 if it is 0.
*
+ * Args: col: The column to evaluate its sign Example:: >>> df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show() ----------------------------------
+ * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ----------------------------------
+ * \|-1 |1 |0 | ----------------------------------
* @since 1.14.0
- * @param colName
- * Column name.
+ * @param e
+ * Column to calculate the sign.
+ * @return
+ * Column object.
+ */
+ def sign(colName: Column): Column = {
+ builtin("SIGN")(colName)
+ }
+
+ /** Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative.
+ * - 1 if it is positive.
+ * - 0 if it is 0.
+ *
+ * Args: col: The column to evaluate its sign Example:: >>> df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show() ----------------------------------
+ * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ----------------------------------
+ * \|-1 |1 |0 | ----------------------------------
+ * @since 1.14.0
+ * @param e
+ * Column to calculate the sign.
+ * @return
+ * Column object.
+ */
+ def signum(colName: Column): Column = {
+ builtin("SIGN")(colName)
+ }
+
+ /** Returns the sign of the given column. Returns either 1 for positive, 0 for 0 or NaN, -1 for
+ * negative and null for null. NOTE: if string values are provided snowflake will attempts to
+ * cast. If it casts correctly, returns the calculation, if not an error will be thrown
+ * @since 1.14.0
+ * @param columnName
+ * Name of the column to calculate the sign.
* @return
- * Column object ordered in a descending manner.
+ * Column object.
+ */
+ def signum(columnName: String): Column = {
+ signum(col(columnName))
+ }
+
+ /** Returns the substring from string str before count occurrences of the delimiter delim. If
+ * count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ * @since 1.14.0
*/
+ def substring_index(str: String, delim: String, count: Int): Column = {
+ when(
+ lit(count) < lit(0),
+ callBuiltin(
+ "substring",
+ lit(str),
+ callBuiltin(
+ "regexp_instr",
+ sqlExpr(s"reverse('${str}')"),
+ lit(delim),
+ 1,
+ abs(lit(count)),
+ lit(0))))
+ .otherwise(
+ callBuiltin(
+ "substring",
+ lit(str),
+ 1,
+ callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1)))
+ }
+
+ /** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
+ * returned. Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+ * >>> df.select(array_agg("a", True).alias("result")).show() ------------
+ * \|"RESULT" | ------------
+ * | [ |
+ * |:---|
+ * | 1, |
+ * | 2, |
+ * | 3 |
+ * | ] |
+ * ------------
+ * @since 1.14.0
+ * @param c
+ * Column to be collect.
+ * @return
+ * The array.
+ */
+ def collect_list(c: Column): Column = array_agg(c)
+
+ /** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
+ * returned.
+ *
+ * Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) >>>
+ * df.select(array_agg("a", True).alias("result")).show() ------------
+ * \|"RESULT" | ------------
+ * | [ |
+ * |:---|
+ * | 1, |
+ * | 2, |
+ * | 3 |
+ * | ] |
+ * ------------
+ * @since 1.14.0
+ * @param s
+ * Column name to be collected.
+ * @return
+ * The array.
+ */
+ def collect_list(s: String): Column = array_agg(col(s))
+
+ /* Returns a Column expression with values sorted in descending order.
+ * Example:
+ * {{{
+ * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
+ * df.sort(desc("id")).show()
+ *
+ * --------
+ * |"ID" |
+ * --------
+ * |3 |
+ * |2 |
+ * |1 |
+ * --------
+ * }}}
+ *
+ * @since 1.14.0
+ * @param colName Column name.
+ * @return Column object ordered in a descending manner.
+ */
def desc(colName: String): Column = col(colName).desc
/** Returns a Column expression with values sorted in ascending order. Example:
@@ -3150,6 +3307,135 @@ object functions {
*/
def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col)
+ /** Locate the position of the first occurrence of substr in a string column, after position pos.
+ *
+ * @note
+ * The position is not zero based, but 1 based index. returns 0 if substr could not be found in
+ * str. This function is just leverages the SF POSITION builtin Example
+ * {{{
+ * val df = session.createDataFrame(Seq(("b", "abcd"))).toDF("a", "b")
+ * df.select(locate(col("a"), col("b"), 1).as("locate")).show()
+ * ------------
+ * |"LOCATE" |
+ * ------------
+ * |2 |
+ * ------------
+ *
+ * }}}
+ * @since 1.14.0
+ * @param substr
+ * string to search
+ * @param str
+ * value where string will be searched
+ * @param pos
+ * index for starting the search
+ * @return
+ * returns the position of the first occurrence.
+ */
+ def locate(substr: Column, str: Column, pos: Int): Column =
+ if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos)
+
+ /** Locate the position of the first occurrence of substr in a string column, after position pos.
+ *
+ * @note
+ * The position is not zero based, but 1 based index. returns 0 if substr could not be found in
+ * str. This function is just leverages the SF POSITION builtin Example
+ * {{{
+ * val df = session.createDataFrame(Seq("java scala python")).toDF("a")
+ * df.select(locate("scala", col("a")).as("locate")).show()
+ * ------------
+ * |"LOCATE" |
+ * ------------
+ * |6 |
+ * ------------
+ *
+ * }}}
+ * @since 1.14.0
+ * @param substr
+ * string to search
+ * @param str
+ * value where string will be searched
+ * @param pos
+ * index for starting the search. default to 1.
+ * @return
+ * Returns the position of the first occurrence
+ */
+ def locate(substr: String, str: Column, pos: Int = 1): Column =
+ if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos))
+
+ /** Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
+ * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the
+ * second quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
+ *
+ * This is equivalent to the NTILE function in SQL. Example
+ * {{{
+ * val df = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
+ * val window = Window.partitionBy(col("grade")).orderBy(col("score"))
+ * df.select(ntile(2).over(window).as("ntile")).show()
+ * -----------
+ * |"NTILE" |
+ * -----------
+ * |1 |
+ * |1 |
+ * |2 |
+ * |2 |
+ * -----------
+ * }}}
+ *
+ * @since 1.14.0
+ * @param n
+ * number of groups
+ * @return
+ * returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
+ */
+ def ntile(n: Int): Column = callBuiltin("ntile", lit(n))
+
+ /** Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
+ * returns integers of 17-19 digits. Example
+ * {{{
+ * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
+ * df.withColumn("randn", randn()).select("randn").show()
+ * ------------------------
+ * |"RANDN" |
+ * ------------------------
+ * |-2093909082984812541 |
+ * |-1379817492278593383 |
+ * |-1231198046297539927 |
+ * ------------------------
+ * }}}
+ *
+ * @since 1.14.0
+ * @return
+ * Random number.
+ */
+ def randn(): Column =
+ builtin("RANDOM")()
+
+ /** Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution. Calls to the Snowflake RANDOM function. NOTE: Snowflake returns
+ * integers of 17-19 digits. Example
+ * {{{
+ * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
+ * df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show()
+ * ------------------------
+ * |"RANDN_WITH_SEED" |
+ * ------------------------
+ * |5777523539921853504 |
+ * |-8190739547906189845 |
+ * |-1138438814981368515 |
+ * ------------------------
+ * }}}
+ *
+ * @since 1.14.0
+ * @param seed
+ * Seed to use in the random function.
+ * @return
+ * Random number.
+ */
+ def randn(seed: Long): Column =
+ builtin("RANDOM")(seed)
+
/** Invokes a built-in snowflake function with the specified name and arguments. Arguments can be
* of two types
*
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
index 05e38211..bdd14337 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
@@ -2766,6 +2766,59 @@ public void any_value() {
}
@Test
+ public void regexp_extract() {
+ DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
+ Row[] expected = {Row.create("MAN")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
+ Row[] expected2 = {Row.create("PLAN")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
+ Row[] expected3 = {Row.create("CANAL")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected3, false);
+ }
+
+ @Test
+ public void signum() {
+ DataFrame df = getSession().sql("select * from values(1) as T(a)");
+ checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
+ DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
+ checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
+ DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
+ checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
+ }
+
+ @Test
+ public void sign() {
+ DataFrame df = getSession().sql("select * from values(1) as T(a)");
+ checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
+ DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
+ checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
+ DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
+ checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
+ }
+
+ @Test
+ public void collect_list() {
+ DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)");
+ df.select(Functions.collect_list(df.col("a"))).show();
+ }
+
+ @Test
+ public void substring_index() {
+ DataFrame df =
+ getSession()
+ .sql(
+ "select * from values ('It was the best of times,it was the worst of times') as T(a)");
+ checkAnswer(
+ df.select(
+ Functions.substring_index(
+ "It was the best of times,it was the worst of times", "was", 1)),
+ new Row[] {Row.create("It was ")},
+ false);
+ }
+
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};
@@ -2874,4 +2927,58 @@ public void unbase64() {
Row[] expected = {Row.create("test")};
checkAnswer(df.select(Functions.unbase64(Functions.col("a"))), expected, false);
}
+
+ @Test
+ public void locate_int() {
+ DataFrame df =
+ getSession()
+ .sql(
+ "select * from values ('scala', 'java scala python'), \n "
+ + "('b', 'abcd') as T(a,b)");
+ Row[] expected = {Row.create(6), Row.create(2)};
+ checkAnswer(
+ df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")),
+ expected,
+ false);
+ }
+
+ @Test
+ public void locate() {
+ DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
+ Row[] expected = {Row.create(2)};
+ checkAnswer(df.select(Functions.locate("b", Functions.col("s")).as("locate")), expected, false);
+ }
+
+ @Test
+ public void ntile_int() {
+ DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
+ Row[] expected = {Row.create(1), Row.create(2), Row.create(3), Row.create(1), Row.create(2)};
+
+ checkAnswer(
+ df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y")))),
+ expected,
+ false);
+ }
+
+ @Test
+ public void randn() {
+ DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+
+ assert (df.withColumn("randn", Functions.randn()).select("randn").first() != null);
+ }
+
+ @Test
+ public void randn_seed() {
+ DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+ Row[] expected = {
+ Row.create(5777523539921853504L),
+ Row.create(-8190739547906189845L),
+ Row.create(-1138438814981368515L)
+ };
+
+ checkAnswer(
+ df.withColumn("randn_with_seed", Functions.randn(123l)).select("randn_with_seed"),
+ expected,
+ false);
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
index 47dc225d..1260ef0e 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
@@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode
import com.snowflake.snowpark._
import com.snowflake.snowpark.functions.{repeat, _}
import com.snowflake.snowpark.types._
+import com.snowflake.snowpark_java.types.LongType
import net.snowflake.client.jdbc.SnowflakeSQLException
import java.sql.{Date, Time, Timestamp}
@@ -1245,8 +1246,7 @@ trait FunctionSuite extends TestData {
)
.collect()(0)
.getTimestamp(0)
- .toString == "2020-10-28 13:35:47.001234567"
- )
+ .toString == "2020-10-28 13:35:47.001234567")
}
test("timestamp_ltz_from_parts") {
@@ -2183,8 +2183,8 @@ trait FunctionSuite extends TestData {
"4.000000000000000e+00,\n 1.000000000000000e+00,\n 5.000000000000000e+00,\n " +
"1.000000000000000e+00,\n 6.000000000000000e+00,\n 1.000000000000000e+00,\n " +
"7.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " +
- "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n " +
- "\"type\": \"tdigest\",\n \"version\": 1\n}"
+ "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00\n ]," +
+ "\n \"type\": \"tdigest\",\n \"version\": 1\n}"
)
),
sort = false
@@ -2222,8 +2222,8 @@ trait FunctionSuite extends TestData {
"1.000000000000000e+00,\n 7.000000000000000e+00,\n 1.000000000000000e+00,\n " +
"8.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " +
"1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00,\n " +
- "9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n \"type\": \"tdigest\",\n " +
- "\"version\": 1\n}"
+ "9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n \"type\": \"tdigest\"," +
+ "\n \"version\": 1\n}"
)
),
sort = false
@@ -2495,6 +2495,49 @@ trait FunctionSuite extends TestData {
sort = false
)
}
+ test("regexp_extract") {
+ val data = Seq("A MAN A PLAN A CANAL").toDF("a")
+ var expected = Seq(Row("MAN"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)),
+ expected,
+ sort = false)
+ expected = Seq(Row("PLAN"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)),
+ expected,
+ sort = false)
+ expected = Seq(Row("CANAL"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)),
+ expected,
+ sort = false)
+
+ }
+ test("signum") {
+ val df = Seq(1).toDF("a")
+ checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false)
+ val df1 = Seq(-2).toDF("a")
+ checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false)
+ val df2 = Seq(0).toDF("a")
+ checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false)
+ }
+ test("sign") {
+ val df = Seq(1).toDF("a")
+ checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false)
+ val df1 = Seq(-2).toDF("a")
+ checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false)
+ val df2 = Seq(0).toDF("a")
+ checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false)
+ }
+
+ test("substring_index") {
+ val df = Seq("It was the best of times, it was the worst of times").toDF("a")
+ checkAnswer(
+ df.select(substring_index("It was the best of times, it was the worst of times", "was", 1)),
+ Seq(Row("It was ")),
+ sort = false)
+ }
test("desc column order") {
val input = Seq(1, 2, 3).toDF("data")
@@ -2600,6 +2643,42 @@ trait FunctionSuite extends TestData {
checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false)
}
+ test("locate Column function") {
+ val input =
+ session.createDataFrame(Seq(("scala", "java scala python"), ("b", "abcd"))).toDF("a", "b")
+ val expected = Seq((6), (2)).toDF("locate")
+ checkAnswer(input.select(locate(col("a"), col("b"), 1).as("locate")), expected, sort = false)
+ }
+
+ test("locate String function") {
+
+ val input = session.createDataFrame(Seq("java scala python")).toDF("a")
+ val expected = Seq(6).toDF("locate")
+ checkAnswer(input.select(locate("scala", col("a")).as("locate")), expected, sort = false)
+ }
+
+ test("ntile function") {
+ val input = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
+ val window = Window.partitionBy(col("grade")).orderBy(col("score"))
+ val expected = Seq((1), (1), (2), (2)).toDF("ntile")
+ checkAnswer(input.select(ntile(2).over(window).as("ntile")), expected, sort = false)
+ }
+
+ test("randn seed function") {
+ val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
+ val expected = Seq((5777523539921853504L), (-8190739547906189845L), (-1138438814981368515L))
+ .toDF("randn_with_seed")
+ val df = input.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed")
+
+ checkAnswer(df, expected, sort = false)
+ }
+
+ test("randn function") {
+ val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
+
+ assert(input.withColumn("randn", randn()).select("randn").first() != null)
+ }
+
}
class EagerFunctionSuite extends FunctionSuite with EagerSession