diff --git a/build.sbt b/build.sbt
index 0143e438..6622e26f 100644
--- a/build.sbt
+++ b/build.sbt
@@ -92,7 +92,16 @@ lazy val sparkCobol = (project in file("spark-cobol"))
log.info(s"Building with Spark ${sparkVersion(scalaVersion.value)}, Scala ${scalaVersion.value}")
sparkVersion(scalaVersion.value)
},
- (Compile / compile) := ((Compile / compile) dependsOn printSparkVersion).value,
+ Compile / compile := ((Compile / compile) dependsOn printSparkVersion).value,
+ Compile / unmanagedSourceDirectories += {
+ val sourceDir = (Compile / sourceDirectory).value
+ CrossVersion.partialVersion(scalaVersion.value) match {
+ case Some((2, n)) if n == 11 => sourceDir / "scala_2.11"
+ case Some((2, n)) if n == 12 => sourceDir / "scala_2.12"
+ case Some((2, n)) if n == 13 => sourceDir / "scala_2.13"
+ case _ => throw new RuntimeException("Unsupported Scala version")
+ }
+ },
libraryDependencies ++= SparkCobolDependencies(scalaVersion.value) :+ getScalaDependency(scalaVersion.value),
dependencyOverrides ++= SparkCobolDependenciesOverride,
Test / fork := true, // Spark tests fail randomly otherwise
diff --git a/spark-cobol/pom.xml b/spark-cobol/pom.xml
index 30e3a721..e8d93870 100644
--- a/spark-cobol/pom.xml
+++ b/spark-cobol/pom.xml
@@ -59,13 +59,33 @@
-
-
-
- src/main/resources
- true
-
-
-
+
+
+
+ src/main/resources
+ true
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+ 3.0.0
+
+
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala
index 800ec866..81dc01c6 100644
--- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala
+++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala
@@ -19,7 +19,8 @@ package za.co.absa.cobrix.spark.cobol.utils
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.SparkContext
-import org.apache.spark.sql.functions.{concat_ws, expr, max}
+import org.apache.spark.sql.functions.{array, col, expr, max, struct}
+import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import za.co.absa.cobrix.cobol.internal.Logging
@@ -178,6 +179,48 @@ object SparkUtils extends Logging {
df.select(fields.toSeq: _*)
}
+ def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = {
+ def mapField(column: Column, field: StructField): Column = {
+ field.dataType match {
+ case st: StructType =>
+ val columns = st.fields.map(f => mapField(column.getField(field.name), f))
+ struct(columns: _*).as(field.name)
+ case ar: ArrayType =>
+ mapArray(ar, column, field.name).as(field.name)
+ case _ =>
+ f(field, column).as(field.name)
+ }
+ }
+
+ def mapArray(arr: ArrayType, column: Column, columnName: String): Column = {
+ arr.elementType match {
+ case st: StructType =>
+ transform(column, c => {
+ val columns = st.fields.map(f => mapField(c.getField(f.name), f))
+ struct(columns: _*)
+ })
+ case ar: ArrayType =>
+ array(mapArray(ar, column, columnName))
+ case p =>
+ array(f(StructField(columnName, p), column))
+ }
+ }
+
+ val columns = df.schema.fields.map(f => mapField(col(f.name), f))
+ df.select(columns: _*)
+ }
+
+ def covertIntegralToDecimal(df: DataFrame): DataFrame = {
+ mapPrimitives(df) { (field, c) =>
+ val metadata = field.metadata
+ if (metadata.contains("precision") && (field.dataType == LongType || field.dataType == IntegerType || field.dataType == ShortType)) {
+ val precision = metadata.getLong("precision").toInt
+ c.cast(DecimalType(precision, 0)).as(field.name)
+ } else {
+ c
+ }
+ }
+ }
/**
* Given an instance of DataFrame returns a dataframe where all primitive fields are converted to String
diff --git a/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
new file mode 100644
index 00000000..1546b6d6
--- /dev/null
+++ b/spark-cobol/src/main/scala_2.11/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2018 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.cobrix.spark.cobol.utils.impl
+
+import org.apache.spark.sql.Column
+
+object HofsWrapper {
+ /**
+ * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
+ * from functional programming.
+ *
+ * The method is not available in Scala 2.11 and Spark < 3.0
+ */
+ def transform(
+ array: Column,
+ f: Column => Column): Column = {
+ throw new IllegalArgumentException("Array transformation is not available for Scala 2.11 and Spark < 3.0.")
+ }
+}
diff --git a/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
new file mode 100644
index 00000000..266a5499
--- /dev/null
+++ b/spark-cobol/src/main/scala_2.12/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2018 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.cobrix.spark.cobol.utils.impl
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions.{transform => sparkTransform}
+
+object HofsWrapper {
+ /**
+ * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
+ * from functional programming.
+ *
+ * (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
+ *
+ * @param array A column of arrays
+ * @param f A function transforming individual elements of the array
+ * @return A column of arrays with transformed elements
+ */
+ def transform(
+ array: Column,
+ f: Column => Column): Column = {
+ sparkTransform(array, f)
+ }
+}
diff --git a/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala b/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
new file mode 100644
index 00000000..266a5499
--- /dev/null
+++ b/spark-cobol/src/main/scala_2.13/za/co/absa/cobrix/spark/cobol/utils/impl/HofsWrapper.scala
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2018 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.cobrix.spark.cobol.utils.impl
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions.{transform => sparkTransform}
+
+object HofsWrapper {
+ /**
+ * Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
+ * from functional programming.
+ *
+ * (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
+ *
+ * @param array A column of arrays
+ * @param f A function transforming individual elements of the array
+ * @return A column of arrays with transformed elements
+ */
+ def transform(
+ array: Column,
+ f: Column => Column): Column = {
+ sparkTransform(array, f)
+ }
+}
diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala
new file mode 100644
index 00000000..991e521e
--- /dev/null
+++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2018 ABSA Group Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package za.co.absa.cobrix.spark.cobol.source.fixtures
+
+import org.scalatest.{Assertion, Suite}
+
+trait TextComparisonFixture {
+ this: Suite =>
+
+ protected def compareText(actual: String, expected: String): Assertion = {
+ if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
+ fail(renderTextDifference(actual, expected))
+ } else {
+ succeed
+ }
+ }
+
+ protected def compareTextVertical(actual: String, expected: String): Unit = {
+ if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
+ fail(s"ACTUAL:\n$actual\nEXPECTED: \n$expected")
+ }
+ }
+
+ protected def renderTextDifference(textActual: String, textExpected: String): String = {
+ val t1 = textActual.replaceAll("\\r\\n", "\\n").split('\n')
+ val t2 = textExpected.replaceAll("\\r\\n", "\\n").split('\n')
+
+ val maxLen = Math.max(getMaxStrLen(t1), getMaxStrLen(t2))
+ val header = s" ${rightPad("ACTUAL:", maxLen)} ${rightPad("EXPECTED:", maxLen)}\n"
+
+ val stringBuilder = new StringBuilder
+ stringBuilder.append(header)
+
+ val linesCount = Math.max(t1.length, t2.length)
+ var i = 0
+ while (i < linesCount) {
+ val a = if (i < t1.length) t1(i) else ""
+ val b = if (i < t2.length) t2(i) else ""
+
+ val marker1 = if (a != b) ">" else " "
+ val marker2 = if (a != b) "<" else " "
+
+ val comparisonText = s"$marker1${rightPad(a, maxLen)} ${rightPad(b, maxLen)}$marker2\n"
+ stringBuilder.append(comparisonText)
+
+ i += 1
+ }
+
+ val footer = s"\nACTUAL:\n$textActual"
+ stringBuilder.append(footer)
+ stringBuilder.toString()
+ }
+
+ def getMaxStrLen(text: Seq[String]): Int = {
+ if (text.isEmpty) {
+ 0
+ } else {
+ text.maxBy(_.length).length
+ }
+ }
+
+ def rightPad(s: String, length: Int): String = {
+ if (s.length < length) {
+ s + " " * (length - s.length)
+ } else if (s.length > length) {
+ s.take(length)
+ } else {
+ s
+ }
+ }
+}
diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala
index 3455745a..51320b61 100644
--- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala
+++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala
@@ -16,17 +16,17 @@
package za.co.absa.cobrix.spark.cobol.utils
-import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.scalatest.funsuite.AnyFunSuite
-import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
import org.slf4j.LoggerFactory
-import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture
+import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
+import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture}
import za.co.absa.cobrix.spark.cobol.utils.TestUtils._
import java.nio.charset.StandardCharsets
-import scala.collection.immutable
+import scala.util.Properties
-class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture {
+class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture with TextComparisonFixture {
import spark.implicits._
@@ -377,7 +377,7 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
assert(dfFlattened.count() == 0)
}
- test("Schema with multiple OCCURS should properly determine array sized") {
+ test("Schema with multiple OCCURS should properly determine array sizes") {
val copyBook: String =
""" 01 RECORD.
| 02 COUNT PIC 9(1).
@@ -429,6 +429,46 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
}
}
+ test("Integral to decimal conversion for complex schema") {
+ val expectedSchema =
+ """|root
+ | |-- COUNT: decimal(1,0) (nullable = true)
+ | |-- GROUP: array (nullable = true)
+ | | |-- element: struct (containsNull = false)
+ | | | |-- INNER_COUNT: decimal(1,0) (nullable = true)
+ | | | |-- INNER_GROUP: array (nullable = true)
+ | | | | |-- element: struct (containsNull = false)
+ | | | | | |-- FIELD: decimal(1,0) (nullable = true)
+ |""".stripMargin
+
+ val copyBook: String =
+ """ 01 RECORD.
+ | 02 COUNT PIC 9(1).
+ | 02 GROUP OCCURS 2 TIMES.
+ | 03 INNER-COUNT PIC S9(1).
+ | 03 INNER-GROUP OCCURS 3 TIMES.
+ | 04 FIELD PIC 9.
+ |""".stripMargin
+
+ withTempTextFile("fletten", "test", StandardCharsets.UTF_8, "") { filePath =>
+ val df = spark.read
+ .format("cobol")
+ .option("copybook_contents", copyBook)
+ .option("pedantic", "true")
+ .option("record_format", "D")
+ .option("metadata", "extended")
+ .load(filePath)
+
+ if (!Properties.versionString.startsWith("2.")) {
+ // This method only works with Scala 2.12+ and Spark 3.0+
+ val actualDf = SparkUtils.covertIntegralToDecimal(df)
+ val actualSchema = actualDf.schema.treeString
+
+ compareText(actualSchema, expectedSchema)
+ }
+ }
+ }
+
private def assertSchema(actualSchema: String, expectedSchema: String): Unit = {
if (actualSchema != expectedSchema) {
logger.error(s"EXPECTED:\n$expectedSchema")