Skip to content

Commit

Permalink
Snow 802269 add locate, randn and ntile functions (#146)
Browse files Browse the repository at this point in the history
* add ntile randn and locate function for scala with test case

* reformat

* add ntile randn and locate function for java with test case

* organize imports

* rebuild pipeline
  • Loading branch information
sfc-gh-gmahadevan authored Aug 21, 2024
1 parent a1babb3 commit 172caca
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 0 deletions.
125 changes: 125 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -4300,6 +4300,131 @@ public static Column unbase64(Column c) {
return new Column(functions.unbase64(c.toScalaColumn()));
}

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values ('scala', 'java scala python'), \n " +
* "('b', 'abcd') as T(a,b)");
* df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")).show();
* ------------
* |"LOCATE" |
* ------------
* |6 |
* |2 |
* ------------
* }</pre>
*
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search
* @return returns the position of the first occurrence.
*/
public static Column locate(Column substr, Column str, int pos) {
return new Column(functions.locate(substr.toScalaColumn(), str.toScalaColumn(), pos));
}

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
* default to 1.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
* df.select(Functions.locate("b", Functions.col("s")).as("locate")).show();
* ------------
* |"LOCATE" |
* ------------
* |2 |
* ------------
* }</pre>
*
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @return returns the position of the first occurrence.
*/
public static Column locate(String substr, Column str) {
return new Column(functions.locate(substr, str.toScalaColumn(), 1));
}

/**
* Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
* partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
* quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
*
* <p>This is equivalent to the NTILE function in SQL.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
* df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y"))).as("ntile")).show();
* -----------
* |"NTILE" |
* -----------
* |1 |
* |2 |
* |3 |
* |1 |
* |2 |
* -----------
* }</pre>
*
* @since 1.14.0
* @param n number of groups
* @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
*/
public static Column ntile(int n) {
return new Column(functions.ntile(n));
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from the
* standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
* returns integers of 17-19 digits.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.withColumn("randn",Functions.randn()).select("randn").show();
* ------------------------
* |"RANDN" |
* ------------------------
* |6799378361097866000 |
* |-7280487148628086605 |
* |775606662514393461 |
* ------------------------
* }</pre>
*
* @since 1.14.0
* @return Random number.
*/
public static Column randn() {
return new Column(functions.randn());
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from the
* standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
* returns integers of 17-19 digits.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.withColumn("randn_with_seed",Functions.randn(123l)).select("randn_with_seed").show();
* ------------------------
* |"RANDN_WITH_SEED" |
* ------------------------
* |5777523539921853504 |
* |-8190739547906189845 |
* |-1138438814981368515 |
* ------------------------
* }</pre>
*
* @since 1.14.0
* @return Random number.
*/
public static Column randn(long seed) {
return new Column(functions.randn(seed));
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
127 changes: 127 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3625,6 +3625,133 @@ object functions {
*/
def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col)

/**
*
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* @note The position is not zero based, but 1 based index. returns 0 if substr
* could not be found in str. This function is just leverages the SF POSITION builtin
*Example
* {{{
* val df = session.createDataFrame(Seq(("b", "abcd"))).toDF("a", "b")
* df.select(locate(col("a"), col("b"), 1).as("locate")).show()
* ------------
* |"LOCATE" |
* ------------
* |2 |
* ------------
*
* }}}
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search
* @return returns the position of the first occurrence.
*/
def locate(substr: Column, str: Column, pos: Int): Column =
if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos)

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* @note The position is not zero based, but 1 based index. returns 0 if substr
* could not be found in str. This function is just leverages the SF POSITION builtin
* Example
* {{{
* val df = session.createDataFrame(Seq("java scala python")).toDF("a")
* df.select(locate("scala", col("a")).as("locate")).show()
* ------------
* |"LOCATE" |
* ------------
* |6 |
* ------------
*
* }}}
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search. default to 1.
* @return Returns the position of the first occurrence
*/
def locate(substr: String, str: Column, pos: Int = 1): Column =
if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos))

/**
* Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
* partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
* quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
*
* This is equivalent to the NTILE function in SQL.
* Example
* {{{
* val df = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
* val window = Window.partitionBy(col("grade")).orderBy(col("score"))
* df.select(ntile(2).over(window).as("ntile")).show()
* -----------
* |"NTILE" |
* -----------
* |1 |
* |1 |
* |2 |
* |2 |
* -----------
* }}}
*
* @since 1.14.0
* @param n number of groups
* @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
*/
def ntile(n: Int): Column = callBuiltin("ntile", lit(n))

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the standard normal distribution.
* Return a call to the Snowflake RANDOM function.
* NOTE: Snowflake returns integers of 17-19 digits.
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("randn", randn()).select("randn").show()
* ------------------------
* |"RANDN" |
* ------------------------
* |-2093909082984812541 |
* |-1379817492278593383 |
* |-1231198046297539927 |
* ------------------------
* }}}
*
* @since 1.14.0
* @return Random number.
*/
def randn(): Column =
builtin("RANDOM")()

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the standard normal distribution.
* Calls to the Snowflake RANDOM function.
* NOTE: Snowflake returns integers of 17-19 digits.
*Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show()
* ------------------------
* |"RANDN_WITH_SEED" |
* ------------------------
* |5777523539921853504 |
* |-8190739547906189845 |
* |-1138438814981368515 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param seed Seed to use in the random function.
* @return Random number.
*/
def randn(seed: Long): Column =
builtin("RANDOM")(seed)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
54 changes: 54 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2927,4 +2927,58 @@ public void unbase64() {
Row[] expected = {Row.create("test")};
checkAnswer(df.select(Functions.unbase64(Functions.col("a"))), expected, false);
}

@Test
public void locate_int() {
DataFrame df =
getSession()
.sql(
"select * from values ('scala', 'java scala python'), \n "
+ "('b', 'abcd') as T(a,b)");
Row[] expected = {Row.create(6), Row.create(2)};
checkAnswer(
df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")),
expected,
false);
}

@Test
public void locate() {
DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
Row[] expected = {Row.create(2)};
checkAnswer(df.select(Functions.locate("b", Functions.col("s")).as("locate")), expected, false);
}

@Test
public void ntile_int() {
DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3), Row.create(1), Row.create(2)};

checkAnswer(
df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y")))),
expected,
false);
}

@Test
public void randn() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");

assert (df.withColumn("randn", Functions.randn()).select("randn").first() != null);
}

@Test
public void randn_seed() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
Row[] expected = {
Row.create(5777523539921853504L),
Row.create(-8190739547906189845L),
Row.create(-1138438814981368515L)
};

checkAnswer(
df.withColumn("randn_with_seed", Functions.randn(123l)).select("randn_with_seed"),
expected,
false);
}
}
37 changes: 37 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode
import com.snowflake.snowpark._
import com.snowflake.snowpark.functions.{repeat, _}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark_java.types.LongType
import net.snowflake.client.jdbc.SnowflakeSQLException

import java.sql.{Date, Time, Timestamp}
Expand Down Expand Up @@ -2324,6 +2325,42 @@ trait FunctionSuite extends TestData {
checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false)
}

test("locate Column function") {
val input =
session.createDataFrame(Seq(("scala", "java scala python"), ("b", "abcd"))).toDF("a", "b")
val expected = Seq((6), (2)).toDF("locate")
checkAnswer(input.select(locate(col("a"), col("b"), 1).as("locate")), expected, sort = false)
}

test("locate String function") {

val input = session.createDataFrame(Seq("java scala python")).toDF("a")
val expected = Seq(6).toDF("locate")
checkAnswer(input.select(locate("scala", col("a")).as("locate")), expected, sort = false)
}

test("ntile function") {
val input = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
val window = Window.partitionBy(col("grade")).orderBy(col("score"))
val expected = Seq((1), (1), (2), (2)).toDF("ntile")
checkAnswer(input.select(ntile(2).over(window).as("ntile")), expected, sort = false)
}

test("randn seed function") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
val expected = Seq((5777523539921853504L), (-8190739547906189845L), (-1138438814981368515L))
.toDF("randn_with_seed")
val df = input.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed")

checkAnswer(df, expected, sort = false)
}

test("randn function") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")

assert(input.withColumn("randn", randn()).select("randn").first() != null)
}

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down

0 comments on commit 172caca

Please sign in to comment.