Skip to content

Commit

Permalink
#634 Retain metadata on schema flattening.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Aug 3, 2023
1 parent 56e3df3 commit ac1e7f3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ object SparkUtils extends Logging {
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName)
fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName, structField.metadata)
stringFields += s"""expr("$path`${structField.name}`[$i] AS `$newFieldName`")"""
}
i += 1
}
}

def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType): Unit = {
def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType, metadata: Metadata): Unit = {
val maxInd = getMaxArraySize(path)
var i = 0
while (i < maxInd) {
Expand All @@ -114,12 +114,12 @@ object SparkUtils extends Logging {
flattenGroup(s"$path[$i]", newFieldNamePrefix, st)
case ar: ArrayType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar)
flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar, metadata)
// AtomicType is protected on package 'sql' level so have to enumerate all subtypes :(
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path[$i]").as(newFieldName)
fields += expr(s"$path[$i]").as(newFieldName, metadata)
stringFields += s"""expr("$path`[$i] AS `$newFieldName`")"""
}
i += 1
Expand All @@ -144,7 +144,7 @@ object SparkUtils extends Logging {
def flattenArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
arrayType.elementType match {
case _: ArrayType =>
flattenNestedArrays(s"$path${structField.name}", fieldNamePrefix, arrayType)
flattenNestedArrays(s"$path${structField.name}", fieldNamePrefix, arrayType, structField.metadata)
case _ =>
flattenStructArray(path, fieldNamePrefix, structField, arrayType)
}
Expand All @@ -164,7 +164,7 @@ object SparkUtils extends Logging {
flattenArray(path, newFieldNamePrefix, field, arr)
case _ =>
val newFieldName = getNewFieldName(s"$fieldNamePrefix${field.name}")
fields += expr(s"$path`${field.name}`").as(newFieldName)
fields += expr(s"$path`${field.name}`").as(newFieldName, field.metadata)
if (path.contains('['))
stringFields += s"""expr("$path`${field.name}` AS `$newFieldName`")"""
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package za.co.absa.cobrix.spark.cobol.utils

import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType}
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -102,6 +102,34 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
assertResults(flatData, expectedFlatData)
}

test("Test metadata is retained") {
val metadata1 = new MetadataBuilder().putLong("test_metadata1", 123).build()
val metadata2 = new MetadataBuilder().putLong("test_metadata2", 456).build()
val metadata3 = new MetadataBuilder().putLong("test_metadata3", 789).build()

val schema = StructType(Array(
StructField("id", LongType, nullable = true, metadata = metadata1),
StructField("legs", ArrayType(StructType(List(
StructField("conditions", ArrayType(StructType(List(
StructField("amount", LongType, nullable = true),
StructField("checks", ArrayType(StructType(List(
StructField("checkNums", ArrayType(StringType, containsNull = true), nullable = true, metadata = metadata3)
)), containsNull = true), nullable = true))), containsNull = true), nullable = true),
StructField("legid", LongType, nullable = true, metadata = metadata2))), containsNull = true), nullable = true)))

val df = spark.read.schema(schema).json(nestedSampleData.toDS)
val dfFlattened = SparkUtils.flattenSchema(df)

assert(dfFlattened.schema.fields(0).metadata.getLong("test_metadata1") == 123)
assert(dfFlattened.schema.fields.find(_.name == "id").get.metadata.getLong("test_metadata1") == 123)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_legid").get.metadata.getLong("test_metadata2") == 456)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_1").get.metadata.getLong("test_metadata3") == 789)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_2").get.metadata.getLong("test_metadata3") == 789)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_3").get.metadata.getLong("test_metadata3") == 789)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_4").get.metadata.getLong("test_metadata3") == 789)
assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_5").get.metadata.getLong("test_metadata3") == 789)
}

test("Test schema flattening when short names are used") {
val expectedFlatSchema =
"""root
Expand Down

0 comments on commit ac1e7f3

Please sign in to comment.