From 246770d4f1428220479d2b58731cd08130589934 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 27 Nov 2023 14:45:01 -0800 Subject: [PATCH] explode array --- .../com/snowflake/snowpark/DataFrame.scala | 56 +++++++++++++------ .../snowflake/snowpark/tableFunctions.scala | 3 + .../snowpark_test/TableFunctionSuite.scala | 9 +++ 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index f7d6909c..8fd81ab2 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -2,11 +2,13 @@ package com.snowflake.snowpark import scala.reflect.ClassTag import scala.util.Random +import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF} import com.snowflake.snowpark.internal.ErrorMessage import com.snowflake.snowpark.internal.{Logging, Utils} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import com.github.vertical_blank.sqlformatter.SqlFormatter +import com.snowflake.snowpark.functions.lit import com.snowflake.snowpark.internal.Utils.{ TempObjectType, getTableFunctionExpression, @@ -1969,23 +1971,45 @@ class DataFrame private[snowpark] ( private def joinTableFunction( func: TableFunctionExpression, partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { - val originalResult = withPlan { - TableFunctionJoin(this.plan, func, partitionByOrderBy) + func match { + // explode is a client side function + case TF(funcName, args) if funcName.toLowerCase().trim.equals("explode") => + // explode has only one argument + joinWithExplode(args.head, partitionByOrderBy) + case _ => + val originalResult = withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } + val resultSchema = originalResult.schema + val columnNames = resultSchema.map(_.name) + // duplicated names + val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName) + // guarantee no duplicated names in the result + if (dup.nonEmpty) { + val dfPrefix = DataFrame.generatePrefix('o') + val renamedDf = + this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) + withPlan { + TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy) + } + } else { + originalResult + } } - val resultSchema = originalResult.schema - val columnNames = resultSchema.map(_.name) - // duplicated names - val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName) - // guarantee no duplicated names in the result - if (dup.nonEmpty) { - val dfPrefix = DataFrame.generatePrefix('o') - val renamedDf = - this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) - withPlan { - TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy) - } - } else { - originalResult + } + + private def joinWithExplode( + expr: Expression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { + val columns: Seq[Column] = this.output.map(attr => col(attr.name)) + // check the column type of input column + this.select(Column(expr)).schema.head.dataType match { + case _: ArrayType => + joinTableFunction( + tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))), + partitionByOrderBy).select(columns :+ Column("VALUE")) + case _: MapType => null + case _ => null } } diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 212a000c..2cf5060f 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -197,4 +197,7 @@ object tableFunctions { "outer" -> lit(outer), "recursive" -> lit(recursive), "mode" -> lit(mode))) + + def explode(expr: Column): Column = TableFunction("explode").apply(expr) + } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index fe330255..cc7adb91 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -386,4 +386,13 @@ class TableFunctionSuite extends TestData { checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4"))) } + test("explode with array column") { + val df = Seq("[1, 2]").toDF("a") + val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a")) + checkAnswer( + df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)), + Seq(Row(1, "1", "2"), Row(1, "2", "2"))) + + } + }