Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCT-10981 Add an overload method for functions.max, functions.min and functions.mean #94

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,34 @@ public static Column kurtosis(Column col) {
return new Column(com.snowflake.snowpark.functions.kurtosis(col.toScalaColumn()));
}

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(max("x")).show();
*
* ----------------
* |"MAX(""X"")" |
* ----------------
* |10 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The maximum value of the given column
* @since 1.13.0
*/
public static Column max(String colName) {
return new Column(com.snowflake.snowpark.functions.max(colName));
}

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
Expand All @@ -275,6 +303,34 @@ public static Column max(Column col) {
return new Column(com.snowflake.snowpark.functions.max(col.toScalaColumn()));
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(min("x")).show();
*
* ----------------
* |"MIN(""X"")" |
* ----------------
* |1 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The minimum value of the given column
* @since 1.13.0
*/
public static Column min(String colName) {
return new Column(com.snowflake.snowpark.functions.min(colName));
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all the
* records are NULL, in which case a NULL value is returned.
Expand All @@ -287,6 +343,34 @@ public static Column min(Column col) {
return new Column(com.snowflake.snowpark.functions.min(col.toScalaColumn()));
}

/**
* Returns the average of non-NULL records. If all records inside a group are NULL, the function
* returns NULL. Alias of avg.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.createDataFrame(
* new Row[] {Row.create(1), Row.create(3), Row.create(10), Row.create(1), Row.create(3)},
* StructType.create(new StructField("x", DataTypes.IntegerType))
* );
* df.select(mean("x")).show();
*
* ----------------
* |"AVG(""X"")" |
* ----------------
* |3.600000 |
* ----------------
* }</pre>
*
* @param colName The name of the column
* @return The average value of the given column
* @since 1.13.0
*/
public static Column mean(String colName) {
return new Column(com.snowflake.snowpark.functions.mean(colName));
}

/**
* Returns the average of non-NULL records. If all records inside a group are NULL, the function
* returns NULL. Alias of avg
Expand Down
69 changes: 69 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ object functions {
*/
def kurtosis(e: Column): Column = builtin("kurtosis")(e)

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(max("x")).show()
*
* ----------------
* |"MAX(""X"")" |
* ----------------
* |10 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The maximum value of the given column
* @group agg_func
* @since 1.13.0
*/
def max(colName: String): Column = max(col(colName))

/**
* Returns the maximum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
Expand All @@ -282,6 +305,29 @@ object functions {
*/
def any_value(e: Column): Column = builtin("any_value")(e)

/**
* Returns the average of non-NULL records. If all records inside a group are NULL,
* the function returns NULL. Alias of avg.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(mean("x")).show()
*
* ----------------
* |"AVG(""X"")" |
* ----------------
* |3.600000 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The average value of the given column
* @group agg_func
* @since 1.13.0
*/
def mean(colName: String): Column = mean(col(colName))

/**
* Returns the average of non-NULL records. If all records inside a group are NULL,
* the function returns NULL. Alias of avg
Expand All @@ -302,6 +348,29 @@ object functions {
builtin("median")(e)
}

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 3, 10, 1, 3)).toDF("x")
* df.select(min("x")).show()
*
* ----------------
* |"MIN(""X"")" |
* ----------------
* |1 |
* ----------------
* }}}
*
* @param colName The name of the column
* @return The minimum value of the given column
* @group agg_func
* @since 1.13.0
*/
def min(colName: String): Column = min(col(colName))

/**
* Returns the minimum value for the records in a group. NULL values are ignored unless all
* the records are NULL, in which case a NULL value is returned.
Expand Down
42 changes: 37 additions & 5 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,48 @@ public void kurtosis() {

@Test
public void max_min_mean() {
DataFrame df =
// Case 01: Non-null values
DataFrame df1 =
getSession()
.sql("select * from values(1,2,1),(1,2,3),(2,1,10),(2,2,1),(2,2,3) as T(x,y,z)");
Row[] expected = {Row.create(2, 1, 3.600000)};
Row[] expected1 = {Row.create(2, 1, 3.600000)};

checkAnswer(
df.select(
Functions.max(df.col("x")), Functions.min(df.col("y")), Functions.mean(df.col("z"))),
expected,
df1.select(
Functions.max(df1.col("x")), Functions.min(df1.col("y")), Functions.mean(df1.col("z"))),
expected1,
false);
checkAnswer(
df1.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected1, false);

// Case 02: Some null values
DataFrame df2 =
getSession()
.sql("select * from values(1,5,8),(null,8,7),(3,null,9),(4,6,null) as T(x,y,z)");
Row[] expected2 = {Row.create(4, 5, 8.000000)};

checkAnswer(
df2.select(
Functions.max(df2.col("x")), Functions.min(df2.col("y")), Functions.mean(df2.col("z"))),
expected2,
false);
checkAnswer(
df2.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected2, false);

// Case 03: All null values
DataFrame df3 =
getSession()
.sql(
"select * from values(null,null,null),(null,null,null),(null,null,null) as T(x,y,z)");
Row[] expected3 = {Row.create(null, null, null)};

checkAnswer(
df3.select(
Functions.max(df3.col("x")), Functions.min(df3.col("y")), Functions.mean(df3.col("z"))),
expected3,
false);
checkAnswer(
df3.select(Functions.max("x"), Functions.min("y"), Functions.mean("z")), expected3, false);
}

@Test
Expand Down
17 changes: 14 additions & 3 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,20 @@ trait FunctionSuite extends TestData {
}

test("max, min, mean") {
checkAnswer(
xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))),
Seq(Row(2, 1, 3.600000)))
// Case 01: Non-null values
val expected1 = Seq(Row(2, 1, 3.600000))
checkAnswer(xyz.select(max(col("X")), min(col("Y")), mean(col("Z"))), expected1)
checkAnswer(xyz.select(max("X"), min("Y"), mean("Z")), expected1)

// Case 02: Some null values
val expected2 = Seq(Row(3, 1, 2.000000))
checkAnswer(nullInts.select(max(col("A")), min(col("A")), mean(col("A"))), expected2)
checkAnswer(nullInts.select(max("A"), min("A"), mean("A")), expected2)

// Case 03: All null values
val expected3 = Seq(Row(null, null, null))
checkAnswer(allNulls.select(max(col("A")), min(col("A")), mean(col("A"))), expected3)
checkAnswer(allNulls.select(max("A"), min("A"), mean("A")), expected3)
}

test("skew") {
Expand Down
Loading