Skip to content

Commit

Permalink
add ntile randn and locate function for scala with test case
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gmahadevan committed Aug 20, 2024
1 parent 833ef6d commit de56820
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 7 deletions.
136 changes: 129 additions & 7 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ package com.snowflake.snowpark

import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.internal.ScalaFunctions._
import com.snowflake.snowpark.internal.{
ErrorMessage,
OpenTelemetry,
UDXRegistrationHandler,
Utils
}
import com.snowflake.snowpark.types.TimestampType
import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils}
import com.snowflake.snowpark.types.{FloatType, TimestampType}

import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
Expand Down Expand Up @@ -3433,6 +3428,133 @@ object functions {
*/
def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col)

/**
*
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* @note The position is not zero based, but 1 based index. returns 0 if substr
* could not be found in str. This function is just leverages the SF POSITION builtin
*Example
* {{{
* val df = session.createDataFrame(Seq(("b", "abcd"))).toDF("a", "b")
* df.select(locate(col("a"), col("b"), 1).as("locate")).show()
* ------------
* |"LOCATE" |
* ------------
* |2 |
* ------------
*
* }}}
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search
* @return returns the position of the first occurrence.
*/
def locate(substr: Column, str: Column, pos: Int): Column =
if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos)

/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
* @note The position is not zero based, but 1 based index. returns 0 if substr
* could not be found in str. This function is just leverages the SF POSITION builtin
* Example
* {{{
* val df = session.createDataFrame(Seq("java scala python")).toDF("a")
* df.select(locate("scala", col("a")).as("locate")).show()
* ------------
* |"LOCATE" |
* ------------
* |6 |
* ------------
*
* }}}
* @since 1.14.0
* @param substr string to search
* @param str value where string will be searched
* @param pos index for starting the search
* @return Returns the position of the first occurrence
*/
def locate(substr: String, str: Column, pos: Int = 1): Column =
if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos))

/**
* Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
* partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
* quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
*
* This is equivalent to the NTILE function in SQL.
* Example
* {{{
* val df = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
* val window = Window.partitionBy(col("grade")).orderBy(col("score"))
* df.select(ntile(2).over(window).as("ntile")).show()
* -----------
* |"NTILE" |
* -----------
* |1 |
* |1 |
* |2 |
* |2 |
* -----------
* }}}
*
* @since 1.14.0
* @param n number of groups
* @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition.
*/
def ntile(n: Int): Column = callBuiltin("ntile", lit(n))

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the standard normal distribution.
* Return a call to the Snowflake RANDOM function.
* NOTE: Snowflake returns integers of 17-19 digits.
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("randn", randn()).select("randn").show()
* ------------------------
* |"RANDN" |
* ------------------------
* |-2093909082984812541 |
* |-1379817492278593383 |
* |-1231198046297539927 |
* ------------------------
* }}}
*
* @since 1.14.0
* @return Random number.
*/
def randn(): Column =
builtin("RANDOM")()

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the standard normal distribution.
* Calls to the Snowflake RANDOM function.
* NOTE: Snowflake returns integers of 17-19 digits.
*Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show()
* ------------------------
* |"RANDN_WITH_SEED" |
* ------------------------
* |5777523539921853504 |
* |-8190739547906189845 |
* |-1138438814981368515 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param seed Seed to use in the random function.
* @return Random number.
*/
def randn(seed: Long): Column =
builtin("RANDOM")(seed)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
37 changes: 37 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode
import com.snowflake.snowpark._
import com.snowflake.snowpark.functions.{repeat, _}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark_java.types.LongType
import net.snowflake.client.jdbc.SnowflakeSQLException

import java.sql.{Date, Time, Timestamp}
Expand Down Expand Up @@ -2280,6 +2281,42 @@ trait FunctionSuite extends TestData {
checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false)
}

test("locate Column function") {
val input = session.createDataFrame(Seq
(("scala", "java scala python"), ("b", "abcd"))).toDF("a", "b")
val expected = Seq((6), (2)).toDF("locate")
checkAnswer(input.select(locate(col("a"), col("b"), 1).as("locate")), expected, sort = false)
}

test("locate String function") {

val input = session.createDataFrame(Seq("java scala python")).toDF("a")
val expected = Seq(6).toDF("locate")
checkAnswer(input.select(locate("scala", col("a")).as("locate")), expected, sort = false)
}

test("ntile function") {
val input = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score")
val window = Window.partitionBy(col("grade")).orderBy(col("score"))
val expected = Seq((1), (1), (2), (2)).toDF("ntile")
checkAnswer(input.select(ntile(2).over(window).as("ntile")), expected, sort = false)
}

test("randn seed function") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
val expected = Seq((5777523539921853504L), (-8190739547906189845L), (-1138438814981368515L))
.toDF("randn_with_seed")
val df = input.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed")

checkAnswer(df, expected, sort = false)
}

test("randn function") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")

assert(input.withColumn("randn", randn()).select("randn").first() != null)
}

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down

0 comments on commit de56820

Please sign in to comment.