Skip to content

Commit

Permalink
Merge branch 'main' into SNOW-1628247
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang authored Aug 20, 2024
2 parents 508c269 + a1babb3 commit 2f6f639
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 2 deletions.
122 changes: 121 additions & 1 deletion src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -3882,7 +3882,127 @@ public static Column listagg(Column col) {
}

/**
* Returns a Column expression with values sorted in descending order.
* Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
* Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
* specified string column. If the regex did not match, or the specified group did not match, an
* empty string is returned. Example:
*
* <pre>{@code
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
* }</pre>
*
* @since 1.14.0
* @param col Column.
* @param exp String
* @param position Integer.
* @param Occurences Integer.
* @param grpIdx Integer.
* @return Column object.
*/
public static Column regexp_extract(
Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
return new Column(
com.snowflake.snowpark.functions.regexp_extract(
col.toScalaColumn(), exp, position, Occurences, grpIdx));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign Example:: *
*
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pre>
*
* @since 1.14.0
* @param col Column to calculate the sign.
* @return Column object.
*/
public static Column signum(Column col) {
return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign Example::
*
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pre>
*
* @since 1.14.0
* @param col Column to calculate the sign.
* @return Column object.
*/
public static Column sign(Column col) {
return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
}

/**
* Returns the substring from string str before count occurrences of the delimiter delim. If count
* is positive, everything the left of the final delimiter (counting from left) is returned. If
* count is negative, every to the right of the final delimiter (counting from the right) is
* returned. substring_index performs a case-sensitive match when searching for delim.
*
* @param col String.
* @param delim String
* @param count Integer.
* @return Column object.
* @since 1.14.0
*/
public static Column substring_index(String col, String delim, Integer count) {
return new Column(com.snowflake.snowpark.functions.substring_index(col, delim, count));
}

/**
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
* returned.
*
* <p>Example::
*
* <pre>{@code
* df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* df.select(array_agg("a", True).alias("result")).show()
* "RESULT" [ 1, 2, 3 ]
* }</pre>
*
* @since 1.14.0
* @param c Column to be collect.
* @return The array.
*/
public static Column collect_list(Column c) {
return new Column(com.snowflake.snowpark.functions.collect_list(c.toScalaColumn()));
}

/* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
Expand Down
194 changes: 193 additions & 1 deletion src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,199 @@ object functions {
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.
* Signature - snowflake.snowpark.functions.regexp_extract
* (value: Union[Column, str], regexp: Union[Column, str], idx: int)
* Column
* Extract a specific group matched by a regex, from the specified string
* column. If the regex did not match, or the specified group did not match,
* an empty string is returned.
* <pr>Example:
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
* ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
*</pr>
*<pr>
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
*</pr>
* Note: non-greedy tokens such as are not supported
* @since 1.14.0
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEXP_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of its argument as mentioned :
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.14.0
* @param e Column to calculate the sign.
* @return Column object.
*/
def sign(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.14.0
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.14.0
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.14.0
*/
def substring_index(str: String, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin(
"regexp_instr",
sqlExpr(s"reverse('${str}')"),
lit(delim),
1,
abs(lit(count)),
lit(0))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1)))
}

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*<pr>
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* </pr>
* @since 1.14.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* @since 1.14.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

/* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
Expand Down
53 changes: 53 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,59 @@ public void any_value() {
}

@Test
public void regexp_extract() {
DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
Row[] expected = {Row.create("MAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
Row[] expected2 = {Row.create("PLAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
Row[] expected3 = {Row.create("CANAL")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected3, false);
}

@Test
public void signum() {
DataFrame df = getSession().sql("select * from values(1) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
}

@Test
public void sign() {
DataFrame df = getSession().sql("select * from values(1) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false);
DataFrame df1 = getSession().sql("select * from values(-2) as T(a)");
checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false);
DataFrame df2 = getSession().sql("select * from values(0) as T(a)");
checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false);
}

@Test
public void collect_list() {
DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)");
df.select(Functions.collect_list(df.col("a"))).show();
}

@Test
public void substring_index() {
DataFrame df =
getSession()
.sql(
"select * from values ('It was the best of times,it was the worst of times') as T(a)");
checkAnswer(
df.select(
Functions.substring_index(
"It was the best of times,it was the worst of times", "was", 1)),
new Row[] {Row.create("It was ")},
false);
}

public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};
Expand Down
Loading

0 comments on commit 2f6f639

Please sign in to comment.