From d6c54bb219eee64e626294750a9958743736c9ab Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 31 Jul 2024 07:13:29 +0200 Subject: [PATCH] #697 Improve metadata merging method in Spark Utils. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 24 ++++++-- .../spark/cobol/utils/SparkUtilsSuite.scala | 55 ++++++++++++++++++- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 160c91c1..23352791 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -245,9 +245,19 @@ object SparkUtils extends Logging { * * @param schemaFrom Schema to copy metadata from. * @param schemaTo Schema to copy metadata to. + * @param overwrite If true, the metadata of schemaTo is not retained * @return Same schema as schemaTo with metadata from schemaFrom. */ - def copyMetadata(schemaFrom: StructType, schemaTo: StructType): StructType = { + def copyMetadata(schemaFrom: StructType, schemaTo: StructType, overwrite: Boolean = false): StructType = { + def joinMetadata(from: Metadata, to: Metadata): Metadata = { + val newMetadataMerged = new MetadataBuilder + + newMetadataMerged.withMetadata(from) + newMetadataMerged.withMetadata(to) + + newMetadataMerged.build() + } + @tailrec def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = { ar.elementType match { @@ -267,15 +277,21 @@ object SparkUtils extends Logging { val newFields: Array[StructField] = schemaTo.fields.map { fieldTo => fieldsMap.get(fieldTo.name) match { case Some(fieldFrom) => + val newMetadata = if (overwrite) { + fieldFrom.metadata + } else { + joinMetadata(fieldFrom.metadata, fieldTo.metadata) + } + fieldTo.dataType match { case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] => val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields) - fieldTo.copy(dataType = newDataType, metadata = fieldFrom.metadata) + fieldTo.copy(dataType = newDataType, metadata = newMetadata) case at: ArrayType => val newType = processArray(at, fieldFrom, fieldTo) - fieldTo.copy(dataType = newType, metadata = fieldFrom.metadata) + fieldTo.copy(dataType = newType, metadata = newMetadata) case _ => - fieldTo.copy(metadata = fieldFrom.metadata) + fieldTo.copy(metadata = newMetadata) } case None => fieldTo diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 78feab6b..81a192fc 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -603,22 +603,73 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt } } - test("copyMetadata should copy metadata from one schema to another") { + test("copyMetadata should copy metadata from one schema to another when overwrite = false") { val df1 = List(1, 2, 3).toDF("col1") val df2 = List(1, 2, 3).toDF("col1") val metadata1 = new MetadataBuilder() metadata1.putString("comment", "Test") + val metadata2 = new MetadataBuilder() + metadata2.putLong("maxLength", 120) + + val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build()))) + val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build()))) + + val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata) + + val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata) + + val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata) + + assert(newDf.schema.fields.head.metadata.getString("comment") == "Test") + assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120) + } + + test("copyMetadata should not retain original metadata when overwrite = true") { + val df1 = List(1, 2, 3).toDF("col1") + val df2 = List(1, 2, 3).toDF("col1") + + val metadata1 = new MetadataBuilder() + metadata1.putString("comment", "Test") + + val metadata2 = new MetadataBuilder() + metadata2.putLong("maxLength", 120) + val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build()))) + val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build()))) val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata) - val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, df2.schema) + val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, overwrite = true) val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata) assert(newDf.schema.fields.head.metadata.getString("comment") == "Test") + assert(!newDf.schema.fields.head.metadata.contains("maxLength")) + } + + test("Make sure flattenning does not remove metadata") { + val df1 = List(1, 2, 3).toDF("col1") + val df2 = List(1, 2, 3).toDF("col1") + + val metadata1 = new MetadataBuilder() + metadata1.putString("comment", "Test") + + val metadata2 = new MetadataBuilder() + metadata2.putLong("maxLength", 120) + + val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build()))) + val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build()))) + + val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata) + + val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata) + + val newDf = SparkUtils.unstructDataFrame(spark.createDataFrame(df2.rdd, schemaWithMetadata)) + + assert(newDf.schema.fields.head.metadata.getString("comment") == "Test") + assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120) } test("Integral to decimal conversion for complex schema") {