diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 8fd81ab2..9d08f924 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -2008,7 +2008,10 @@ class DataFrame private[snowpark] ( joinTableFunction( tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))), partitionByOrderBy).select(columns :+ Column("VALUE")) - case _: MapType => null + case _: MapType => + joinTableFunction( + tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("object"))), + partitionByOrderBy).select(columns ++ Seq(Column("KEY"), Column("VALUE"))) case _ => null } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index cc7adb91..4d110639 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -392,7 +392,18 @@ class TableFunctionSuite extends TestData { checkAnswer( df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)), Seq(Row(1, "1", "2"), Row(1, "2", "2"))) + } + test("explode with map column") { + val df = Seq("""{"a":1, "b": 2}""").toDF("a") + val df1 = df.select( + parse_json(df("a")) + .cast(types.MapType(types.StringType, types.IntegerType)) + .as("a")) + df1.select(tableFunctions.explode(df1("a"))).show() + checkAnswer( + df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")), + Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1"))) } }