From 76c26ed9f2cfe37cc12e16410425fd64131d3648 Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Tue, 2 Jan 2024 09:40:39 -0800 Subject: [PATCH 1/7] add alias --- .../scala/com/snowflake/snowpark/Column.scala | 1 + .../com/snowflake/snowpark/DataFrame.scala | 36 +++++++++++++-- .../snowpark/internal/ErrorMessage.scala | 4 ++ .../internal/analyzer/Expression.scala | 13 ++++++ .../analyzer/ExpressionAnalyzer.scala | 45 ++++++++++++++----- .../internal/analyzer/MultiChildrenNode.scala | 3 +- .../analyzer/SnowflakeCreateTable.scala | 2 +- .../internal/analyzer/SnowflakePlan.scala | 2 +- .../internal/analyzer/SnowflakePlanNode.scala | 34 +++++++++++++- .../internal/analyzer/SqlGenerator.scala | 1 + .../internal/analyzer/TableDelete.scala | 2 +- .../internal/analyzer/TableUpdate.scala | 2 +- .../internal/analyzer/binaryPlanNodes.scala | 4 +- .../internal/analyzer/unaryExpressions.scala | 10 +++++ .../snowpark/ErrorMessageSuite.scala | 8 ++++ 15 files changed, 145 insertions(+), 22 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index 1e37ed9d..e8972f69 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -732,6 +732,7 @@ private[snowpark] object Column { def apply(name: String): Column = new Column(name match { case "*" => Star(Seq.empty) + case c if c.contains(".") => DfAliasAttribute(name) case _ => UnresolvedAttribute(quoteName(name)) }) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 54c43c49..12417787 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -518,6 +518,23 @@ class DataFrame private[snowpark] ( case _ => Column(resolve(colName)) } + /** + * Returns the current DataFrame aliased as the input alias name. + * + * For example: + * + * {{{ + * val df2 = df.alias("A") + * df2.select(df2.col("A.num")) + * }}} + * + * @group basic + * @since 1.10.0 + * @param alias The alias name of the dataframe + * @return a [[DataFrame]] + */ + def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan)) + /** * Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in * SQL). Only the Columns specified as arguments will be present in the resulting DataFrame. @@ -2791,7 +2808,8 @@ class DataFrame private[snowpark] ( // utils private[snowpark] def resolve(colName: String): NamedExpression = { - val normalizedColName = quoteName(colName) + val (aliasColName, aliasOutput) = resolveAlias(colName, output) + val normalizedColName = quoteName(aliasColName) def isDuplicatedName: Boolean = { if (session.conn.hideInternalAlias) { this.plan.internalRenamedColumns.values.exists(_ == normalizedColName) @@ -2800,13 +2818,23 @@ class DataFrame private[snowpark] ( } } val col = - output.filter(attr => attr.name.equals(normalizedColName)) + aliasOutput.filter(attr => attr.name.equals(normalizedColName)) if (col.length == 1) { col.head.withName(normalizedColName).withSourceDF(this) } else if (isDuplicatedName) { - throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(colName, colName) + throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(aliasColName, aliasColName) + } else { + throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name)) + } + } + + // Handle dataframe alias by redirecting output and column name resolution + private def resolveAlias(colName: String, output: Seq[Attribute]): (String, Seq[Attribute]) = { + val colNameSplit = colName.split("\\.", 2) + if (colNameSplit.length > 1 && plan.dfAliasMap.contains(colNameSplit(0))) { + (colNameSplit(1), plan.dfAliasMap(colNameSplit(0))) } else { - throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(colName, output.map(_.name)) + (colName, output) } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 21db9dea..fbb7966c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -69,6 +69,7 @@ private[snowpark] object ErrorMessage { "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", "0131" -> "At most one table function can be called inside select() function", + "0132" -> "Duplicated dataframe alias defined: %s", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -252,6 +253,9 @@ private[snowpark] object ErrorMessage { def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException = createException("0131") + def DF_ALIAS_DUPLICATES(duplicatedAlias: scala.collection.Set[String]): SnowparkClientException = + createException("0132", duplicatedAlias.mkString(", ")) + /* * 2NN: UDF error code */ diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala index 0ea99977..3e8763c6 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala @@ -390,6 +390,19 @@ private[snowpark] case class UnresolvedAttribute(override val name: String) this } +private[snowpark] case class DfAliasAttribute(override val name: String) + extends Expression with NamedExpression { + override def sql: String = "" + + override def children: Seq[Expression] = Seq.empty + + // can't analyze + override lazy val dependentColumnNames: Option[Set[String]] = None + + override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression = + this +} + private[snowpark] case class ListAgg(col: Expression, delimiter: String, isDistinct: Boolean) extends Expression { override def children: Seq[Expression] = Seq(col) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala index bd412766..ae399b68 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala @@ -1,33 +1,39 @@ package com.snowflake.snowpark.internal.analyzer +import com.snowflake.snowpark.internal.ErrorMessage + import scala.collection.mutable.{Map => MMap} private[snowpark] object ExpressionAnalyzer { - def apply(aliasMap: Map[ExprId, String]): ExpressionAnalyzer = - new ExpressionAnalyzer(aliasMap) + def apply(aliasMap: Map[ExprId, String], + dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = + new ExpressionAnalyzer(aliasMap, dfAliasMap) - def apply(): ExpressionAnalyzer = - new ExpressionAnalyzer(Map.empty) + def apply(dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = + new ExpressionAnalyzer(Map.empty, dfAliasMap) // create new analyzer by combining two alias maps - def apply(map1: Map[ExprId, String], map2: Map[ExprId, String]): ExpressionAnalyzer = { + def apply(map1: Map[ExprId, String], map2: Map[ExprId, String], + dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { val common = map1.keySet & map2.keySet val result = (map1 ++ map2).filter { // remove common column, let (df1.join(df2)) // .join(df2.join(df3)).select(df2) report error case (id, _) => !common.contains(id) } - new ExpressionAnalyzer(result) + new ExpressionAnalyzer(result, dfAliasMap) } - def apply(maps: Seq[Map[ExprId, String]]): ExpressionAnalyzer = { - maps.foldLeft(ExpressionAnalyzer()) { - case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map) + def apply(maps: Seq[Map[ExprId, String]], + dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { + maps.foldLeft(ExpressionAnalyzer(dfAliasMap)) { + case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap) } } } -private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) { +private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], + dfAliasMap: MMap[String, Seq[Attribute]]) { private val generatedAliasMap: MMap[ExprId, String] = MMap.empty def analyze(ex: Expression): Expression = ex match { @@ -52,6 +58,25 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) { // removed useless alias case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) => child + case DfAliasAttribute(name) => + val colNameSplit = name.split("\\.", 2) + if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) { + val aliasOutput = dfAliasMap(colNameSplit(0)) + val aliasColName = colNameSplit(1) + val normalizedColName = quoteName(aliasColName) + val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName)) + if (col.length == 1) { + col.head.withName(normalizedColName) + } else { + throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name)) + } + } else { + // if didn't find alias in the map + name match { + case "*" => Star(Seq.empty) + case _ => UnresolvedAttribute(quoteName(name)) + } + } case _ => ex } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala index 7239e1dd..ce83d10c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala @@ -11,13 +11,14 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan { protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode + children.foreach(child => addToDataframeAliasMap(child.dfAliasMap)) override protected def analyze: LogicalPlan = createFromAnalyzedChildren(children.map(_.analyzed)) protected def createFromAnalyzedChildren: Seq[LogicalPlan] => MultiChildrenNode override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(children.map(_.aliasMap)) + ExpressionAnalyzer(children.map(_.aliasMap), dfAliasMap) lazy override val internalRenamedColumns: Map[String, String] = children.map(_.internalRenamedColumns).reduce(_ ++ _) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala index 8b30eb9c..ff44f927 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala @@ -10,7 +10,7 @@ case class SnowflakeCreateTable(tableName: String, mode: SaveMode, query: Option SnowflakeCreateTable(tableName, mode, query.map(_.analyzed)) override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newQuery = query.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index f90d1dd3..a3218758 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -106,7 +106,7 @@ class SnowflakePlan( sourcePlan.map(_.analyzed).getOrElse(this) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def getSnowflakePlan: Option[SnowflakePlan] = Some(this) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index 12b54e41..6406754e 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -2,6 +2,7 @@ package com.snowflake.snowpark.internal.analyzer import com.snowflake.snowpark.internal.ErrorMessage import com.snowflake.snowpark.Row +import scala.collection.mutable.{Map => MMap} private[snowpark] trait LogicalPlan { def children: Seq[LogicalPlan] = Seq.empty @@ -18,6 +19,23 @@ private[snowpark] trait LogicalPlan { (analyzedPlan, analyzer.getAliasMap) } + var dfAliasMap: MMap[String, Seq[Attribute]] = MMap.empty + + // map from df alias string to snowflakePlan.output + // add to map when DataframeAlias node is createdFromChild + // merge map when analyze is called on leafNode, unaryNode, multiChildrenNode + // report conflict if there is merge collision + // New expression dataframeAttribute when input has . + // Expression analizer -> see dataframeAttribute -> split and search map + // if map does not contain the key, then treat as normal column name + // else search for Attribute with the name in the attribute list + protected def addToDataframeAliasMap(map: MMap[String, Seq[Attribute]]): Unit = { + val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) + if (duplicatedAlias.nonEmpty) { + throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) + } + dfAliasMap ++= map + } protected def analyze: LogicalPlan protected def analyzer: ExpressionAnalyzer @@ -69,7 +87,7 @@ private[snowpark] trait LogicalPlan { private[snowpark] trait LeafNode extends LogicalPlan { // create ExpressionAnalyzer with empty alias map - override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer() + override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(dfAliasMap) // leaf node doesn't have child override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this @@ -138,8 +156,9 @@ private[snowpark] trait UnaryNode extends LogicalPlan { lazy protected val analyzedChild: LogicalPlan = child.analyzed // create expression analyzer from child's alias map lazy override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(child.aliasMap) + ExpressionAnalyzer(child.aliasMap, dfAliasMap) + addToDataframeAliasMap(child.dfAliasMap) override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild) protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan @@ -192,6 +211,17 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext Sort(order, _) } +private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan) + extends UnaryNode { + dfAliasMap += (alias -> child.getSnowflakePlan.get.output) + override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = child => { + DataframeAlias(alias, child) + } + + override protected def updateChild: LogicalPlan => LogicalPlan = + createFromAnalyzedChild +} + private[snowpark] case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala index a7a5f655..7c2add81 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala @@ -157,6 +157,7 @@ private object SqlGenerator extends Logging { .transformations(transformations) .options(options) .createSnowflakePlan() + case DataframeAlias(_, child) => resolveChild(child) } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala index f50900d8..9c4922dd 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala @@ -14,7 +14,7 @@ case class TableDelete( TableDelete(tableName, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed)) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newSource = sourceData.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala index 07faa247..ddee926c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala @@ -17,7 +17,7 @@ case class TableUpdate( }, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed)) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newSource = sourceData.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala index e0b9b4f4..efaa8a0c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala @@ -14,8 +14,10 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan { lazy protected val analyzedRight: LogicalPlan = right.analyzed lazy override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(left.aliasMap, right.aliasMap) + ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap) + addToDataframeAliasMap(left.dfAliasMap) + addToDataframeAliasMap(right.dfAliasMap) override def analyze: LogicalPlan = createFromAnalyzedChildren(analyzedLeft, analyzedRight) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala index 572ea890..5db1bfef 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala @@ -75,6 +75,16 @@ private[snowpark] case class Alias(child: Expression, name: String, isInternal: override protected val createAnalyzedUnary: Expression => Expression = Alias(_, name) } +private[snowpark] case class DfAlias(child: Expression, name: String) + extends UnaryExpression + with NamedExpression { + override def sqlOperator: String = "" + override def operatorFirst: Boolean = false + override def toString: String = "" + + override protected val createAnalyzedUnary: Expression => Expression = DfAlias(_, name) +} + private[snowpark] case class UnresolvedAlias( child: Expression, aliasFunc: Option[Expression => String] = None) diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index bfeddd72..2a3df7f7 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -304,6 +304,14 @@ class ErrorMessageSuite extends FunSuite { "At most one table function can be called inside select() function")) } + test("DF_ALIAS_DUPLICATES") { + val ex = ErrorMessage.DF_ALIAS_DUPLICATES(Set("a", "b")) + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0132"))) + assert( + ex.message.startsWith("Error Code: 0132, Error message: " + + "Duplicated dataframe alias defined: a, b")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) From b861e52b3d25cfb38f5c0943b041b315b22d0991 Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Tue, 2 Jan 2024 09:47:46 -0800 Subject: [PATCH 2/7] update test --- .../internal/analyzer/SnowflakePlanNode.scala | 8 -- .../snowpark_test/DataFrameAliasSuite.scala | 87 +++++++++++++++++++ 2 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index 6406754e..dcf8f09a 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -21,14 +21,6 @@ private[snowpark] trait LogicalPlan { var dfAliasMap: MMap[String, Seq[Attribute]] = MMap.empty - // map from df alias string to snowflakePlan.output - // add to map when DataframeAlias node is createdFromChild - // merge map when analyze is called on leafNode, unaryNode, multiChildrenNode - // report conflict if there is merge collision - // New expression dataframeAttribute when input has . - // Expression analizer -> see dataframeAttribute -> split and search map - // if map does not contain the key, then treat as normal column name - // else search for Attribute with the name in the attribute list protected def addToDataframeAliasMap(map: MMap[String, Seq[Attribute]]): Unit = { val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) if (duplicatedAlias.nonEmpty) { diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala new file mode 100644 index 00000000..397a1000 --- /dev/null +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala @@ -0,0 +1,87 @@ +package com.snowflake.snowpark_test + +import com.snowflake.snowpark._ +import com.snowflake.snowpark.functions._ +import com.snowflake.snowpark.internal.analyzer._ +import com.snowflake.snowpark.types._ +import net.snowflake.client.jdbc.SnowflakeSQLException +import org.scalatest.BeforeAndAfterEach +import java.sql.{Date, Time, Timestamp} +import scala.util.Random + +class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSession { + val tableName1: String = randomName() + val tableName2: String = randomName() + import session.implicits._ + + override def afterEach(): Unit = { + dropTable(tableName1) + dropTable(tableName2) + super.afterEach() + } + + test("Test for alias with df.col, col and $") { + createTable(tableName1, "num int") + runQuery(s"insert into $tableName1 values(1),(2),(3)", session) + val df = session.table(tableName1).alias("A") + checkAnswer(df.select(df.col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select(col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select($"A.num"), Seq(Row(1), Row(2), Row(3))) + + val df1 = df.alias("B") + checkAnswer(df1.select(df1.col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select(col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select($"A.num"), Seq(Row(1), Row(2), Row(3))) + + checkAnswer(df1.select(df1.col("B.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select(col("B.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select($"B.num"), Seq(Row(1), Row(2), Row(3))) + } + + test("Test for alias with join") { + createTable(tableName1, "id1 int, num1 int") + createTable(tableName2, "id2 int, num2 int") + runQuery(s"insert into $tableName1 values(1, 4),(2, 5),(3, 6)", session) + runQuery(s"insert into $tableName2 values(1, 7),(2, 8),(3, 9)", session) + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("B") + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select(df1.col("A.num1")), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select(df2.col("B.num2")), Seq(Row(7), Row(8), Row(9))) + + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select($"A.num1"), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select($"B.num2"), Seq(Row(7), Row(8), Row(9))) + } + + test("Test for alias with join with column renaming") { + createTable(tableName1, "id int, num int") + createTable(tableName2, "id int, num int") + runQuery(s"insert into $tableName1 values(1, 4),(2, 5),(3, 6)", session) + runQuery(s"insert into $tableName2 values(1, 7),(2, 8),(3, 9)", session) + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("B") + checkAnswer(df1.join(df2, df1.col("id") === df2.col("id")) + .select(df1.col("A.num")), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, df1.col("id") === df2.col("id")) + .select(df2.col("B.num")), Seq(Row(7), Row(8), Row(9))) + + // The following use case is out of the scope of supporting alias + // We still follow the old ambiguity resolving policy and require DF to be used + assertThrows[SnowparkClientException]( + df1.join(df2, df1.col("id") === df2.col("id")) + .select($"A.num")) + } + + test("Test for alias conflict") { + createTable(tableName1, "id int, num int") + createTable(tableName2, "id int, num int") + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("A") + assertThrows[SnowparkClientException]( + df1.join(df2, df1.col("id") === df2.col("id")) + .select(df1.col("A.num"))) + } +} From 9163a66bd2f84c7c0be884c8bc5abc99d1cacb0d Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Tue, 2 Jan 2024 15:58:01 -0800 Subject: [PATCH 3/7] address comments --- .../snowflake/snowpark_java/DataFrame.java | 19 +++++++++++++++++++ .../scala/com/snowflake/snowpark/Column.scala | 2 +- .../internal/analyzer/Expression.scala | 2 +- .../analyzer/ExpressionAnalyzer.scala | 16 ++++++++-------- .../internal/analyzer/SnowflakePlanNode.scala | 6 +++--- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index 2e6d23ec..56e416e0 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -802,6 +802,25 @@ public Column col(String colName) { return new Column(df.col(colName)); } + /** + * Returns the current DataFrame aliased as the input alias name. + * + * For example: + * + * {{{ + * val df2 = df.alias("A") + * df2.select(df2.col("A.num")) + * }}} + * + * @group basic + * @since 1.10.0 + * @param alias The alias name of the dataframe + * @return a [[DataFrame]] + */ + public DataFrame alias(String alias) { + return new DataFrame(this.df.alias(alias)); + } + /** * Executes the query representing this DataFrame and returns the result as an array of Row * objects. diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index e8972f69..56996aa9 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -732,7 +732,7 @@ private[snowpark] object Column { def apply(name: String): Column = new Column(name match { case "*" => Star(Seq.empty) - case c if c.contains(".") => DfAliasAttribute(name) + case c if c.contains(".") => UnresolvedDFAliasAttribute(name) case _ => UnresolvedAttribute(quoteName(name)) }) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala index 3e8763c6..adbddd38 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala @@ -390,7 +390,7 @@ private[snowpark] case class UnresolvedAttribute(override val name: String) this } -private[snowpark] case class DfAliasAttribute(override val name: String) +private[snowpark] case class UnresolvedDFAliasAttribute(override val name: String) extends Expression with NamedExpression { override def sql: String = "" diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala index ae399b68..76bed5ef 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala @@ -6,15 +6,15 @@ import scala.collection.mutable.{Map => MMap} private[snowpark] object ExpressionAnalyzer { def apply(aliasMap: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = new ExpressionAnalyzer(aliasMap, dfAliasMap) - def apply(dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = - new ExpressionAnalyzer(Map.empty, dfAliasMap) + def apply(): ExpressionAnalyzer = + new ExpressionAnalyzer(Map.empty, Map.empty) // create new analyzer by combining two alias maps def apply(map1: Map[ExprId, String], map2: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { val common = map1.keySet & map2.keySet val result = (map1 ++ map2).filter { // remove common column, let (df1.join(df2)) @@ -25,15 +25,15 @@ private[snowpark] object ExpressionAnalyzer { } def apply(maps: Seq[Map[ExprId, String]], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { - maps.foldLeft(ExpressionAnalyzer(dfAliasMap)) { + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { + maps.foldLeft(ExpressionAnalyzer()) { case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap) } } } private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]) { + dfAliasMap: Map[String, Seq[Attribute]]) { private val generatedAliasMap: MMap[ExprId, String] = MMap.empty def analyze(ex: Expression): Expression = ex match { @@ -58,7 +58,7 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], // removed useless alias case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) => child - case DfAliasAttribute(name) => + case UnresolvedDFAliasAttribute(name) => val colNameSplit = name.split("\\.", 2) if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) { val aliasOutput = dfAliasMap(colNameSplit(0)) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index dcf8f09a..b49b38d1 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -19,9 +19,9 @@ private[snowpark] trait LogicalPlan { (analyzedPlan, analyzer.getAliasMap) } - var dfAliasMap: MMap[String, Seq[Attribute]] = MMap.empty + var dfAliasMap: Map[String, Seq[Attribute]] = Map.empty - protected def addToDataframeAliasMap(map: MMap[String, Seq[Attribute]]): Unit = { + protected def addToDataframeAliasMap(map: Map[String, Seq[Attribute]]): Unit = { val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) if (duplicatedAlias.nonEmpty) { throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) @@ -79,7 +79,7 @@ private[snowpark] trait LogicalPlan { private[snowpark] trait LeafNode extends LogicalPlan { // create ExpressionAnalyzer with empty alias map - override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(dfAliasMap) + override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer() // leaf node doesn't have child override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this From a69116612c6be9f57aedeb52a6e54e99c9cb677b Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Tue, 2 Jan 2024 16:31:20 -0800 Subject: [PATCH 4/7] update code --- .../internal/analyzer/MultiChildrenNode.scala | 2 +- .../internal/analyzer/SnowflakePlanNode.scala | 15 +++++++++------ .../internal/analyzer/binaryPlanNodes.scala | 4 ++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala index ce83d10c..584925b1 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala @@ -11,7 +11,7 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan { protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode - children.foreach(child => addToDataframeAliasMap(child.dfAliasMap)) + children.foreach(child => addToDataframeAliasMap(child)) override protected def analyze: LogicalPlan = createFromAnalyzedChildren(children.map(_.analyzed)) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index b49b38d1..c3734e91 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -21,12 +21,15 @@ private[snowpark] trait LogicalPlan { var dfAliasMap: Map[String, Seq[Attribute]] = Map.empty - protected def addToDataframeAliasMap(map: Map[String, Seq[Attribute]]): Unit = { - val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) - if (duplicatedAlias.nonEmpty) { - throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) + protected def addToDataframeAliasMap(child: LogicalPlan): Unit = { + if (child != null) { + val map = child.dfAliasMap + val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) + if (duplicatedAlias.nonEmpty) { + throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) + } + dfAliasMap ++= map } - dfAliasMap ++= map } protected def analyze: LogicalPlan protected def analyzer: ExpressionAnalyzer @@ -150,7 +153,7 @@ private[snowpark] trait UnaryNode extends LogicalPlan { lazy override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(child.aliasMap, dfAliasMap) - addToDataframeAliasMap(child.dfAliasMap) + addToDataframeAliasMap(child) override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild) protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala index efaa8a0c..4c61287a 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala @@ -16,8 +16,8 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan { lazy override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap) - addToDataframeAliasMap(left.dfAliasMap) - addToDataframeAliasMap(right.dfAliasMap) + addToDataframeAliasMap(left) + addToDataframeAliasMap(right) override def analyze: LogicalPlan = createFromAnalyzedChildren(analyzedLeft, analyzedRight) From c5e7c58494277c3e0f59184a1685c03a0202bdbe Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Wed, 3 Jan 2024 13:13:36 -0800 Subject: [PATCH 5/7] address comment --- .../snowflake/snowpark/internal/Utils.scala | 16 ++++++++++++- .../internal/analyzer/MultiChildrenNode.scala | 10 +++++++- .../internal/analyzer/SnowflakePlanNode.scala | 23 +++++++------------ .../internal/analyzer/binaryPlanNodes.scala | 9 ++++---- 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala index 76f06f73..92c8173d 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark.internal import com.snowflake.snowpark.Column -import com.snowflake.snowpark.internal.analyzer.{Attribute, TableFunctionExpression, singleQuote} +import com.snowflake.snowpark.internal.analyzer.{Attribute, LogicalPlan, TableFunctionExpression, singleQuote} import java.io.{File, FileInputStream} import java.lang.invoke.SerializedLambda @@ -99,6 +99,20 @@ object Utils extends Logging { lastInternalLine + "\n" + stackTrace.take(stackDepth).mkString("\n") } + def addToDataframeAliasMap(result: Map[String, Seq[Attribute]], child: LogicalPlan) + : Map[String, Seq[Attribute]] = { + if (child != null) { + val map = child.dfAliasMap + val duplicatedAlias = result.keySet.intersect(map.keySet) + if (duplicatedAlias.nonEmpty) { + throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) + } + result ++ map + } else { + result + } + } + def logTime[T](f: => T, funcDescription: String): T = { logInfo(funcDescription) val start = System.currentTimeMillis() diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala index 584925b1..c3ad8f0e 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala @@ -1,5 +1,7 @@ package com.snowflake.snowpark.internal.analyzer +import com.snowflake.snowpark.internal.Utils + private[snowpark] trait MultiChildrenNode extends LogicalPlan { override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newChildren: Seq[LogicalPlan] = children.map(func) @@ -11,7 +13,13 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan { protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode - children.foreach(child => addToDataframeAliasMap(child)) + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = { + var result: Map[String, Seq[Attribute]] = Map.empty + children.foreach(child => result = Utils.addToDataframeAliasMap(result, child)) + result + } + override protected def analyze: LogicalPlan = createFromAnalyzedChildren(children.map(_.analyzed)) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index c3734e91..90933396 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -1,8 +1,7 @@ package com.snowflake.snowpark.internal.analyzer -import com.snowflake.snowpark.internal.ErrorMessage +import com.snowflake.snowpark.internal.{ErrorMessage, Utils} import com.snowflake.snowpark.Row -import scala.collection.mutable.{Map => MMap} private[snowpark] trait LogicalPlan { def children: Seq[LogicalPlan] = Seq.empty @@ -19,18 +18,8 @@ private[snowpark] trait LogicalPlan { (analyzedPlan, analyzer.getAliasMap) } - var dfAliasMap: Map[String, Seq[Attribute]] = Map.empty + lazy val dfAliasMap: Map[String, Seq[Attribute]] = Map.empty - protected def addToDataframeAliasMap(child: LogicalPlan): Unit = { - if (child != null) { - val map = child.dfAliasMap - val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) - if (duplicatedAlias.nonEmpty) { - throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) - } - dfAliasMap ++= map - } - } protected def analyze: LogicalPlan protected def analyzer: ExpressionAnalyzer @@ -84,6 +73,8 @@ private[snowpark] trait LeafNode extends LogicalPlan { // create ExpressionAnalyzer with empty alias map override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer() + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = Map.empty + // leaf node doesn't have child override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this @@ -153,7 +144,7 @@ private[snowpark] trait UnaryNode extends LogicalPlan { lazy override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(child.aliasMap, dfAliasMap) - addToDataframeAliasMap(child) + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = child.dfAliasMap override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild) protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan @@ -208,7 +199,9 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan) extends UnaryNode { - dfAliasMap += (alias -> child.getSnowflakePlan.get.output) + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + Utils.addToDataframeAliasMap(Map(alias -> child.getSnowflakePlan.get.output), child) override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = child => { DataframeAlias(alias, child) } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala index 4c61287a..98ff26d5 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala @@ -1,8 +1,7 @@ package com.snowflake.snowpark.internal.analyzer import java.util.Locale - -import com.snowflake.snowpark.internal.ErrorMessage +import com.snowflake.snowpark.internal.{ErrorMessage, Utils} private[snowpark] abstract class BinaryNode extends LogicalPlan { def left: LogicalPlan @@ -16,8 +15,10 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan { lazy override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap) - addToDataframeAliasMap(left) - addToDataframeAliasMap(right) + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + Utils.addToDataframeAliasMap(Utils.addToDataframeAliasMap(Map.empty, left), right) + override def analyze: LogicalPlan = createFromAnalyzedChildren(analyzedLeft, analyzedRight) From 69c1cfaeaa360da98c2762506ba74c4c6a629a03 Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Wed, 3 Jan 2024 15:12:47 -0800 Subject: [PATCH 6/7] address comments --- .../snowpark/internal/analyzer/MultiChildrenNode.scala | 9 ++++----- .../snowflake/snowpark_test/DataFrameAliasSuite.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala index c3ad8f0e..b5594184 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala @@ -14,11 +14,10 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan { protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode - override lazy val dfAliasMap: Map[String, Seq[Attribute]] = { - var result: Map[String, Seq[Attribute]] = Map.empty - children.foreach(child => result = Utils.addToDataframeAliasMap(result, child)) - result - } + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + children.foldLeft(Map.empty[String, Seq[Attribute]]) { + case (map, child) => Utils.addToDataframeAliasMap(map, child) + } override protected def analyze: LogicalPlan = createFromAnalyzedChildren(children.map(_.analyzed)) diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala index 397a1000..5deca2c9 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala @@ -38,6 +38,15 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes checkAnswer(df1.select($"B.num"), Seq(Row(1), Row(2), Row(3))) } + test("Test for alias with dot in column name") { + createTable(tableName1, "\"num.col\" int") + runQuery(s"insert into $tableName1 values(1),(2),(3)", session) + val df = session.table(tableName1).alias("A") + checkAnswer(df.select(df.col("A.num.col")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select(col("A.num.col")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select($"A.num.col"), Seq(Row(1), Row(2), Row(3))) + } + test("Test for alias with join") { createTable(tableName1, "id1 int, num1 int") createTable(tableName2, "id2 int, num2 int") From 29b33fd51cb1457eca5017256ee4d296f920d3ee Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Wed, 3 Jan 2024 15:18:34 -0800 Subject: [PATCH 7/7] address comments --- .../snowpark/internal/analyzer/SnowflakePlanNode.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index 90933396..3cebd228 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -73,8 +73,6 @@ private[snowpark] trait LeafNode extends LogicalPlan { // create ExpressionAnalyzer with empty alias map override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer() - override lazy val dfAliasMap: Map[String, Seq[Attribute]] = Map.empty - // leaf node doesn't have child override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this @@ -202,9 +200,8 @@ private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan) override lazy val dfAliasMap: Map[String, Seq[Attribute]] = Utils.addToDataframeAliasMap(Map(alias -> child.getSnowflakePlan.get.output), child) - override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = child => { - DataframeAlias(alias, child) - } + override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = + DataframeAlias(alias, _) override protected def updateChild: LogicalPlan => LogicalPlan = createFromAnalyzedChild