Skip to content

Commit

Permalink
refactor dataframe join
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Nov 15, 2023
1 parent 276d10a commit f57baef
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1805,9 +1805,8 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args: _*), None)
}
def join(func: TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table
Expand Down Expand Up @@ -1839,12 +1838,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args: _*),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func` that takes
Expand Down Expand Up @@ -1876,9 +1873,8 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args), None)
}
def join(func: TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand Down Expand Up @@ -1917,12 +1913,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
Expand All @@ -1946,9 +1940,8 @@ class DataFrame private[snowpark] (
* @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]]
* object or an object that you create from the [[TableFunction.apply()]].
*/
def join(func: Column): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func), None)
}
def join(func: Column): DataFrame =
joinTableFunction(getTableFunctionExpression(func), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand All @@ -1968,12 +1961,17 @@ class DataFrame private[snowpark] (
* @param partitionBy A list of columns partitioned by.
* @param orderBy A list of columns ordered by.
*/
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame =
joinTableFunction(
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame =
withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}

/**
* Performs a cross join, which returns the cartesian product of the current DataFrame and
Expand Down

0 comments on commit f57baef

Please sign in to comment.