Skip to content

Commit

Permalink
explode array
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Nov 27, 2023
1 parent 729d50f commit 246770d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
56 changes: 40 additions & 16 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package com.snowflake.snowpark

import scala.reflect.ClassTag
import scala.util.Random
import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF}
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.internal.{Logging, Utils}
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.snowflake.snowpark.functions.lit
import com.snowflake.snowpark.internal.Utils.{
TempObjectType,
getTableFunctionExpression,
Expand Down Expand Up @@ -1969,23 +1971,45 @@ class DataFrame private[snowpark] (
private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
func match {
// explode is a client side function
case TF(funcName, args) if funcName.toLowerCase().trim.equals("explode") =>
// explode has only one argument
joinWithExplode(args.head, partitionByOrderBy)
case _ =>
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}

private def joinWithExplode(
expr: Expression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val columns: Seq[Column] = this.output.map(attr => col(attr.name))
// check the column type of input column
this.select(Column(expr)).schema.head.dataType match {
case _: ArrayType =>
joinTableFunction(
tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))),
partitionByOrderBy).select(columns :+ Column("VALUE"))
case _: MapType => null
case _ => null
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/snowflake/snowpark/tableFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,7 @@ object tableFunctions {
"outer" -> lit(outer),
"recursive" -> lit(recursive),
"mode" -> lit(mode)))

def explode(expr: Column): Column = TableFunction("explode").apply(expr)

}
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,13 @@ class TableFunctionSuite extends TestData {
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

test("explode with array column") {
val df = Seq("[1, 2]").toDF("a")
val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a"))
checkAnswer(
df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)),
Seq(Row(1, "1", "2"), Row(1, "2", "2")))

}

}

0 comments on commit 246770d

Please sign in to comment.