Skip to content

Commit

Permalink
add error code
Browse files Browse the repository at this point in the history
tf in select

refactor dataframe join

rename df join
  • Loading branch information
sfc-gh-bli committed Nov 20, 2023
1 parent 84e12fe commit 09d50b5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 32 deletions.
93 changes: 61 additions & 32 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -561,18 +561,35 @@ class DataFrame private[snowpark] (
"Provide at least one column expression for select(). " +
s"This DataFrame has column names (${output.length}): " +
s"${output.map(_.name).mkString(", ")}\n")

val resultDF = withPlan { Project(columns.map(_.named), plan) }
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
// todo: error message
val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression])
tf.size match {
case 0 =>
val resultDF = withPlan {
Project(columns.map(_.named), plan)
}
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
}
case 1 =>
val base = this.join(tf.head)
val baseColumns = base.schema.map(field => base(field.name))
val inputDFColumnSize = this.schema.size
val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2
val (beforeTf, afterTf) = columns.span(_ != tf.head)
val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail
base.select(resultColumns)
case _ =>
// more than 1 TF
throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
}
}

Expand Down Expand Up @@ -1788,9 +1805,8 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args: _*), None)
}
def join(func: TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table
Expand Down Expand Up @@ -1822,12 +1838,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args: _*),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func` that takes
Expand Down Expand Up @@ -1859,9 +1873,8 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args), None)
}
def join(func: TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand Down Expand Up @@ -1900,12 +1913,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
Expand All @@ -1929,9 +1940,8 @@ class DataFrame private[snowpark] (
* @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]]
* object or an object that you create from the [[TableFunction.apply()]].
*/
def join(func: Column): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func), None)
}
def join(func: Column): DataFrame =
joinTableFunction(getTableFunctionExpression(func), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand All @@ -1951,11 +1961,30 @@ class DataFrame private[snowpark] (
* @param partitionBy A list of columns partitioned by.
* @param orderBy A list of columns ordered by.
*/
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame =
joinTableFunction(
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))

private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
renamedDf.joinTableFunction(func, partitionByOrderBy)
} else {
originalResult
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage {
"0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.",
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage {
def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException =
createException("0130")

def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

/*
* 2NN: UDF error code
*/
Expand Down
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ class ErrorMessageSuite extends FunSuite {
" or TableFunctions only"))
}

test("DF_MORE_THAN_ONE_TF_IN_SELECT") {
val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131")))
assert(
ex.message.startsWith("Error Code: 0131, Error message: " +
"At most one table function can be called inside select() function"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,32 @@ class TableFunctionSuite extends TestData {
.select("value"),
Seq(Row("77"), Row("88")))
}

test("table function in select") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data")
// only tf
val result1 = df.select(tableFunctions.split_to_table(df("data"), ","))
assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE"))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))

// columns + tf
val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","))
assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE"))
checkAnswer(
result2,
Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4")))

// columns + tf + columns
val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx"))
assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX"))
checkAnswer(
result3,
Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2)))

// tf + other express
val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100)
checkAnswer(
result4,
Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102)))
}
}

0 comments on commit 09d50b5

Please sign in to comment.