Skip to content

Commit

Permalink
SCT-10981 Add an overload method for functions.max, functions.min and…
Browse files Browse the repository at this point in the history
… functions.mean (#94)
  • Loading branch information
sfc-gh-fgonzalezmendez authored Apr 10, 2024
1 parent 29da4bc commit c01d321
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 8 deletions.
84 changes: 84 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,34 @@ public static Column kurtosis(Column col) {
return new Column(com.snowflake.snowpark.functions.kurtosis(col.toScalaColumn()));
}

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(max("x")).show();
*
* ----------------
* |"MAX(""X"")" |
* ----------------
* |10 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The maximum value of the given column
* @since 1.13.0
*/
public static Column max(String colName) {
return new Column(com.snowflake.snowpark.functions.max(colName));
}

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
Expand All @@ -275,6 +303,34 @@ public static Column max(Column col) {
return new Column(com.snowflake.snowpark.functions.max(col.toScalaColumn()));
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(min("x")).show();
*
* ----------------
* |"MIN(""X"")" |
* ----------------
* |1 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The minimum value of the given column
* @since 1.13.0
*/
public static Column min(String colName) {
return new Column(com.snowflake.snowpark.functions.min(colName));
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
Expand All @@ -287,6 +343,34 @@ public static Column min(Column col) {
return new Column(com.snowflake.snowpark.functions.min(col.toScalaColumn()));
}

/**
* Returns the average of non-NULL records. If all records inside a group are NULL, the function
* returns NULL. Alias of avg.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(mean("x")).show();
*
* ----------------
* |"AVG(""X"")" |
* ----------------
* |3.600000 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The average value of the given column
* @since 1.13.0
*/
public static Column mean(String colName) {
return new Column(com.snowflake.snowpark.functions.mean(colName));
}

/**
* Returns the average of non-NULL records. If all records inside a group are NULL, the function
* returns NULL. Alias of avg
Expand Down
69 changes: 69 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ object functions {
*/
def kurtosis(e: Column): Column = builtin("kurtosis")(e)

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(max("x")).show()
*
* ----------------
* |"MAX(""X"")" |
* ----------------
* |10 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The maximum value of the given column
* @group agg_func
* @since 1.13.0
*/
def max(colName: String): Column = max(col(colName))

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
Expand All @@ -282,6 +305,29 @@ object functions {
*/
def any_value(e: Column): Column = builtin("any_value")(e)

/**
* Returns the average of non-NULL records. If all records inside a group are NULL,
* the function returns NULL. Alias of avg.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(mean("x")).show()
*
* ----------------
* |"AVG(""X"")" |
* ----------------
* |3.600000 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The average value of the given column
* @group agg_func
* @since 1.13.0
*/
def mean(colName: String): Column = mean(col(colName))

/**
* Returns the average of non-NULL records. If all records inside a group are NULL,
* the function returns NULL. Alias of avg
Expand All @@ -302,6 +348,29 @@ object functions {
builtin("median")(e)
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(min("x")).show()
*
* ----------------
* |"MIN(""X"")" |
* ----------------
* |1 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The minimum value of the given column
* @group agg_func
* @since 1.13.0
*/
def min(colName: String): Column = min(col(colName))

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
Expand Down
42 changes: 37 additions & 5 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,48 @@ public void kurtosis() {

@Test
public void max_min_mean() {
DataFrame df =
// Case 01: Non-null values
DataFrame df1 =
getSession()
.sql("select * from values(1,2,1),(1,2,3),(2,1,10),(2,2,1),(2,2,3) as T(x,y,z)");
Row[] expected = {Row.create(2, 1, 3.600000)};
Row[] expected1 = {Row.create(2, 1, 3.600000)};

checkAnswer(
df.select(
Functions.max(df.col("x")), Functions.min(df.col("y")), Functions.mean(df.col("z"))),
expected,
df1.select(
Functions.max(df1.col("x")), Functions.min(df1.col("y")), Functions.mean(df1.col("z"))),
expected1,
false);
checkAnswer(
df1.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected1, false);

// Case 02: Some null values
DataFrame df2 =
getSession()
.sql("select * from values(1,5,8),(null,8,7),(3,null,9),(4,6,null) as T(x,y,z)");
Row[] expected2 = {Row.create(4, 5, 8.000000)};

checkAnswer(
df2.select(
Functions.max(df2.col("x")), Functions.min(df2.col("y")), Functions.mean(df2.col("z"))),
expected2,
false);
checkAnswer(
df2.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected2, false);

// Case 03: All null values
DataFrame df3 =
getSession()
.sql(
"select * from values(null,null,null),(null,null,null),(null,null,null) as T(x,y,z)");
Row[] expected3 = {Row.create(null, null, null)};

checkAnswer(
df3.select(
Functions.max(df3.col("x")), Functions.min(df3.col("y")), Functions.mean(df3.col("z"))),
expected3,
false);
checkAnswer(
df3.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected3, false);
}

@Test
Expand Down
17 changes: 14 additions & 3 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,20 @@ trait FunctionSuite extends TestData {
}

test("max, min, mean") {
checkAnswer(
xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))),
Seq(Row(2, 1, 3.600000)))
// Case 01: Non-null values
val expected1 = Seq(Row(2, 1, 3.600000))
checkAnswer(xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))), expected1)
checkAnswer(xyz.select(max("X"), min("Y"), mean("Z")), expected1)

// Case 02: Some null values
val expected2 = Seq(Row(3, 1, 2.000000))
checkAnswer(nullInts.select(max(col("A")), min(col("A")), mean(col("A"))), expected2)
checkAnswer(nullInts.select(max("A"), min("A"), mean("A")), expected2)

// Case 03: All null values
val expected3 = Seq(Row(null, null, null))
checkAnswer(allNulls.select(max(col("A")), min(col("A")), mean(col("A"))), expected3)
checkAnswer(allNulls.select(max("A"), min("A"), mean("A")), expected3)
}

test("skew") {
Expand Down

0 comments on commit c01d321

Please sign in to comment.