diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 74e0190a..416d8c1b 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -263,6 +263,34 @@ public static Column kurtosis(Column col) { return new Column(com.snowflake.snowpark.functions.kurtosis(col.toScalaColumn())); } + /** + * Returns the maximum value for the records in a group. NULL values are ignored unless all the + * records are NULL, in which case a NULL value is returned. + * + *
Example: + * + *
{@code + * DataFrame df = session.createDataFrame( + * new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)}, + * StructType.create(new StructField("x", DataTypes.IntegerType)) + * ); + * df.select(max("x")).show(); + * + * ---------------- + * |"MAX(""X"")" | + * ---------------- + * |10 | + * ---------------- + * }+ * + * @param colName The name of the column + * @return The maximum value of the given column + * @since 1.13.0 + */ + public static Column max(String colName) { + return new Column(com.snowflake.snowpark.functions.max(colName)); + } + /** * Returns the maximum value for the records in a group. NULL values are ignored unless all the * records are NULL, in which case a NULL value is returned. @@ -275,6 +303,34 @@ public static Column max(Column col) { return new Column(com.snowflake.snowpark.functions.max(col.toScalaColumn())); } + /** + * Returns the minimum value for the records in a group. NULL values are ignored unless all the + * records are NULL, in which case a NULL value is returned. + * + *
Example: + * + *
{@code + * DataFrame df = session.createDataFrame( + * new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)}, + * StructType.create(new StructField("x", DataTypes.IntegerType)) + * ); + * df.select(min("x")).show(); + * + * ---------------- + * |"MIN(""X"")" | + * ---------------- + * |1 | + * ---------------- + * }+ * + * @param colName The name of the column + * @return The minimum value of the given column + * @since 1.13.0 + */ + public static Column min(String colName) { + return new Column(com.snowflake.snowpark.functions.min(colName)); + } + /** * Returns the minimum value for the records in a group. NULL values are ignored unless all the * records are NULL, in which case a NULL value is returned. @@ -287,6 +343,34 @@ public static Column min(Column col) { return new Column(com.snowflake.snowpark.functions.min(col.toScalaColumn())); } + /** + * Returns the average of non-NULL records. If all records inside a group are NULL, the function + * returns NULL. Alias of avg. + * + *
Example: + * + *
{@code + * DataFrame df = session.createDataFrame( + * new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)}, + * StructType.create(new StructField("x", DataTypes.IntegerType)) + * ); + * df.select(mean("x")).show(); + * + * ---------------- + * |"AVG(""X"")" | + * ---------------- + * |3.600000 | + * ---------------- + * }+ * + * @param colName The name of the column + * @return The average value of the given column + * @since 1.13.0 + */ + public static Column mean(String colName) { + return new Column(com.snowflake.snowpark.functions.mean(colName)); + } + /** * Returns the average of non-NULL records. If all records inside a group are NULL, the function * returns NULL. Alias of avg diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 445fdab0..ee2676da 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -265,6 +265,29 @@ object functions { */ def kurtosis(e: Column): Column = builtin("kurtosis")(e) + /** + * Returns the maximum value for the records in a group. NULL values are ignored unless all + * the records are NULL, in which case a NULL value is returned. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x") + * df.select(max("x")).show() + * + * ---------------- + * |"MAX(""X"")" | + * ---------------- + * |10 | + * ---------------- + * }}} + * + * @param colName The name of the column + * @return The maximum value of the given column + * @group agg_func + * @since 1.13.0 + */ + def max(colName: String): Column = max(col(colName)) + /** * Returns the maximum value for the records in a group. NULL values are ignored unless all * the records are NULL, in which case a NULL value is returned. @@ -282,6 +305,29 @@ object functions { */ def any_value(e: Column): Column = builtin("any_value")(e) + /** + * Returns the average of non-NULL records. If all records inside a group are NULL, + * the function returns NULL. Alias of avg. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x") + * df.select(mean("x")).show() + * + * ---------------- + * |"AVG(""X"")" | + * ---------------- + * |3.600000 | + * ---------------- + * }}} + * + * @param colName The name of the column + * @return The average value of the given column + * @group agg_func + * @since 1.13.0 + */ + def mean(colName: String): Column = mean(col(colName)) + /** * Returns the average of non-NULL records. If all records inside a group are NULL, * the function returns NULL. Alias of avg @@ -302,6 +348,29 @@ object functions { builtin("median")(e) } + /** + * Returns the minimum value for the records in a group. NULL values are ignored unless all + * the records are NULL, in which case a NULL value is returned. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x") + * df.select(min("x")).show() + * + * ---------------- + * |"MIN(""X"")" | + * ---------------- + * |1 | + * ---------------- + * }}} + * + * @param colName The name of the column + * @return The minimum value of the given column + * @group agg_func + * @since 1.13.0 + */ + def min(colName: String): Column = min(col(colName)) + /** * Returns the minimum value for the records in a group. NULL values are ignored unless all * the records are NULL, in which case a NULL value is returned. diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 3de8cdab..42253362 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -143,16 +143,48 @@ public void kurtosis() { @Test public void max_min_mean() { - DataFrame df = + // Case 01: Non-null values + DataFrame df1 = getSession() .sql("select * from values(1,2,1),(1,2,3),(2,1,10),(2,2,1),(2,2,3) as T(x,y,z)"); - Row[] expected = {Row.create(2, 1, 3.600000)}; + Row[] expected1 = {Row.create(2, 1, 3.600000)}; checkAnswer( - df.select( - Functions.max(df.col("x")), Functions.min(df.col("y")), Functions.mean(df.col("z"))), - expected, + df1.select( + Functions.max(df1.col("x")), Functions.min(df1.col("y")), Functions.mean(df1.col("z"))), + expected1, + false); + checkAnswer( + df1.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected1, false); + + // Case 02: Some null values + DataFrame df2 = + getSession() + .sql("select * from values(1,5,8),(null,8,7),(3,null,9),(4,6,null) as T(x,y,z)"); + Row[] expected2 = {Row.create(4, 5, 8.000000)}; + + checkAnswer( + df2.select( + Functions.max(df2.col("x")), Functions.min(df2.col("y")), Functions.mean(df2.col("z"))), + expected2, false); + checkAnswer( + df2.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected2, false); + + // Case 03: All null values + DataFrame df3 = + getSession() + .sql( + "select * from values(null,null,null),(null,null,null),(null,null,null) as T(x,y,z)"); + Row[] expected3 = {Row.create(null, null, null)}; + + checkAnswer( + df3.select( + Functions.max(df3.col("x")), Functions.min(df3.col("y")), Functions.mean(df3.col("z"))), + expected3, + false); + checkAnswer( + df3.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected3, false); } @Test diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 4b6dfdca..38cad6c5 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -76,9 +76,20 @@ trait FunctionSuite extends TestData { } test("max, min, mean") { - checkAnswer( - xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))), - Seq(Row(2, 1, 3.600000))) + // Case 01: Non-null values + val expected1 = Seq(Row(2, 1, 3.600000)) + checkAnswer(xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))), expected1) + checkAnswer(xyz.select(max("X"), min("Y"), mean("Z")), expected1) + + // Case 02: Some null values + val expected2 = Seq(Row(3, 1, 2.000000)) + checkAnswer(nullInts.select(max(col("A")), min(col("A")), mean(col("A"))), expected2) + checkAnswer(nullInts.select(max("A"), min("A"), mean("A")), expected2) + + // Case 03: All null values + val expected3 = Seq(Row(null, null, null)) + checkAnswer(allNulls.select(max(col("A")), min(col("A")), mean(col("A"))), expected3) + checkAnswer(allNulls.select(max("A"), min("A"), mean("A")), expected3) } test("skew") {