diff --git a/src/main/java/com/snowflake/snowpark_java/types/CloudProviderToken.java b/src/main/java/com/snowflake/snowpark_java/types/CloudProviderToken.java index f4cb9f4e..c92044d0 100644 --- a/src/main/java/com/snowflake/snowpark_java/types/CloudProviderToken.java +++ b/src/main/java/com/snowflake/snowpark_java/types/CloudProviderToken.java @@ -1,8 +1,9 @@ package com.snowflake.snowpark_java.types; -/** The Snowflake class provides access to the CloudProviderToken secret object with the following properties: accessKeyId, - * secretAccessKey, and token.*/ - +/** + * The Snowflake class provides access to the CloudProviderToken secret object with the following + * properties: accessKeyId, secretAccessKey, and token. + */ public class CloudProviderToken { private final String accessKeyId; private final String secretAccessKey; diff --git a/src/main/java/com/snowflake/snowpark_java/types/SnowflakeSecrets.java b/src/main/java/com/snowflake/snowpark_java/types/SnowflakeSecrets.java index 4438bea4..75fa63d9 100644 --- a/src/main/java/com/snowflake/snowpark_java/types/SnowflakeSecrets.java +++ b/src/main/java/com/snowflake/snowpark_java/types/SnowflakeSecrets.java @@ -32,8 +32,8 @@ public UsernamePassword getUsernamePassword(String secretName) { } /** - * Get the Cloud provider token from the secret. On success, it returns a valid object with access key id, - * secret access key and token. + * Get the Cloud provider token from the secret. On success, it returns a valid object with access + * key id, secret access key and token. * * @param secretName name of the secret object. */ diff --git a/src/main/scala/com/snowflake/snowpark/CopyableDataFrame.scala b/src/main/scala/com/snowflake/snowpark/CopyableDataFrame.scala index f30c24e9..fa8b8e15 100644 --- a/src/main/scala/com/snowflake/snowpark/CopyableDataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/CopyableDataFrame.scala @@ -19,8 +19,9 @@ import com.snowflake.snowpark.internal.analyzer._ class CopyableDataFrame private[snowpark] ( override private[snowpark] val session: Session, override private[snowpark] val plan: SnowflakePlan, + override private[snowpark] val methodChain: Seq[String], private val stagedFileReader: StagedFileReader) - extends DataFrame(session, plan) { + extends DataFrame(session, plan, methodChain) { /** * Executes a `COPY INTO ` command to @@ -238,7 +239,7 @@ class CopyableDataFrame private[snowpark] ( * @group basic */ override def clone: CopyableDataFrame = action("clone", 2) { - new CopyableDataFrame(session, plan, stagedFileReader) + new CopyableDataFrame(session, plan, Seq(), stagedFileReader) } /** @@ -261,11 +262,12 @@ class CopyableDataFrame private[snowpark] ( @inline override protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = this.session.conn.isScalaAPI - OpenTelemetry.action("CopyableDataFrame", funcName, isScala)(func) + OpenTelemetry.action("CopyableDataFrame", funcName, methodChainString, isScala)(func) } @inline protected def action[T](funcName: String, javaOffset: Int)(func: => T): T = { val isScala: Boolean = this.session.conn.isScalaAPI - OpenTelemetry.action("CopyableDataFrame", funcName, isScala, javaOffset)(func) + OpenTelemetry.action("CopyableDataFrame", funcName, methodChainString, isScala, javaOffset)( + func) } } @@ -360,6 +362,10 @@ class CopyableDataFrameAsyncActor private[snowpark] (cdf: CopyableDataFrame) @inline override protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = cdf.session.conn.isScalaAPI - OpenTelemetry.action("CopyableDataFrameAsyncActor", funcName, isScala)(func) + OpenTelemetry.action( + "CopyableDataFrameAsyncActor", + funcName, + cdf.methodChainString + ".async", + isScala)(func) } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index ce45b362..8f99693c 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark import scala.reflect.ClassTag -import scala.util.Random +import scala.util.{DynamicVariable, Random} import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF} import com.snowflake.snowpark.internal.{ErrorMessage, Logging, OpenTelemetry, Utils} import com.snowflake.snowpark.internal.analyzer._ @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer private[snowpark] object DataFrame extends Logging { def apply(session: Session, plan: LogicalPlan): DataFrame = - new DataFrame(session, plan) + new DataFrame(session, plan, methodChainCache.value) def getUnaliased(colName: String): List[String] = { val ColPattern = s"""._[a-zA-Z0-9]{${numPrefixDigits}}_(.*)""".r @@ -37,6 +37,18 @@ private[snowpark] object DataFrame extends Logging { } private val numPrefixDigits = 4 + + // build method chain in the Dataframe transformation functions + // in case of recursion, only record the outer function in the method chain. + val methodChainCache = new DynamicVariable[Seq[String]](Seq.empty[String]) + + def buildMethodChain(current: Seq[String], newMethod: String)( + thunk: => DataFrame): DataFrame = { + methodChainCache.withValue( + if (methodChainCache.value.isEmpty) current :+ newMethod else methodChainCache.value) { + thunk + } + } } /** @@ -190,7 +202,8 @@ private[snowpark] object DataFrame extends Logging { */ class DataFrame private[snowpark] ( private[snowpark] val session: Session, - private[snowpark] val plan: LogicalPlan) + private[snowpark] val plan: LogicalPlan, + private[snowpark] val methodChain: Seq[String]) extends Logging { lazy private[snowpark] val snowflakePlan: SnowflakePlan = session.analyzer.resolve(plan) @@ -202,7 +215,7 @@ class DataFrame private[snowpark] ( * @since 0.4.0 * @return A [[DataFrame]] */ - override def clone: DataFrame = { + override def clone: DataFrame = transformation("clone") { DataFrame(session, snowflakePlan.clone) } @@ -242,7 +255,7 @@ class DataFrame private[snowpark] ( session.conn.execute(createTempTable) val newPlan = session.table(tempTableName).plan session.conn.telemetry.reportActionCacheResult() - new HasCachedResult(session, newPlan) + new HasCachedResult(session, newPlan, Seq()) } /** @@ -326,7 +339,7 @@ class DataFrame private[snowpark] ( * @param remaining A list of the rest of the column names. * @return A [[DataFrame]] */ - def toDF(first: String, remaining: String*): DataFrame = { + def toDF(first: String, remaining: String*): DataFrame = transformation("toDF") { toDF(first +: remaining) } @@ -370,7 +383,7 @@ class DataFrame private[snowpark] ( * @param colNames A list of column names. * @return A [[DataFrame]] */ - def toDF(colNames: Seq[String]): DataFrame = { + def toDF(colNames: Seq[String]): DataFrame = transformation("toDF") { require( output.length == colNames.length, "The number of columns doesn't match. \n" + @@ -431,8 +444,9 @@ class DataFrame private[snowpark] ( * @param colNames An array of column names. * @return A [[DataFrame]] */ - def toDF(colNames: Array[String]): DataFrame = + def toDF(colNames: Array[String]): DataFrame = transformation("toDF") { toDF(colNames.toSeq) + } /** * Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL). @@ -449,8 +463,9 @@ class DataFrame private[snowpark] ( * @param remaining Additional Column expressions for sorting the DataFrame. * @return A [[DataFrame]] */ - def sort(first: Column, remaining: Column*): DataFrame = + def sort(first: Column, remaining: Column*): DataFrame = transformation("sort") { sort(first +: remaining) + } /** * Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL). @@ -465,7 +480,7 @@ class DataFrame private[snowpark] ( * @param sortExprs A list of Column expressions for sorting the DataFrame. * @return A [[DataFrame]] */ - def sort(sortExprs: Seq[Column]): DataFrame = + def sort(sortExprs: Seq[Column]): DataFrame = transformation("sort") { if (sortExprs.nonEmpty) { withPlan(Sort(sortExprs.map { col => col.expr match { @@ -476,6 +491,7 @@ class DataFrame private[snowpark] ( } else { throw ErrorMessage.DF_SORT_NEED_AT_LEAST_ONE_EXPR() } + } /** * Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL). @@ -532,7 +548,9 @@ class DataFrame private[snowpark] ( * @param alias The alias name of the dataframe * @return a [[DataFrame]] */ - def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan, output)) + def alias(alias: String): DataFrame = transformation("alias") { + withPlan(DataframeAlias(alias, plan, output)) + } /** * Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in @@ -552,7 +570,7 @@ class DataFrame private[snowpark] ( * @param remaining A list of expressions for the additional columns to return. * @return A [[DataFrame]] */ - def select(first: Column, remaining: Column*): DataFrame = { + def select(first: Column, remaining: Column*): DataFrame = transformation("select") { select(first +: remaining) } @@ -573,7 +591,7 @@ class DataFrame private[snowpark] ( * @param columns A list of expressions for the columns to return. * @return A [[DataFrame]] */ - def select[T: ClassTag](columns: Seq[Column]): DataFrame = { + def select[T: ClassTag](columns: Seq[Column]): DataFrame = transformation("select") { require( columns.nonEmpty, "Provide at least one column expression for select(). " + @@ -679,7 +697,9 @@ class DataFrame private[snowpark] ( * @param columns An array of expressions for the columns to return. * @return A [[DataFrame]] */ - def select(columns: Array[Column]): DataFrame = select(columns.toSeq) + def select(columns: Array[Column]): DataFrame = transformation("select") { + select(columns.toSeq) + } /** * Returns a new DataFrame with a subset of named columns (similar to SELECT in SQL). @@ -696,7 +716,7 @@ class DataFrame private[snowpark] ( * @param remaining A list of the names of the additional columns to return. * @return A [[DataFrame]] */ - def select(first: String, remaining: String*): DataFrame = { + def select(first: String, remaining: String*): DataFrame = transformation("select") { select(first +: remaining) } @@ -714,7 +734,7 @@ class DataFrame private[snowpark] ( * @param columns A list of the names of columns to return. * @return A [[DataFrame]] */ - def select(columns: Seq[String]): DataFrame = { + def select(columns: Seq[String]): DataFrame = transformation("select") { select(columns.map(Column(_))) } @@ -732,7 +752,9 @@ class DataFrame private[snowpark] ( * @param columns An array of the names of columns to return. * @return A [[DataFrame]] */ - def select(columns: Array[String]): DataFrame = select(columns.toSeq) + def select(columns: Array[String]): DataFrame = transformation("select") { + select(columns.toSeq) + } /** * Returns a new DataFrame that excludes the columns with the specified names from the output. @@ -747,7 +769,7 @@ class DataFrame private[snowpark] ( * @param remaining A list of the names of additional columns to exclude. * @return A [[DataFrame]] */ - def drop(first: String, remaining: String*): DataFrame = { + def drop(first: String, remaining: String*): DataFrame = transformation("drop") { drop(first +: remaining) } @@ -765,7 +787,7 @@ class DataFrame private[snowpark] ( * @param colNames A list of the names of columns to exclude. * @return A [[DataFrame]] */ - def drop(colNames: Seq[String]): DataFrame = { + def drop(colNames: Seq[String]): DataFrame = transformation("drop") { val dropColumns: Seq[Column] = colNames.map(name => functions.col(name)) drop(dropColumns) } @@ -783,7 +805,9 @@ class DataFrame private[snowpark] ( * @param colNames An array of the names of columns to exclude. * @return A [[DataFrame]] */ - def drop(colNames: Array[String]): DataFrame = drop(colNames.toSeq) + def drop(colNames: Array[String]): DataFrame = transformation("drop") { + drop(colNames.toSeq) + } /** * Returns a new DataFrame that excludes the columns specified by the expressions from the @@ -802,7 +826,7 @@ class DataFrame private[snowpark] ( * @param remaining A list of expressions for additional columns to exclude. * @return A [[DataFrame]] */ - def drop(first: Column, remaining: Column*): DataFrame = { + def drop(first: Column, remaining: Column*): DataFrame = transformation("drop") { drop(first +: remaining) } @@ -822,7 +846,7 @@ class DataFrame private[snowpark] ( * @param cols A list of the names of the columns to exclude. * @return A [[DataFrame]] */ - def drop[T: ClassTag](cols: Seq[Column]): DataFrame = { + def drop[T: ClassTag](cols: Seq[Column]): DataFrame = transformation("drop") { val dropColumns: Seq[NamedExpression] = cols.map { case Column(expr: NamedExpression) => expr case c => @@ -847,7 +871,9 @@ class DataFrame private[snowpark] ( * @param cols An array of the names of the columns to exclude. * @return A [[DataFrame]] */ - def drop(cols: Array[Column]): DataFrame = drop(cols.toSeq) + def drop(cols: Array[Column]): DataFrame = transformation("drop") { + drop(cols.toSeq) + } /** * Filters rows based on the specified conditional expression (similar to WHERE in SQL). @@ -863,7 +889,9 @@ class DataFrame private[snowpark] ( * @param condition Filter condition defined as an expression on columns. * @return A filtered [[DataFrame]] */ - def filter(condition: Column): DataFrame = withPlan(Filter(condition.expr, plan)) + def filter(condition: Column): DataFrame = transformation("filter") { + withPlan(Filter(condition.expr, plan)) + } /** * Filters rows based on the specified conditional expression (similar to WHERE in SQL). @@ -882,7 +910,9 @@ class DataFrame private[snowpark] ( * @param condition Filter condition defined as an expression on columns. * @return A filtered [[DataFrame]] */ - def where(condition: Column): DataFrame = filter(condition) + def where(condition: Column): DataFrame = transformation("where") { + filter(condition) + } /** * Aggregate the data in the DataFrame. Use this method if you don't need to @@ -909,8 +939,9 @@ class DataFrame private[snowpark] ( * @param expr A map of column names and aggregate functions. * @return A [[DataFrame]] */ - def agg(expr: (String, String), exprs: (String, String)*): DataFrame = + def agg(expr: (String, String), exprs: (String, String)*): DataFrame = transformation("agg") { agg(expr +: exprs) + } /** * Aggregate the data in the DataFrame. Use this method if you don't need @@ -937,8 +968,9 @@ class DataFrame private[snowpark] ( * @param exprs A map of column names and aggregate functions. * @return A [[DataFrame]] */ - def agg(exprs: Seq[(String, String)]): DataFrame = + def agg(exprs: Seq[(String, String)]): DataFrame = transformation("agg") { groupBy().agg(exprs.map({ case (c, a) => (col(c), a) })) + } /** * Aggregate the data in the DataFrame. Use this method if you don't need to group the data @@ -963,7 +995,9 @@ class DataFrame private[snowpark] ( * @param expr A list of expressions on columns. * @return A [[DataFrame]] */ - def agg(expr: Column, exprs: Column*): DataFrame = agg(expr +: exprs) + def agg(expr: Column, exprs: Column*): DataFrame = transformation("agg") { + agg(expr +: exprs) + } /** * Aggregate the data in the DataFrame. Use this method if you don't need @@ -985,7 +1019,9 @@ class DataFrame private[snowpark] ( * @param exprs A list of expressions on columns. * @return A [[DataFrame]] */ - def agg[T: ClassTag](exprs: Seq[Column]): DataFrame = groupBy().agg(exprs) + def agg[T: ClassTag](exprs: Seq[Column]): DataFrame = transformation("agg") { + groupBy().agg(exprs) + } /** * Aggregate the data in the DataFrame. Use this method if you don't need @@ -1010,7 +1046,9 @@ class DataFrame private[snowpark] ( * @param exprs An array of expressions on columns. * @return A [[DataFrame]] */ - def agg(exprs: Array[Column]): DataFrame = agg(exprs.toSeq) + def agg(exprs: Array[Column]): DataFrame = transformation("agg") { + agg(exprs.toSeq) + } /** * Performs an SQL @@ -1246,7 +1284,7 @@ class DataFrame private[snowpark] ( RelationalGroupedDataFrame( this, groupingSets.map(_.toExpression), - RelationalGroupedDataFrame.GroupByType) + RelationalGroupedDataFrame.GroupByGroupingSetsType) /** * Performs an SQL @@ -1323,8 +1361,9 @@ class DataFrame private[snowpark] ( * @since 0.1.0 * @return A [[DataFrame]] */ - def distinct(): DataFrame = + def distinct(): DataFrame = transformation("distinct") { groupBy(output.map(att => quoteName(att.name)).map(this.col)).agg(Map.empty[Column, String]) + } /** * Creates a new DataFrame by removing duplicated rows on given subset of columns. @@ -1343,7 +1382,7 @@ class DataFrame private[snowpark] ( * @since 0.10.0 * @return A [[DataFrame]] */ - def dropDuplicates(colNames: String*): DataFrame = { + def dropDuplicates(colNames: String*): DataFrame = transformation("dropDuplicates") { if (colNames.isEmpty) { this.distinct() } else { @@ -1421,7 +1460,9 @@ class DataFrame private[snowpark] ( * @param n Number of rows to return. * @return A [[DataFrame]] */ - def limit(n: Int): DataFrame = withPlan(Limit(Literal(n), plan)) + def limit(n: Int): DataFrame = transformation("limit") { + withPlan(Limit(Literal(n), plan)) + } /** * Returns a new DataFrame that contains all the rows in the current DataFrame and another @@ -1439,7 +1480,9 @@ class DataFrame private[snowpark] ( * @param other The other [[DataFrame]] that contains the rows to include. * @return A [[DataFrame]] */ - def union(other: DataFrame): DataFrame = withPlan(Union(plan, other.plan)) + def union(other: DataFrame): DataFrame = transformation("union") { + withPlan(Union(plan, other.plan)) + } /** * Returns a new DataFrame that contains all the rows in the current DataFrame and another @@ -1457,7 +1500,9 @@ class DataFrame private[snowpark] ( * @param other The other [[DataFrame]] that contains the rows to include. * @return A [[DataFrame]] */ - def unionAll(other: DataFrame): DataFrame = withPlan(UnionAll(plan, other.plan)) + def unionAll(other: DataFrame): DataFrame = transformation("unionAll") { + withPlan(UnionAll(plan, other.plan)) + } /** * Returns a new DataFrame that contains all the rows in the current DataFrame and another @@ -1478,7 +1523,9 @@ class DataFrame private[snowpark] ( * @param other The other [[DataFrame]] that contains the rows to include. * @return A [[DataFrame]] */ - def unionByName(other: DataFrame): DataFrame = internalUnionByName(other, isAll = false) + def unionByName(other: DataFrame): DataFrame = transformation("unionByName") { + internalUnionByName(other, isAll = false) + } /** * Returns a new DataFrame that contains all the rows in the current DataFrame and another @@ -1499,7 +1546,9 @@ class DataFrame private[snowpark] ( * @param other The other [[DataFrame]] that contains the rows to include. * @return A [[DataFrame]] */ - def unionAllByName(other: DataFrame): DataFrame = internalUnionByName(other, isAll = true) + def unionAllByName(other: DataFrame): DataFrame = transformation("unionAllByName") { + internalUnionByName(other, isAll = true) + } private def internalUnionByName(other: DataFrame, isAll: Boolean): DataFrame = { val leftOutputAttrs = output @@ -1550,7 +1599,9 @@ class DataFrame private[snowpark] ( * @param other The other [[DataFrame]] that contains the rows to use for the intersection. * @return A [[DataFrame]] */ - def intersect(other: DataFrame): DataFrame = withPlan(Intersect(plan, other.plan)) + def intersect(other: DataFrame): DataFrame = transformation("intersect") { + withPlan(Intersect(plan, other.plan)) + } /** * Returns a new DataFrame that contains all the rows from the current DataFrame except for the @@ -1567,7 +1618,9 @@ class DataFrame private[snowpark] ( * @param other The [[DataFrame]] that contains the rows to exclude. * @return A [[DataFrame]] */ - def except(other: DataFrame): DataFrame = withPlan(Except(plan, other.plan)) + def except(other: DataFrame): DataFrame = transformation("except") { + withPlan(Except(plan, other.plan)) + } /** * Performs a default inner join of the current DataFrame and another DataFrame (`right`). @@ -1591,7 +1644,7 @@ class DataFrame private[snowpark] ( * @param right The other [[DataFrame]] to join. * @return A [[DataFrame]] */ - def join(right: DataFrame): DataFrame = { + def join(right: DataFrame): DataFrame = transformation("join") { join(right, Seq.empty) } @@ -1614,7 +1667,7 @@ class DataFrame private[snowpark] ( * @param usingColumn The name of the column to use for the join. * @return A [[DataFrame]] */ - def join(right: DataFrame, usingColumn: String): DataFrame = { + def join(right: DataFrame, usingColumn: String): DataFrame = transformation("join") { join(right, Seq(usingColumn)) } @@ -1638,7 +1691,7 @@ class DataFrame private[snowpark] ( * @param usingColumns A list of the names of the columns to use for the join. * @return A [[DataFrame]] */ - def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = transformation("join") { join(right, usingColumns, "inner") } @@ -1663,21 +1716,22 @@ class DataFrame private[snowpark] ( * @param joinType The type of join (e.g. {@code "right"}, {@code "outer"}, etc.). * @return A [[DataFrame]] */ - def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { - val jType = JoinType(joinType) - if (jType == LeftSemi || jType == LeftAnti) { - val joinCond = usingColumns - .map(quoteName) - .map(n => this.col(n) === right.col(n)) - .foldLeft(functions.lit(true))(_ && _) - join(right, joinCond, joinType) - } else { - val (lhs, rhs) = disambiguate(this, right, jType, usingColumns) - withPlan { - Join(lhs.plan, rhs.plan, UsingJoin(jType, usingColumns), None) + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = + transformation("join") { + val jType = JoinType(joinType) + if (jType == LeftSemi || jType == LeftAnti) { + val joinCond = usingColumns + .map(quoteName) + .map(n => this.col(n) === right.col(n)) + .foldLeft(functions.lit(true))(_ && _) + join(right, joinCond, joinType) + } else { + val (lhs, rhs) = disambiguate(this, right, jType, usingColumns) + withPlan { + Join(lhs.plan, rhs.plan, UsingJoin(jType, usingColumns), None) + } } } - } // scalastyle:off line.size.limit /** @@ -1716,7 +1770,7 @@ class DataFrame private[snowpark] ( * @return A [[DataFrame]] */ // scalastyle:on line.size.limit - def join(right: DataFrame, joinExprs: Column): DataFrame = { + def join(right: DataFrame, joinExprs: Column): DataFrame = transformation("join") { join(right, joinExprs, "inner") } @@ -1762,12 +1816,13 @@ class DataFrame private[snowpark] ( * @return A [[DataFrame]] */ // scalastyle:on line.size.limit - def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - if (this.eq(right) || this.plan.eq(right.plan)) { - throw ErrorMessage.DF_SELF_JOIN_NOT_SUPPORTED() + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = + transformation("join") { + if (this.eq(right) || this.plan.eq(right.plan)) { + throw ErrorMessage.DF_SELF_JOIN_NOT_SUPPORTED() + } + join(right, JoinType(joinType), Some(joinExprs)) } - join(right, JoinType(joinType), Some(joinExprs)) - } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1797,7 +1852,9 @@ class DataFrame private[snowpark] ( * @param remaining A list of any additional arguments for the specified table function. */ def join(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame = - join(func, firstArg +: remaining) + transformation("join") { + join(func, firstArg +: remaining) + } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1823,8 +1880,9 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = + def join(func: TableFunction, args: Seq[Column]): DataFrame = transformation("join") { joinTableFunction(func.call(args: _*), None) + } /** * Joins the current DataFrame with the output of the specified user-defined table @@ -1856,10 +1914,11 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Seq[Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = + orderBy: Seq[Column]): DataFrame = transformation("join") { joinTableFunction( func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + } /** * Joins the current DataFrame with the output of the specified table function `func` that takes @@ -1892,7 +1951,9 @@ class DataFrame private[snowpark] ( * Use this map to specify the parameter names and their corresponding values. */ def join(func: TableFunction, args: Map[String, Column]): DataFrame = - joinTableFunction(func.call(args), None) + transformation("join") { + joinTableFunction(func.call(args), None) + } /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1931,10 +1992,11 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Map[String, Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = + orderBy: Seq[Column]): DataFrame = transformation("join") { joinTableFunction( func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1958,8 +2020,9 @@ class DataFrame private[snowpark] ( * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] * object or an object that you create from the [[TableFunction.apply()]]. */ - def join(func: Column): DataFrame = + def join(func: Column): DataFrame = transformation("join") { joinTableFunction(getTableFunctionExpression(func), None) + } /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1980,9 +2043,11 @@ class DataFrame private[snowpark] ( * @param orderBy A list of columns ordered by. */ def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = - joinTableFunction( - getTableFunctionExpression(func), - Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + transformation("join") { + joinTableFunction( + getTableFunctionExpression(func), + Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + } private def joinTableFunction( func: TableFunctionExpression, @@ -2053,7 +2118,7 @@ class DataFrame private[snowpark] ( * @param right The other [[DataFrame]] to join. * @return A [[DataFrame]] */ - def crossJoin(right: DataFrame): DataFrame = { + def crossJoin(right: DataFrame): DataFrame = transformation("crossJoin") { join(right, JoinType("cross"), None) } @@ -2084,7 +2149,7 @@ class DataFrame private[snowpark] ( * @param right The other [[DataFrame]] to join. * @return A [[DataFrame]] */ - def naturalJoin(right: DataFrame): DataFrame = { + def naturalJoin(right: DataFrame): DataFrame = transformation("naturalJoin") { naturalJoin(right, "inner") } @@ -2104,7 +2169,7 @@ class DataFrame private[snowpark] ( * @param joinType The type of join (e.g. {@code "right"}, {@code "outer"}, etc.). * @return A [[DataFrame]] */ - def naturalJoin(right: DataFrame, joinType: String): DataFrame = { + def naturalJoin(right: DataFrame, joinType: String): DataFrame = transformation("naturalJoin") { withPlan { Join(this.plan, right.plan, NaturalJoin(JoinType(joinType)), None) } @@ -2129,7 +2194,9 @@ class DataFrame private[snowpark] ( * @param col The [[Column]] to add or replace. * @return A [[DataFrame]] */ - def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + def withColumn(colName: String, col: Column): DataFrame = transformation("withColumn") { + withColumns(Seq(colName), Seq(col)) + } /** * Returns a DataFrame with additional columns with the specified names (`colNames`). The @@ -2151,19 +2218,22 @@ class DataFrame private[snowpark] ( * @param values A list of the [[Column]] objects to add or replace. * @return A [[DataFrame]] */ - def withColumns(colNames: Seq[String], values: Seq[Column]): DataFrame = { - if (colNames.size != values.size) { - throw ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_NOT_MATCH_VALUES(colNames.size, values.size) - } - val qualifiedNames = colNames.map(quoteName) - if (qualifiedNames.toSet.size != colNames.size) { - throw ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_CONTAINS_DUPLICATES - } - val newCols = qualifiedNames.zip(values).map { - case (name, col) => col.as(name).expr.asInstanceOf[NamedExpression] + def withColumns(colNames: Seq[String], values: Seq[Column]): DataFrame = + transformation("withColumns") { + if (colNames.size != values.size) { + throw ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_NOT_MATCH_VALUES( + colNames.size, + values.size) + } + val qualifiedNames = colNames.map(quoteName) + if (qualifiedNames.toSet.size != colNames.size) { + throw ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_CONTAINS_DUPLICATES + } + val newCols = qualifiedNames.zip(values).map { + case (name, col) => col.as(name).expr.asInstanceOf[NamedExpression] + } + withPlan(WithColumns(newCols, plan)) } - withPlan(WithColumns(newCols, plan)) - } /** * Returns a DataFrame with the specified column `col` renamed as `newName`. @@ -2180,7 +2250,7 @@ class DataFrame private[snowpark] ( * @param col The [[Column]] to be renamed * @return A [[DataFrame]] */ - def rename(newName: String, col: Column): DataFrame = { + def rename(newName: String, col: Column): DataFrame = transformation("rename") { // Normalize the new column name val newQuotedName = quoteName(newName) @@ -2196,10 +2266,8 @@ class DataFrame private[snowpark] ( if (toBeRenamed.isEmpty) { throw ErrorMessage.DF_CANNOT_RENAME_COLUMN_BECAUSE_NOT_EXIST(oldName, newQuotedName) } else if (toBeRenamed.size > 1) { - throw ErrorMessage.DF_CANNOT_RENAME_COLUMN_BECAUSE_MULTIPLE_EXIST( - oldName, - newQuotedName, - toBeRenamed.size) + throw ErrorMessage + .DF_CANNOT_RENAME_COLUMN_BECAUSE_MULTIPLE_EXIST(oldName, newQuotedName, toBeRenamed.size) } val newColumns = output.map { @@ -2645,8 +2713,9 @@ class DataFrame private[snowpark] ( * @since 0.2.0 * @return A [[DataFrame]] containing the sample of {@code num} rows. */ - def sample(num: Long): DataFrame = + def sample(num: Long): DataFrame = transformation("sample") { withPlan(SnowflakeSampleNode(None, Some(num), plan)) + } /** * Returns a new DataFrame that contains a sampling of rows from the current DataFrame. @@ -2667,8 +2736,9 @@ class DataFrame private[snowpark] ( * @since 0.2.0 * @return A [[DataFrame]] containing the sample of rows. */ - def sample(probabilityFraction: Double): DataFrame = + def sample(probabilityFraction: Double): DataFrame = transformation("sample") { withPlan(SnowflakeSampleNode(Some(probabilityFraction), None, plan)) + } /** * Randomly splits the current DataFrame into separate DataFrames, using the specified weights. @@ -2753,8 +2823,9 @@ class DataFrame private[snowpark] ( * @return A [[DataFrame]] containing the flattened values. * @since 0.2.0 */ - def flatten(input: Column): DataFrame = + def flatten(input: Column): DataFrame = transformation("flatten") { flatten(input, "", outer = false, recursive = false, "BOTH") + } /** * Flattens (explodes) compound values into multiple rows (similar to the SQL @@ -2806,7 +2877,7 @@ class DataFrame private[snowpark] ( path: String, outer: Boolean, recursive: Boolean, - mode: String): DataFrame = { + mode: String): DataFrame = transformation("flatten") { // scalastyle:off val flattenMode = mode.toUpperCase() match { case m @ ("OBJECT" | "ARRAY" | "BOTH") => m @@ -2950,12 +3021,20 @@ class DataFrame private[snowpark] ( } } + lazy private[snowpark] val methodChainString: String = + methodChain.foldLeft("DataFrame") { + case (str, methodName) => s"$str.$methodName" + } + @inline protected def withPlan(plan: LogicalPlan): DataFrame = DataFrame(session, plan) @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = this.session.conn.isScalaAPI - OpenTelemetry.action("DataFrame", funcName, isScala)(func) + OpenTelemetry.action("DataFrame", funcName, methodChainString, isScala)(func) } + + @inline protected def transformation(funcName: String)(func: => DataFrame): DataFrame = + DataFrame.buildMethodChain(this.methodChain, funcName)(func) } /** @@ -2967,8 +3046,9 @@ class DataFrame private[snowpark] ( */ class HasCachedResult private[snowpark] ( override private[snowpark] val session: Session, - override private[snowpark] val plan: LogicalPlan) - extends DataFrame(session, plan) { + override private[snowpark] val plan: LogicalPlan, + override private[snowpark] val methodChain: Seq[String]) + extends DataFrame(session, plan, methodChain) { /** * Caches the content of this DataFrame to create a new cached DataFrame. @@ -2983,7 +3063,7 @@ class HasCachedResult private[snowpark] ( override def cacheResult(): HasCachedResult = action("cacheResult") { // cacheResult function of HashCachedResult returns a clone of this // HashCachedResult DataFrame instead of to cache this DataFrame again. - new HasCachedResult(session, snowflakePlan.clone) + new HasCachedResult(session, snowflakePlan.clone, Seq()) } } @@ -3029,6 +3109,10 @@ class DataFrameAsyncActor private[snowpark] (df: DataFrame) { @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = df.session.conn.isScalaAPI - OpenTelemetry.action("DataFrameAsyncActor", funcName, isScala)(func) + OpenTelemetry.action( + "DataFrameAsyncActor", + funcName, + df.methodChainString + ".async", + isScala)(func) } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrameNaFunctions.scala b/src/main/scala/com/snowflake/snowpark/DataFrameNaFunctions.scala index 7a6874b4..5bc3e4e9 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrameNaFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrameNaFunctions.scala @@ -28,7 +28,7 @@ final class DataFrameNaFunctions private[snowpark] (df: DataFrame) extends Loggi * @throws SnowparkClientException if cols contains any unrecognized column name * @since 0.2.0 */ - def drop(minNonNullsPerRow: Int, cols: Seq[String]): DataFrame = { + def drop(minNonNullsPerRow: Int, cols: Seq[String]): DataFrame = transformation("drop") { // translate to // select * from table where // iff(floatCol = 'NaN' or floatCol is null, 0, 1) @@ -95,7 +95,7 @@ final class DataFrameNaFunctions private[snowpark] (df: DataFrame) extends Loggi * * @since 0.2.0 */ - def fill(valueMap: Map[String, Any]): DataFrame = { + def fill(valueMap: Map[String, Any]): DataFrame = transformation("fill") { // translate to // select col, iff(floatCol is null or floatCol == 'NaN', replacement, floatCol), // iff(nonFloatCol is null, replacement, nonFloatCol) from table @@ -167,39 +167,43 @@ final class DataFrameNaFunctions private[snowpark] (df: DataFrame) extends Loggi * @throws SnowparkClientException if colName is an unrecognized column name * @since 0.2.0 */ - def replace(colName: String, replacement: Map[Any, Any]): DataFrame = { - // verify name - val column = df.col(colName) + def replace(colName: String, replacement: Map[Any, Any]): DataFrame = + transformation("replace") { + // verify name + val column = df.col(colName) - if (replacement.isEmpty) { - df - } else { - val columns = df.output.map { field => - if (quoteName(field.name) == quoteName(colName)) { - val conditionReplacement = replacement.toSeq.map { - case (original, replace) => - val cond = if (original == None || original == null) { - column.is_null - } else { - column === lit(original) - } - val replacement = if (replace == None) { - lit(null) - } else { - lit(replace) - } - (cond, replacement) - } - var caseWhen = when(conditionReplacement.head._1, conditionReplacement.head._2) - conditionReplacement.tail.foreach { - case (cond, replace) => caseWhen = caseWhen.when(cond, replace) + if (replacement.isEmpty) { + df + } else { + val columns = df.output.map { field => + if (quoteName(field.name) == quoteName(colName)) { + val conditionReplacement = replacement.toSeq.map { + case (original, replace) => + val cond = if (original == None || original == null) { + column.is_null + } else { + column === lit(original) + } + val replacement = if (replace == None) { + lit(null) + } else { + lit(replace) + } + (cond, replacement) + } + var caseWhen = when(conditionReplacement.head._1, conditionReplacement.head._2) + conditionReplacement.tail.foreach { + case (cond, replace) => caseWhen = caseWhen.when(cond, replace) + } + caseWhen.otherwise(column).cast(field.dataType).as(colName) + } else { + df.col(field.name) } - caseWhen.otherwise(column).cast(field.dataType).as(colName) - } else { - df.col(field.name) } + df.select(columns) } - df.select(columns) } - } + + @inline protected def transformation(funcName: String)(func: => DataFrame): DataFrame = + DataFrame.buildMethodChain(this.df.methodChain :+ "na", funcName)(func) } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrameReader.scala b/src/main/scala/com/snowflake/snowpark/DataFrameReader.scala index 94b32395..0ed9c81a 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrameReader.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrameReader.scala @@ -157,7 +157,11 @@ class DataFrameReader(session: Session) { .path(path) .format("csv") .databaseSchema(session.getFullyQualifiedCurrentSchema) - new CopyableDataFrame(session, stagedFileReader.createSnowflakePlan(), stagedFileReader) + new CopyableDataFrame( + session, + stagedFileReader.createSnowflakePlan(), + Seq(), + stagedFileReader) } /** @@ -427,6 +431,10 @@ class DataFrameReader(session: Session) { .path(path) .format(format) .databaseSchema(session.getFullyQualifiedCurrentSchema) - new CopyableDataFrame(session, stagedFileReader.createSnowflakePlan(), stagedFileReader) + new CopyableDataFrame( + session, + stagedFileReader.createSnowflakePlan(), + Seq(), + stagedFileReader) } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrameStatFunctions.scala b/src/main/scala/com/snowflake/snowpark/DataFrameStatFunctions.scala index d3f47942..e42ada91 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrameStatFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrameStatFunctions.scala @@ -263,17 +263,18 @@ final class DataFrameStatFunctions private[snowpark] (df: DataFrame) extends Log * @since 0.2.0 * @return A new DataFrame that contains the stratified sample. */ - def sampleBy[T](col: Column, fractions: Map[T, Double]): DataFrame = { - if (fractions.isEmpty) { - return df.limit(0) - } - val (k, v) = fractions.head - var resDF = df.where(col === k).sample(v) - for ((k, v) <- fractions.tail) { - resDF = resDF.unionAll(df.where(col === k).sample(v)) + def sampleBy[T](col: Column, fractions: Map[T, Double]): DataFrame = + transformation("sampleBy") { + if (fractions.isEmpty) { + return df.limit(0) + } + val (k, v) = fractions.head + var resDF = df.where(col === k).sample(v) + for ((k, v) <- fractions.tail) { + resDF = resDF.unionAll(df.where(col === k).sample(v)) + } + resDF } - resDF - } /** * Returns a DataFrame containing a stratified sample without replacement, based on a Map that @@ -304,12 +305,19 @@ final class DataFrameStatFunctions private[snowpark] (df: DataFrame) extends Log * @since 0.2.0 * @return A new DataFrame that contains the stratified sample. */ - def sampleBy[T](col: String, fractions: Map[T, Double]): DataFrame = { - sampleBy(Col(col), fractions) - } + def sampleBy[T](col: String, fractions: Map[T, Double]): DataFrame = + transformation("sampleBy") { + sampleBy(Col(col), fractions) + } @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = df.session.conn.isScalaAPI - OpenTelemetry.action("DataFrameStatFunctions", funcName, isScala)(func) + OpenTelemetry.action( + "DataFrameStatFunctions", + funcName, + df.methodChainString + ".stat", + isScala)(func) } + @inline protected def transformation(funcName: String)(func: => DataFrame): DataFrame = + DataFrame.buildMethodChain(this.df.methodChain :+ "stat", funcName)(func) } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrameWriter.scala b/src/main/scala/com/snowflake/snowpark/DataFrameWriter.scala index 727f5bc7..a0643f1d 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrameWriter.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrameWriter.scala @@ -391,7 +391,11 @@ class DataFrameWriter(private[snowpark] val dataFrame: DataFrame) { @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = dataFrame.session.conn.isScalaAPI - OpenTelemetry.action("DataFrameWriter", funcName, isScala)(func) + OpenTelemetry.action( + "DataFrameWriter", + funcName, + this.dataFrame.methodChainString + ".writer", + isScala)(func) } } @@ -487,7 +491,11 @@ class DataFrameWriterAsyncActor private[snowpark] (writer: DataFrameWriter) { @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = writer.dataFrame.session.conn.isScalaAPI - OpenTelemetry.action("DataFrameWriterAsyncActor", funcName, isScala)(func) + OpenTelemetry.action( + "DataFrameWriterAsyncActor", + funcName, + writer.dataFrame.methodChainString + ".writer.async", + isScala)(func) } } diff --git a/src/main/scala/com/snowflake/snowpark/MergeBuilder.scala b/src/main/scala/com/snowflake/snowpark/MergeBuilder.scala index 08ca51d6..146e854a 100644 --- a/src/main/scala/com/snowflake/snowpark/MergeBuilder.scala +++ b/src/main/scala/com/snowflake/snowpark/MergeBuilder.scala @@ -205,7 +205,8 @@ class MergeBuilder private[snowpark] ( @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = target.session.conn.isScalaAPI - OpenTelemetry.action("MergeBuilder", funcName, isScala)(func) + OpenTelemetry.action("MergeBuilder", funcName, target.methodChainString + ".merge", isScala)( + func) } } @@ -231,6 +232,10 @@ class MergeBuilderAsyncActor private[snowpark] (mergeBuilder: MergeBuilder) { @inline protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = mergeBuilder.target.session.conn.isScalaAPI - OpenTelemetry.action("MergeBuilderAsyncActor", funcName, isScala)(func) + OpenTelemetry.action( + "MergeBuilderAsyncActor", + funcName, + mergeBuilder.target.methodChainString + ".merge.async", + isScala)(func) } } diff --git a/src/main/scala/com/snowflake/snowpark/RelationalGroupedDataFrame.scala b/src/main/scala/com/snowflake/snowpark/RelationalGroupedDataFrame.scala index 55a7cd73..13d64bbe 100644 --- a/src/main/scala/com/snowflake/snowpark/RelationalGroupedDataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/RelationalGroupedDataFrame.scala @@ -22,6 +22,8 @@ private[snowpark] object RelationalGroupedDataFrame { object GroupByType extends GroupType + object GroupByGroupingSetsType extends GroupType + object CubeType extends GroupType object RollupType extends GroupType @@ -61,7 +63,8 @@ class RelationalGroupedDataFrame private[snowpark] ( } ++ aggExprs).distinct.map(alias) groupType match { - case RelationalGroupedDataFrame.GroupByType => + case RelationalGroupedDataFrame.GroupByType | + RelationalGroupedDataFrame.GroupByGroupingSetsType => DataFrame(dataFrame.session, Aggregate(groupingExprs, aliasedAgg, dataFrame.plan)) case RelationalGroupedDataFrame.RollupType => DataFrame( @@ -125,8 +128,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.1.0 */ - def agg(expr: (Column, String), exprs: (Column, String)*): DataFrame = + def agg(expr: (Column, String), exprs: (Column, String)*): DataFrame = transformation("agg") { agg(expr +: exprs) + } /** * Returns a DataFrame with computed aggregates. The first element @@ -148,8 +152,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.2.0 */ - def agg(exprs: Seq[(Column, String)]): DataFrame = + def agg(exprs: Seq[(Column, String)]): DataFrame = transformation("agg") { toDF(exprs.map { case (col, expr) => strToExpr(expr)(col.expr) }) + } /** * Returns a DataFrame with aggregated computed according to the supplied @@ -166,7 +171,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.1.0 */ - def agg(expr: Column, exprs: Column*): DataFrame = agg(expr +: exprs) + def agg(expr: Column, exprs: Column*): DataFrame = transformation("agg") { + agg(expr +: exprs) + } /** * Returns a DataFrame with aggregated computed according to the supplied @@ -183,7 +190,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.2.0 */ - def agg[T: ClassTag](exprs: Seq[Column]): DataFrame = toDF(exprs.map(_.expr)) + def agg[T: ClassTag](exprs: Seq[Column]): DataFrame = transformation("agg") { + toDF(exprs.map(_.expr)) + } /** * Returns a DataFrame with aggregated computed according to the supplied @@ -193,7 +202,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.9.0 */ - def agg(exprs: Array[Column]): DataFrame = agg(exprs.toSeq) + def agg(exprs: Array[Column]): DataFrame = transformation("agg") { + agg(exprs.toSeq) + } /** * Returns a DataFrame with computed aggregates. The first element @@ -216,10 +227,11 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return a [[DataFrame]] * @since 0.1.0 */ - def agg(exprs: Map[Column, String]): DataFrame = + def agg(exprs: Map[Column, String]): DataFrame = transformation("agg") { toDF(exprs.map { case (col, expr) => strToExpr(expr)(col.expr) }.toSeq) + } /** * Return the average for the specified numeric columns. @@ -227,7 +239,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.4.0 * @return a [[DataFrame]] */ - def avg(cols: Column*): DataFrame = nonEmptyArgumentFunction("avg", cols) + def avg(cols: Column*): DataFrame = transformation("avg") { + nonEmptyArgumentFunction("avg", cols) + } /** * Return the average for the specified numeric columns. Alias of avg @@ -235,7 +249,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.4.0 * @return a [[DataFrame]] */ - def mean(cols: Column*): DataFrame = avg(cols: _*) + def mean(cols: Column*): DataFrame = transformation("mean") { + avg(cols: _*) + } /** * Return the sum for the specified numeric columns. @@ -243,7 +259,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.1.0 * @return a [[DataFrame]] */ - def sum(cols: Column*): DataFrame = nonEmptyArgumentFunction("sum", cols) + def sum(cols: Column*): DataFrame = transformation("sum") { + nonEmptyArgumentFunction("sum", cols) + } /** * Return the median for the specified numeric columns. @@ -251,7 +269,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.5.0 * @return A [[DataFrame]] */ - def median(cols: Column*): DataFrame = nonEmptyArgumentFunction("median", cols) + def median(cols: Column*): DataFrame = transformation("median") { + nonEmptyArgumentFunction("median", cols) + } /** * Return the min for the specified numeric columns. @@ -259,7 +279,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.1.0 * @return A [[DataFrame]] */ - def min(cols: Column*): DataFrame = nonEmptyArgumentFunction("min", cols) + def min(cols: Column*): DataFrame = transformation("min") { + nonEmptyArgumentFunction("min", cols) + } /** * Return the max for the specified numeric columns. @@ -267,7 +289,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.4.0 * @return A [[DataFrame]] */ - def max(cols: Column*): DataFrame = nonEmptyArgumentFunction("max", cols) + def max(cols: Column*): DataFrame = transformation("max") { + nonEmptyArgumentFunction("max", cols) + } /** * Returns non-deterministic values for the specified columns. @@ -275,7 +299,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.12.0 * @return A [[DataFrame]] */ - def any_value(cols: Column*): DataFrame = nonEmptyArgumentFunction("any_value", cols) + def any_value(cols: Column*): DataFrame = transformation("any_value") { + nonEmptyArgumentFunction("any_value", cols) + } /** * Return the number of rows for each group. @@ -283,7 +309,9 @@ class RelationalGroupedDataFrame private[snowpark] ( * @since 0.1.0 * @return A [[DataFrame]] */ - def count(): DataFrame = toDF(Seq(Alias(functions.builtin("count")(Literal(1)).expr, "count"))) + def count(): DataFrame = transformation("count") { + toDF(Seq(Alias(functions.builtin("count")(Literal(1)).expr, "count"))) + } /** * Computes the builtin aggregate 'aggName' over the specified columns. @@ -299,7 +327,7 @@ class RelationalGroupedDataFrame private[snowpark] ( * @return A [[DataFrame]] * */ - def builtin(aggName: String)(cols: Column*): DataFrame = { + def builtin(aggName: String)(cols: Column*): DataFrame = transformation("builtin") { toDF(cols.map(_.expr).map(expr => functions.builtin(aggName)(expr).expr)) } @@ -311,4 +339,10 @@ class RelationalGroupedDataFrame private[snowpark] ( } } + @inline private def transformation(funcName: String)(func: => DataFrame): DataFrame = { + val typeName = groupType.toString.head.toLower + groupType.toString.tail + val name = s"$typeName.$funcName" + DataFrame.buildMethodChain(dataFrame.methodChain, name)(func) + } + } diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index b89d8f2a..ca992166 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -7,7 +7,7 @@ import scala.reflect.ClassTag private[snowpark] object Updatable extends Logging { def apply(tableName: String, session: Session): Updatable = - new Updatable(tableName, session) + new Updatable(tableName, session, DataFrame.methodChainCache.value) private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) @@ -52,8 +52,12 @@ case class DeleteResult(rowsDeleted: Long) */ class Updatable private[snowpark] ( private[snowpark] val tableName: String, - override private[snowpark] val session: Session) - extends DataFrame(session, session.analyzer.resolve(UnresolvedRelation(tableName))) { + override private[snowpark] val session: Session, + override private[snowpark] val methodChain: Seq[String]) + extends DataFrame( + session, + session.analyzer.resolve(UnresolvedRelation(tableName)), + methodChain) { /** * Updates all rows in the Updatable with specified assignments and returns a [[UpdateResult]], @@ -329,7 +333,7 @@ class Updatable private[snowpark] ( * @group basic */ override def clone: Updatable = action("clone", 2) { - new Updatable(tableName, session) + new Updatable(tableName, session, Seq()) } /** @@ -353,12 +357,12 @@ class Updatable private[snowpark] ( @inline override protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = this.session.conn.isScalaAPI - OpenTelemetry.action("Updatable", funcName, isScala)(func) + OpenTelemetry.action("Updatable", funcName, methodChainString, isScala)(func) } @inline protected def action[T](funcName: String, javaOffset: Int)(func: => T): T = { val isScala: Boolean = this.session.conn.isScalaAPI - OpenTelemetry.action("Updatable", funcName, isScala, javaOffset)(func) + OpenTelemetry.action("Updatable", funcName, methodChainString, isScala, javaOffset)(func) } } @@ -496,11 +500,20 @@ class UpdatableAsyncActor private[snowpark] (updatable: Updatable) @inline override protected def action[T](funcName: String)(func: => T): T = { val isScala: Boolean = updatable.session.conn.isScalaAPI - OpenTelemetry.action("UpdatableAsyncActor", funcName, isScala)(func) + OpenTelemetry.action( + "UpdatableAsyncActor", + funcName, + updatable.methodChainString + ".async", + isScala)(func) } @inline protected def action[T](funcName: String, javaOffset: Int)(func: => T): T = { val isScala: Boolean = updatable.session.conn.isScalaAPI - OpenTelemetry.action("UpdatableAsyncActor", funcName, isScala, javaOffset)(func) + OpenTelemetry.action( + "UpdatableAsyncActor", + funcName, + updatable.methodChainString + ".async", + isScala, + javaOffset)(func) } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala b/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala index 34de4dd3..6d9d4248 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala @@ -1,6 +1,5 @@ package com.snowflake.snowpark.internal -import com.snowflake.snowpark.DataFrame import io.opentelemetry.api.GlobalOpenTelemetry import io.opentelemetry.api.trace.{Span, StatusCode} @@ -9,27 +8,30 @@ import scala.util.DynamicVariable object OpenTelemetry extends Logging { // only report the top function info in case of recursion. - val spanInfo = new DynamicVariable[Option[SpanInfo]](None) + private val actionInfo = new DynamicVariable[Option[ActionInfo]](None) // wrapper of all action functions - def action[T](className: String, funcName: String, isScala: Boolean, javaOffSet: Int = 0)( - func: => T): T = { + def action[T]( + className: String, + funcName: String, + methodChain: String, + isScala: Boolean, + javaOffSet: Int = 0)(func: => T): T = { try { - spanInfo.withValue[T](spanInfo.value match { + actionInfo.withValue[T](actionInfo.value match { // empty info means this is the entry of the recursion case None => val stacks = Thread.currentThread().getStackTrace - val methodChain = "" val index = if (isScala) 4 else 5 + javaOffSet val fileName = stacks(index).getFileName val lineNumber = stacks(index).getLineNumber - Some(SpanInfo(className, funcName, fileName, lineNumber, methodChain)) + Some(ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName")) // if value is not empty, this function call should be recursion. // do not issue new SpanInfo, use the info inherited from previous. case other => other }) { val result: T = func - OpenTelemetry.emit(spanInfo.value.get) + OpenTelemetry.emit(actionInfo.value.get) result } } catch { @@ -40,28 +42,15 @@ object OpenTelemetry extends Logging { } // class name format: snow.snowpark. // method chain: Dataframe.filter.join.select.collect - def emit( - className: String, - funcName: String, - fileName: String, - lineNumber: Int, - methodChain: String): Unit = - emit(className, funcName) { span => + def emit(spanInfo: ActionInfo): Unit = + emit(spanInfo.className, spanInfo.funcName) { span => { - span.setAttribute("code.filepath", fileName) - span.setAttribute("code.lineno", lineNumber) - span.setAttribute("method.chain", methodChain) + span.setAttribute("code.filepath", spanInfo.fileName) + span.setAttribute("code.lineno", spanInfo.lineNumber) + span.setAttribute("method.chain", spanInfo.methodChain) } } - def emit(spanInfo: SpanInfo): Unit = - emit( - spanInfo.className, - spanInfo.funcName, - spanInfo.fileName, - spanInfo.lineNumber, - spanInfo.methodChain) - def reportError(className: String, funcName: String, error: Throwable): Unit = emit(className, funcName) { span => { @@ -90,13 +79,9 @@ object OpenTelemetry extends Logging { } } - // todo: Snow-1480779 - def buildMethodChain(funcName: String, df: DataFrame): String = { - "" - } } -case class SpanInfo( +case class ActionInfo( className: String, funcName: String, fileName: String, diff --git a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetryEnabled.java b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetryEnabled.java index 2a37e739..bd68c8d0 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetryEnabled.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetryEnabled.java @@ -42,7 +42,7 @@ protected void checkSpan( assert Objects.equals( span.getAttributes().get(AttributeKey.longKey("code.lineno")), (long) lineNumber); assert Objects.equals( - span.getAttributes().get(AttributeKey.stringKey("code.chain")), methodChain); + span.getAttributes().get(AttributeKey.stringKey("method.chain")), methodChain); testSpanExporter.reset(); } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java index 997b194b..0acefee3 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaOpenTelemetrySuite.java @@ -14,39 +14,39 @@ public class JavaOpenTelemetrySuite extends JavaOpenTelemetryEnabled { public void cacheResult() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.cacheResult(); - checkSpan("snow.snowpark.DataFrame", "cacheResult", null); + checkSpan("snow.snowpark.DataFrame", "cacheResult"); } @Test public void count() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.count(); - checkSpan("snow.snowpark.DataFrame", "count", null); + checkSpan("snow.snowpark.DataFrame", "count"); } @Test public void collect() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.collect(); - checkSpan("snow.snowpark.DataFrame", "collect", null); + checkSpan("snow.snowpark.DataFrame", "collect"); } @Test public void toLocalIterator() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.toLocalIterator(); - checkSpan("snow.snowpark.DataFrame", "toLocalIterator", null); + checkSpan("snow.snowpark.DataFrame", "toLocalIterator"); } @Test public void show() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.show(); - checkSpan("snow.snowpark.DataFrame", "show", null); + checkSpan("snow.snowpark.DataFrame", "show"); df.show(1); - checkSpan("snow.snowpark.DataFrame", "show", null); + checkSpan("snow.snowpark.DataFrame", "show"); df.show(1, 100); - checkSpan("snow.snowpark.DataFrame", "show", null); + checkSpan("snow.snowpark.DataFrame", "show"); } @Test @@ -55,10 +55,10 @@ public void createOrReplaceView() { String name = randomName(); try { df.createOrReplaceView(name); - checkSpan("snow.snowpark.DataFrame", "createOrReplaceView", null); + checkSpan("snow.snowpark.DataFrame", "createOrReplaceView"); String[] names = {name}; df.createOrReplaceView(names); - checkSpan("snow.snowpark.DataFrame", "createOrReplaceView", null); + checkSpan("snow.snowpark.DataFrame", "createOrReplaceView"); } finally { dropView(name); } @@ -70,10 +70,10 @@ public void createOrReplaceTempView() { String name = randomName(); try { df.createOrReplaceTempView(name); - checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView", null); + checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView"); String[] names = {name}; df.createOrReplaceTempView(names); - checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView", null); + checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView"); } finally { dropView(name); } @@ -83,9 +83,9 @@ public void createOrReplaceTempView() { public void first() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.first(); - checkSpan("snow.snowpark.DataFrame", "first", null); + checkSpan("snow.snowpark.DataFrame", "first"); df.first(1); - checkSpan("snow.snowpark.DataFrame", "first", null); + checkSpan("snow.snowpark.DataFrame", "first"); } @Test @@ -93,40 +93,44 @@ public void randomSplit() { DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); double[] weight = {0.5, 0.5}; df.randomSplit(weight); - checkSpan("snow.snowpark.DataFrame", "randomSplit", null); + checkSpan("snow.snowpark.DataFrame", "randomSplit"); } @Test public void DataFrameAsyncActor() { + String className = "snow.snowpark.DataFrameAsyncActor"; DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.async().collect(); - checkSpan("snow.snowpark.DataFrameAsyncActor", "collect", null); + checkSpan(className, "collect", "DataFrame.async.collect"); df.async().toLocalIterator(); - checkSpan("snow.snowpark.DataFrameAsyncActor", "toLocalIterator", null); + checkSpan(className, "toLocalIterator", "DataFrame.async.toLocalIterator"); df.async().count(); - checkSpan("snow.snowpark.DataFrameAsyncActor", "count", null); + checkSpan(className, "count", "DataFrame.async.count"); } @Test public void dataFrameStatFunctionsCorr() { DataFrame df = getSession().sql("select * from values(0.1, 0.5) as t(a, b)"); + String className = "snow.snowpark.DataFrameStatFunctions"; df.stat().corr("a", "b"); - checkSpan("snow.snowpark.DataFrameStatFunctions", "corr", null); + checkSpan(className, "corr", "DataFrame.stat.corr"); } @Test public void dataFrameStatFunctionsCov() { DataFrame df = getSession().sql("select * from values(0.1, 0.5) as t(a, b)"); + String className = "snow.snowpark.DataFrameStatFunctions"; df.stat().cov("a", "b"); - checkSpan("snow.snowpark.DataFrameStatFunctions", "cov", null); + checkSpan(className, "cov", "DataFrame.stat.cov"); } @Test public void dataFrameStatFunctionsApproxQuantile() { DataFrame df = getSession().sql("select * from values(1), (2) as t(a)"); double[] values = {0, 0.1, 0.4, 0.6, 1}; + String className = "snow.snowpark.DataFrameStatFunctions"; df.stat().approxQuantile("a", values); - checkSpan("snow.snowpark.DataFrameStatFunctions", "approxQuantile", null); + checkSpan(className, "approxQuantile", "DataFrame.stat.approxQuantile"); } @Test @@ -134,15 +138,17 @@ public void dataFrameStatFunctionsApproxQuantile2() { DataFrame df = getSession().sql("select * from values(0.1, 0.5) as t(a, b)"); double[] values = {0, 0.1, 0.6}; String[] cols = {"a", "b"}; + String className = "snow.snowpark.DataFrameStatFunctions"; df.stat().approxQuantile(cols, values); - checkSpan("snow.snowpark.DataFrameStatFunctions", "approxQuantile", null); + checkSpan(className, "approxQuantile", "DataFrame.stat.approxQuantile"); } @Test public void dataFrameStatFunctionsCrosstab() { DataFrame df = getSession().sql("select * from values(0.1, 0.5) as t(a, b)"); + String className = "snow.snowpark.DataFrameStatFunctions"; df.stat().crosstab("a", "b"); - checkSpan("snow.snowpark.DataFrameStatFunctions", "crosstab", null); + checkSpan(className, "crosstab", "DataFrame.stat.crosstab"); } @Test @@ -152,8 +158,9 @@ public void dataFrameWriterCsv() { createTempStage(name); testSpanExporter.reset(); DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); + String className = "snow.snowpark.DataFrameWriter"; df.write().csv("@" + name + "/csv"); - checkSpan("snow.snowpark.DataFrameWriter", "csv", null); + checkSpan(className, "csv", "DataFrame.writer.csv"); } finally { dropStage(name); } @@ -165,12 +172,13 @@ public void dataFrameWriterJson() { try { createTempStage(name); testSpanExporter.reset(); + String className = "snow.snowpark.DataFrameWriter"; DataFrame df = getSession().sql("select * from values(1, 2) as t(a, b)"); DataFrame df2 = df.select( com.snowflake.snowpark_java.Functions.array_construct(df.col("a"), df.col("b"))); df2.write().json("@" + name + "/json"); - checkSpan("snow.snowpark.DataFrameWriter", "json", null); + checkSpan(className, "json", "DataFrame.select.writer.json"); } finally { dropStage(name); } @@ -182,9 +190,10 @@ public void dataFrameWriterParquet() { try { createTempStage(name); testSpanExporter.reset(); + String className = "snow.snowpark.DataFrameWriter"; DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.write().parquet("@" + name + "/parquet"); - checkSpan("snow.snowpark.DataFrameWriter", "parquet", null); + checkSpan(className, "parquet", "DataFrame.writer.parquet"); } finally { dropStage(name); } @@ -194,9 +203,10 @@ public void dataFrameWriterParquet() { public void dataFrameWriterSaveAsTable() { String name = randomName(); DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); + String className = "snow.snowpark.DataFrameWriter"; try { df.write().saveAsTable(name); - checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", null); + checkSpan(className, "saveAsTable", "DataFrame.writer.saveAsTable"); } finally { dropTable(name); } @@ -204,7 +214,7 @@ public void dataFrameWriterSaveAsTable() { String[] names = {name}; testSpanExporter.reset(); df.write().saveAsTable(names); - checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", null); + checkSpan(className, "saveAsTable", "DataFrame.writer.saveAsTable"); } finally { dropTable(name); } @@ -214,9 +224,10 @@ public void dataFrameWriterSaveAsTable() { public void dataFrameWriterAsyncActorSaveAsTable() { String name = randomName(); DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); + String className = "snow.snowpark.DataFrameWriterAsyncActor"; try { df.write().async().saveAsTable(name).getResult(); - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "saveAsTable", null); + checkSpan(className, "saveAsTable", "DataFrame.writer.async.saveAsTable"); } finally { dropTable(name); } @@ -224,7 +235,7 @@ public void dataFrameWriterAsyncActorSaveAsTable() { String[] names = {name}; testSpanExporter.reset(); df.write().async().saveAsTable(names).getResult(); - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "saveAsTable", null); + checkSpan(className, "saveAsTable", "DataFrame.writer.async.saveAsTable"); } finally { dropTable(name); } @@ -236,9 +247,10 @@ public void dataFrameWriterAsyncActorCsv() { try { createTempStage(name); testSpanExporter.reset(); + String className = "snow.snowpark.DataFrameWriterAsyncActor"; DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.write().async().csv("@" + name + "/csv").getResult(); - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "csv", null); + checkSpan(className, "csv", "DataFrame.writer.async.csv"); } finally { dropStage(name); } @@ -250,12 +262,13 @@ public void dataFrameWriterAsyncActorJson() { try { createTempStage(name); testSpanExporter.reset(); + String className = "snow.snowpark.DataFrameWriterAsyncActor"; DataFrame df = getSession().sql("select * from values(1, 2) as t(a, b)"); DataFrame df2 = df.select( com.snowflake.snowpark_java.Functions.array_construct(df.col("a"), df.col("b"))); df2.write().async().json("@" + name + "/json").getResult(); - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "json", null); + checkSpan(className, "json", "DataFrame.select.writer.async.json"); } finally { dropStage(name); } @@ -267,9 +280,10 @@ public void dataFrameWriterAsyncActorParquet() { try { createTempStage(name); testSpanExporter.reset(); + String className = "snow.snowpark.DataFrameWriterAsyncActor"; DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(num)"); df.write().async().parquet("@" + name + "/parquet").getResult(); - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "parquet", null); + checkSpan(className, "parquet", "DataFrame.writer.async.parquet"); } finally { dropStage(name); } @@ -295,14 +309,14 @@ public void copyableDataFrame() { .schema(schema) .csv("@" + stageName + "/" + TestFiles.testFileCsv) .copyInto(tableName); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto"); Column[] transformation = {Functions.col("$1"), Functions.col("$2"), Functions.col("$3")}; getSession() .read() .schema(schema) .csv("@" + stageName + "/" + TestFiles.testFileCsv) .copyInto(tableName, transformation); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto"); Map options = new HashMap<>(); options.put("skip_header", 1); options.put("FORCE", "true"); @@ -311,16 +325,16 @@ public void copyableDataFrame() { .schema(schema) .csv("@" + stageName + "/" + TestFiles.testFileCsv) .copyInto(tableName, transformation, options); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto"); String[] columns = {"a", "b", "c"}; getSession() .read() .schema(schema) .csv("@" + stageName + "/" + TestFiles.testFileCsv) .copyInto(tableName, columns, transformation, options); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto"); getSession().read().schema(schema).csv("@" + stageName + "/" + TestFiles.testFileCsv).clone(); - checkSpan(className, "clone", null); + checkSpan(className, "clone"); } finally { dropTable(tableName); dropStage(stageName); @@ -349,7 +363,7 @@ public void copyableDataFrameAsyncActor() { .csv("@" + stageName + "/" + TestFiles.testFileCsv) .async(); df1.copyInto(tableName).getResult(); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto", "DataFrame.async.copyInto"); Column[] transformation = {Functions.col("$1"), Functions.col("$2"), Functions.col("$3")}; CopyableDataFrameAsyncActor df2 = getSession() @@ -358,7 +372,7 @@ public void copyableDataFrameAsyncActor() { .csv("@" + stageName + "/" + TestFiles.testFileCsv) .async(); df2.copyInto(tableName, transformation).getResult(); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto", "DataFrame.async.copyInto"); Map options = new HashMap<>(); options.put("skip_header", 1); options.put("FORCE", "true"); @@ -369,7 +383,7 @@ public void copyableDataFrameAsyncActor() { .csv("@" + stageName + "/" + TestFiles.testFileCsv) .async(); df3.copyInto(tableName, transformation, options).getResult(); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto", "DataFrame.async.copyInto"); String[] columns = {"a", "b", "c"}; CopyableDataFrameAsyncActor df4 = getSession() @@ -378,7 +392,7 @@ public void copyableDataFrameAsyncActor() { .csv("@" + stageName + "/" + TestFiles.testFileCsv) .async(); df4.copyInto(tableName, columns, transformation, options).getResult(); - checkSpan(className, "copyInto", null); + checkSpan(className, "copyInto", "DataFrame.async.copyInto"); } finally { dropTable(tableName); dropStage(stageName); @@ -404,31 +418,31 @@ public void updatable() { Map map1 = new HashMap<>(); map1.put("col1", Functions.lit(3)); getSession().table(tableName).update(map); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession().table(tableName).updateColumn(map1); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession() .table(tableName) .update(map, Functions.col("col3").equal_to(Functions.lit(true))); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession() .table(tableName) .updateColumn(map1, Functions.col("col3").equal_to(Functions.lit(true))); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession().table(tableName).update(map, Functions.col("col1").equal_to(df.col("a")), df); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession() .table(tableName) .updateColumn(map1, Functions.col("col1").equal_to(df.col("a")), df); - checkSpan(className, "update", null); + checkSpan(className, "update"); getSession().table(tableName).delete(); - checkSpan(className, "delete", null); + checkSpan(className, "delete"); getSession().table(tableName).delete(Functions.col("col1").equal_to(Functions.lit(1))); - checkSpan(className, "delete", null); + checkSpan(className, "delete"); getSession().table(tableName).delete(Functions.col("col1").equal_to(df.col("a")), df); - checkSpan(className, "delete", null); + checkSpan(className, "delete"); getSession().table(tableName).clone(); - checkSpan(className, "clone", null); + checkSpan(className, "clone"); } finally { dropTable(tableName); } @@ -454,23 +468,23 @@ public void updatableAsyncActor() { map1.put("col1", Functions.lit(3)); UpdatableAsyncActor df1 = getSession().table(tableName).async(); df1.update(map).getResult(); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.updateColumn(map1).getResult(); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.update(map, Functions.col("col3").equal_to(Functions.lit(true))).getResult(); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.updateColumn(map1, Functions.col("col3").equal_to(Functions.lit(true))).getResult(); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.update(map, Functions.col("col1").equal_to(df.col("a")), df); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.updateColumn(map1, Functions.col("col1").equal_to(df.col("a")), df); - checkSpan(className, "update", null); + checkSpan(className, "update", "DataFrame.async.update"); df1.delete().getResult(); - checkSpan(className, "delete", null); + checkSpan(className, "delete", "DataFrame.async.delete"); df1.delete(Functions.col("col1").equal_to(Functions.lit(1))).getResult(); - checkSpan(className, "delete", null); + checkSpan(className, "delete", "DataFrame.async.delete"); df1.delete(Functions.col("col1").equal_to(df.col("a")), df).getResult(); - checkSpan(className, "delete", null); + checkSpan(className, "delete", "DataFrame.async.delete"); } finally { dropTable(tableName); } @@ -491,13 +505,14 @@ public void mergeBuilder() { testSpanExporter.reset(); Map assignments = new HashMap<>(); assignments.put(Functions.col("col1"), df.col("b")); + String className = "snow.snowpark.MergeBuilder"; getSession() .table(tableName) .merge(df, Functions.col("col1").equal_to(df.col("a"))) .whenMatched() .update(assignments) .collect(); - checkSpan("snow.snowpark.MergeBuilder", "collect", null); + checkSpan(className, "collect", "DataFrame.merge.collect"); } finally { dropTable(tableName); } @@ -525,8 +540,9 @@ public void mergeBuilderAsyncActor() { .whenMatched() .update(assignments) .async(); + String className = "snow.snowpark.MergeBuilderAsyncActor"; builderAsyncActor.collect().getResult(); - checkSpan("snow.snowpark.MergeBuilderAsyncActor", "collect", null); + checkSpan(className, "collect", "DataFrame.merge.async.collect"); } finally { dropTable(tableName); } @@ -538,4 +554,15 @@ private void checkSpan(String className, String funcName, String methodChain) { checkSpan( className, funcName, "JavaOpenTelemetrySuite.java", file.getLineNumber() - 1, methodChain); } + + private void checkSpan(String className, String funcName) { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + StackTraceElement file = stack[2]; + checkSpan( + className, + funcName, + "JavaOpenTelemetrySuite.java", + file.getLineNumber() - 1, + "DataFrame." + funcName); + } } diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala index c0f1ed3e..dfb5f0f0 100644 --- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala +++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala @@ -91,6 +91,8 @@ class JavaScalaAPISuite extends FunSuite { class2Only = Set( // package private functions "getUnaliased", + "methodChainCache", + "buildMethodChain", "generatePrefix") ++ scalaCaseClassFunctions)) } diff --git a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala index a28d167a..c19353d1 100644 --- a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala @@ -649,7 +649,7 @@ class APIInternalSuite extends TestData { None, supportAsyncMode = true) - new DataFrame(session, session.analyzer.resolve(plan)) + new DataFrame(session, session.analyzer.resolve(plan), Seq()) } // This test DataFrame can't be defined in TestData, @@ -673,7 +673,7 @@ class APIInternalSuite extends TestData { None, supportAsyncMode = true) - new DataFrame(session, session.analyzer.resolve(plan2)) + new DataFrame(session, session.analyzer.resolve(plan2), Seq()) } test("test col(DataFrame) with multiple queries") { @@ -730,7 +730,7 @@ class APIInternalSuite extends TestData { session, None, supportAsyncMode = true) - val df = new DataFrame(session, plan) + val df = new DataFrame(session, plan, Seq()) df.explain() val explainString1 = df.explainString diff --git a/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala b/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala new file mode 100644 index 00000000..480f9c1a --- /dev/null +++ b/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala @@ -0,0 +1,260 @@ +package com.snowflake.snowpark + +import com.snowflake.snowpark.functions._ +import com.snowflake.snowpark.types.{StringType, StructField, StructType} +import com.snowflake.snowpark.udtf.UDTF1 + +import scala.collection.mutable + +class MethodChainSuite extends TestData { + private val df1 = session.sql("select * from values(1,2,3) as T(a, b, c)") + + private def checkMethodChain(df: DataFrame, methodNames: String*): Unit = { + val methodChain = df.methodChain + assert(methodChain == methodNames) + } + + test("new dataframe") { + checkMethodChain(df1) + } + + test("clone") { + checkMethodChain(df1.clone, "clone") + } + + test("toDF") { + checkMethodChain( + df1 + .toDF("a1", "b1", "c1") + .toDF(Seq("a2", "b2", "c2")) + .toDF(Array("a3", "b3", "c3")), + "toDF", + "toDF", + "toDF") + } + + test("sort") { + checkMethodChain( + df1 + .sort(col("a")) + .sort(Seq(col("a"))) + .sort(Array(col("a"))), + "sort", + "sort", + "sort") + } + + test("alias") { + checkMethodChain(df1.alias("a"), "alias") + } + + test("select") { + checkMethodChain( + df1 + .select(col("a")) + .select(Seq(col("a"))) + .select(Array(col("a"))) + .select("a") + .select(Seq("a")) + .select(Array("a")), + "select", + "select", + "select", + "select", + "select", + "select") + } + + test("drop") { + checkMethodChain(df1.drop("a").drop(Seq("b")), "drop", "drop") + checkMethodChain(df1.drop(Array("a")).drop(col("b")), "drop", "drop") + checkMethodChain(df1.drop(Seq(col("a"))).drop(Array(col("b"))), "drop", "drop") + } + + test("filter") { + checkMethodChain(df1.filter(col("a") < 0), "filter") + } + + test("where") { + checkMethodChain(df1.where(col("a") < 0), "where") + } + + test("agg") { + checkMethodChain(df1.agg("a" -> "max"), "agg") + checkMethodChain(df1.agg(Seq("a" -> "max")), "agg") + checkMethodChain(df1.agg(max(col("a"))), "agg") + checkMethodChain(df1.agg(Seq(max(col("a")))), "agg") + checkMethodChain(df1.agg(Array(max(col("a")))), "agg") + } + + test("rollup.agg") { + checkMethodChain(df1.rollup(col("a")).agg(col("a") -> "max"), "rollup.agg") + checkMethodChain(df1.rollup(Seq(col("a"))).agg(Seq(col("a") -> "max")), "rollup.agg") + checkMethodChain(df1.rollup(Array(col("a"))).agg(max(col("a"))), "rollup.agg") + checkMethodChain(df1.rollup("a").agg(Seq(max(col("a")))), "rollup.agg") + checkMethodChain(df1.rollup(Seq("a")).agg(Array(max(col("a")))), "rollup.agg") + checkMethodChain(df1.rollup(Array("a")).agg(Map(col("a") -> "max")), "rollup.agg") + } + + test("groupBy") { + checkMethodChain(df1.groupBy(col("a")).avg(col("a")), "groupBy.avg") + checkMethodChain(df1.groupBy().mean(col("a")), "groupBy.mean") + checkMethodChain(df1.groupBy(Seq(col("a"))).sum(col("a")), "groupBy.sum") + checkMethodChain(df1.groupBy(Array(col("a"))).median(col("a")), "groupBy.median") + checkMethodChain(df1.groupBy("a").min(col("a")), "groupBy.min") + checkMethodChain(df1.groupBy(Seq("a")).max(col("a")), "groupBy.max") + checkMethodChain(df1.groupBy(Array("a")).any_value(col("a")), "groupBy.any_value") + } + + test("groupByGroupingSets") { + checkMethodChain( + df1.groupByGroupingSets(GroupingSets(Set(col("a")))).count(), + "groupByGroupingSets.count") + checkMethodChain( + df1 + .groupByGroupingSets(Seq(GroupingSets(Set(col("a"))))) + .builtin("count")(col("a")), + "groupByGroupingSets.builtin") + } + + test("cube") { + checkMethodChain(df1.cube(col("a")).count(), "cube.count") + checkMethodChain(df1.cube(Seq(col("a"))).count(), "cube.count") + checkMethodChain(df1.cube(Array(col("a"))).count(), "cube.count") + checkMethodChain(df1.cube("a").count(), "cube.count") + checkMethodChain(df1.cube(Seq("a")).count(), "cube.count") + } + + test("distinct") { + checkMethodChain(df1.distinct(), "distinct") + } + + test("dropDuplicates") { + checkMethodChain(df1.dropDuplicates(), "dropDuplicates") + } + + test("pivot") { + checkMethodChain(df1.pivot(col("a"), Seq(1, 2, 3)).count(), "pivot.count") + checkMethodChain(df1.pivot("a", Seq(1, 2, 3)).count(), "pivot.count") + } + + test("limit") { + checkMethodChain(df1.limit(1), "limit") + } + + test("union") { + checkMethodChain(df1.union(df1.clone), "union") + checkMethodChain(df1.unionAll(df1.clone), "unionAll") + checkMethodChain(df1.unionByName(df1.clone), "unionByName") + checkMethodChain(df1.unionAllByName(df1.clone), "unionAllByName") + } + + test("intersect") { + checkMethodChain(df1.intersect(df1.clone), "intersect") + } + + test("except") { + checkMethodChain(df1.except(df1.clone), "except") + } + + test("join") { + val df2 = df1.clone + checkMethodChain(df1.join(df2), "join") + checkMethodChain(df1.join(df2, "a"), "join") + checkMethodChain(df1.join(df2, Seq("a", "b")), "join") + checkMethodChain(df1.join(df2, Seq("a", "b"), "inner"), "join") + checkMethodChain(df1.join(df2, df1("a") === df2("a")), "join") + checkMethodChain(df1.join(df2, df1("a") === df2("a"), "inner"), "join") + } + + test("join 2") { + import com.snowflake.snowpark.tableFunctions._ + val df2 = session.sql("select * from values('1,2,3') as T(a)") + checkMethodChain(df2.join(split_to_table, df2("a"), lit(",")), "join") + checkMethodChain(df2.join(split_to_table, Seq(df2("a"), lit(","))), "join") + checkMethodChain(df2.join(split_to_table, Seq(df2("a"), lit(","))), "join") + + val TableFunc1 = new UDTF1[String] { + + private val freq = new mutable.HashMap[String, Int] + + override def process(colValue: String): Iterable[Row] = { + val curValue = freq.getOrElse(colValue, 0) + freq.put(colValue, curValue + 1) + mutable.Iterable.empty + } + + override def outputSchema(): StructType = + StructType(StructField("FREQUENCIES", StringType)) + + override def endPartition(): Iterable[Row] = { + Seq(Row(freq.toString())) + } + } + import session.implicits._ + val df = Seq(("a", "b"), ("a", "c"), ("a", "b"), ("d", "e")).toDF("a", "b") + val tf = session.udtf.registerTemporary(TableFunc1) + checkMethodChain( + df.join(tf, Seq(df("b")), Seq(df("a")), Seq(df("b"))), + "select", + "toDF", + "join") + checkMethodChain( + df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq(df("b"))), + "select", + "toDF", + "join") + checkMethodChain( + df.join(tf(Map("arg1" -> df("b"))), Seq(df("a")), Seq(df("b"))), + "select", + "toDF", + "join") + + val df3 = session.sql("select * from values('[1,2,3]') as T(a)") + checkMethodChain(df3.join(flatten, Map("input" -> parse_json(df("a")))), "join") + + checkMethodChain(df3.join(flatten(parse_json(df("a")))), "join") + } + + test("join 3") { + checkMethodChain(df1.crossJoin(df1.clone), "crossJoin") + checkMethodChain(df1.naturalJoin(df1.clone), "naturalJoin") + checkMethodChain(df1.naturalJoin(df1.clone, "left"), "naturalJoin") + } + + test("withColumn") { + checkMethodChain(df1.withColumn("a1", lit(1)), "withColumn") + checkMethodChain(df1.withColumns(Seq("a1"), Seq(lit(1))), "withColumns") + } + + test("rename") { + checkMethodChain(df1.rename("a1", col("a")), "rename") + } + + test("sample") { + checkMethodChain(df1.sample(1), "sample") + checkMethodChain(df1.sample(0.5), "sample") + } + + test("na") { + checkMethodChain(double3.na.drop(1, Seq("a")), "na", "drop") + checkMethodChain( + nullData3.na.fill(Map("flo" -> 12.3, "int" -> 11, "boo" -> false, "str" -> "f")), + "na", + "fill") + checkMethodChain(nullData3.na.replace("flo", Map(2 -> 300, 1 -> 200)), "na", "replace") + } + + test("stat") { + checkMethodChain(df1.stat.sampleBy(col("a"), Map(1 -> 0.0, 2 -> 1.0)), "stat", "sampleBy") + checkMethodChain(df1.stat.sampleBy("a", Map(1 -> 0.0, 2 -> 1.0)), "stat", "sampleBy") + } + + test("flatten") { + val table1 = session.sql("select parse_json(value) as value from values('[1,2]') as T(value)") + checkMethodChain(table1.flatten(table1("value")), "flatten") + checkMethodChain( + table1.flatten(table1("value"), "", outer = false, recursive = false, "both"), + "flatten") + } +} diff --git a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala index 81c2161e..4e8f6d58 100644 --- a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.{MergeResult, OpenTelemetryEnabled, SaveMode, UpdateResult} -import com.snowflake.snowpark.internal.OpenTelemetry +import com.snowflake.snowpark.internal.{OpenTelemetry, ActionInfo} import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.types.{DoubleType, IntegerType, StringType, StructField, StructType} @@ -10,48 +10,48 @@ import java.util class OpenTelemetrySuite extends OpenTelemetryEnabled { test("line number - collect") { session.sql("select 1").collect() - checkSpan("snow.snowpark.DataFrame", "collect", "") + checkSpan("snow.snowpark.DataFrame", "collect") } test("line number - randomSplit") { session.sql("select * from values(1),(2),(3) as t(num)").randomSplit(Array(0.5, 0.5)) - checkSpan("snow.snowpark.DataFrame", "randomSplit", "") + checkSpan("snow.snowpark.DataFrame", "randomSplit") } test("line number - first") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.first() - checkSpan("snow.snowpark.DataFrame", "first", "") + checkSpan("snow.snowpark.DataFrame", "first") df.first(2) - checkSpan("snow.snowpark.DataFrame", "first", "") + checkSpan("snow.snowpark.DataFrame", "first") } test("line number - cacheResult") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.cacheResult() - checkSpan("snow.snowpark.DataFrame", "cacheResult", "") + checkSpan("snow.snowpark.DataFrame", "cacheResult") } test("line number - toLocalIterator") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.toLocalIterator - checkSpan("snow.snowpark.DataFrame", "toLocalIterator", "") + checkSpan("snow.snowpark.DataFrame", "toLocalIterator") } test("line number - count") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.count() - checkSpan("snow.snowpark.DataFrame", "count", "") + checkSpan("snow.snowpark.DataFrame", "count") } test("line number - show") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.show() - checkSpan("snow.snowpark.DataFrame", "show", "") + checkSpan("snow.snowpark.DataFrame", "show") df.show(1) - checkSpan("snow.snowpark.DataFrame", "show", "") + checkSpan("snow.snowpark.DataFrame", "show") df.show(1, 10) - checkSpan("snow.snowpark.DataFrame", "show", "") + checkSpan("snow.snowpark.DataFrame", "show") } test("line number - createOrReplaceView") { @@ -59,13 +59,13 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val name = randomName() try { df.createOrReplaceView(name) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceView") } finally { dropView(name) } try { df.createOrReplaceView(Seq(name)) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceView") } finally { dropView(name) } @@ -74,7 +74,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val list: java.util.List[String] = new util.ArrayList[String](1) list.add(name) df.createOrReplaceView(list) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceView") } finally { dropView(name) } @@ -85,13 +85,13 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val name = randomName() try { df.createOrReplaceTempView(name) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView") } finally { dropView(name) } try { df.createOrReplaceTempView(Seq(name)) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView") } finally { dropView(name) } @@ -100,7 +100,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val list: java.util.List[String] = new util.ArrayList[String](1) list.add(name) df.createOrReplaceTempView(list) - checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView", "") + checkSpan("snow.snowpark.DataFrame", "createOrReplaceTempView") } finally { dropView(name) } @@ -109,60 +109,64 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { test("line number - HasCachedResult") { val df = session.sql("select * from values(1),(2),(3) as t(num)") val cached = df.cacheResult() - checkSpan("snow.snowpark.DataFrame", "cacheResult", "") + checkSpan("snow.snowpark.DataFrame", "cacheResult") cached.cacheResult() - checkSpan("snow.snowpark.DataFrame", "cacheResult", "") + checkSpan("snow.snowpark.DataFrame", "cacheResult") } test("line number - DataFrameAsyncActor") { + val className = "snow.snowpark.DataFrameAsyncActor" val df = session.sql("select * from values(1),(2),(3) as t(num)") df.async.count() - checkSpan("snow.snowpark.DataFrameAsyncActor", "count", "") + checkSpan(className, "count", "DataFrame.async.count") df.async.collect() - checkSpan("snow.snowpark.DataFrameAsyncActor", "collect", "") + checkSpan(className, "collect", "DataFrame.async.collect") df.async.toLocalIterator() - checkSpan("snow.snowpark.DataFrameAsyncActor", "toLocalIterator", "") + checkSpan(className, "toLocalIterator", "DataFrame.async.toLocalIterator") } test("line number - DataFrameStatFunctions - corr") { import session.implicits._ val df = Seq((0.1, 0.5), (0.2, 0.6), (0.3, 0.7)).toDF("a", "b") df.stat.corr("a", "b") - checkSpan("snow.snowpark.DataFrameStatFunctions", "corr", "") + checkSpan("snow.snowpark.DataFrameStatFunctions", "corr", "DataFrame.select.toDF.stat.corr") } test("line number - DataFrameStatFunctions - cov") { import session.implicits._ val df = Seq((0.1, 0.5), (0.2, 0.6), (0.3, 0.7)).toDF("a", "b") df.stat.cov("a", "b") - checkSpan("snow.snowpark.DataFrameStatFunctions", "cov", "") + checkSpan("snow.snowpark.DataFrameStatFunctions", "cov", "DataFrame.select.toDF.stat.cov") } test("line number - DataFrameStatFunctions - approxQuantile") { import session.implicits._ + val className = "snow.snowpark.DataFrameStatFunctions" val df = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 0).toDF("a") df.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)) - checkSpan("snow.snowpark.DataFrameStatFunctions", "approxQuantile", "") + checkSpan(className, "approxQuantile", "DataFrame.select.toDF.stat.approxQuantile") } test("line number - DataFrameStatFunctions - approxQuantile 2") { import session.implicits._ + val className = "snow.snowpark.DataFrameStatFunctions" val df = Seq((0.1, 0.5), (0.2, 0.6), (0.3, 0.7)).toDF("a", "b") df.stat.approxQuantile(Array("a", "b"), Array(0, 0.1, 0.6)) - checkSpan("snow.snowpark.DataFrameStatFunctions", "approxQuantile", "") + checkSpan(className, "approxQuantile", "DataFrame.select.toDF.stat.approxQuantile") } test("line number - DataFrameStatFunctions - crosstab") { import session.implicits._ + val className = "snow.snowpark.DataFrameStatFunctions" val df = Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3)).toDF("key", "value") df.stat.crosstab("key", "value") - checkSpan("snow.snowpark.DataFrameStatFunctions", "crosstab", "") + checkSpan(className, "crosstab", "DataFrame.select.toDF.stat.crosstab") } test("line number - DataFrameWriter - csv") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.write.csv(s"@$stageName1/csv1") - checkSpan("snow.snowpark.DataFrameWriter", "csv", "") + checkSpan("snow.snowpark.DataFrameWriter", "csv", "DataFrame.writer.csv") } test("line number - DataFrameWriter - json") { @@ -170,13 +174,13 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val df = Seq((1, 1.1, "a"), (2, 2.2, "b")).toDF("a", "b", "c") val df2 = df.select(array_construct(df.schema.names.map(df(_)): _*)) df2.write.option("compression", "none").json(s"@$stageName1/json1") - checkSpan("snow.snowpark.DataFrameWriter", "json", "") + checkSpan("snow.snowpark.DataFrameWriter", "json", "DataFrame.select.toDF.select.writer.json") } test("line number - DataFrameWriter - parquet") { val df = session.sql("select * from values(1),(2),(3) as t(num)") df.write.parquet(s"@$stageName1/parquet1") - checkSpan("snow.snowpark.DataFrameWriter", "parquet", "") + checkSpan("snow.snowpark.DataFrameWriter", "parquet", "DataFrame.writer.parquet") } test("line number - DataFrameWriter - saveAsTable") { @@ -184,13 +188,13 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val tableName = randomName() try { df.write.saveAsTable(tableName) - checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "") + checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "DataFrame.writer.saveAsTable") } finally { dropTable(tableName) } try { df.write.saveAsTable(Seq(tableName)) - checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "") + checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "DataFrame.writer.saveAsTable") } finally { dropTable(tableName) } @@ -198,7 +202,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val list = new util.ArrayList[String](1) list.add(tableName) df.write.saveAsTable(tableName) - checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "") + checkSpan("snow.snowpark.DataFrameWriter", "saveAsTable", "DataFrame.writer.saveAsTable") } finally { dropTable(tableName) } @@ -206,16 +210,17 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { test("line number - DataFrameWriterAsyncActor - saveAsTable") { val df = session.sql("select * from values(1),(2),(3) as t(num)") + val className = "snow.snowpark.DataFrameWriterAsyncActor" val tableName = randomName() try { df.write.async.saveAsTable(tableName).getResult() - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "saveAsTable", "") + checkSpan(className, "saveAsTable", "DataFrame.writer.async.saveAsTable") } finally { dropTable(tableName) } try { df.write.async.saveAsTable(Seq(tableName)).getResult() - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "saveAsTable", "") + checkSpan(className, "saveAsTable", "DataFrame.writer.async.saveAsTable") } finally { dropTable(tableName) } @@ -223,30 +228,33 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val list = new util.ArrayList[String](1) list.add(tableName) df.write.async.saveAsTable(tableName).getResult() - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "saveAsTable", "") + checkSpan(className, "saveAsTable", "DataFrame.writer.async.saveAsTable") } finally { dropTable(tableName) } } test("line number - DataFrameWriterAsyncActor - csv") { + val className = "snow.snowpark.DataFrameWriterAsyncActor" val df = session.sql("select * from values(1),(2),(3) as t(num)") df.write.async.csv(s"@$stageName1/csv2").getResult() - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "csv", "") + checkSpan(className, "csv", "DataFrame.writer.async.csv") } test("line number - DataFrameWriterAsyncActor - json") { import session.implicits._ + val className = "snow.snowpark.DataFrameWriterAsyncActor" val df = Seq((1, 1.1, "a"), (2, 2.2, "b")).toDF("a", "b", "c") val df2 = df.select(array_construct(df.schema.names.map(df(_)): _*)) df2.write.option("compression", "none").async.json(s"@$stageName1/json2") - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "json", "") + checkSpan(className, "json", "DataFrame.select.toDF.select.writer.async.json") } test("line number - DataFrameWriterAsyncActor - parquet") { + val className = "snow.snowpark.DataFrameWriterAsyncActor" val df = session.sql("select * from values(1),(2),(3) as t(num)") df.write.async.parquet(s"@$stageName1/parquet2") - checkSpan("snow.snowpark.DataFrameWriterAsyncActor", "parquet", "") + checkSpan(className, "parquet", "DataFrame.writer.async.parquet") } test("line number - CopyableDataFrame") { @@ -265,15 +273,15 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { testSpanExporter.reset() val df = session.read.schema(userSchema).csv(testFileOnStage) df.copyInto(tableName) - checkSpan("snow.snowpark.CopyableDataFrame", "copyInto", "") + checkSpan("snow.snowpark.CopyableDataFrame", "copyInto") df.copyInto(tableName, Seq(col("$1"), col("$2"), col("$3"))) - checkSpan("snow.snowpark.CopyableDataFrame", "copyInto", "") + checkSpan("snow.snowpark.CopyableDataFrame", "copyInto") df.copyInto(tableName, Seq(col("$1"), col("$2"), col("$3")), Map("FORCE" -> "TRUE")) - checkSpan("snow.snowpark.CopyableDataFrame", "copyInto", "") + checkSpan("snow.snowpark.CopyableDataFrame", "copyInto") df.copyInto(tableName, Seq("a", "b", "c"), Seq(col("$1"), col("$2"), col("$3")), Map.empty) - checkSpan("snow.snowpark.CopyableDataFrame", "copyInto", "") + checkSpan("snow.snowpark.CopyableDataFrame", "copyInto") df.clone() - checkSpan("snow.snowpark.CopyableDataFrame", "clone", "") + checkSpan("snow.snowpark.CopyableDataFrame", "clone") } finally { dropStage(stageName) dropTable(tableName) @@ -283,6 +291,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { test("line number - CopyableDataFrameAsyncActor") { val stageName = randomName() val tableName = randomName() + val className = "snow.snowpark.CopyableDataFrameAsyncActor" val userSchema: StructType = StructType( Seq( StructField("a", IntegerType), @@ -296,14 +305,14 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { testSpanExporter.reset() val df = session.read.schema(userSchema).csv(testFileOnStage) df.async.copyInto(tableName).getResult() - checkSpan("snow.snowpark.CopyableDataFrameAsyncActor", "copyInto", "") + checkSpan(className, "copyInto", "DataFrame.async.copyInto") df.async.copyInto(tableName, Seq(col("$1"), col("$2"), col("$3"))).getResult() - checkSpan("snow.snowpark.CopyableDataFrameAsyncActor", "copyInto", "") + checkSpan(className, "copyInto", "DataFrame.async.copyInto") val seq1 = Seq(col("$1"), col("$2"), col("$3")) df.async.copyInto(tableName, seq1, Map("FORCE" -> "TRUE")).getResult() - checkSpan("snow.snowpark.CopyableDataFrameAsyncActor", "copyInto", "") + checkSpan(className, "copyInto", "DataFrame.async.copyInto") df.async.copyInto(tableName, Seq("a", "b", "c"), seq1, Map.empty).getResult() - checkSpan("snow.snowpark.CopyableDataFrameAsyncActor", "copyInto", "") + checkSpan(className, "copyInto", "DataFrame.async.copyInto") } finally { dropStage(stageName) dropTable(tableName) @@ -320,25 +329,25 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val t2 = session.table(tableName2) testSpanExporter.reset() updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") updatable.update(Map("b" -> (col("a") + col("b")))) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") updatable.update(Map(col("b") -> lit(0)), col("a") === 1) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") updatable.update(Map("b" -> lit(0)), col("a") === 1) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") t2.update(Map(col("n") -> lit(0)), updatable("a") === t2("n"), updatable) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") t2.update(Map("n" -> lit(0)), updatable("a") === t2("n"), updatable) - checkSpan("snow.snowpark.Updatable", "update", "") + checkSpan("snow.snowpark.Updatable", "update") updatable.delete() - checkSpan("snow.snowpark.Updatable", "delete", "") + checkSpan("snow.snowpark.Updatable", "delete") updatable.delete(col("a") === 1 && col("b") === 2) - checkSpan("snow.snowpark.Updatable", "delete", "") + checkSpan("snow.snowpark.Updatable", "delete") t2.delete(updatable("a") === t2("n"), updatable) - checkSpan("snow.snowpark.Updatable", "delete", "") + checkSpan("snow.snowpark.Updatable", "delete") updatable.clone - checkSpan("snow.snowpark.Updatable", "clone", "") + checkSpan("snow.snowpark.Updatable", "clone") } finally { dropTable(tableName) dropTable(tableName2) @@ -348,6 +357,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { test("line number - UpdatableAsyncActor") { val tableName = randomName() val tableName2 = randomName() + val className = "snow.snowpark.UpdatableAsyncActor" try { testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) val updatable = session.table(tableName) @@ -355,23 +365,23 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val t2 = session.table(tableName2) testSpanExporter.reset() updatable.async.update(Map(col("a") -> lit(1), col("b") -> lit(0))).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") updatable.async.update(Map("b" -> (col("a") + col("b")))).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") updatable.async.update(Map(col("b") -> lit(0)), col("a") === 1).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") updatable.async.update(Map("b" -> lit(0)), col("a") === 1).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") t2.async.update(Map(col("n") -> lit(0)), updatable("a") === t2("n"), updatable).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") t2.async.update(Map("n" -> lit(0)), updatable("a") === t2("n"), updatable).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "update", "") + checkSpan(className, "update", "DataFrame.async.update") updatable.async.delete().getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "delete", "") + checkSpan(className, "delete", "DataFrame.async.delete") updatable.async.delete(col("a") === 1 && col("b") === 2).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "delete", "") + checkSpan(className, "delete", "DataFrame.async.delete") t2.async.delete(updatable("a") === t2("n"), updatable).getResult() - checkSpan("snow.snowpark.UpdatableAsyncActor", "delete", "") + checkSpan(className, "delete", "DataFrame.async.delete") } finally { dropTable(tableName) dropTable(tableName2) @@ -392,7 +402,7 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { .whenMatched .update(Map(target("desc") -> source("desc"))) builder.collect() - checkSpan("snow.snowpark.MergeBuilder", "collect", "") + checkSpan("snow.snowpark.MergeBuilder", "collect", "DataFrame.merge.collect") } finally { dropTable(tableName) } @@ -407,19 +417,20 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val target = session.table(tableName) val source = Seq((10, "new")).toDF("id", "desc") testSpanExporter.reset() + val className = "snow.snowpark.MergeBuilderAsyncActor" val builder = target .merge(source, target("id") === source("id")) .whenMatched .update(Map(target("desc") -> source("desc"))) builder.async.collect().getResult() - checkSpan("snow.snowpark.MergeBuilderAsyncActor", "collect", "") + checkSpan(className, "collect", "DataFrame.merge.async.collect") } finally { dropTable(tableName) } } test("OpenTelemetry.emit") { - OpenTelemetry.emit("ClassA", "functionB", "fileC", 123, "chainD") + OpenTelemetry.emit(ActionInfo("ClassA", "functionB", "fileC", 123, "chainD")) checkSpan("snow.snowpark.ClassA", "functionB", "fileC", 123, "chainD") } @@ -451,4 +462,15 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { file.getLineNumber - 1, methodChain) } + + def checkSpan(className: String, funcName: String): Unit = { + val stack = Thread.currentThread().getStackTrace + val file = stack(2) // this file + checkSpan( + className, + funcName, + "OpenTelemetrySuite.scala", + file.getLineNumber - 1, + s"DataFrame.$funcName") + } }