Skip to content

Commit

Permalink
add log functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gmahadevan committed Aug 14, 2024
1 parent 14770b7 commit eb1b2ce
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
56 changes: 56 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3312,6 +3312,62 @@ object functions {
def last(c: Column): Column =
builtin("LAST_VALUE")(c)

/**
* Computes the logarithm of the given value in base 10.
* Example
* {{{
* val df = session.createDataFrame(Seq(100)).toDF("a")
* df.select(log10(col("a"))).show()
*
* -----------
* |"LOG10" |
* -----------
* |2.0 |
* -----------
* }}}
*
* @since 1.14.0
* @param c Column to apply logarithm operation
* @return log10 of the given column
*/
def log10(c: Column): Column = builtin("LOG")(10, c)

/**
* Computes the logarithm of the given column in base 10.
* Example
* {{{
* val df = session.createDataFrame(Seq(100)).toDF("a")
* df.select(log10("a"))).show()
* -----------
* |"LOG10" |
* -----------
* |2.0 |
* -----------
*
* }}}
*
* @since 1.14.0
* @param columnName Column to apply logarithm operation
* @return log10 of the given column
*/
def log10(columnName: String): Column = builtin("LOG")(10, col(columnName))

/**
* Computes the natural logarithm of the given value plus one.
* @since 1.14.0
* @param c the value to use
* @return the natural logarithm of the given value plus one.
*/
def log1p(c: Column): Column = callBuiltin("ln", lit(1) + c)

/**
* Computes the natural logarithm of the given value plus one.
* @since 1.14.0
* @param columnName the value to use
* @return the natural logarithm of the given value plus one.
*/
def log1p(columnName: String): Column = callBuiltin("ln", lit(1) + col(columnName))

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
43 changes: 42 additions & 1 deletion src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,6 @@ trait FunctionSuite extends TestData {
}

test("last function") {

val input =
Seq((5, "a", 10), (5, "b", 20), (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score")
val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
Expand All @@ -2245,6 +2244,48 @@ trait FunctionSuite extends TestData {
sort = false)
}

test("log10 Column function") {
val input = session.createDataFrame(Seq(100)).toDF("a")
val expected = Seq(2.0).toDF("log10")

checkAnswer(
input.select(log10(col("a")).as("log10")),
expected,
sort = false)
}

test("log10 String function") {
val input = session.createDataFrame(Seq("100")).toDF("a")
val expected = Seq(2.0).toDF("log10")

checkAnswer(
input.select(log10("a").as("log10")),
expected,
sort = false)
}

test("log1p Column function") {
val input = session.createDataFrame(Seq(100)).toDF("a")

input.select(log1p(col("a"))).show()
// val expected = Seq(2.0).toDF("log10")
//
// checkAnswer(
// input.select(log10(col("a")).as("log10")),
// expected,
// sort = false)
}
//
// test("log1p String function") {
// val input = session.createDataFrame(Seq("100")).toDF("a")
// val expected = Seq(2.0).toDF("log10")
//
// checkAnswer(
// input.select(log10("a").as("log10")),
// expected,
// sort = false)
// }

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down

0 comments on commit eb1b2ce

Please sign in to comment.