From 9163a66bd2f84c7c0be884c8bc5abc99d1cacb0d Mon Sep 17 00:00:00 2001 From: Zihan Li Date: Tue, 2 Jan 2024 15:58:01 -0800 Subject: [PATCH] address comments --- .../snowflake/snowpark_java/DataFrame.java | 19 +++++++++++++++++++ .../scala/com/snowflake/snowpark/Column.scala | 2 +- .../internal/analyzer/Expression.scala | 2 +- .../analyzer/ExpressionAnalyzer.scala | 16 ++++++++-------- .../internal/analyzer/SnowflakePlanNode.scala | 6 +++--- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index 2e6d23ec..56e416e0 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -802,6 +802,25 @@ public Column col(String colName) { return new Column(df.col(colName)); } + /** + * Returns the current DataFrame aliased as the input alias name. + * + * For example: + * + * {{{ + * val df2 = df.alias("A") + * df2.select(df2.col("A.num")) + * }}} + * + * @group basic + * @since 1.10.0 + * @param alias The alias name of the dataframe + * @return a [[DataFrame]] + */ + public DataFrame alias(String alias) { + return new DataFrame(this.df.alias(alias)); + } + /** * Executes the query representing this DataFrame and returns the result as an array of Row * objects. diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index e8972f69..56996aa9 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -732,7 +732,7 @@ private[snowpark] object Column { def apply(name: String): Column = new Column(name match { case "*" => Star(Seq.empty) - case c if c.contains(".") => DfAliasAttribute(name) + case c if c.contains(".") => UnresolvedDFAliasAttribute(name) case _ => UnresolvedAttribute(quoteName(name)) }) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala index 3e8763c6..adbddd38 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala @@ -390,7 +390,7 @@ private[snowpark] case class UnresolvedAttribute(override val name: String) this } -private[snowpark] case class DfAliasAttribute(override val name: String) +private[snowpark] case class UnresolvedDFAliasAttribute(override val name: String) extends Expression with NamedExpression { override def sql: String = "" diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala index ae399b68..76bed5ef 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala @@ -6,15 +6,15 @@ import scala.collection.mutable.{Map => MMap} private[snowpark] object ExpressionAnalyzer { def apply(aliasMap: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = new ExpressionAnalyzer(aliasMap, dfAliasMap) - def apply(dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = - new ExpressionAnalyzer(Map.empty, dfAliasMap) + def apply(): ExpressionAnalyzer = + new ExpressionAnalyzer(Map.empty, Map.empty) // create new analyzer by combining two alias maps def apply(map1: Map[ExprId, String], map2: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { val common = map1.keySet & map2.keySet val result = (map1 ++ map2).filter { // remove common column, let (df1.join(df2)) @@ -25,15 +25,15 @@ private[snowpark] object ExpressionAnalyzer { } def apply(maps: Seq[Map[ExprId, String]], - dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = { - maps.foldLeft(ExpressionAnalyzer(dfAliasMap)) { + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { + maps.foldLeft(ExpressionAnalyzer()) { case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap) } } } private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], - dfAliasMap: MMap[String, Seq[Attribute]]) { + dfAliasMap: Map[String, Seq[Attribute]]) { private val generatedAliasMap: MMap[ExprId, String] = MMap.empty def analyze(ex: Expression): Expression = ex match { @@ -58,7 +58,7 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], // removed useless alias case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) => child - case DfAliasAttribute(name) => + case UnresolvedDFAliasAttribute(name) => val colNameSplit = name.split("\\.", 2) if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) { val aliasOutput = dfAliasMap(colNameSplit(0)) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index dcf8f09a..b49b38d1 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -19,9 +19,9 @@ private[snowpark] trait LogicalPlan { (analyzedPlan, analyzer.getAliasMap) } - var dfAliasMap: MMap[String, Seq[Attribute]] = MMap.empty + var dfAliasMap: Map[String, Seq[Attribute]] = Map.empty - protected def addToDataframeAliasMap(map: MMap[String, Seq[Attribute]]): Unit = { + protected def addToDataframeAliasMap(map: Map[String, Seq[Attribute]]): Unit = { val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet) if (duplicatedAlias.nonEmpty) { throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) @@ -79,7 +79,7 @@ private[snowpark] trait LogicalPlan { private[snowpark] trait LeafNode extends LogicalPlan { // create ExpressionAnalyzer with empty alias map - override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(dfAliasMap) + override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer() // leaf node doesn't have child override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this