Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-966358 Support TableFunction in DataFrame Select() #65

Merged
merged 5 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 63 additions & 32 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -561,18 +561,35 @@ class DataFrame private[snowpark] (
"Provide at least one column expression for select(). " +
s"This DataFrame has column names (${output.length}): " +
s"${output.map(_.name).mkString(", ")}\n")

val resultDF = withPlan { Project(columns.map(_.named), plan) }
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
// todo: error message
val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression])
tf.size match {
case 0 => // no table function
val resultDF = withPlan {
Project(columns.map(_.named), plan)
}
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
}
case 1 => // 1 table function
val base = this.join(tf.head)
val baseColumns = base.schema.map(field => base(field.name))
val inputDFColumnSize = this.schema.size
val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2
val (beforeTf, afterTf) = columns.span(_ != tf.head)
val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail
base.select(resultColumns)
case _ =>
// more than 1 TF
throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
}
}

Expand Down Expand Up @@ -1788,9 +1805,8 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args: _*), None)
}
def join(func: TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table
Expand Down Expand Up @@ -1822,12 +1838,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args: _*),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func` that takes
Expand Down Expand Up @@ -1859,9 +1873,8 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args), None)
}
def join(func: TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand Down Expand Up @@ -1900,12 +1913,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
Expand All @@ -1929,9 +1940,8 @@ class DataFrame private[snowpark] (
* @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]]
* object or an object that you create from the [[TableFunction.apply()]].
*/
def join(func: Column): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func), None)
}
def join(func: Column): DataFrame =
joinTableFunction(getTableFunctionExpression(func), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand All @@ -1951,11 +1961,32 @@ class DataFrame private[snowpark] (
* @param partitionBy A list of columns partitioned by.
* @param orderBy A list of columns ordered by.
*/
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame =
joinTableFunction(
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))

private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage {
"0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.",
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage {
def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException =
createException("0130")

def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

/*
* 2NN: UDF error code
*/
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void basicTypes() {
"create or replace temp table "
+ tableName
+ "(i1 smallint, i2 int, l1 bigint, f1 float, d1 double, "
+ "decimal number(38, 18), b boolean, s string, bi binary)";
+ "de number(38, 18), b boolean, s string, bi binary)";
runQuery(crt);
String insert =
"insert into "
Expand All @@ -68,7 +68,7 @@ public void basicTypes() {
col("l1"),
col("f1"),
col("d1"),
col("decimal"),
col("de"),
col("b"),
col("s"),
col("bi"));
Expand All @@ -82,7 +82,7 @@ public void basicTypes() {
.append("|--L1: Long (nullable = true)")
.append("|--F1: Double (nullable = true)")
.append("|--D1: Double (nullable = true)")
.append("|--DECIMAL: Decimal(38, 18) (nullable = true)")
.append("|--DE: Decimal(38, 18) (nullable = true)")
.append("|--B: Boolean (nullable = true)")
.append("|--S: String (nullable = true)")
.append("|--BI: Binary (nullable = true)")
Expand Down
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ class ErrorMessageSuite extends FunSuite {
" or TableFunctions only"))
}

test("DF_MORE_THAN_ONE_TF_IN_SELECT") {
val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131")))
assert(
ex.message.startsWith("Error Code: 0131, Error message: " +
"At most one table function can be called inside select() function"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,53 @@ class TableFunctionSuite extends TestData {
.select("value"),
Seq(Row("77"), Row("88")))
}

test("table function in select") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data")
// only tf
val result1 = df.select(tableFunctions.split_to_table(df("data"), ","))
assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE"))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))

// columns + tf
val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","))
assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE"))
checkAnswer(
result2,
Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4")))

// columns + tf + columns
val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx"))
assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX"))
checkAnswer(
result3,
Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2)))

// tf + other express
val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100)
checkAnswer(
result4,
Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102)))
}

test("table function join with duplicated column name") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value")
val result = df.join(tableFunctions.split_to_table(df("value"), lit(",")))
// only one VALUE in the result
checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

test("table function select with duplicated column name") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value")
val result1 = df.select(tableFunctions.split_to_table(df("value"), lit(",")))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))
val result = df.select(df("value"), tableFunctions.split_to_table(df("value"), lit(",")))
// only one VALUE in the result
checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

}
Loading