diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 7666d9e8..babe590c 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -561,18 +561,35 @@ class DataFrame private[snowpark] ( "Provide at least one column expression for select(). " + s"This DataFrame has column names (${output.length}): " + s"${output.map(_.name).mkString(", ")}\n") - - val resultDF = withPlan { Project(columns.map(_.named), plan) } - // do not rename back if this project contains internal alias. - // because no named duplicated if just renamed. - val hasInternalAlias: Boolean = columns.map(_.expr).exists { - case Alias(_, _, true) => true - case _ => false - } - if (hasInternalAlias) { - resultDF - } else { - renameBackIfDeduped(resultDF) + // todo: error message + val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression]) + tf.size match { + case 0 => + val resultDF = withPlan { + Project(columns.map(_.named), plan) + } + // do not rename back if this project contains internal alias. + // because no named duplicated if just renamed. + val hasInternalAlias: Boolean = columns.map(_.expr).exists { + case Alias(_, _, true) => true + case _ => false + } + if (hasInternalAlias) { + resultDF + } else { + renameBackIfDeduped(resultDF) + } + case 1 => + val base = this.join(tf.head) + val baseColumns = base.schema.map(field => base(field.name)) + val inputDFColumnSize = this.schema.size + val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2 + val (beforeTf, afterTf) = columns.span(_ != tf.head) + val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail + base.select(resultColumns) + case _ => + // more than 1 TF + throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() } } @@ -1788,9 +1805,8 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args: _*), None) - } + def join(func: TableFunction, args: Seq[Column]): DataFrame = + joinTableFunction(func.call(args: _*), None) /** * Joins the current DataFrame with the output of the specified user-defined table @@ -1822,12 +1838,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Seq[Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func` that takes @@ -1859,9 +1873,8 @@ class DataFrame private[snowpark] ( * Some functions, like `flatten`, have named parameters. * Use this map to specify the parameter names and their corresponding values. */ - def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args), None) - } + def join(func: TableFunction, args: Map[String, Column]): DataFrame = + joinTableFunction(func.call(args), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1900,12 +1913,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Map[String, Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1929,9 +1940,8 @@ class DataFrame private[snowpark] ( * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] * object or an object that you create from the [[TableFunction.apply()]]. */ - def join(func: Column): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) - } + def join(func: Column): DataFrame = + joinTableFunction(getTableFunctionExpression(func), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1951,11 +1961,30 @@ class DataFrame private[snowpark] ( * @param partitionBy A list of columns partitioned by. * @param orderBy A list of columns ordered by. */ - def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = + joinTableFunction( getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + + private def joinTableFunction( + func: TableFunctionExpression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { + val originalResult = withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } + val resultSchema = originalResult.schema + val columnNames = resultSchema.map(_.name) + // duplicated names + val dup = columnNames.diff(columnNames.distinct).distinct + // guarantee no duplicated names in the result + if (dup.nonEmpty) { + val dfPrefix = DataFrame.generatePrefix('o') + val renamedDf = + this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) + renamedDf.joinTableFunction(func, partitionByOrderBy) + } else { + originalResult + } } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 3dbd5b89..505b2b6d 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage { "0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.", "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", + "0131" -> "At most one table function can be called inside select() function", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage { def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException = createException("0130") + def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException = + createException("0131") + /* * 2NN: UDF error code */ diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index d5ede212..072c73e8 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -296,6 +296,14 @@ class ErrorMessageSuite extends FunSuite { " or TableFunctions only")) } + test("DF_MORE_THAN_ONE_TF_IN_SELECT") { + val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131"))) + assert( + ex.message.startsWith("Error Code: 0131, Error message: " + + "At most one table function can be called inside select() function")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 46f0028f..c87a8f2c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -337,4 +337,32 @@ class TableFunctionSuite extends TestData { .select("value"), Seq(Row("77"), Row("88"))) } + + test("table function in select") { + val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data") + // only tf + val result1 = df.select(tableFunctions.split_to_table(df("data"), ",")) + assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE")) + checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4"))) + + // columns + tf + val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ",")) + assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE")) + checkAnswer( + result2, + Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4"))) + + // columns + tf + columns + val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx")) + assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX")) + checkAnswer( + result3, + Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2))) + + // tf + other express + val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100) + checkAnswer( + result4, + Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102))) + } }