diff --git a/build.sbt b/build.sbt index 0143e438..6622e26f 100644 --- a/build.sbt +++ b/build.sbt @@ -92,7 +92,16 @@ lazy val sparkCobol = (project in file("spark-cobol")) log.info(s"Building with Spark ${sparkVersion(scalaVersion.value)}, Scala ${scalaVersion.value}") sparkVersion(scalaVersion.value) }, - (Compile / compile) := ((Compile / compile) dependsOn printSparkVersion).value, + Compile / compile := ((Compile / compile) dependsOn printSparkVersion).value, + Compile / unmanagedSourceDirectories += { + val sourceDir = (Compile / sourceDirectory).value + CrossVersion.partialVersion(scalaVersion.value) match { + case Some((2, n)) if n == 11 => sourceDir / "scala_2.11" + case Some((2, n)) if n == 12 => sourceDir / "scala_2.12" + case Some((2, n)) if n == 13 => sourceDir / "scala_2.13" + case _ => throw new RuntimeException("Unsupported Scala version") + } + }, libraryDependencies ++= SparkCobolDependencies(scalaVersion.value) :+ getScalaDependency(scalaVersion.value), dependencyOverrides ++= SparkCobolDependenciesOverride, Test / fork := true, // Spark tests fail randomly otherwise diff --git a/spark-cobol/pom.xml b/spark-cobol/pom.xml index 30e3a721..e8d93870 100644 --- a/spark-cobol/pom.xml +++ b/spark-cobol/pom.xml @@ -59,13 +59,33 @@ - - - - src/main/resources - true - - - + + + + src/main/resources + true + + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.0.0 + + + generate-sources + + add-source + + + + src/main/scala_${scala.compat.version} + + + + + + + \ No newline at end of file diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 800ec866..81dc01c6 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -19,7 +19,8 @@ package za.co.absa.cobrix.spark.cobol.utils import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.FileSystem import org.apache.spark.SparkContext -import org.apache.spark.sql.functions.{concat_ws, expr, max} +import org.apache.spark.sql.functions.{array, col, expr, max, struct} +import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} import za.co.absa.cobrix.cobol.internal.Logging @@ -178,6 +179,48 @@ object SparkUtils extends Logging { df.select(fields.toSeq: _*) } + def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = { + def mapField(column: Column, field: StructField): Column = { + field.dataType match { + case st: StructType => + val columns = st.fields.map(f => mapField(column.getField(field.name), f)) + struct(columns: _*).as(field.name) + case ar: ArrayType => + mapArray(ar, column, field.name).as(field.name) + case _ => + f(field, column).as(field.name) + } + } + + def mapArray(arr: ArrayType, column: Column, columnName: String): Column = { + arr.elementType match { + case st: StructType => + transform(column, c => { + val columns = st.fields.map(f => mapField(c.getField(f.name), f)) + struct(columns: _*) + }) + case ar: ArrayType => + array(mapArray(ar, column, columnName)) + case p => + array(f(StructField(columnName, p), column)) + } + } + + val columns = df.schema.fields.map(f => mapField(col(f.name), f)) + df.select(columns: _*) + } + + def covertIntegralToDecimal(df: DataFrame): DataFrame = { + mapPrimitives(df) { (field, c) => + val metadata = field.metadata + if (metadata.contains("precision") && (field.dataType == LongType || field.dataType == IntegerType || field.dataType == ShortType)) { + val precision = metadata.getLong("precision").toInt + c.cast(DecimalType(precision, 0)).as(field.name) + } else { + c + } + } + } /** * Given an instance of DataFrame returns a dataframe where all primitive fields are converted to String diff --git a/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala new file mode 100644 index 00000000..1546b6d6 --- /dev/null +++ b/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala @@ -0,0 +1,33 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.utils.impl + +import org.apache.spark.sql.Column + +object HofsWrapper { + /** + * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function + * from functional programming. + * + * The method is not available in Scala 2.11 and Spark < 3.0 + */ + def transform( + array: Column, + f: Column => Column): Column = { + throw new IllegalArgumentException("Array transformation is not available for Scala 2.11 and Spark < 3.0.") + } +} diff --git a/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala new file mode 100644 index 00000000..266a5499 --- /dev/null +++ b/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala @@ -0,0 +1,38 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.utils.impl + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.{transform => sparkTransform} + +object HofsWrapper { + /** + * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function + * from functional programming. + * + * (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala) + * + * @param array A column of arrays + * @param f A function transforming individual elements of the array + * @return A column of arrays with transformed elements + */ + def transform( + array: Column, + f: Column => Column): Column = { + sparkTransform(array, f) + } +} diff --git a/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala new file mode 100644 index 00000000..266a5499 --- /dev/null +++ b/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala @@ -0,0 +1,38 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.utils.impl + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.{transform => sparkTransform} + +object HofsWrapper { + /** + * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function + * from functional programming. + * + * (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala) + * + * @param array A column of arrays + * @param f A function transforming individual elements of the array + * @return A column of arrays with transformed elements + */ + def transform( + array: Column, + f: Column => Column): Column = { + sparkTransform(array, f) + } +} diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala new file mode 100644 index 00000000..991e521e --- /dev/null +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.source.fixtures + +import org.scalatest.{Assertion, Suite} + +trait TextComparisonFixture { + this: Suite => + + protected def compareText(actual: String, expected: String): Assertion = { + if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) { + fail(renderTextDifference(actual, expected)) + } else { + succeed + } + } + + protected def compareTextVertical(actual: String, expected: String): Unit = { + if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) { + fail(s"ACTUAL:\n$actual\nEXPECTED: \n$expected") + } + } + + protected def renderTextDifference(textActual: String, textExpected: String): String = { + val t1 = textActual.replaceAll("\\r\\n", "\\n").split('\n') + val t2 = textExpected.replaceAll("\\r\\n", "\\n").split('\n') + + val maxLen = Math.max(getMaxStrLen(t1), getMaxStrLen(t2)) + val header = s" ${rightPad("ACTUAL:", maxLen)} ${rightPad("EXPECTED:", maxLen)}\n" + + val stringBuilder = new StringBuilder + stringBuilder.append(header) + + val linesCount = Math.max(t1.length, t2.length) + var i = 0 + while (i < linesCount) { + val a = if (i < t1.length) t1(i) else "" + val b = if (i < t2.length) t2(i) else "" + + val marker1 = if (a != b) ">" else " " + val marker2 = if (a != b) "<" else " " + + val comparisonText = s"$marker1${rightPad(a, maxLen)} ${rightPad(b, maxLen)}$marker2\n" + stringBuilder.append(comparisonText) + + i += 1 + } + + val footer = s"\nACTUAL:\n$textActual" + stringBuilder.append(footer) + stringBuilder.toString() + } + + def getMaxStrLen(text: Seq[String]): Int = { + if (text.isEmpty) { + 0 + } else { + text.maxBy(_.length).length + } + } + + def rightPad(s: String, length: Int): String = { + if (s.length < length) { + s + " " * (length - s.length) + } else if (s.length > length) { + s.take(length) + } else { + s + } + } +} diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 3455745a..51320b61 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -16,17 +16,17 @@ package za.co.absa.cobrix.spark.cobol.utils -import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite -import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase import org.slf4j.LoggerFactory -import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture +import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase +import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture} import za.co.absa.cobrix.spark.cobol.utils.TestUtils._ import java.nio.charset.StandardCharsets -import scala.collection.immutable +import scala.util.Properties -class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture { +class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture with TextComparisonFixture { import spark.implicits._ @@ -377,7 +377,7 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt assert(dfFlattened.count() == 0) } - test("Schema with multiple OCCURS should properly determine array sized") { + test("Schema with multiple OCCURS should properly determine array sizes") { val copyBook: String = """ 01 RECORD. | 02 COUNT PIC 9(1). @@ -429,6 +429,46 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt } } + test("Integral to decimal conversion for complex schema") { + val expectedSchema = + """|root + | |-- COUNT: decimal(1,0) (nullable = true) + | |-- GROUP: array (nullable = true) + | | |-- element: struct (containsNull = false) + | | | |-- INNER_COUNT: decimal(1,0) (nullable = true) + | | | |-- INNER_GROUP: array (nullable = true) + | | | | |-- element: struct (containsNull = false) + | | | | | |-- FIELD: decimal(1,0) (nullable = true) + |""".stripMargin + + val copyBook: String = + """ 01 RECORD. + | 02 COUNT PIC 9(1). + | 02 GROUP OCCURS 2 TIMES. + | 03 INNER-COUNT PIC S9(1). + | 03 INNER-GROUP OCCURS 3 TIMES. + | 04 FIELD PIC 9. + |""".stripMargin + + withTempTextFile("fletten", "test", StandardCharsets.UTF_8, "") { filePath => + val df = spark.read + .format("cobol") + .option("copybook_contents", copyBook) + .option("pedantic", "true") + .option("record_format", "D") + .option("metadata", "extended") + .load(filePath) + + if (!Properties.versionString.startsWith("2.")) { + // This method only works with Scala 2.12+ and Spark 3.0+ + val actualDf = SparkUtils.covertIntegralToDecimal(df) + val actualSchema = actualDf.schema.treeString + + compareText(actualSchema, expectedSchema) + } + } + } + private def assertSchema(actualSchema: String, expectedSchema: String): Unit = { if (actualSchema != expectedSchema) { logger.error(s"EXPECTED:\n$expectedSchema")