Skip to content

Commit

Permalink
SNOW-821644 add dataframe alias support (#79)
Browse files Browse the repository at this point in the history
* add alias

* update test

* address comments

* update code

* address comment

* address comments

* address comments
  • Loading branch information
sfc-gh-zli authored Jan 4, 2024
1 parent 71f54ae commit e8a35b3
Show file tree
Hide file tree
Showing 18 changed files with 268 additions and 23 deletions.
19 changes: 19 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/DataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,25 @@ public Column col(String colName) {
return new Column(df.col(colName));
}

/**
* Returns the current DataFrame aliased as the input alias name.
*
* For example:
*
* {{{
* val df2 = df.alias("A")
* df2.select(df2.col("A.num"))
* }}}
*
* @group basic
* @since 1.10.0
* @param alias The alias name of the dataframe
* @return a [[DataFrame]]
*/
public DataFrame alias(String alias) {
return new DataFrame(this.df.alias(alias));
}

/**
* Executes the query representing this DataFrame and returns the result as an array of Row
* objects.
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/com/snowflake/snowpark/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ private[snowpark] object Column {
def apply(name: String): Column =
new Column(name match {
case "*" => Star(Seq.empty)
case c if c.contains(".") => UnresolvedDFAliasAttribute(name)
case _ => UnresolvedAttribute(quoteName(name))
})

Expand Down
36 changes: 32 additions & 4 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,23 @@ class DataFrame private[snowpark] (
case _ => Column(resolve(colName))
}

/**
* Returns the current DataFrame aliased as the input alias name.
*
* For example:
*
* {{{
* val df2 = df.alias("A")
* df2.select(df2.col("A.num"))
* }}}
*
* @group basic
* @since 1.10.0
* @param alias The alias name of the dataframe
* @return a [[DataFrame]]
*/
def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan))

/**
* Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in
* SQL). Only the Columns specified as arguments will be present in the resulting DataFrame.
Expand Down Expand Up @@ -2791,7 +2808,8 @@ class DataFrame private[snowpark] (

// utils
private[snowpark] def resolve(colName: String): NamedExpression = {
val normalizedColName = quoteName(colName)
val (aliasColName, aliasOutput) = resolveAlias(colName, output)
val normalizedColName = quoteName(aliasColName)
def isDuplicatedName: Boolean = {
if (session.conn.hideInternalAlias) {
this.plan.internalRenamedColumns.values.exists(_ == normalizedColName)
Expand All @@ -2800,13 +2818,23 @@ class DataFrame private[snowpark] (
}
}
val col =
output.filter(attr => attr.name.equals(normalizedColName))
aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName).withSourceDF(this)
} else if (isDuplicatedName) {
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(colName, colName)
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(aliasColName, aliasColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
}

// Handle dataframe alias by redirecting output and column name resolution
private def resolveAlias(colName: String, output: Seq[Attribute]): (String, Seq[Attribute]) = {
val colNameSplit = colName.split("\\.", 2)
if (colNameSplit.length > 1 && plan.dfAliasMap.contains(colNameSplit(0))) {
(colNameSplit(1), plan.dfAliasMap(colNameSplit(0)))
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(colName, output.map(_.name))
(colName, output)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private[snowpark] object ErrorMessage {
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
"0132" -> "Duplicated dataframe alias defined: %s",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -252,6 +253,9 @@ private[snowpark] object ErrorMessage {
def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

def DF_ALIAS_DUPLICATES(duplicatedAlias: scala.collection.Set[String]): SnowparkClientException =
createException("0132", duplicatedAlias.mkString(", "))

/*
* 2NN: UDF error code
*/
Expand Down
16 changes: 15 additions & 1 deletion src/main/scala/com/snowflake/snowpark/internal/Utils.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.snowpark.internal

import com.snowflake.snowpark.Column
import com.snowflake.snowpark.internal.analyzer.{Attribute, TableFunctionExpression, singleQuote}
import com.snowflake.snowpark.internal.analyzer.{Attribute, LogicalPlan, TableFunctionExpression, singleQuote}

import java.io.{File, FileInputStream}
import java.lang.invoke.SerializedLambda
Expand Down Expand Up @@ -99,6 +99,20 @@ object Utils extends Logging {
lastInternalLine + "\n" + stackTrace.take(stackDepth).mkString("\n")
}

def addToDataframeAliasMap(result: Map[String, Seq[Attribute]], child: LogicalPlan)
: Map[String, Seq[Attribute]] = {
if (child != null) {
val map = child.dfAliasMap
val duplicatedAlias = result.keySet.intersect(map.keySet)
if (duplicatedAlias.nonEmpty) {
throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias)
}
result ++ map
} else {
result
}
}

def logTime[T](f: => T, funcDescription: String): T = {
logInfo(funcDescription)
val start = System.currentTimeMillis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,19 @@ private[snowpark] case class UnresolvedAttribute(override val name: String)
this
}

private[snowpark] case class UnresolvedDFAliasAttribute(override val name: String)
extends Expression with NamedExpression {
override def sql: String = ""

override def children: Seq[Expression] = Seq.empty

// can't analyze
override lazy val dependentColumnNames: Option[Set[String]] = None

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}

private[snowpark] case class ListAgg(col: Expression, delimiter: String, isDistinct: Boolean)
extends Expression {
override def children: Seq[Expression] = Seq(col)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage

import scala.collection.mutable.{Map => MMap}

private[snowpark] object ExpressionAnalyzer {
def apply(aliasMap: Map[ExprId, String]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap)
def apply(aliasMap: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap, dfAliasMap)

def apply(): ExpressionAnalyzer =
new ExpressionAnalyzer(Map.empty)
new ExpressionAnalyzer(Map.empty, Map.empty)

// create new analyzer by combining two alias maps
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String]): ExpressionAnalyzer = {
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = {
val common = map1.keySet & map2.keySet
val result = (map1 ++ map2).filter {
// remove common column, let (df1.join(df2))
// .join(df2.join(df3)).select(df2) report error
case (id, _) => !common.contains(id)
}
new ExpressionAnalyzer(result)
new ExpressionAnalyzer(result, dfAliasMap)
}

def apply(maps: Seq[Map[ExprId, String]]): ExpressionAnalyzer = {
def apply(maps: Seq[Map[ExprId, String]],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = {
maps.foldLeft(ExpressionAnalyzer()) {
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map)
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap)
}
}
}

private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]) {
private val generatedAliasMap: MMap[ExprId, String] = MMap.empty

def analyze(ex: Expression): Expression = ex match {
Expand All @@ -52,6 +58,25 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
// removed useless alias
case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) =>
child
case UnresolvedDFAliasAttribute(name) =>
val colNameSplit = name.split("\\.", 2)
if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) {
val aliasOutput = dfAliasMap(colNameSplit(0))
val aliasColName = colNameSplit(1)
val normalizedColName = quoteName(aliasColName)
val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
} else {
// if didn't find alias in the map
name match {
case "*" => Star(Seq.empty)
case _ => UnresolvedAttribute(quoteName(name))
}
}
case _ => ex
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.Utils

private[snowpark] trait MultiChildrenNode extends LogicalPlan {
override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newChildren: Seq[LogicalPlan] = children.map(func)
Expand All @@ -11,13 +13,19 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan {

protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode


override lazy val dfAliasMap: Map[String, Seq[Attribute]] =
children.foldLeft(Map.empty[String, Seq[Attribute]]) {
case (map, child) => Utils.addToDataframeAliasMap(map, child)
}

override protected def analyze: LogicalPlan =
createFromAnalyzedChildren(children.map(_.analyzed))

protected def createFromAnalyzedChildren: Seq[LogicalPlan] => MultiChildrenNode

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(children.map(_.aliasMap))
ExpressionAnalyzer(children.map(_.aliasMap), dfAliasMap)

lazy override val internalRenamedColumns: Map[String, String] =
children.map(_.internalRenamedColumns).reduce(_ ++ _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ case class SnowflakeCreateTable(tableName: String, mode: SaveMode, query: Option
SnowflakeCreateTable(tableName, mode, query.map(_.analyzed))

override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newQuery = query.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SnowflakePlan(
sourcePlan.map(_.analyzed).getOrElse(this)

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def getSnowflakePlan: Option[SnowflakePlan] = Some(this)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.internal.{ErrorMessage, Utils}
import com.snowflake.snowpark.Row

private[snowpark] trait LogicalPlan {
Expand All @@ -18,6 +18,8 @@ private[snowpark] trait LogicalPlan {
(analyzedPlan, analyzer.getAliasMap)
}

lazy val dfAliasMap: Map[String, Seq[Attribute]] = Map.empty

protected def analyze: LogicalPlan
protected def analyzer: ExpressionAnalyzer

Expand Down Expand Up @@ -138,8 +140,9 @@ private[snowpark] trait UnaryNode extends LogicalPlan {
lazy protected val analyzedChild: LogicalPlan = child.analyzed
// create expression analyzer from child's alias map
lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(child.aliasMap)
ExpressionAnalyzer(child.aliasMap, dfAliasMap)

override lazy val dfAliasMap: Map[String, Seq[Attribute]] = child.dfAliasMap
override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild)

protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan
Expand Down Expand Up @@ -192,6 +195,18 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext
Sort(order, _)
}

private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan)
extends UnaryNode {

override lazy val dfAliasMap: Map[String, Seq[Attribute]] =
Utils.addToDataframeAliasMap(Map(alias -> child.getSnowflakePlan.get.output), child)
override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan =
DataframeAlias(alias, _)

override protected def updateChild: LogicalPlan => LogicalPlan =
createFromAnalyzedChild
}

private[snowpark] case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private object SqlGenerator extends Logging {
.transformations(transformations)
.options(options)
.createSnowflakePlan()
case DataframeAlias(_, child) => resolveChild(child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ case class TableDelete(
TableDelete(tableName, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ case class TableUpdate(
}, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package com.snowflake.snowpark.internal.analyzer

import java.util.Locale

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.internal.{ErrorMessage, Utils}

private[snowpark] abstract class BinaryNode extends LogicalPlan {
def left: LogicalPlan
Expand All @@ -14,7 +13,11 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan {
lazy protected val analyzedRight: LogicalPlan = right.analyzed

lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(left.aliasMap, right.aliasMap)
ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap)


override lazy val dfAliasMap: Map[String, Seq[Attribute]] =
Utils.addToDataframeAliasMap(Utils.addToDataframeAliasMap(Map.empty, left), right)

override def analyze: LogicalPlan =
createFromAnalyzedChildren(analyzedLeft, analyzedRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ private[snowpark] case class Alias(child: Expression, name: String, isInternal:
override protected val createAnalyzedUnary: Expression => Expression = Alias(_, name)
}

private[snowpark] case class DfAlias(child: Expression, name: String)
extends UnaryExpression
with NamedExpression {
override def sqlOperator: String = ""
override def operatorFirst: Boolean = false
override def toString: String = ""

override protected val createAnalyzedUnary: Expression => Expression = DfAlias(_, name)
}

private[snowpark] case class UnresolvedAlias(
child: Expression,
aliasFunc: Option[Expression => String] = None)
Expand Down
Loading

0 comments on commit e8a35b3

Please sign in to comment.