Skip to content

Commit

Permalink
tf in select
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Nov 15, 2023
1 parent 6bc8d62 commit 276d10a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
41 changes: 28 additions & 13 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,19 +562,34 @@ class DataFrame private[snowpark] (
s"This DataFrame has column names (${output.length}): " +
s"${output.map(_.name).mkString(", ")}\n")
// todo: error message
require(columns.count(_.expr.isInstanceOf[TableFunctionExpression]) <= 1, "error")

val resultDF = withPlan { Project(columns.map(_.named), plan) }
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression])
tf.size match {
case 0 =>
val resultDF = withPlan {
Project(columns.map(_.named), plan)
}
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
}
case 1 =>
val base = this.join(tf.head)
val baseColumns = base.schema.map(field => base(field.name))
val inputDFColumnSize = this.schema.size
val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2
val (beforeTf, afterTf) = columns.span(_ != tf.head)
val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail
base.select(resultColumns)
case _ =>
// more than 1 TF
throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ class ErrorMessageSuite extends FunSuite {
val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131")))
assert(
ex.message.startsWith(
"Error Code: 0131, Error message: " +
"At most one table function can be called inside select() function"))
ex.message.startsWith("Error Code: 0131, Error message: " +
"At most one table function can be called inside select() function"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,32 @@ class TableFunctionSuite extends TestData {
.select("value"),
Seq(Row("77"), Row("88")))
}

test("table function in select") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data")
// only tf
val result1 = df.select(tableFunctions.split_to_table(df("data"), ","))
assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE"))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))

// columns + tf
val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","))
assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE"))
checkAnswer(
result2,
Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4")))

// columns + tf + columns
val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx"))
assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX"))
checkAnswer(
result3,
Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2)))

// tf + other express
val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100)
checkAnswer(
result4,
Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102)))
}
}

0 comments on commit 276d10a

Please sign in to comment.