From dd37e31ab720e435d48c5ae6632b22e5952e95f6 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 9 Nov 2023 14:08:03 -0800 Subject: [PATCH 01/23] convert lazy val in tableFunctions to def --- src/main/scala/com/snowflake/snowpark/tableFunctions.scala | 4 ++-- .../com/snowflake/snowpark_test/TableFunctionSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index ae9dc32f..9abd80b4 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -55,7 +55,7 @@ object tableFunctions { * * @since 0.4.0 */ - lazy val split_to_table: TableFunction = TableFunction("split_to_table") + def split_to_table(): TableFunction = TableFunction("split_to_table") /** * Flattens (explodes) compound values into multiple rows. @@ -105,5 +105,5 @@ object tableFunctions { * * @since 0.4.0 */ - lazy val flatten: TableFunction = TableFunction("flatten") + def flatten(): TableFunction = TableFunction("flatten") } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 9f72e7a1..617d6477 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -182,4 +182,11 @@ class TableFunctionSuite extends TestData { |---------------------------------------- |""".stripMargin) } + + test("Argument in table function") { + val df = Seq((1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), + (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") + + df.join(tableFunctions.flatten, df("arr")).show() + } } From 6200dbacb3e73159f4bbe79e1114a6598662c9ad Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 9 Nov 2023 14:42:40 -0800 Subject: [PATCH 02/23] add join exception --- src/main/scala/com/snowflake/snowpark/DataFrame.scala | 7 +++++++ .../com/snowflake/snowpark/internal/ErrorMessage.scala | 4 ++++ .../scala/com/snowflake/snowpark/tableFunctions.scala | 4 ++++ .../scala/com/snowflake/snowpark/ErrorMessageSuite.scala | 9 +++++++++ 4 files changed, 24 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 56307cb6..b6fd21ba 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1903,6 +1903,13 @@ class DataFrame private[snowpark] ( Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } + def join(func: Column): DataFrame = withPlan { + func.expr match { + case tf: TableFunction => null + case _ => null + } + } + /** * Performs a cross join, which returns the cartesian product of the current DataFrame and * another DataFrame (`right`). diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 7b8ade4e..cc32b72f 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -67,6 +67,7 @@ private[snowpark] object ErrorMessage { "0127" -> "DataFrameWriter doesn't support to set option '%s' as '%s' when writing to a %s.", "0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.", "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", + "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -239,6 +240,9 @@ private[snowpark] object ErrorMessage { def DF_WRITER_INVALID_MODE(mode: String, target: String): SnowparkClientException = createException("0129", mode, target) + def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException = + createException("0130") + /* * 2NN: UDF error code */ diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 9abd80b4..0e59c76a 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -106,4 +106,8 @@ object tableFunctions { * @since 0.4.0 */ def flatten(): TableFunction = TableFunction("flatten") + + def flatten(input: Column): Column = Column(flatten().apply(input)) + +// def flatten(input: Column, path: String, outer: Boolean, recursive: Boolean): Column = null } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index e89e6382..31f59aed 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -286,6 +286,15 @@ class ErrorMessageSuite extends FunSuite { "DataFrameWriter doesn't support mode 'Append' when writing to a file.")) } + test("DF_JOIN_WITH_WRONG_ARGUMENT") { + val ex = ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0130"))) + assert( + ex.message.startsWith("Error Code: 0130, Error message: " + + "Unsupported join operations, Dataframes can join with other Dataframes" + + " or TableFunctions only")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) From af76e69a0553e273565c0502c4c08e7cc80d93f3 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 9 Nov 2023 15:06:54 -0800 Subject: [PATCH 03/23] scala flatten1 --- src/main/scala/com/snowflake/snowpark/DataFrame.scala | 9 +++++++-- .../scala/com/snowflake/snowpark/tableFunctions.scala | 6 +++--- .../com/snowflake/snowpark_test/TableFunctionSuite.scala | 7 +++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index b6fd21ba..9ae544c4 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1905,8 +1905,13 @@ class DataFrame private[snowpark] ( def join(func: Column): DataFrame = withPlan { func.expr match { - case tf: TableFunction => null - case _ => null + case tf: TableFunctionExpression => + TableFunctionJoin( + this.plan, + tf, + None + ) + case _ => throw ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() } } diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 0e59c76a..8fff2a7e 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -55,7 +55,7 @@ object tableFunctions { * * @since 0.4.0 */ - def split_to_table(): TableFunction = TableFunction("split_to_table") + lazy val split_to_table: TableFunction = TableFunction("split_to_table") /** * Flattens (explodes) compound values into multiple rows. @@ -105,9 +105,9 @@ object tableFunctions { * * @since 0.4.0 */ - def flatten(): TableFunction = TableFunction("flatten") + lazy val flatten: TableFunction = TableFunction("flatten") - def flatten(input: Column): Column = Column(flatten().apply(input)) + def flatten(input: Column): Column = Column(flatten.apply(input)) // def flatten(input: Column, path: String, outer: Boolean, recursive: Boolean): Column = null } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 617d6477..8d3530ee 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -3,6 +3,7 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.functions._ import com.snowflake.snowpark._ + class TableFunctionSuite extends TestData { import session.implicits._ @@ -186,7 +187,9 @@ class TableFunctionSuite extends TestData { test("Argument in table function") { val df = Seq((1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") - - df.join(tableFunctions.flatten, df("arr")).show() + checkAnswer( + df.join(tableFunctions.flatten(df("arr"))) + .select("value"), + Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) } } From 908fd04b368deb2783a204d30e3acaab6fad1dd0 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 10:11:44 -0800 Subject: [PATCH 04/23] flatten --- .../com/snowflake/snowpark/DataFrame.scala | 6 +--- .../snowflake/snowpark/tableFunctions.scala | 11 ++++++- .../snowpark/ErrorMessageSuite.scala | 7 +++-- .../snowpark_test/TableFunctionSuite.scala | 31 +++++++++++++++++-- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 9ae544c4..367e015f 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1906,11 +1906,7 @@ class DataFrame private[snowpark] ( def join(func: Column): DataFrame = withPlan { func.expr match { case tf: TableFunctionExpression => - TableFunctionJoin( - this.plan, - tf, - None - ) + TableFunctionJoin(this.plan, tf, None) case _ => throw ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() } } diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 8fff2a7e..d67d06db 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -1,5 +1,7 @@ package com.snowflake.snowpark +import com.snowflake.snowpark.functions.lit + // scalastyle:off /** * Provides utility functions that generate table function expressions that can be @@ -109,5 +111,12 @@ object tableFunctions { def flatten(input: Column): Column = Column(flatten.apply(input)) -// def flatten(input: Column, path: String, outer: Boolean, recursive: Boolean): Column = null + def flatten(input: Column, path: String, outer: Boolean, recursive: Boolean): Column = + Column( + flatten.apply( + Map( + "input" -> input, + "path" -> lit(path), + "outer" -> lit(outer), + "recursive" -> lit(recursive)))) } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 31f59aed..0fa30d8b 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -290,9 +290,10 @@ class ErrorMessageSuite extends FunSuite { val ex = ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0130"))) assert( - ex.message.startsWith("Error Code: 0130, Error message: " + - "Unsupported join operations, Dataframes can join with other Dataframes" + - " or TableFunctions only")) + ex.message.startsWith( + "Error Code: 0130, Error message: " + + "Unsupported join operations, Dataframes can join with other Dataframes" + + " or TableFunctions only")) } test("UDF_INCORRECT_ARGS_NUMBER") { diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 8d3530ee..84e85810 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -1,8 +1,9 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.functions._ -import com.snowflake.snowpark._ +import com.snowflake.snowpark.{Row, _} +import scala.collection.Seq class TableFunctionSuite extends TestData { import session.implicits._ @@ -184,12 +185,36 @@ class TableFunctionSuite extends TestData { |""".stripMargin) } - test("Argument in table function") { - val df = Seq((1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), + test("Argument in table function: flatten") { + val df = Seq( + (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") checkAnswer( df.join(tableFunctions.flatten(df("arr"))) .select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) + // error if it is not a table function + val error1 = intercept[SnowparkClientException] { df.join(lit("dummy")) } + assert( + error1.message.contains("Unsupported join operations, Dataframes can join " + + "with other Dataframes or TableFunctions only")) + } + + test("Argument in table function: flatten2") { + val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") + checkAnswer( + df1.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "b", outer = true, recursive = true)).select("value"), + Seq(Row("77"), Row("88"))) + + val df2 = Seq("[]").toDF("col") + checkAnswer(df2.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "", outer = true, recursive = true)).select("value"), + Seq(Row(null))) + + assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "", outer = true, recursive = true)).count() == 4) + assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "", outer = true, recursive = false)).count() == 2) } } From 0773ec16a57a14c3dd4b48370ec483d4aa9adf18 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 11:19:06 -0800 Subject: [PATCH 05/23] change join --- src/main/scala/com/snowflake/snowpark/DataFrame.scala | 8 ++------ .../scala/com/snowflake/snowpark/internal/Utils.scala | 10 +++++++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 367e015f..f8f22fc5 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -7,7 +7,7 @@ import com.snowflake.snowpark.internal.{Logging, Utils} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import com.github.vertical_blank.sqlformatter.SqlFormatter -import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject} +import com.snowflake.snowpark.internal.Utils.{TempObjectType, getTableFunctionExpression, randomNameForTempObject} import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -1904,11 +1904,7 @@ class DataFrame private[snowpark] ( } def join(func: Column): DataFrame = withPlan { - func.expr match { - case tf: TableFunctionExpression => - TableFunctionJoin(this.plan, tf, None) - case _ => throw ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() - } + TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala index 26973709..76f06f73 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala @@ -1,6 +1,7 @@ package com.snowflake.snowpark.internal -import com.snowflake.snowpark.internal.analyzer.{Attribute, singleQuote} +import com.snowflake.snowpark.Column +import com.snowflake.snowpark.internal.analyzer.{Attribute, TableFunctionExpression, singleQuote} import java.io.{File, FileInputStream} import java.lang.invoke.SerializedLambda @@ -419,4 +420,11 @@ object Utils extends Logging { .map(newName => Attribute(newName, att.dataType, att.nullable, att.exprId)) .getOrElse(att)) } + + private[snowpark] def getTableFunctionExpression(col: Column): TableFunctionExpression = { + col.expr match { + case tf: TableFunctionExpression => tf + case _ => throw ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT() + } + } } From 4c0519db334531557c597620f6e302c616c46c3e Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 11:20:50 -0800 Subject: [PATCH 06/23] join orderby --- src/main/scala/com/snowflake/snowpark/DataFrame.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index f8f22fc5..ee33762e 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1907,6 +1907,12 @@ class DataFrame private[snowpark] ( TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) } + // todo: add test with UDTF + def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { + TableFunctionJoin(this.plan, getTableFunctionExpression(func), + Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + } + /** * Performs a cross join, which returns the cartesian product of the current DataFrame and * another DataFrame (`right`). From dee1584fa8f007520e61a96db95172e43c050beb Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 11:47:46 -0800 Subject: [PATCH 07/23] session table function --- .../com/snowflake/snowpark/Session.scala | 24 ++++++++++++------- .../snowpark/internal/ErrorMessage.scala | 6 ++++- .../snowflake/snowpark/tableFunctions.scala | 6 +++-- .../snowpark/ErrorMessageSuite.scala | 9 +++++++ .../snowpark_test/TableFunctionSuite.scala | 12 ++++++---- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 41cc6fa2..c5a75cbb 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -7,18 +7,12 @@ import java.util.{Properties, Map => JMap, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ForkJoinPool, ForkJoinWorkerThread} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.internal._ +import com.snowflake.snowpark.internal.analyzer.{TableFunction => TFunction} import com.snowflake.snowpark.types._ import com.snowflake.snowpark.functions._ -import com.snowflake.snowpark.internal.ErrorMessage.{ - UDF_CANNOT_ACCEPT_MANY_DF_COLS, - UDF_UNEXPECTED_COLUMN_ORDER -} +import com.snowflake.snowpark.internal.ErrorMessage.{UDF_CANNOT_ACCEPT_MANY_DF_COLS, UDF_UNEXPECTED_COLUMN_ORDER} import com.snowflake.snowpark.internal.ParameterUtils.ClosureCleanerMode -import com.snowflake.snowpark.internal.Utils.{ - TempObjectNamePattern, - TempObjectType, - randomNameForTempObject -} +import com.snowflake.snowpark.internal.Utils.{TempObjectNamePattern, TempObjectType, getTableFunctionExpression, randomNameForTempObject} import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, SnowflakeSQLException} import scala.concurrent.{ExecutionContext, Future} @@ -578,6 +572,18 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log } } + def tableFunction(func: Column): DataFrame = { + func.expr match { + case TFunction(funcName, args) => + tableFunction(TableFunction(funcName), args.map(Column(_))) + case NamedArgumentsTableFunction(funcName, argMap) => + tableFunction(TableFunction(funcName), argMap.map { + case (key, value) => key -> Column(value) + }) + case _ => throw ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() + } + } + private def createFromStoredProc(spName: String, args: Seq[Any]): DataFrame = DataFrame(this, StoredProcedureRelation(spName, args.map(functions.lit).map(_.expr))) diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index cc32b72f..3dbd5b89 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -151,7 +151,8 @@ private[snowpark] object ErrorMessage { "0419" -> "%s exceeds the maximum allowed time: %d second(s).", "0420" -> "Invalid RSA private key. The error is: %s", "0421" -> "Invalid stage location: %s. Reason: %s.", - "0422" -> "Internal error: Server fetching is disabled for the parameter %s and there is no default value for it.") + "0422" -> "Internal error: Server fetching is disabled for the parameter %s and there is no default value for it.", + "0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments") // scalastyle:on /* @@ -385,6 +386,9 @@ private[snowpark] object ErrorMessage { parameterName: String): SnowparkClientException = createException("0422", parameterName) + def MISC_INVALID_TABLE_FUNCTION_INPUT(): SnowparkClientException = + createException("0423") + /** * Create Snowpark client Exception. * diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index d67d06db..31ebbff5 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -111,12 +111,14 @@ object tableFunctions { def flatten(input: Column): Column = Column(flatten.apply(input)) - def flatten(input: Column, path: String, outer: Boolean, recursive: Boolean): Column = + def flatten(input: Column, + path: String, outer: Boolean, recursive: Boolean, mode: String): Column = Column( flatten.apply( Map( "input" -> input, "path" -> lit(path), "outer" -> lit(outer), - "recursive" -> lit(recursive)))) + "recursive" -> lit(recursive), + "mode" -> lit(mode)))) } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 0fa30d8b..30124185 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -814,4 +814,13 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0422, Error message: Internal error: Server fetching is disabled" + " for the parameter someParameter and there is no default value for it.")) } + + test("MISC_INVALID_TABLE_FUNCTION_INPUT") { + val ex = ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0423"))) + assert( + ex.message.startsWith( + "Error Code: 0423, Error message: Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 84e85810..6cdb85e2 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -204,17 +204,21 @@ class TableFunctionSuite extends TestData { val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") checkAnswer( df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "b", outer = true, recursive = true)).select("value"), + path = "b", outer = true, recursive = true, mode = "both")).select("value"), Seq(Row("77"), Row("88"))) val df2 = Seq("[]").toDF("col") checkAnswer(df2.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true)).select("value"), + path = "", outer = true, recursive = true, mode = "both")).select("value"), Seq(Row(null))) assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true)).count() == 4) + path = "", outer = true, recursive = true, mode = "both")).count() == 4) assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = false)).count() == 2) + path = "", outer = true, recursive = false, mode = "both")).count() == 2) + assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "", outer = true, recursive = true, mode = "array")).count() == 1) + assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), + path = "", outer = true, recursive = true, mode = "object")).count() == 2) } } From 11f0266087c37a3a590468a48269d054b4f9e038 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 11:52:28 -0800 Subject: [PATCH 08/23] session table function --- .../snowpark_test/TableFunctionSuite.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 6cdb85e2..2de216ec 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -4,6 +4,7 @@ import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.{Row, _} import scala.collection.Seq +import scala.collection.immutable.Map class TableFunctionSuite extends TestData { import session.implicits._ @@ -221,4 +222,19 @@ class TableFunctionSuite extends TestData { assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), path = "", outer = true, recursive = true, mode = "object")).count() == 2) } + + test("Argument in table function: flatten - session") { + val df = Seq( + (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), + (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") + checkAnswer(session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), + Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) + // error if it is not a table function + val error1 = intercept[SnowparkClientException] { + session.tableFunction(lit("dummy")) + } + assert( + error1.message.contains("Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) + } } From d925e1bb4e1413e1425f8463984e5e7a90e10198 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 11:55:21 -0800 Subject: [PATCH 09/23] add test --- .../com/snowflake/snowpark/DataFrame.scala | 10 ++- .../com/snowflake/snowpark/Session.scala | 12 ++- .../snowflake/snowpark/tableFunctions.scala | 8 +- .../snowpark/ErrorMessageSuite.scala | 5 +- .../snowpark_test/TableFunctionSuite.scala | 89 ++++++++++++++++--- 5 files changed, 102 insertions(+), 22 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index ee33762e..49ee2bfd 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -7,7 +7,11 @@ import com.snowflake.snowpark.internal.{Logging, Utils} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import com.github.vertical_blank.sqlformatter.SqlFormatter -import com.snowflake.snowpark.internal.Utils.{TempObjectType, getTableFunctionExpression, randomNameForTempObject} +import com.snowflake.snowpark.internal.Utils.{ + TempObjectType, + getTableFunctionExpression, + randomNameForTempObject +} import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -1909,7 +1913,9 @@ class DataFrame private[snowpark] ( // todo: add test with UDTF def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), + TableFunctionJoin( + this.plan, + getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index c5a75cbb..dc0489b9 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -10,9 +10,17 @@ import com.snowflake.snowpark.internal._ import com.snowflake.snowpark.internal.analyzer.{TableFunction => TFunction} import com.snowflake.snowpark.types._ import com.snowflake.snowpark.functions._ -import com.snowflake.snowpark.internal.ErrorMessage.{UDF_CANNOT_ACCEPT_MANY_DF_COLS, UDF_UNEXPECTED_COLUMN_ORDER} +import com.snowflake.snowpark.internal.ErrorMessage.{ + UDF_CANNOT_ACCEPT_MANY_DF_COLS, + UDF_UNEXPECTED_COLUMN_ORDER +} import com.snowflake.snowpark.internal.ParameterUtils.ClosureCleanerMode -import com.snowflake.snowpark.internal.Utils.{TempObjectNamePattern, TempObjectType, getTableFunctionExpression, randomNameForTempObject} +import com.snowflake.snowpark.internal.Utils.{ + TempObjectNamePattern, + TempObjectType, + getTableFunctionExpression, + randomNameForTempObject +} import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, SnowflakeSQLException} import scala.concurrent.{ExecutionContext, Future} diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 31ebbff5..5f765e64 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -111,8 +111,12 @@ object tableFunctions { def flatten(input: Column): Column = Column(flatten.apply(input)) - def flatten(input: Column, - path: String, outer: Boolean, recursive: Boolean, mode: String): Column = + def flatten( + input: Column, + path: String, + outer: Boolean, + recursive: Boolean, + mode: String): Column = Column( flatten.apply( Map( diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 30124185..d5ede212 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -819,8 +819,7 @@ class ErrorMessageSuite extends FunSuite { val ex = ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0423"))) assert( - ex.message.startsWith( - "Error Code: 0423, Error message: Invalid input argument, " + - "Session.tableFunction only supports table function arguments")) + ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 2de216ec..a28e4913 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -204,30 +204,78 @@ class TableFunctionSuite extends TestData { test("Argument in table function: flatten2") { val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") checkAnswer( - df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "b", outer = true, recursive = true, mode = "both")).select("value"), + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "b", + outer = true, + recursive = true, + mode = "both")) + .select("value"), Seq(Row("77"), Row("88"))) val df2 = Seq("[]").toDF("col") - checkAnswer(df2.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "both")).select("value"), + checkAnswer( + df2 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "both")) + .select("value"), Seq(Row(null))) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "both")).count() == 4) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = false, mode = "both")).count() == 2) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "array")).count() == 1) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "object")).count() == 2) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "both")) + .count() == 4) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = false, + mode = "both")) + .count() == 2) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "array")) + .count() == 1) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "object")) + .count() == 2) } test("Argument in table function: flatten - session") { val df = Seq( (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") - checkAnswer(session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), + checkAnswer( + session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) // error if it is not a table function val error1 = intercept[SnowparkClientException] { @@ -237,4 +285,19 @@ class TableFunctionSuite extends TestData { error1.message.contains("Invalid input argument, " + "Session.tableFunction only supports table function arguments")) } + + test("Argument in table function: flatten - session 2") { + val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") + checkAnswer( + session + .tableFunction( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "b", + outer = true, + recursive = true, + mode = "both")) + .select("value"), + Seq(Row("77"), Row("88"))) + } } From 5d0b4e8f99793e7eac796fd6d25bcfc4316d577f Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 13:27:17 -0800 Subject: [PATCH 10/23] split_to_table --- .../com/snowflake/snowpark/tableFunctions.scala | 3 +++ .../snowpark_test/TableFunctionSuite.scala | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 5f765e64..11975fa2 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -59,6 +59,9 @@ object tableFunctions { */ lazy val split_to_table: TableFunction = TableFunction("split_to_table") + def split_to_table(str: Column, delimiter: String): Column = + Column(split_to_table.apply(str, lit(delimiter))) + /** * Flattens (explodes) compound values into multiple rows. * diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index a28e4913..b4533310 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -300,4 +300,18 @@ class TableFunctionSuite extends TestData { .select("value"), Seq(Row("77"), Row("88"))) } + + test("Argument in table function: split_to_table") { + val df = Seq("1,2", "3,4").toDF("data") + + checkAnswer( + df.join(tableFunctions.split_to_table(df("data"), ",")).select("value"), + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + + checkAnswer( + session + .tableFunction(tableFunctions.split_to_table(df("data"), ",")) + .select("value"), + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + } } From 5e4b5e33929cf07864e1ae61f446bc07ce987421 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 14:02:33 -0800 Subject: [PATCH 11/23] support UDTF --- .../com/snowflake/snowpark/DataFrame.scala | 8 ++++---- .../com/snowflake/snowpark/Session.scala | 4 ++-- .../snowflake/snowpark/TableFunction.scala | 8 ++++++-- .../snowflake/snowpark/tableFunctions.scala | 19 +++++++++---------- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 49ee2bfd..875d5db3 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1789,7 +1789,7 @@ class DataFrame private[snowpark] ( * @param args A list of arguments to pass to the specified table function. */ def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func(args: _*), None) + TableFunctionJoin(this.plan, func.call(args: _*), None) } /** @@ -1825,7 +1825,7 @@ class DataFrame private[snowpark] ( orderBy: Seq[Column]): DataFrame = withPlan { TableFunctionJoin( this.plan, - func(args: _*), + func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } @@ -1860,7 +1860,7 @@ class DataFrame private[snowpark] ( * Use this map to specify the parameter names and their corresponding values. */ def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func(args), None) + TableFunctionJoin(this.plan, func.call(args), None) } /** @@ -1903,7 +1903,7 @@ class DataFrame private[snowpark] ( orderBy: Seq[Column]): DataFrame = withPlan { TableFunctionJoin( this.plan, - func(args), + func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index dc0489b9..f0537db6 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -529,7 +529,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log // Use df.join to apply function result if args contains a DF column val sourceDFs = args.flatMap(_.expr.sourceDFs) if (sourceDFs.isEmpty) { - DataFrame(this, TableFunctionRelation(func(args: _*))) + DataFrame(this, TableFunctionRelation(func.call(args: _*))) } else if (sourceDFs.toSet.size > 1) { throw UDF_CANNOT_ACCEPT_MANY_DF_COLS() } else { @@ -570,7 +570,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log // Use df.join to apply function result if args contains a DF column val sourceDFs = args.values.flatMap(_.expr.sourceDFs) if (sourceDFs.isEmpty) { - DataFrame(this, TableFunctionRelation(func(args))) + DataFrame(this, TableFunctionRelation(func.call(args))) } else if (sourceDFs.toSet.size > 1) { throw UDF_CANNOT_ACCEPT_MANY_DF_COLS() } else { diff --git a/src/main/scala/com/snowflake/snowpark/TableFunction.scala b/src/main/scala/com/snowflake/snowpark/TableFunction.scala index ac251411..7c37751e 100644 --- a/src/main/scala/com/snowflake/snowpark/TableFunction.scala +++ b/src/main/scala/com/snowflake/snowpark/TableFunction.scala @@ -32,11 +32,15 @@ import com.snowflake.snowpark.internal.analyzer.{ * @since 0.4.0 */ case class TableFunction(funcName: String) { - private[snowpark] def apply(args: Column*): TableFunctionExpression = + private[snowpark] def call(args: Column*): TableFunctionExpression = analyzer.TableFunction(funcName, args.map(_.expr)) - private[snowpark] def apply(args: Map[String, Column]): TableFunctionExpression = + private[snowpark] def call(args: Map[String, Column]): TableFunctionExpression = NamedArgumentsTableFunction(funcName, args.map { case (key, value) => key -> value.expr }) + + def apply(args: Column*): Column = Column(this.call(args: _*)) + + def apply(args: Map[String, Column]): Column = Column(this.call(args)) } diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 11975fa2..a6b290b5 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -60,7 +60,7 @@ object tableFunctions { lazy val split_to_table: TableFunction = TableFunction("split_to_table") def split_to_table(str: Column, delimiter: String): Column = - Column(split_to_table.apply(str, lit(delimiter))) + split_to_table.apply(str, lit(delimiter)) /** * Flattens (explodes) compound values into multiple rows. @@ -112,7 +112,7 @@ object tableFunctions { */ lazy val flatten: TableFunction = TableFunction("flatten") - def flatten(input: Column): Column = Column(flatten.apply(input)) + def flatten(input: Column): Column = flatten.apply(input) def flatten( input: Column, @@ -120,12 +120,11 @@ object tableFunctions { outer: Boolean, recursive: Boolean, mode: String): Column = - Column( - flatten.apply( - Map( - "input" -> input, - "path" -> lit(path), - "outer" -> lit(outer), - "recursive" -> lit(recursive), - "mode" -> lit(mode)))) + flatten.apply( + Map( + "input" -> input, + "path" -> lit(path), + "outer" -> lit(outer), + "recursive" -> lit(recursive), + "mode" -> lit(mode))) } From 513d1bbe92998895912c3042aec094951abbcd5d Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 13 Nov 2023 14:11:18 -0800 Subject: [PATCH 12/23] test table function --- .../snowpark_test/TableFunctionSuite.scala | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index b4533310..70a3561d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -314,4 +314,27 @@ class TableFunctionSuite extends TestData { .select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4"))) } + + test("Argument in table function: table function") { + val df = Seq("1,2", "3,4").toDF("data") + + checkAnswer( + df.join(TableFunction("split_to_table")(df("data"), lit(","))) + .select("value"), + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + + val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") + checkAnswer( + session + .tableFunction( + TableFunction("flatten")(Map( + "input" -> parse_json(df1("col")), + "path" -> lit("b"), + "outer" -> lit(true), + "recursive" -> lit(true), + "mode" -> lit("both") + ))) + .select("value"), + Seq(Row("77"), Row("88"))) + } } From 4e1e6cb0ef0fd6528808c3e8ed8843796023c41a Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 11:06:17 -0800 Subject: [PATCH 13/23] Java TableFunction --- .../com/snowflake/snowpark_java/Session.java | 4 +++ .../snowpark_java/TableFunction.java | 16 ++++++++++++ .../snowpark_test/JavaTableFunctionSuite.java | 25 +++++++++++++++++++ .../snowpark_test/TableFunctionSuite.scala | 14 +++++------ 4 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 886804d0..6e879d8e 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -566,6 +566,10 @@ public DataFrame tableFunction(TableFunction func, Map args) { func.getScalaTableFunction(), JavaUtils.javaStringColumnMapToScala(scalaArgs))); } + public DataFrame tableFunction(Column func) { + return new DataFrame(session.tableFunction(func.toScalaColumn())); + } + /** * Returns a SProcRegistration object that you can use to register Stored Procedures. * diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunction.java b/src/main/java/com/snowflake/snowpark_java/TableFunction.java index 62c75cd0..a6321bdd 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunction.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunction.java @@ -1,5 +1,9 @@ package com.snowflake.snowpark_java; +import com.snowflake.snowpark.internal.JavaUtils; +import java.util.HashMap; +import java.util.Map; + /** * Looks up table functions by funcName and returns tableFunction object which can be used in {@code * DataFrame.join} and {@code Session.tableFunction} methods. @@ -38,4 +42,16 @@ com.snowflake.snowpark.TableFunction getScalaTableFunction() { public String funcName() { return func.funcName(); } + + public Column call(Column... args) { + return new Column(this.func.apply(JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(args)))); + } + + public Column call(Map args) { + Map scalaArgs = new HashMap<>(); + for (Map.Entry entry : args.entrySet()) { + scalaArgs.put(entry.getKey(), entry.getValue().toScalaColumn()); + } + return new Column(this.func.apply(JavaUtils.javaStringColumnMapToScala(scalaArgs))); + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java index af8c8aaa..225f0d5b 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java @@ -63,4 +63,29 @@ public void tableFunctionName() { TableFunction tableFunction = new TableFunction("flatten"); assert tableFunction.funcName().equals("flatten"); } + + @Test + public void argumentInTableFunction() { + checkAnswer( + getSession() + .tableFunction( + new TableFunction("split_to_table") + .call(Functions.lit("split by space"), Functions.lit(" "))), + new Row[] {Row.create(1, 1, "split"), Row.create(1, 2, "by"), Row.create(1, 3, "space")}, + true); + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("{\"a\":1, \"b\":[77, 88]}")}, + StructType.create(new StructField("col", DataTypes.StringType))); + Map args = new HashMap<>(); + args.put("input", Functions.parse_json(df.col("col"))); + args.put("path", Functions.lit("b")); + args.put("outer", Functions.lit(true)); + args.put("recursive", Functions.lit(true)); + args.put("mode", Functions.lit("both")); + checkAnswer( + getSession().tableFunction(new TableFunction("flatten").call(args)).select("value"), + new Row[] {Row.create("77"), Row.create("88")}); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 70a3561d..46f0028f 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -327,13 +327,13 @@ class TableFunctionSuite extends TestData { checkAnswer( session .tableFunction( - TableFunction("flatten")(Map( - "input" -> parse_json(df1("col")), - "path" -> lit("b"), - "outer" -> lit(true), - "recursive" -> lit(true), - "mode" -> lit("both") - ))) + TableFunction("flatten")( + Map( + "input" -> parse_json(df1("col")), + "path" -> lit("b"), + "outer" -> lit(true), + "recursive" -> lit(true), + "mode" -> lit("both")))) .select("value"), Seq(Row("77"), Row("88"))) } From 2554b38059c501dcadc1fdcfdc012bbc3d7d7789 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 11:14:52 -0800 Subject: [PATCH 14/23] udf test --- .../com/snowflake/snowpark_java/TableFunctions.java | 4 ++++ src/main/scala/com/snowflake/snowpark/DataFrame.scala | 1 - .../scala/com/snowflake/snowpark_test/UDTFSuite.scala | 11 +++++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index ddf344df..267dc38d 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -40,6 +40,10 @@ public static TableFunction split_to_table() { return new TableFunction(com.snowflake.snowpark.tableFunctions.split_to_table()); } + public static Column split_to_table(Column str, String delimiter) { + return new Column(com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); + } + /** * Flattens (explodes) compound values into multiple rows. * diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 875d5db3..56b7247d 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1911,7 +1911,6 @@ class DataFrame private[snowpark] ( TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) } - // todo: add test with UDTF def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { TableFunctionJoin( this.plan, diff --git a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala index de6464be..d6845bde 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala @@ -1,19 +1,18 @@ package com.snowflake.snowpark_test import java.math.RoundingMode - import com.snowflake.snowpark.TestUtils._ import com.snowflake.snowpark.functions._ + import java.nio.file._ import java.sql.{Date, Time, Timestamp} import java.util.TimeZone - -import com.snowflake.snowpark._ +import com.snowflake.snowpark.{Row, _} import com.snowflake.snowpark.internal._ import com.snowflake.snowpark.types._ import com.snowflake.snowpark.udtf._ -import scala.collection.mutable +import scala.collection.{Seq, mutable} @UDFTest class UDTFSuite extends TestData { @@ -2130,6 +2129,10 @@ class UDTFSuite extends TestData { df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq(df("b"))), Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) + checkAnswer( + df.join(tf(Map("arg1" -> df("b"))), Seq(df("a")), Seq(df("b"))), + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) + checkAnswer( df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq.empty), Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) From 02e497c0c839894bd23db79dbe99343177517c42 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 11:33:44 -0800 Subject: [PATCH 15/23] java flatten --- .../snowflake/snowpark_java/DataFrame.java | 12 +++++++++ .../snowpark_java/TableFunctions.java | 10 ++++++- .../snowpark_test/JavaTableFunctionSuite.java | 27 +++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index ec26efc0..806b2537 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -1312,6 +1312,18 @@ public DataFrame join( JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy)))); } + public DataFrame join(Column func) { + return new DataFrame(this.df.join(func.toScalaColumn())); + } + + public DataFrame join(Column func, Column[] partitionBy, Column[] orderBy) { + return new DataFrame( + this.df.join( + func.toScalaColumn(), + JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(partitionBy)), + JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy)))); + } + com.snowflake.snowpark.DataFrame getScalaDataFrame() { return this.df; } diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index 267dc38d..f60e4cb1 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -41,7 +41,8 @@ public static TableFunction split_to_table() { } public static Column split_to_table(Column str, String delimiter) { - return new Column(com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); + return new Column( + com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); } /** @@ -81,4 +82,11 @@ public static Column split_to_table(Column str, String delimiter) { public static TableFunction flatten() { return new TableFunction(com.snowflake.snowpark.tableFunctions.flatten()); } + + public static Column flatten( + Column input, String path, boolean outer, boolean recursive, String mode) { + return new Column( + com.snowflake.snowpark.tableFunctions.flatten( + input.toScalaColumn(), path, outer, recursive, mode)); + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java index 225f0d5b..1d8c16f9 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java @@ -88,4 +88,31 @@ public void argumentInTableFunction() { getSession().tableFunction(new TableFunction("flatten").call(args)).select("value"), new Row[] {Row.create("77"), Row.create("88")}); } + + @Test + public void argumentInSplitToTable() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("split by space")}, + StructType.create(new StructField("col", DataTypes.StringType))); + checkAnswer( + df.join(TableFunctions.split_to_table(df.col("col"), " ")).select("value"), + new Row[] {Row.create("split"), Row.create("by"), Row.create("space")}); + } + + @Test + public void argumentInFlatten() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("{\"a\":1, \"b\":[77, 88]}")}, + StructType.create(new StructField("col", DataTypes.StringType))); + checkAnswer( + df.join( + TableFunctions.flatten( + Functions.parse_json(df.col("col")), "b", true, true, "both")) + .select("value"), + new Row[] {Row.create("77"), Row.create("88")}); + } } From c05b0c5aa5d76f170ecb77e292f1b1c048c11188 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 11:36:12 -0800 Subject: [PATCH 16/23] add test --- .../java/com/snowflake/snowpark_test/JavaUDTFSuite.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java index 3c80daaf..686b4093 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java @@ -345,6 +345,12 @@ public void partitionBy() { DataFrame result4 = df.join(tf, map, new Column[] {}, new Column[] {}); result4.show(); + + DataFrame result5 = df.join(tf.call(map), new Column[] {df.col("a")}, new Column[] {df.col("b")}); + checkAnswer( + result5, + new Row[] {Row.create("a", null, "{b=2, c=1}"), Row.create("d", null, "{e=1}")}, + true); } } From aa976e1f1893d036202dbc5191bde4b4a007c539 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 13:33:23 -0800 Subject: [PATCH 17/23] fix java scala checker --- .../com/snowflake/code_verification/JavaScalaAPISuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala index b8b77638..4d2aa737 100644 --- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala +++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala @@ -354,7 +354,7 @@ class JavaScalaAPISuite extends FunSuite { ClassUtils.containsSameFunctionNames( classOf[JavaTableFunction], classOf[ScalaTableFunction], - class1Only = Set(), + class1Only = Set("call"), // `call` in Scala is `apply` class2Only = Set("funcName") ++ scalaCaseClassFunctions)) } From 283e17bbd0da2b676e2dc7d99e3b1969103890fa Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 15:07:19 -0800 Subject: [PATCH 18/23] add java doc --- .../snowflake/snowpark_java/DataFrame.java | 63 +++++++++++++++++++ .../com/snowflake/snowpark_java/Session.java | 17 +++++ .../snowpark_java/TableFunction.java | 15 +++++ 3 files changed, 95 insertions(+) diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index 806b2537..86b46823 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -1312,10 +1312,73 @@ public DataFrame join( JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy)))); } + /** + * Joins the current DataFrame with the output of the specified table function `func`. + * + *

Pre-defined table functions can be found in `TableFunctions` class. + * + *

For example: + * + *

{@code
+   * df.join(TableFunctions.flatten(
+   *   Functions.parse_json(df.col("col")),
+   *   "path", true, true, "both"
+   * ));
+   * }
+ * + *

Or load any Snowflake builtin table function via TableFunction Class. + * + *

{@code
+   * Map args = new HashMap<>();
+   * args.put("input", Functions.parse_json(df.col("a")));
+   * df.join(new TableFunction("flatten").call(args));
+   * }
+ * + * @since 1.10.0 + * @param func Column object, which can be one of the values in the TableFunctions class or + * an object that you create from the `new TableFunction("name").call()`. + * @return The result DataFrame + */ public DataFrame join(Column func) { return new DataFrame(this.df.join(func.toScalaColumn())); } + /** + * Joins the current DataFrame with the output of the specified table function `func`. + * + *

To specify a PARTITION BY or ORDER BY clause, use the `partitionBy` and `orderBy` arguments. + * + *

Pre-defined table functions can be found in `TableFunctions` class. + * + *

For example: + * + *

{@code
+   * df.join(TableFunctions.flatten(
+   *     Functions.parse_json(df.col("col1")),
+   *     "path", true, true, "both"
+   *   ),
+   *   new Column[] {df.col("col2")},
+   *   new Column[] {df.col("col1")}
+   * );
+   * }
+ * + *

Or load any Snowflake builtin table function via TableFunction Class. + * + *

{@code
+   * Map args = new HashMap<>();
+   * args.put("input", Functions.parse_json(df.col("col1")));
+   * df.join(new TableFunction("flatten").call(args),
+   * new Column[] {df.col("col2")},
+   * new Column[] {df.col("col1")});
+   * }
+ * + * @since 1.10.0 + * @param func Column object, which can be one of the values in the TableFunctions class or + * an object that you create from the `new TableFunction("name").call()`. + * @param partitionBy An array of columns partitioned by. + * @param orderBy An array of columns ordered by. + * @return The result DataFrame + */ public DataFrame join(Column func, Column[] partitionBy, Column[] orderBy) { return new DataFrame( this.df.join( diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 6e879d8e..1a72ff85 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -566,6 +566,23 @@ public DataFrame tableFunction(TableFunction func, Map args) { func.getScalaTableFunction(), JavaUtils.javaStringColumnMapToScala(scalaArgs))); } + /** + * Creates a new DataFrame from the given table function and arguments. + * + *

Example + * + *

{@code
+   * session.tableFunction(TableFunctions.flatten(
+   *   Functions.parse_json(df.col("col")),
+   *   "path", true, true, "both"
+   * ));
+   * }
+ * + * @since 1.10.0 + * @param func Column object, which can be one of the values in the TableFunctions class or + * an object that you create from the `new TableFunction("name").call()`. + * @return The result DataFrame + */ public DataFrame tableFunction(Column func) { return new DataFrame(session.tableFunction(func.toScalaColumn())); } diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunction.java b/src/main/java/com/snowflake/snowpark_java/TableFunction.java index a6321bdd..50d1d4fa 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunction.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunction.java @@ -43,10 +43,25 @@ public String funcName() { return func.funcName(); } + /** + * Create a Column reference by passing arguments in the TableFunction object. + * + * @param args A list of Column objects representing the arguments of the given table function + * @return A Column reference + * @since 1.10.0 + */ public Column call(Column... args) { return new Column(this.func.apply(JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(args)))); } + /** + * Create a Column reference by passing arguments in the TableFunction object. + * + * @param args function arguments map of the given table function. Some functions, like flatten, + * have named parameters. use this map to assign values to the corresponding parameters. + * @return A Column reference + * @since 1.10.0 + */ public Column call(Map args) { Map scalaArgs = new HashMap<>(); for (Map.Entry entry : args.entrySet()) { From 7d8103a3122cd5730fc440b17c41e6e7db056d13 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 15:19:45 -0800 Subject: [PATCH 19/23] Java doc --- .../snowflake/snowpark_java/DataFrame.java | 20 +++--- .../com/snowflake/snowpark_java/Session.java | 4 +- .../snowpark_java/TableFunction.java | 2 +- .../snowpark_java/TableFunctions.java | 61 +++++++++++++++++++ .../snowpark_test/JavaTableFunctionSuite.java | 6 ++ .../snowpark_test/JavaUDTFSuite.java | 9 +-- 6 files changed, 85 insertions(+), 17 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index 86b46823..2e6d23ec 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -1315,9 +1315,9 @@ public DataFrame join( /** * Joins the current DataFrame with the output of the specified table function `func`. * - *

Pre-defined table functions can be found in `TableFunctions` class. + *

Pre-defined table functions can be found in `TableFunctions` class. * - *

For example: + *

For example: * *

{@code
    * df.join(TableFunctions.flatten(
@@ -1326,7 +1326,7 @@ public DataFrame join(
    * ));
    * }
* - *

Or load any Snowflake builtin table function via TableFunction Class. + *

Or load any Snowflake builtin table function via TableFunction Class. * *

{@code
    * Map args = new HashMap<>();
@@ -1335,8 +1335,8 @@ public DataFrame join(
    * }
* * @since 1.10.0 - * @param func Column object, which can be one of the values in the TableFunctions class or - * an object that you create from the `new TableFunction("name").call()`. + * @param func Column object, which can be one of the values in the TableFunctions class or an + * object that you create from the `new TableFunction("name").call()`. * @return The result DataFrame */ public DataFrame join(Column func) { @@ -1348,9 +1348,9 @@ public DataFrame join(Column func) { * *

To specify a PARTITION BY or ORDER BY clause, use the `partitionBy` and `orderBy` arguments. * - *

Pre-defined table functions can be found in `TableFunctions` class. + *

Pre-defined table functions can be found in `TableFunctions` class. * - *

For example: + *

For example: * *

{@code
    * df.join(TableFunctions.flatten(
@@ -1362,7 +1362,7 @@ public DataFrame join(Column func) {
    * );
    * }
* - *

Or load any Snowflake builtin table function via TableFunction Class. + *

Or load any Snowflake builtin table function via TableFunction Class. * *

{@code
    * Map args = new HashMap<>();
@@ -1373,8 +1373,8 @@ public DataFrame join(Column func) {
    * }
* * @since 1.10.0 - * @param func Column object, which can be one of the values in the TableFunctions class or - * an object that you create from the `new TableFunction("name").call()`. + * @param func Column object, which can be one of the values in the TableFunctions class or an + * object that you create from the `new TableFunction("name").call()`. * @param partitionBy An array of columns partitioned by. * @param orderBy An array of columns ordered by. * @return The result DataFrame diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 1a72ff85..b22f327a 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -579,8 +579,8 @@ public DataFrame tableFunction(TableFunction func, Map args) { * } * * @since 1.10.0 - * @param func Column object, which can be one of the values in the TableFunctions class or - * an object that you create from the `new TableFunction("name").call()`. + * @param func Column object, which can be one of the values in the TableFunctions class or an + * object that you create from the `new TableFunction("name").call()`. * @return The result DataFrame */ public DataFrame tableFunction(Column func) { diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunction.java b/src/main/java/com/snowflake/snowpark_java/TableFunction.java index 50d1d4fa..5e0b4f95 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunction.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunction.java @@ -58,7 +58,7 @@ public Column call(Column... args) { * Create a Column reference by passing arguments in the TableFunction object. * * @param args function arguments map of the given table function. Some functions, like flatten, - * have named parameters. use this map to assign values to the corresponding parameters. + * have named parameters. use this map to assign values to the corresponding parameters. * @return A Column reference * @since 1.10.0 */ diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index f60e4cb1..8a202656 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -40,6 +40,22 @@ public static TableFunction split_to_table() { return new TableFunction(com.snowflake.snowpark.tableFunctions.split_to_table()); } + /** + * This table function splits a string (based on a specified delimiter) and flattens the results + * into rows. + * + *

Example + * + *

{@code
+   * session.tableFunction(TableFunctions.split_to_table(,
+   *   Functions.lit("split by space"), Functions.lit(" ")));
+   * }
+ * + * @since 1.10.0 + * @param str Text to be split. + * @param delimiter Text to split string by. + * @return The result TableFunction reference + */ public static Column split_to_table(Column str, String delimiter) { return new Column( com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter)); @@ -83,10 +99,55 @@ public static TableFunction flatten() { return new TableFunction(com.snowflake.snowpark.tableFunctions.flatten()); } + /** + * Flattens (explodes) compound values into multiple rows. + * + *

Example + * + *

{@code
+   * df.join(TableFunctions.flatten(
+   *   Functions.parse_json(df.col("col")), "path", true, true, "both"));
+   * }
+ * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. The expression must be of data + * type VariantType, MapType or ArrayType. + * @param path The path to the element within a VariantType data structure which needs to be + * flattened. Can be a zero-length string (i.e. empty path) if the outermost element is to be + * flattened. Default: Zero-length string (i.e. empty path) + * @param outer If FALSE, any input rows that cannot be expanded, either because they cannot be + * accessed in the path or because they have zero fields or entries, are completely omitted + * from the output. If TRUE, exactly one row is generated for zero-row expansions (with NULL + * in the KEY, INDEX, and VALUE columns). + * @param recursive If FALSE, only the element referenced by PATH is expanded. If TRUE, the + * expansion is performed for all sub-elements recursively. Default: FALSE + * @param mode ("object", "array", or "both") Specifies whether only objects, arrays, or both + * should be flattened. + * @return The result TableFunction reference + */ public static Column flatten( Column input, String path, boolean outer, boolean recursive, String mode) { return new Column( com.snowflake.snowpark.tableFunctions.flatten( input.toScalaColumn(), path, outer, recursive, mode)); } + + /** + * Flattens (explodes) compound values into multiple rows. + * + *

Example + * + *

{@code
+   * df.join(TableFunctions.flatten(
+   *   Functions.parse_json(df.col("col"))));
+   * }
+ * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. The expression must be of data + * type VariantType, MapType or ArrayType. + * @return The result TableFunction reference + */ + public static Column flatten(Column input) { + return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn())); + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java index 1d8c16f9..12ca8391 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java @@ -114,5 +114,11 @@ public void argumentInFlatten() { Functions.parse_json(df.col("col")), "b", true, true, "both")) .select("value"), new Row[] {Row.create("77"), Row.create("88")}); + + checkAnswer( + getSession() + .tableFunction(TableFunctions.flatten(Functions.parse_json(Functions.lit("[1,2]")))) + .select("value"), + new Row[] {Row.create("1"), Row.create("2")}); } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java index 686b4093..4c2d3aae 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java @@ -346,11 +346,12 @@ public void partitionBy() { DataFrame result4 = df.join(tf, map, new Column[] {}, new Column[] {}); result4.show(); - DataFrame result5 = df.join(tf.call(map), new Column[] {df.col("a")}, new Column[] {df.col("b")}); + DataFrame result5 = + df.join(tf.call(map), new Column[] {df.col("a")}, new Column[] {df.col("b")}); checkAnswer( - result5, - new Row[] {Row.create("a", null, "{b=2, c=1}"), Row.create("d", null, "{e=1}")}, - true); + result5, + new Row[] {Row.create("a", null, "{b=2, c=1}"), Row.create("d", null, "{e=1}")}, + true); } } From 202919e6d982dbedeaac0467911d11e8f5835257 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 14 Nov 2023 15:43:58 -0800 Subject: [PATCH 20/23] scala doc --- .../snowpark_java/TableFunctions.java | 6 +- .../com/snowflake/snowpark/DataFrame.scala | 40 +++++++++++ .../com/snowflake/snowpark/Session.scala | 17 +++++ .../snowflake/snowpark/TableFunction.scala | 16 +++++ .../snowflake/snowpark/tableFunctions.scala | 70 +++++++++++++++++++ 5 files changed, 146 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index 8a202656..022ec6df 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -54,7 +54,7 @@ public static TableFunction split_to_table() { * @since 1.10.0 * @param str Text to be split. * @param delimiter Text to split string by. - * @return The result TableFunction reference + * @return The result Column reference */ public static Column split_to_table(Column str, String delimiter) { return new Column( @@ -123,7 +123,7 @@ public static TableFunction flatten() { * expansion is performed for all sub-elements recursively. Default: FALSE * @param mode ("object", "array", or "both") Specifies whether only objects, arrays, or both * should be flattened. - * @return The result TableFunction reference + * @return The result Column reference */ public static Column flatten( Column input, String path, boolean outer, boolean recursive, String mode) { @@ -145,7 +145,7 @@ public static Column flatten( * @since 1.10.0 * @param input The expression that will be unseated into rows. The expression must be of data * type VariantType, MapType or ArrayType. - * @return The result TableFunction reference + * @return The result Column reference */ public static Column flatten(Column input) { return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn())); diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 56b7247d..7666d9e8 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1907,10 +1907,50 @@ class DataFrame private[snowpark] ( Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } + /** + * Joins the current DataFrame with the output of the specified table function `func`. + * + * + * For example: + * {{{ + * // The following example uses the flatten function to explode compound values from + * // column 'a' in this DataFrame into multiple columns. + * + * import com.snowflake.snowpark.functions._ + * import com.snowflake.snowpark.tableFunctions._ + * + * df.join( + * tableFunctions.flatten(parse_json(df("a"))) + * ) + * }}} + * + * @group transform + * @since 1.10.0 + * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] + * object or an object that you create from the [[TableFunction.apply()]]. + */ def join(func: Column): DataFrame = withPlan { TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) } + /** + * Joins the current DataFrame with the output of the specified user-defined table function + * (UDTF) `func`. + * + * To specify a PARTITION BY or ORDER BY clause, use the `partitionBy` and `orderBy` arguments. + * + * For example: + * {{{ + * val tf = session.udtf.registerTemporary(TableFunc1) + * df.join(tf(Map("arg1" -> df("col1")),Seq(df("col2")), Seq(df("col1")))) + * }}} + * + * @group transform + * @since 1.10.0 + * @param func [[TableFunction]] object that represents a user-defined table function. + * @param partitionBy A list of columns partitioned by. + * @param orderBy A list of columns ordered by. + */ def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { TableFunctionJoin( this.plan, diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index f0537db6..ea1e5913 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -580,6 +580,23 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log } } + /** + * Creates a new DataFrame from the given table function. + * + * Example + * {{{ + * import com.snowflake.snowpark.functions._ + * import com.snowflake.snowpark.tableFunctions._ + * + * session.tableFunction( + * flatten(parse_json(lit("[1,2]"))) + * ) + * }}} + * + * @since 1.10.0 + * @param func Table function object, can be created from TableFunction class or + * referred from the built-in list from tableFunctions. + */ def tableFunction(func: Column): DataFrame = { func.expr match { case TFunction(funcName, args) => diff --git a/src/main/scala/com/snowflake/snowpark/TableFunction.scala b/src/main/scala/com/snowflake/snowpark/TableFunction.scala index 7c37751e..39d4261b 100644 --- a/src/main/scala/com/snowflake/snowpark/TableFunction.scala +++ b/src/main/scala/com/snowflake/snowpark/TableFunction.scala @@ -40,7 +40,23 @@ case class TableFunction(funcName: String) { case (key, value) => key -> value.expr }) + /** + * Create a Column reference by passing arguments in the TableFunction object. + * + * @param args A list of Column objects representing the arguments of the given table function + * @return A Column reference + * @since 1.10.0 + */ def apply(args: Column*): Column = Column(this.call(args: _*)) + /** + * Create a Column reference by passing arguments in the TableFunction object. + * + * @param args function arguments map of the given table function. Some functions, like flatten, + * have named parameters. use this map to assign values to the corresponding + * parameters. + * @return A Column reference + * @since 1.10.0 + */ def apply(args: Map[String, Column]): Column = Column(this.call(args)) } diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index a6b290b5..212a000c 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -59,6 +59,23 @@ object tableFunctions { */ lazy val split_to_table: TableFunction = TableFunction("split_to_table") + /** + * This table function splits a string (based on a specified delimiter) + * and flattens the results into rows. + * + * Example + * {{{ + * import com.snowflake.snowpark.functions._ + * import com.snowflake.snowpark.tableFunctions._ + * + * df.join(tableFunctions.split_to_table(df("a"), lit(","))) + * }}} + * + * @since 1.10.0 + * @param str Text to be split. + * @param delimiter Text to split string by. + * @return The result Column reference + */ def split_to_table(str: Column, delimiter: String): Column = split_to_table.apply(str, lit(delimiter)) @@ -112,8 +129,61 @@ object tableFunctions { */ lazy val flatten: TableFunction = TableFunction("flatten") + /** + * Flattens (explodes) compound values into multiple rows. + * + * + * Example + * {{{ + * import com.snowflake.snowpark.functions._ + * import com.snowflake.snowpark.tableFunctions._ + * + * df.join( + * tableFunctions.flatten(parse_json(df("a"))) + * ) + * + * }}} + * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. + * The expression must be of data type VariantType, MapType or ArrayType. + * @return The result Column reference + */ def flatten(input: Column): Column = flatten.apply(input) + /** + * Flattens (explodes) compound values into multiple rows. + * + * + * Example + * {{{ + * import com.snowflake.snowpark.functions._ + * import com.snowflake.snowpark.tableFunctions._ + * + * df.join( + * tableFunctions.flatten(parse_json(df("a")), "path", true, true, "both") + * ) + * + * }}} + * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. + * The expression must be of data type VariantType, MapType or ArrayType. + * @param path The path to the element within a VariantType data structure + * which needs to be flattened. Can be a zero-length string (i.e. empty path) + * if the outermost element is to be flattened. + * @param outer Optional boolean value. + * If FALSE, any input rows that cannot be expanded, + * either because they cannot be accessed in the path or because they have + * zero fields or entries, are completely omitted from the output. + * If TRUE, exactly one row is generated for zero-row expansions + * (with NULL in the KEY, INDEX, and VALUE columns). + * @param recursive If FALSE, only the element referenced by PATH is expanded. + * If TRUE, the expansion is performed for all sub-elements recursively. + * @param mode ("object", "array", or "both") + * Specifies whether only objects, arrays, or both should be flattened. + * @return The result Column reference + */ def flatten( input: Column, path: String, From 6bc8d628abcb213c222bd2bfa488ecf344c3e002 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 15 Nov 2023 13:47:44 -0800 Subject: [PATCH 21/23] add error code --- src/main/scala/com/snowflake/snowpark/DataFrame.scala | 2 ++ .../com/snowflake/snowpark/internal/ErrorMessage.scala | 4 ++++ .../scala/com/snowflake/snowpark/ErrorMessageSuite.scala | 9 +++++++++ 3 files changed, 15 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 7666d9e8..c1179ba1 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -561,6 +561,8 @@ class DataFrame private[snowpark] ( "Provide at least one column expression for select(). " + s"This DataFrame has column names (${output.length}): " + s"${output.map(_.name).mkString(", ")}\n") + // todo: error message + require(columns.count(_.expr.isInstanceOf[TableFunctionExpression]) <= 1, "error") val resultDF = withPlan { Project(columns.map(_.named), plan) } // do not rename back if this project contains internal alias. diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 3dbd5b89..505b2b6d 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage { "0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.", "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", + "0131" -> "At most one table function can be called inside select() function", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage { def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException = createException("0130") + def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException = + createException("0131") + /* * 2NN: UDF error code */ diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index d5ede212..a2056f78 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -296,6 +296,15 @@ class ErrorMessageSuite extends FunSuite { " or TableFunctions only")) } + test("DF_MORE_THAN_ONE_TF_IN_SELECT") { + val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131"))) + assert( + ex.message.startsWith( + "Error Code: 0131, Error message: " + + "At most one table function can be called inside select() function")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) From 276d10afca851238bf2b578e07349dfcb7c30316 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 15 Nov 2023 15:07:18 -0800 Subject: [PATCH 22/23] tf in select --- .../com/snowflake/snowpark/DataFrame.scala | 41 +++++++++++++------ .../snowpark/ErrorMessageSuite.scala | 5 +-- .../snowpark_test/TableFunctionSuite.scala | 28 +++++++++++++ 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index c1179ba1..25b6fe8a 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -562,19 +562,34 @@ class DataFrame private[snowpark] ( s"This DataFrame has column names (${output.length}): " + s"${output.map(_.name).mkString(", ")}\n") // todo: error message - require(columns.count(_.expr.isInstanceOf[TableFunctionExpression]) <= 1, "error") - - val resultDF = withPlan { Project(columns.map(_.named), plan) } - // do not rename back if this project contains internal alias. - // because no named duplicated if just renamed. - val hasInternalAlias: Boolean = columns.map(_.expr).exists { - case Alias(_, _, true) => true - case _ => false - } - if (hasInternalAlias) { - resultDF - } else { - renameBackIfDeduped(resultDF) + val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression]) + tf.size match { + case 0 => + val resultDF = withPlan { + Project(columns.map(_.named), plan) + } + // do not rename back if this project contains internal alias. + // because no named duplicated if just renamed. + val hasInternalAlias: Boolean = columns.map(_.expr).exists { + case Alias(_, _, true) => true + case _ => false + } + if (hasInternalAlias) { + resultDF + } else { + renameBackIfDeduped(resultDF) + } + case 1 => + val base = this.join(tf.head) + val baseColumns = base.schema.map(field => base(field.name)) + val inputDFColumnSize = this.schema.size + val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2 + val (beforeTf, afterTf) = columns.span(_ != tf.head) + val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail + base.select(resultColumns) + case _ => + // more than 1 TF + throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() } } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index a2056f78..072c73e8 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -300,9 +300,8 @@ class ErrorMessageSuite extends FunSuite { val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131"))) assert( - ex.message.startsWith( - "Error Code: 0131, Error message: " + - "At most one table function can be called inside select() function")) + ex.message.startsWith("Error Code: 0131, Error message: " + + "At most one table function can be called inside select() function")) } test("UDF_INCORRECT_ARGS_NUMBER") { diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 46f0028f..c87a8f2c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -337,4 +337,32 @@ class TableFunctionSuite extends TestData { .select("value"), Seq(Row("77"), Row("88"))) } + + test("table function in select") { + val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data") + // only tf + val result1 = df.select(tableFunctions.split_to_table(df("data"), ",")) + assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE")) + checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4"))) + + // columns + tf + val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ",")) + assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE")) + checkAnswer( + result2, + Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4"))) + + // columns + tf + columns + val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx")) + assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX")) + checkAnswer( + result3, + Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2))) + + // tf + other express + val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100) + checkAnswer( + result4, + Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102))) + } } From f57baef3f06b49939db90760b414a4fc6d558e52 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 15 Nov 2023 15:43:36 -0800 Subject: [PATCH 23/23] refactor dataframe join --- .../com/snowflake/snowpark/DataFrame.scala | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 25b6fe8a..e7347da7 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1805,9 +1805,8 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args: _*), None) - } + def join(func: TableFunction, args: Seq[Column]): DataFrame = + joinTableFunction(func.call(args: _*), None) /** * Joins the current DataFrame with the output of the specified user-defined table @@ -1839,12 +1838,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Seq[Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func` that takes @@ -1876,9 +1873,8 @@ class DataFrame private[snowpark] ( * Some functions, like `flatten`, have named parameters. * Use this map to specify the parameter names and their corresponding values. */ - def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args), None) - } + def join(func: TableFunction, args: Map[String, Column]): DataFrame = + joinTableFunction(func.call(args), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1917,12 +1913,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Map[String, Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1946,9 +1940,8 @@ class DataFrame private[snowpark] ( * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] * object or an object that you create from the [[TableFunction.apply()]]. */ - def join(func: Column): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) - } + def join(func: Column): DataFrame = + joinTableFunction(getTableFunctionExpression(func), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1968,12 +1961,17 @@ class DataFrame private[snowpark] ( * @param partitionBy A list of columns partitioned by. * @param orderBy A list of columns ordered by. */ - def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = + joinTableFunction( getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } + + private def joinTableFunction( + func: TableFunctionExpression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = + withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } /** * Performs a cross join, which returns the cartesian product of the current DataFrame and