Skip to content

Commit

Permalink
Sfc gh sjayabalan sma regextract signum subindex collectlist (#141)
Browse files Browse the repository at this point in the history
* Added regexp_extract,signum,substring_index,collect_list

1) Added regexp_extract,signum,substring_index,collect_list to functions.scala .
2) Added test cases for the same

* Added examples and updated the description

* Fixed format

* formatted the comments

* Added java functions and unit test cases for java

* Added sign function

* Modified the alignment

* Added examples

* adjusted comments

* Update Functions.java

---------

Co-authored-by: sfc-gh-mrojas <[email protected]>
  • Loading branch information
sfc-gh-sjayabalan and sfc-gh-mrojas authored Aug 13, 2024
1 parent 14770b7 commit d14ff9f
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 0 deletions.
114 changes: 114 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -3882,6 +3882,119 @@ public static Column listagg(Column col) {
}

/**
* Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
* Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
* specified string column. If the regex did not match, or the specified group did not match, an
* empty string is returned.
* Example:
* <pre>{@code
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
* }</pr>
*
* @since 1.12.1
* @return Column object.
*/
public static Column regexp_extract(
Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
return new Column(
com.snowflake.snowpark.functions.regexp_extract(
col.toScalaColumn(), exp, position, Occurences, grpIdx));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign
* Example::
* * <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pr>
*
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
public static Column signum(Column col) {
return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign
* Example::
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pr>
*
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
public static Column sign(Column col) {
return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
}

/**
* Returns the substring from string str before count occurrences of the delimiter delim. If count
* is positive, everything the left of the final delimiter (counting from left) is returned. If
* count is negative, every to the right of the final delimiter (counting from the right) is
* returned. substring_index performs a case-sensitive match when searching for delim.
*
* @since 1.12.1
*/
public static Column substring_index(Column col, String delim, Integer count) {
return new Column(
com.snowflake.snowpark.functions.substring_index(col.toScalaColumn(), delim, count));
}

/**
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
* returned.
*
* <p>Example::
*
* <pre>{@code
* df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* df.select(array_agg("a", True).alias("result")).show()
* "RESULT" [ 1, 2, 3 ]
* }</pre>
*
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
public static Column collect_list(Column col) {
return new Column(com.snowflake.snowpark.functions.collect_list(col.toScalaColumn()));
}

* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
Expand Down Expand Up @@ -4053,6 +4166,7 @@ public static Column last(Column col) {
return new Column(functions.last(col.toScalaColumn()));
}


/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
187 changes: 187 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,192 @@ object functions {
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Signature - snowflake.snowpark.functions.regexp_extract
* (value: Union[Column, str], regexp: Union[Column, str], idx: int)
* Column
* Extract a specific group matched by a regex, from the specified string
* column. If the regex did not match, or the specified group did not match,
* an empty string is returned.
* <pr>Example:
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
* ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
*</pr>
*<pr>
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
*</pr>
* Note: non-greedy tokens such as are not supported
* @since 1.12.1
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEX_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def sign(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.12.1
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.12.1
*/
def substring_index(str: Column, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin("regexp_instr", sqlExpr(s"reverse(${str}, ${delim}, 1, abs(${count}), 0"))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", col("str"), lit(delim), 1, lit(count), 1)))
}

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*<pr>
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* </pr>
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* @since 1.10.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
Expand Down Expand Up @@ -3312,6 +3498,7 @@ object functions {
def last(c: Column): Column =
builtin("LAST_VALUE")(c)


/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
49 changes: 49 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,54 @@ public void any_value() {
}

@Test

public void regexp_extract() {
DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
Row[] expected = {Row.create("MAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
Row[] expected2 = {Row.create("PLAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
Row[] expected3 = {Row.create("CANAL")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected3, false);
Row[] expected4 = {Row.create(null)};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected4, false);
}

@Test
public void signum() {
DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
}

@Test
public void sign() {
DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
checkAnswer(df.select(Functions.sign(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
}

@Test
public void collect_list() {
DataFrame df = getSession().sql("select * from values(10000,400,450) as T(a)");
checkAnswer(
df.select(Functions.collect_list(df.col("a"))),
new Row[] {Row.create("[\n \"10000,400,450\"\n]")},
false);
}

@Test
public void substring_index() {
DataFrame df =
getSession()
.sql(
"select * from values ('It was the best of times,it was the worst of times') as T(a)");
checkAnswer(
df.select(Functions.substring_index(df.col("a"), "was", 1)),
new Row[] {Row.create(7)},

public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};
Expand Down Expand Up @@ -2826,6 +2874,7 @@ public void last() {
Functions.last(df.col("name"))
.over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))),
expected,

false);
}
}
Loading

0 comments on commit d14ff9f

Please sign in to comment.