From f57baef3f06b49939db90760b414a4fc6d558e52 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 15 Nov 2023 15:43:36 -0800 Subject: [PATCH] refactor dataframe join --- .../com/snowflake/snowpark/DataFrame.scala | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 25b6fe8a..e7347da7 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1805,9 +1805,8 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args: _*), None) - } + def join(func: TableFunction, args: Seq[Column]): DataFrame = + joinTableFunction(func.call(args: _*), None) /** * Joins the current DataFrame with the output of the specified user-defined table @@ -1839,12 +1838,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Seq[Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func` that takes @@ -1876,9 +1873,8 @@ class DataFrame private[snowpark] ( * Some functions, like `flatten`, have named parameters. * Use this map to specify the parameter names and their corresponding values. */ - def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args), None) - } + def join(func: TableFunction, args: Map[String, Column]): DataFrame = + joinTableFunction(func.call(args), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1917,12 +1913,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Map[String, Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1946,9 +1940,8 @@ class DataFrame private[snowpark] ( * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] * object or an object that you create from the [[TableFunction.apply()]]. */ - def join(func: Column): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) - } + def join(func: Column): DataFrame = + joinTableFunction(getTableFunctionExpression(func), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1968,12 +1961,17 @@ class DataFrame private[snowpark] ( * @param partitionBy A list of columns partitioned by. * @param orderBy A list of columns ordered by. */ - def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = + joinTableFunction( getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } + + private def joinTableFunction( + func: TableFunctionExpression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = + withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } /** * Performs a cross join, which returns the cartesian product of the current DataFrame and