Skip to content

Commit

Permalink
#683 Add support for '_' for hierarchical key generation at leaf level.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Jul 16, 2024
1 parent 4872933 commit 6368a8b
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 5 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,23 @@ val df = spark
.load("examples/multisegment_data/COMP.DETAILS.SEP30.DATA.dat")
```

Sometimes, the leaf level has many segments. In this case, you can use `_` as the list of segment ids to specify
'the rest of segment ids', like this:

```scala
val df = spark
.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "V")
.option("segment_field", "SEGMENT_ID")
.option("segment_id_level0", "C")
.option("segment_id_level1", "_")
.load("examples/multisegment_data/COMP.DETAILS.SEP30.DATA.dat")
```

The result of both above code snippets is the same.

The resulting table will look like this:
```
df.show(10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package za.co.absa.cobrix.cobol.reader.iterator

import za.co.absa.cobrix.cobol.reader.parameters.ParameterParsingUtils

final class SegmentIdAccumulator (segmentIds: scala.collection.Seq[String], segmentIdPrefix: String, val fileId: Int) {
private val segmentIdsArr = segmentIds.toArray.map(_.split(","))
private val segmentIdsArr = ParameterParsingUtils.splitSegmentIds(segmentIds)

private val segmentIdCount = segmentIds.size
private val segmentIdAccumulator = new Array[Long](segmentIdCount + 1)
private var currentLevel = -1
Expand Down Expand Up @@ -77,7 +80,7 @@ final class SegmentIdAccumulator (segmentIds: scala.collection.Seq[String], segm
var level: Option[Int] = None
var i = 0
while (level.isEmpty && i<segmentIdCount) {
if (segmentIdsArr(i).contains(id))
if (segmentIdsArr(i).contains(id) || segmentIdsArr(i).contains("_"))
level = Some(i)
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.cobol.reader.parameters

object ParameterParsingUtils {
/** Splits segment ids defined in spark-cobol options for hierarchical id generation. */
def splitSegmentIds(segmentIdsToSplit: scala.collection.Seq[String]): Array[Array[String]] = {
segmentIdsToSplit.toArray
.map{ ids =>
ids.split(',')
.map(_.trim())
.map(id => if (id == "*") "_" else id)
}
}

/** Validates segment ids for hierarchical record id generation. */
def validateSegmentIds(segmentIds: Array[Array[String]]): Unit = {
val maxLevel = segmentIds.length - 1
segmentIds.zipWithIndex.foreach {
case (ids, level) =>
if (ids.contains("_") && level < maxLevel)
throw new IllegalArgumentException(s"The '_' as a segment id can only be used on the leaf level (segment_id_level$maxLevel), found at 'segment_id_level$level'")
if (ids.contains("*") && level < maxLevel)
throw new IllegalArgumentException(s"The '*' as a segment id can only be used on the leaf level (segment_id_level$maxLevel), found at 'segment_id_level$level'")
if ((ids.contains("*") || ids.contains("_")) && ids.length > 1)
throw new IllegalArgumentException(s"The '*' or '_' as a segment id cannot be used with other ids 'segment_id_level$level = ${ids.mkString(",")}' is incorrect.")
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.cobol.reader.parameters

import org.scalatest.wordspec.AnyWordSpec

class ParameterParsingUtilsSuite extends AnyWordSpec {
"splitSegmentIds()" should {
"split input segment ids" in {
val segmentIds = Seq("A,B,C", "D,E,F")

val actual = ParameterParsingUtils.splitSegmentIds(segmentIds)

assert(actual.length == 2)
assert(actual(0).sameElements(Array("A", "B", "C")))
assert(actual(1).sameElements(Array("D", "E", "F")))
}

"trim if split with spaces" in {
val segmentIds = Seq("A, B, C", "D, E, F")

val actual = ParameterParsingUtils.splitSegmentIds(segmentIds)

assert(actual.length == 2)
assert(actual(0).sameElements(Array("A", "B", "C")))
assert(actual(1).sameElements(Array("D", "E", "F")))
}

"handle empty strings" in {
val segmentIds = Seq("", "")

val actual = ParameterParsingUtils.splitSegmentIds(segmentIds)

assert(actual.length == 2)
assert(actual(0).head == "")
assert(actual(1).head == "")
}

}


"validateSegmentIds()" should {
"validate segment ids" in {
val segmentIds = Array(
Array("A", "B", "C"),
Array("D", "E", "F")
)

ParameterParsingUtils.validateSegmentIds(segmentIds)
}

"throw an exception if '_' is used on the wrong level" in {
val segmentIds = Array(
Array("_"),
Array("A", "B", "C")
)

val ex = intercept[IllegalArgumentException] {
ParameterParsingUtils.validateSegmentIds(segmentIds)
}

assert(ex.getMessage.contains("The '_' as a segment id can only be used on the leaf level (segment_id_level1), found at 'segment_id_level0'"))
}

"throw an exception if '*' is used on the wrong level" in {
val segmentIds = Array(
Array("A"),
Array("B"),
Array("*"),
Array("C")
)

val ex = intercept[IllegalArgumentException] {
ParameterParsingUtils.validateSegmentIds(segmentIds)
}

assert(ex.getMessage.contains("The '*' as a segment id can only be used on the leaf level (segment_id_level3), found at 'segment_id_level2'"))
}

"throw an exception if '*' or '_' is used with other ids" in {
val segmentIds = Array(
Array("A", "B", "C"),
Array("D", "*", "F", "G")
)

val ex = intercept[IllegalArgumentException] {
ParameterParsingUtils.validateSegmentIds(segmentIds)
}

assert(ex.getMessage.contains("'*' or '_' as a segment id cannot be used with other ids"))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,19 @@ object CobolParametersParser extends Logging {
private def parseMultisegmentParameters(params: Parameters): Option[MultisegmentParameters] = {
if (params.contains(PARAM_SEGMENT_FIELD)) {
val levels = parseSegmentLevels(params)
Some(MultisegmentParameters
(
val multiseg = MultisegmentParameters(
params(PARAM_SEGMENT_FIELD),
params.get(PARAM_SEGMENT_FILTER).map(_.split(',')),
levels,
params.getOrElse(PARAM_SEGMENT_ID_PREFIX, ""),
getSegmentIdRedefineMapping(params),
getSegmentRedefineParents(params)
))
)

val segmentIds = ParameterParsingUtils.splitSegmentIds(multiseg.segmentLevelIds)
ParameterParsingUtils.validateSegmentIds(segmentIds)

Some(multiseg)
}
else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,72 @@ class Test17HierarchicalSpec extends AnyWordSpec with SparkTestBase with CobolTe

testData(actualDf, actualResultsPath, expectedResultsPath)
}

"return a dataframe with ids generated when _ notation is used" in {
val df = spark
.read
.format("cobol")
.option("copybook", inputCopybookPath)
.option("pedantic", "true")
.option("record_format", "V")
.option("generate_record_id", "true")
.option("schema_retention_policy", "collapse_root")
.option("segment_field", "SEGMENT_ID")
.option("segment_id_level0", "1")
.option("segment_id_level1", "2,5")
.option("segment_id_level2", "_")
.option("segment_id_prefix", "A")
.option("redefine_segment_id_map:1", "COMPANY => 1")
.option("redefine-segment-id-map:2", "DEPT => 2")
.option("redefine-segment-id-map:3", "EMPLOYEE => 3")
.option("redefine-segment-id-map:4", "OFFICE => 4")
.option("redefine-segment-id-map:5", "CUSTOMER => 5")
.option("redefine-segment-id-map:6", "CONTACT => 6")
.option("redefine-segment-id-map:7", "CONTRACT => 7")
.load(inputDataPath)

testSchema(df, actualSchemaPath, expectedSchemaPath)

val actualDf = df
.orderBy("File_Id", "Record_Id")
.toJSON
.take(300)

testData(actualDf, actualResultsPath, expectedResultsPath)
}

"return a dataframe with ids generated when * notation is used" in {
val df = spark
.read
.format("cobol")
.option("copybook", inputCopybookPath)
.option("pedantic", "true")
.option("record_format", "V")
.option("generate_record_id", "true")
.option("schema_retention_policy", "collapse_root")
.option("segment_field", "SEGMENT_ID")
.option("segment_id_level0", "1")
.option("segment_id_level1", "2,5")
.option("segment_id_level2", "*")
.option("segment_id_prefix", "A")
.option("redefine_segment_id_map:1", "COMPANY => 1")
.option("redefine-segment-id-map:2", "DEPT => 2")
.option("redefine-segment-id-map:3", "EMPLOYEE => 3")
.option("redefine-segment-id-map:4", "OFFICE => 4")
.option("redefine-segment-id-map:5", "CUSTOMER => 5")
.option("redefine-segment-id-map:6", "CONTACT => 6")
.option("redefine-segment-id-map:7", "CONTRACT => 7")
.load(inputDataPath)

testSchema(df, actualSchemaPath, expectedSchemaPath)

val actualDf = df
.orderBy("File_Id", "Record_Id")
.toJSON
.take(300)

testData(actualDf, actualResultsPath, expectedResultsPath)
}
}

"read as a hierarchical file with parent child relationships defined" should {
Expand Down

0 comments on commit 6368a8b

Please sign in to comment.