Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-802269 Add regextract signum subindex collectlist functions #142

122 changes: 121 additions & 1 deletion src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -3882,7 +3882,127 @@ public static Column listagg(Column col) {
}

/**
* Returns a Column expression with values sorted in descending order.
* Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
* Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
* specified string column. If the regex did not match, or the specified group did not match, an
* empty string is returned. Example:
*
* <pre>{@code
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
* }</pre>
*
* @since 1.14.0
* @param col Column.
* @param exp String
* @param position Integer.
* @param Occurences Integer.
* @param grpIdx Integer.
* @return Column object.
*/
public static Column regexp_extract(
Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
return new Column(
com.snowflake.snowpark.functions.regexp_extract(
col.toScalaColumn(), exp, position, Occurences, grpIdx));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign Example:: *
*
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pre>
*
* @since 1.14.0
* @param col Column to calculate the sign.
* @return Column object.
*/
public static Column signum(Column col) {
return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign Example::
*
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pre>
*
* @since 1.14.0
* @param col Column to calculate the sign.
* @return Column object.
*/
public static Column sign(Column col) {
return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
}

/**
* Returns the substring from string str before count occurrences of the delimiter delim. If count
* is positive, everything the left of the final delimiter (counting from left) is returned. If
* count is negative, every to the right of the final delimiter (counting from the right) is
* returned. substring_index performs a case-sensitive match when searching for delim.
*
* @param col String.
* @param delim String
* @param count Integer.
* @return Column object.
* @since 1.14.0
*/
public static Column substring_index(String col, String delim, Integer count) {
return new Column(com.snowflake.snowpark.functions.substring_index(col, delim, count));
}

/**
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
* returned.
*
* <p>Example::
*
* <pre>{@code
* df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* df.select(array_agg("a", True).alias("result")).show()
* "RESULT" [ 1, 2, 3 ]
* }</pre>
*
* @since 1.14.0
* @param c Column to be collect.
* @return The array.
*/
public static Column collect_list(Column c) {
return new Column(com.snowflake.snowpark.functions.collect_list(c.toScalaColumn()));
}

/* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
Expand Down
194 changes: 193 additions & 1 deletion src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,199 @@ object functions {
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.

* Signature - snowflake.snowpark.functions.regexp_extract
* (value: Union[Column, str], regexp: Union[Column, str], idx: int)
* Column
* Extract a specific group matched by a regex, from the specified string
* column. If the regex did not match, or the specified group did not match,
* an empty string is returned.
* <pr>Example:
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
* ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
*</pr>
*<pr>
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
*</pr>
* Note: non-greedy tokens such as are not supported
* @since 1.14.0
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEXP_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of its argument as mentioned :
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.14.0
* @param e Column to calculate the sign.
* @return Column object.
*/
def sign(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.14.0
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.14.0
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.14.0
*/
def substring_index(str: String, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin(
"regexp_instr",
sqlExpr(s"reverse('${str}')"),
lit(delim),
1,
abs(lit(count)),
lit(0))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1)))
}

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*<pr>
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* </pr>
* @since 1.14.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* @since 1.14.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

/* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
Expand Down
53 changes: 53 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,59 @@ public void any_value() {
}

@Test
public void regexp_extract() {
DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
Row[] expected = {Row.create("MAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
Row[] expected2 = {Row.create("PLAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
Row[] expected3 = {Row.create("CANAL")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected3, false);
}

@Test
public void signum() {
DataFrame df = getSession().sql("select * from values(1) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
}

@Test
public void sign() {
DataFrame df = getSession().sql("select * from values(1) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
}

@Test
public void collect_list() {
DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)");
df.select(Functions.collect_list(df.col("a"))).show();
}

@Test
public void substring_index() {
DataFrame df =
getSession()
.sql(
"select * from values ('It was the best of times,it was the worst of times') as T(a)");
checkAnswer(
df.select(
Functions.substring_index(
"It was the best of times,it was the worst of times", "was", 1)),
new Row[] {Row.create("It was ")},
false);
}

public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};
Expand Down
Loading
Loading