From aa95886b5389c905695d39a51536b109ffd3ffe6 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Mon, 3 Jun 2024 18:18:33 +0200 Subject: [PATCH] #678 Add metadata to the target schema when converting integral data types to decimals. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 61 ++++++++++++++++++- .../spark/cobol/utils/SparkUtilsSuite.scala | 3 + 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 39184aed..f059ecf8 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -179,6 +179,60 @@ object SparkUtils extends Logging { df.select(fields.toSeq: _*) } + /** + * Copies metadata from one schema to another as long as names and data types are the same. + * + * @param schemaFrom Schema to copy metadata from. + * @param schemaTo Schema to copy metadata to. + * @return Same schema as schemaTo with metadata from schemaFrom. + */ + def copyMetadata(schemaFrom: StructType, schemaTo: StructType): StructType = { + @tailrec + def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = { + ar.elementType match { + case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] => + val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType] + val newDataType = StructType(copyMetadata(innerStructFrom, st).fields) + ArrayType(newDataType, ar.containsNull) + case at: ArrayType => + processArray(at, fieldFrom, fieldTo) + case p => + ArrayType(p, ar.containsNull) + } + } + + val fieldsMap = schemaFrom.fields.map(f => (f.name, f)).toMap + + val newFields: Array[StructField] = schemaTo.fields.map { fieldTo => + fieldsMap.get(fieldTo.name) match { + case Some(fieldFrom) => + fieldTo.dataType match { + case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] => + val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields) + fieldTo.copy(dataType = newDataType, metadata = fieldFrom.metadata) + case at: ArrayType => + val newType = processArray(at, fieldFrom, fieldTo) + fieldTo.copy(dataType = newType, metadata = fieldFrom.metadata) + case _ => + fieldTo.copy(metadata = fieldFrom.metadata) + } + case None => + fieldTo + } + } + + StructType(newFields) + } + + /** + * Allows mapping every primitive field in a dataframe with a Spark expression. + * + * The metadata of the original schema is retained. + * + * @param df The dataframe to map. + * @param f The function to apply to each primitive field. + * @return The new dataframe with the mapping applied. + */ def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = { def mapField(column: Column, field: StructField): Column = { field.dataType match { @@ -207,7 +261,10 @@ object SparkUtils extends Logging { } val columns = df.schema.fields.map(f => mapField(col(f.name), f)) - df.select(columns: _*) + val newDf = df.select(columns: _*) + val newSchema = copyMetadata(df.schema, newDf.schema) + + df.sparkSession.createDataFrame(newDf.rdd, newSchema) } def covertIntegralToDecimal(df: DataFrame): DataFrame = { @@ -325,7 +382,7 @@ object SparkUtils extends Logging { val fileSystem = FileSystem.get(conf) val hdfsBlockSize = HDFSUtils.getHDFSDefaultBlockSizeMB(fileSystem) hdfsBlockSize match { - case None => logger.info(s"Unable to get HDFS default block size.") + case None => logger.info(s"Unable to get HDFS default block size.") case Some(size) => logger.info(s"HDFS default block size = $size MB.") } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 8e18575a..3f1ee467 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -466,6 +466,9 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt assert(actualDf.schema.fields.head.metadata.json.nonEmpty) assert(actualDf.schema.fields(1).metadata.json.nonEmpty) + assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) + assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).metadata.json.nonEmpty) + assert(actualDf.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty) compareText(actualSchema, expectedSchema) }