diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 1c2e12c7..977205a6 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -372,6 +372,17 @@ public static Column sum(Column col) { return new Column(com.snowflake.snowpark.functions.sum(col.toScalaColumn())); } + /** + * Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute the + * sum of unique non-null values. If all records inside a group are NULL, the function returns + * NULL. + * + * @since 0.9.0 + * @param str The input string + * @return The result column + */ + public static Column sum(String str) { return sum(col(str)); } + /** * Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to * compute the sum of unique non-null values. If all records inside a group are NULL, the function diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 8361b78d..83f55400 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -357,6 +357,18 @@ object functions { */ def sum(e: Column): Column = builtin("sum")(e) + /** + * Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute + * the sum of unique non-null values. If all records inside a group are NULL, + * the function returns NULL. + * + * @group agg_func + * @since 0.1.0 + * @param e The input string + * @return The result column + */ + def sum(e: String): Column = sum(col(e)) + /** * Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to * compute the sum of unique non-null values. If all records inside a group are NULL, diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 6e93454e..3de8cdab 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -195,6 +195,8 @@ public void sum() { checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum(df.col("a"))), expected, false); + checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum("a")), expected, false); + Row[] expected1 = {Row.create(3, 3), Row.create(2, 2), Row.create(1, 1)}; checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum_distinct(df.col("a"))), expected1, false); } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 274e88f5..4b6dfdca 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -99,6 +99,11 @@ trait FunctionSuite extends TestData { Seq(Row(3, 6), Row(2, 4), Row(1, 1)), sort = false) + checkAnswer( + duplicatedNumbers.groupBy("A").agg(sum("A")), + Seq(Row(3, 6), Row(2, 4), Row(1, 1)), + sort = false) + checkAnswer( duplicatedNumbers.groupBy("A").agg(sum_distinct(col("A"))), Seq(Row(3, 3), Row(2, 2), Row(1, 1)),