Skip to content

Commit

Permalink
add ntile randn and locate function for java with test case
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gmahadevan committed Aug 20, 2024
1 parent e6a0da8 commit 41d8e1a
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 0 deletions.
124 changes: 124 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,130 @@ public static Column unbase64(Column c) {
return new Column(functions.unbase64(c.toScalaColumn()));
}

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values ('scala', 'java scala python'), \n " +
* "('b', 'abcd') as T(a,b)");
* df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")).show();
* ------------
* |"LOCATE" |
* ------------
* |6 |
* |2 |
* ------------
* }</pre>
*
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search
* @return returns the position of the first occurrence.
*/
public static Column locate(Column substr, Column str, int pos) {
return new Column(functions.locate(substr.toScalaColumn(), str.toScalaColumn(), pos));
}

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
* df.select(Functions.locate("b", Functions.col("s")).as("locate")).show();
* ------------
* |"LOCATE" |
* ------------
* |2 |
* ------------
* }</pre>
*
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @return returns the position of the first occurrence.
*/
public static Column locate(String substr, Column str) {
return new Column(functions.locate(substr, str.toScalaColumn(), 1));
}

/**
* Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
* partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
* quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
*
* <p>This is equivalent to the NTILE function in SQL.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
* df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y"))).as("ntile")).show();
* -----------
* |"NTILE" |
* -----------
* |1 |
* |2 |
* |3 |
* |1 |
* |2 |
* -----------
* }</pre>
*
* @since 1.14.0
* @param n number of groups
* @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
*/
public static Column ntile(int n) {
return new Column(functions.ntile(n));
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from the
* standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
* returns integers of 17-19 digits.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.withColumn("randn",Functions.randn()).select("randn").show();
* ------------------------
* |"RANDN" |
* ------------------------
* |6799378361097866000 |
* |-7280487148628086605 |
* |775606662514393461 |
* ------------------------
* }</pre>
*
* @since 1.14.0
* @return Random number.
*/
public static Column randn() {
return new Column(functions.randn());
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from the
* standard normal distribution. Return a call to the Snowflake RANDOM function. NOTE: Snowflake
* returns integers of 17-19 digits.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.withColumn("randn_with_seed",Functions.randn(123l)).select("randn_with_seed").show();
* ------------------------
* |"RANDN_WITH_SEED" |
* ------------------------
* |5777523539921853504 |
* |-8190739547906189845 |
* |-1138438814981368515 |
* ------------------------
* }</pre>
*
* @since 1.14.0
* @return Random number.
*/
public static Column randn(long seed) {
return new Column(functions.randn(seed));
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
54 changes: 54 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2874,4 +2874,58 @@ public void unbase64() {
Row[] expected = {Row.create("test")};
checkAnswer(df.select(Functions.unbase64(Functions.col("a"))), expected, false);
}

@Test
public void locate_int() {
DataFrame df =
getSession()
.sql(
"select * from values ('scala', 'java scala python'), \n "
+ "('b', 'abcd') as T(a,b)");
Row[] expected = {Row.create(6), Row.create(2)};
checkAnswer(
df.select(Functions.locate(Functions.col("a"), Functions.col("b"), 1).as("locate")),
expected,
false);
}

@Test
public void locate() {
DataFrame df = getSession().sql("select * from values ('abcd') as T(s)");
Row[] expected = {Row.create(2)};
checkAnswer(df.select(Functions.locate("b", Functions.col("s")).as("locate")), expected, false);
}

@Test
public void ntile_int() {
DataFrame df = getSession().sql("select * from values(1,2),(1,2),(2,1),(2,2),(2,2) as T(x,y)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3), Row.create(1), Row.create(2)};

checkAnswer(
df.select(Functions.ntile(4).over(Window.partitionBy(df.col("x")).orderBy(df.col("y")))),
expected,
false);
}

@Test
public void randn() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");

assert (df.withColumn("randn", Functions.randn()).select("randn").first() != null);
}

@Test
public void randn_seed() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
Row[] expected = {
Row.create(5777523539921853504L),
Row.create(-8190739547906189845L),
Row.create(-1138438814981368515L)
};

checkAnswer(
df.withColumn("randn_with_seed", Functions.randn(123l)).select("randn_with_seed"),
expected,
false);
}
}

0 comments on commit 41d8e1a

Please sign in to comment.