From de5682024cadf15595d1d3824f16911e4aa7cfd0 Mon Sep 17 00:00:00 2001 From: Ganesh Mahadevan Date: Tue, 20 Aug 2024 11:26:15 -0500 Subject: [PATCH] add ntile randn and locate function for scala with test case --- .../com/snowflake/snowpark/functions.scala | 136 +++++++++++++++++- .../snowpark_test/FunctionSuite.scala | 37 +++++ 2 files changed, 166 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 160c3112..65df29d9 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2,13 +2,8 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.internal.ScalaFunctions._ -import com.snowflake.snowpark.internal.{ - ErrorMessage, - OpenTelemetry, - UDXRegistrationHandler, - Utils -} -import com.snowflake.snowpark.types.TimestampType +import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils} +import com.snowflake.snowpark.types.{FloatType, TimestampType} import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -3433,6 +3428,133 @@ object functions { */ def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col) + /** + * + * Locate the position of the first occurrence of substr in a string column, after position pos. + * + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + *Example + * {{{ + * val df = session.createDataFrame(Seq(("b", "abcd"))).toDF("a", "b") + * df.select(locate(col("a"), col("b"), 1).as("locate")).show() + * ------------ + * |"LOCATE" | + * ------------ + * |2 | + * ------------ + * + * }}} + * @since 1.14.0 + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return returns the position of the first occurrence. + */ + def locate(substr: Column, str: Column, pos: Int): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos) + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + * Example + * {{{ + * val df = session.createDataFrame(Seq("java scala python")).toDF("a") + * df.select(locate("scala", col("a")).as("locate")).show() + * ------------ + * |"LOCATE" | + * ------------ + * |6 | + * ------------ + * + * }}} + * @since 1.14.0 + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return Returns the position of the first occurrence + */ + def locate(substr: String, str: Column, pos: Int = 1): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos)) + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * Example + * {{{ + * val df = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score") + * val window = Window.partitionBy(col("grade")).orderBy(col("score")) + * df.select(ntile(2).over(window).as("ntile")).show() + * ----------- + * |"NTILE" | + * ----------- + * |1 | + * |1 | + * |2 | + * |2 | + * ----------- + * }}} + * + * @since 1.14.0 + * @param n number of groups + * @return returns the ntile group id (from 1 to n inclusive) in an ordered window partition. + */ + def ntile(n: Int): Column = callBuiltin("ntile", lit(n)) + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the standard normal distribution. + * Return a call to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + * Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.withColumn("randn", randn()).select("randn").show() + * ------------------------ + * |"RANDN" | + * ------------------------ + * |-2093909082984812541 | + * |-1379817492278593383 | + * |-1231198046297539927 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @return Random number. + */ + def randn(): Column = + builtin("RANDOM")() + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the standard normal distribution. + * Calls to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + *Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show() + * ------------------------ + * |"RANDN_WITH_SEED" | + * ------------------------ + * |5777523539921853504 | + * |-8190739547906189845 | + * |-1138438814981368515 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @param seed Seed to use in the random function. + * @return Random number. + */ + def randn(seed: Long): Column = + builtin("RANDOM")(seed) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 3db8fd02..574b5bd1 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode import com.snowflake.snowpark._ import com.snowflake.snowpark.functions.{repeat, _} import com.snowflake.snowpark.types._ +import com.snowflake.snowpark_java.types.LongType import net.snowflake.client.jdbc.SnowflakeSQLException import java.sql.{Date, Time, Timestamp} @@ -2280,6 +2281,42 @@ trait FunctionSuite extends TestData { checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false) } + test("locate Column function") { + val input = session.createDataFrame(Seq + (("scala", "java scala python"), ("b", "abcd"))).toDF("a", "b") + val expected = Seq((6), (2)).toDF("locate") + checkAnswer(input.select(locate(col("a"), col("b"), 1).as("locate")), expected, sort = false) + } + + test("locate String function") { + + val input = session.createDataFrame(Seq("java scala python")).toDF("a") + val expected = Seq(6).toDF("locate") + checkAnswer(input.select(locate("scala", col("a")).as("locate")), expected, sort = false) + } + + test("ntile function") { + val input = Seq((5, 15), (5, 15), (5, 15), (5, 20)).toDF("grade", "score") + val window = Window.partitionBy(col("grade")).orderBy(col("score")) + val expected = Seq((1), (1), (2), (2)).toDF("ntile") + checkAnswer(input.select(ntile(2).over(window).as("ntile")), expected, sort = false) + } + + test("randn seed function") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + val expected = Seq((5777523539921853504L), (-8190739547906189845L), (-1138438814981368515L)) + .toDF("randn_with_seed") + val df = input.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed") + + checkAnswer(df, expected, sort = false) + } + + test("randn function") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + + assert(input.withColumn("randn", randn()).select("randn").first() != null) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession