Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-821644 add dataframe alias support #79

Merged
merged 7 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/DataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,25 @@ public Column col(String colName) {
return new Column(df.col(colName));
}

/**
* Returns the current DataFrame aliased as the input alias name.
*
* For example:
*
* {{{
* val df2 = df.alias("A")
* df2.select(df2.col("A.num"))
* }}}
*
* @group basic
* @since 1.10.0
* @param alias The alias name of the dataframe
* @return a [[DataFrame]]
*/
public DataFrame alias(String alias) {
return new DataFrame(this.df.alias(alias));
}

/**
* Executes the query representing this DataFrame and returns the result as an array of Row
* objects.
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/com/snowflake/snowpark/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ private[snowpark] object Column {
def apply(name: String): Column =
new Column(name match {
case "*" => Star(Seq.empty)
case c if c.contains(".") => UnresolvedDFAliasAttribute(name)
case _ => UnresolvedAttribute(quoteName(name))
})

Expand Down
36 changes: 32 additions & 4 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,23 @@ class DataFrame private[snowpark] (
case _ => Column(resolve(colName))
}

/**
* Returns the current DataFrame aliased as the input alias name.
*
* For example:
*
* {{{
* val df2 = df.alias("A")
* df2.select(df2.col("A.num"))
* }}}
*
* @group basic
* @since 1.10.0
* @param alias The alias name of the dataframe
* @return a [[DataFrame]]
*/
def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan))
sfc-gh-zli marked this conversation as resolved.
Show resolved Hide resolved

/**
* Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in
* SQL). Only the Columns specified as arguments will be present in the resulting DataFrame.
Expand Down Expand Up @@ -2791,7 +2808,8 @@ class DataFrame private[snowpark] (

// utils
private[snowpark] def resolve(colName: String): NamedExpression = {
val normalizedColName = quoteName(colName)
val (aliasColName, aliasOutput) = resolveAlias(colName, output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we directly return result form resolveAlias if we can find that name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to quote the column name and check for name conflicts, so it's better to return the column name and output list.

val normalizedColName = quoteName(aliasColName)
def isDuplicatedName: Boolean = {
if (session.conn.hideInternalAlias) {
this.plan.internalRenamedColumns.values.exists(_ == normalizedColName)
Expand All @@ -2800,13 +2818,23 @@ class DataFrame private[snowpark] (
}
}
val col =
output.filter(attr => attr.name.equals(normalizedColName))
aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName).withSourceDF(this)
} else if (isDuplicatedName) {
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(colName, colName)
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(aliasColName, aliasColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
}

// Handle dataframe alias by redirecting output and column name resolution
private def resolveAlias(colName: String, output: Seq[Attribute]): (String, Seq[Attribute]) = {
val colNameSplit = colName.split("\\.", 2)
if (colNameSplit.length > 1 && plan.dfAliasMap.contains(colNameSplit(0))) {
(colNameSplit(1), plan.dfAliasMap(colNameSplit(0)))
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(colName, output.map(_.name))
(colName, output)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private[snowpark] object ErrorMessage {
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
"0132" -> "Duplicated dataframe alias defined: %s",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -252,6 +253,9 @@ private[snowpark] object ErrorMessage {
def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

def DF_ALIAS_DUPLICATES(duplicatedAlias: scala.collection.Set[String]): SnowparkClientException =
createException("0132", duplicatedAlias.mkString(", "))

/*
* 2NN: UDF error code
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,19 @@ private[snowpark] case class UnresolvedAttribute(override val name: String)
this
}

private[snowpark] case class UnresolvedDFAliasAttribute(override val name: String)
extends Expression with NamedExpression {
override def sql: String = ""

override def children: Seq[Expression] = Seq.empty

// can't analyze
override lazy val dependentColumnNames: Option[Set[String]] = None

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
sfc-gh-bli marked this conversation as resolved.
Show resolved Hide resolved
this
}

private[snowpark] case class ListAgg(col: Expression, delimiter: String, isDistinct: Boolean)
extends Expression {
override def children: Seq[Expression] = Seq(col)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage

import scala.collection.mutable.{Map => MMap}

private[snowpark] object ExpressionAnalyzer {
def apply(aliasMap: Map[ExprId, String]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap)
def apply(aliasMap: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap, dfAliasMap)

def apply(): ExpressionAnalyzer =
new ExpressionAnalyzer(Map.empty)
new ExpressionAnalyzer(Map.empty, Map.empty)

// create new analyzer by combining two alias maps
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String]): ExpressionAnalyzer = {
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = {
val common = map1.keySet & map2.keySet
val result = (map1 ++ map2).filter {
// remove common column, let (df1.join(df2))
// .join(df2.join(df3)).select(df2) report error
case (id, _) => !common.contains(id)
}
new ExpressionAnalyzer(result)
new ExpressionAnalyzer(result, dfAliasMap)
}

def apply(maps: Seq[Map[ExprId, String]]): ExpressionAnalyzer = {
def apply(maps: Seq[Map[ExprId, String]],
dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = {
maps.foldLeft(ExpressionAnalyzer()) {
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map)
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap)
}
}
}

private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String],
dfAliasMap: Map[String, Seq[Attribute]]) {
private val generatedAliasMap: MMap[ExprId, String] = MMap.empty

def analyze(ex: Expression): Expression = ex match {
Expand All @@ -52,6 +58,25 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
// removed useless alias
case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) =>
child
case UnresolvedDFAliasAttribute(name) =>
val colNameSplit = name.split("\\.", 2)
if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) {
val aliasOutput = dfAliasMap(colNameSplit(0))
val aliasColName = colNameSplit(1)
val normalizedColName = quoteName(aliasColName)
val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
} else {
// if didn't find alias in the map
name match {
case "*" => Star(Seq.empty)
case _ => UnresolvedAttribute(quoteName(name))
}
}
case _ => ex
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan {

protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode

children.foreach(child => addToDataframeAliasMap(child))
override protected def analyze: LogicalPlan =
createFromAnalyzedChildren(children.map(_.analyzed))

protected def createFromAnalyzedChildren: Seq[LogicalPlan] => MultiChildrenNode

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(children.map(_.aliasMap))
ExpressionAnalyzer(children.map(_.aliasMap), dfAliasMap)

lazy override val internalRenamedColumns: Map[String, String] =
children.map(_.internalRenamedColumns).reduce(_ ++ _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ case class SnowflakeCreateTable(tableName: String, mode: SaveMode, query: Option
SnowflakeCreateTable(tableName, mode, query.map(_.analyzed))

override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newQuery = query.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SnowflakePlan(
sourcePlan.map(_.analyzed).getOrElse(this)

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def getSnowflakePlan: Option[SnowflakePlan] = Some(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.Row
import scala.collection.mutable.{Map => MMap}

private[snowpark] trait LogicalPlan {
def children: Seq[LogicalPlan] = Seq.empty
Expand All @@ -18,6 +19,18 @@ private[snowpark] trait LogicalPlan {
(analyzedPlan, analyzer.getAliasMap)
}

var dfAliasMap: Map[String, Seq[Attribute]] = Map.empty

protected def addToDataframeAliasMap(child: LogicalPlan): Unit = {
if (child != null) {
val map = child.dfAliasMap
val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet)
if (duplicatedAlias.nonEmpty) {
throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias)
}
dfAliasMap ++= map
}
}
protected def analyze: LogicalPlan
protected def analyzer: ExpressionAnalyzer

Expand Down Expand Up @@ -138,8 +151,9 @@ private[snowpark] trait UnaryNode extends LogicalPlan {
lazy protected val analyzedChild: LogicalPlan = child.analyzed
// create expression analyzer from child's alias map
lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(child.aliasMap)
ExpressionAnalyzer(child.aliasMap, dfAliasMap)

addToDataframeAliasMap(child)
override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild)

protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan
Expand Down Expand Up @@ -192,6 +206,17 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext
Sort(order, _)
}

private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan)
extends UnaryNode {
dfAliasMap += (alias -> child.getSnowflakePlan.get.output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override lazy val dfAliasMap = child.dfAliasMap + (alias -> child.getSnowflakePlan.get.output)

override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = child => {
DataframeAlias(alias, child)
sfc-gh-zli marked this conversation as resolved.
Show resolved Hide resolved
}

override protected def updateChild: LogicalPlan => LogicalPlan =
createFromAnalyzedChild
}

private[snowpark] case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private object SqlGenerator extends Logging {
.transformations(transformations)
.options(options)
.createSnowflakePlan()
case DataframeAlias(_, child) => resolveChild(child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ case class TableDelete(
TableDelete(tableName, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ case class TableUpdate(
}, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan {
lazy protected val analyzedRight: LogicalPlan = right.analyzed

lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(left.aliasMap, right.aliasMap)
ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap)

addToDataframeAliasMap(left)
addToDataframeAliasMap(right)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override lazy val dfAliasMap = left.dfAliasMap ++ right.dfAliasMap

override def analyze: LogicalPlan =
createFromAnalyzedChildren(analyzedLeft, analyzedRight)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ private[snowpark] case class Alias(child: Expression, name: String, isInternal:
override protected val createAnalyzedUnary: Expression => Expression = Alias(_, name)
}

private[snowpark] case class DfAlias(child: Expression, name: String)
extends UnaryExpression
with NamedExpression {
override def sqlOperator: String = ""
override def operatorFirst: Boolean = false
override def toString: String = ""

override protected val createAnalyzedUnary: Expression => Expression = DfAlias(_, name)
}

private[snowpark] case class UnresolvedAlias(
child: Expression,
aliasFunc: Option[Expression => String] = None)
Expand Down
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ class ErrorMessageSuite extends FunSuite {
"At most one table function can be called inside select() function"))
}

test("DF_ALIAS_DUPLICATES") {
val ex = ErrorMessage.DF_ALIAS_DUPLICATES(Set("a", "b"))
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0132")))
assert(
ex.message.startsWith("Error Code: 0132, Error message: " +
"Duplicated dataframe alias defined: a, b"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200")))
Expand Down
Loading
Loading