From ac1e7f39f207d7491e1dc6b62097f00a086ee7d3 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Thu, 3 Aug 2023 15:30:25 +0200 Subject: [PATCH] #634 Retain metadata on schema flattening. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 12 ++++---- .../spark/cobol/utils/SparkUtilsSuite.scala | 30 ++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 7c030967..800ec866 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -97,14 +97,14 @@ object SparkUtils extends Logging { case _ => val newFieldNamePrefix = s"${fieldNamePrefix}${i}" val newFieldName = getNewFieldName(s"$newFieldNamePrefix") - fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName) + fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName, structField.metadata) stringFields += s"""expr("$path`${structField.name}`[$i] AS `$newFieldName`")""" } i += 1 } } - def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType): Unit = { + def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType, metadata: Metadata): Unit = { val maxInd = getMaxArraySize(path) var i = 0 while (i < maxInd) { @@ -114,12 +114,12 @@ object SparkUtils extends Logging { flattenGroup(s"$path[$i]", newFieldNamePrefix, st) case ar: ArrayType => val newFieldNamePrefix = s"${fieldNamePrefix}${i}_" - flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar) + flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar, metadata) // AtomicType is protected on package 'sql' level so have to enumerate all subtypes :( case _ => val newFieldNamePrefix = s"${fieldNamePrefix}${i}" val newFieldName = getNewFieldName(s"$newFieldNamePrefix") - fields += expr(s"$path[$i]").as(newFieldName) + fields += expr(s"$path[$i]").as(newFieldName, metadata) stringFields += s"""expr("$path`[$i] AS `$newFieldName`")""" } i += 1 @@ -144,7 +144,7 @@ object SparkUtils extends Logging { def flattenArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = { arrayType.elementType match { case _: ArrayType => - flattenNestedArrays(s"$path${structField.name}", fieldNamePrefix, arrayType) + flattenNestedArrays(s"$path${structField.name}", fieldNamePrefix, arrayType, structField.metadata) case _ => flattenStructArray(path, fieldNamePrefix, structField, arrayType) } @@ -164,7 +164,7 @@ object SparkUtils extends Logging { flattenArray(path, newFieldNamePrefix, field, arr) case _ => val newFieldName = getNewFieldName(s"$fieldNamePrefix${field.name}") - fields += expr(s"$path`${field.name}`").as(newFieldName) + fields += expr(s"$path`${field.name}`").as(newFieldName, field.metadata) if (path.contains('[')) stringFields += s"""expr("$path`${field.name}` AS `$newFieldName`")""" else diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 4f7ecc96..2739f307 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -16,7 +16,7 @@ package za.co.absa.cobrix.spark.cobol.utils -import org.apache.spark.sql.types.{ArrayType, StructType} +import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase import org.slf4j.LoggerFactory @@ -102,6 +102,34 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt assertResults(flatData, expectedFlatData) } + test("Test metadata is retained") { + val metadata1 = new MetadataBuilder().putLong("test_metadata1", 123).build() + val metadata2 = new MetadataBuilder().putLong("test_metadata2", 456).build() + val metadata3 = new MetadataBuilder().putLong("test_metadata3", 789).build() + + val schema = StructType(Array( + StructField("id", LongType, nullable = true, metadata = metadata1), + StructField("legs", ArrayType(StructType(List( + StructField("conditions", ArrayType(StructType(List( + StructField("amount", LongType, nullable = true), + StructField("checks", ArrayType(StructType(List( + StructField("checkNums", ArrayType(StringType, containsNull = true), nullable = true, metadata = metadata3) + )), containsNull = true), nullable = true))), containsNull = true), nullable = true), + StructField("legid", LongType, nullable = true, metadata = metadata2))), containsNull = true), nullable = true))) + + val df = spark.read.schema(schema).json(nestedSampleData.toDS) + val dfFlattened = SparkUtils.flattenSchema(df) + + assert(dfFlattened.schema.fields(0).metadata.getLong("test_metadata1") == 123) + assert(dfFlattened.schema.fields.find(_.name == "id").get.metadata.getLong("test_metadata1") == 123) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_legid").get.metadata.getLong("test_metadata2") == 456) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_1").get.metadata.getLong("test_metadata3") == 789) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_2").get.metadata.getLong("test_metadata3") == 789) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_3").get.metadata.getLong("test_metadata3") == 789) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_4").get.metadata.getLong("test_metadata3") == 789) + assert(dfFlattened.schema.fields.find(_.name == "legs_0_conditions_0_checks_0_checkNums_5").get.metadata.getLong("test_metadata3") == 789) + } + test("Test schema flattening when short names are used") { val expectedFlatSchema = """root