Skip to content

Commit

Permalink
#156 Add sum of truncated values as measure
Browse files Browse the repository at this point in the history
  • Loading branch information
Zejnilovic committed Oct 11, 2024
1 parent 9609081 commit d5df5a0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
5 changes: 4 additions & 1 deletion atum/src/main/scala/za/co/absa/atum/core/ControlType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ object ControlType {
case object DistinctCount extends ControlType("distinctCount", false)
case object AggregatedTotal extends ControlType("aggregatedTotal", true)
case object AbsAggregatedTotal extends ControlType("absAggregatedTotal", true)
case object AggregatedTruncTotal extends ControlType("aggregatedTruncTotal", true)
case object AbsAggregatedTruncTotal extends ControlType("absAggregatedTruncTotal", true)
case object HashCrc32 extends ControlType("hashCrc32", false)

val values: Seq[ControlType] = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, HashCrc32)
val values: Seq[ControlType] = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal,
AggregatedTruncTotal, AbsAggregatedTruncTotal, HashCrc32)
val valueNames: Seq[String] = values.map(_.value)

def getNormalizedValueName(input: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ object MeasurementProcessor {
.agg(sum(col(aggColName))).collect()(0)(0)
if (v == null) "" else v.toString
}
case AggregatedTruncTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(col(valueColumnName).cast(LongType))
aggregateColumn(ds, controlCol, aggCol)
}
case AbsAggregatedTruncTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(abs(col(valueColumnName).cast(LongType)))
aggregateColumn(ds, controlCol, aggCol)
}
}
}

Expand Down
36 changes: 33 additions & 3 deletions atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
)
))

val measurementsIntOverflow = List(
val measurementsIntOverflow: Seq[Measurement] = List(
Measurement(
controlName = "RecordCount",
controlType = ControlType.Count.value,
Expand Down Expand Up @@ -112,7 +112,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
assert(newMeasurements == measurementsIntOverflow)
}

val measurementsAggregation = List(
val measurementsAggregation: Seq[Measurement] = List(
Measurement(
controlName = "RecordCount",
controlType = ControlType.Count.value,
Expand Down Expand Up @@ -304,7 +304,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
assert(newMeasurements == measurements3)
}

val measurementsWithHash = List(
val measurementsWithHash: Seq[Measurement] = List(
Measurement(
controlName = "RecordCount",
controlType = ControlType.Count.value,
Expand Down Expand Up @@ -394,4 +394,34 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
assert(newMeasurements == measurementsAggregationShort)
}

val measurementsAggregatedTruncTotal: Seq[Measurement] = List(
Measurement(
controlName = "aggregatedTruncTotal",
controlType = "aggregatedTruncTotal",
controlCol = "price",
controlValue = "999"
),
Measurement(
controlName = "absAggregatedTruncTotal",
controlType = "absAggregatedTruncTotal",
controlCol = "price",
controlValue = "2999"
)
)

"aggregatedTruncTotal types" should "return truncated sum of values" in {
val inputDataJson = spark.sparkContext.parallelize(
s"""{"id": ${Long.MaxValue}, "price": -1000.000001, "order": { "orderid": 1, "items": 1 } } """ ::
s"""{"id": ${Long.MinValue}, "price": 1000.9, "order": { "orderid": -1, "items": -1 } } """ ::
s"""{"id": ${Long.MinValue}, "price": 999.999999, "order": { "orderid": -1, "items": -1 } } """ ::Nil)
val df = spark.read
.schema(schema)
.json(inputDataJson.toDS)

val processor = new MeasurementProcessor(measurementsAggregatedTruncTotal)
val newMeasurements = processor.measureDataset(df)

assert(newMeasurements == measurementsAggregatedTruncTotal)
}

}

0 comments on commit d5df5a0

Please sign in to comment.