From c7aeab99163e75e7d9298d58806383634a6b1c4d Mon Sep 17 00:00:00 2001 From: Doe-Ed <63307024+Doe-Ed@users.noreply.github.com> Date: Mon, 18 May 2020 17:13:57 +0300 Subject: [PATCH] fix(SelectiveMerge):always override with df2's column (if exists) (#322) * fix(SelectiveMerge):always override with df2's column (if exists) * fix(SelectiveMerge):Favor not favour --- .../metorikku/code/steps/SelectiveMerge.scala | 64 +++++++++---------- .../code/steps/test/SelectiveMergeTests.scala | 9 +-- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala b/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala index 88d4cd7a2..bce47d5a4 100644 --- a/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala +++ b/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala @@ -5,14 +5,13 @@ import org.apache.log4j.{LogManager, Logger} import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.StructType object SelectiveMerge { private val message = "You need to send 3 parameters with the names of the dataframes to merge and the key(s) to merge on" + - "(merged df1 into df2 favouring values from df2): df1, df2, Seq[String]" + "(merged df1 into df2 favoring values from df2): df1, df2, Seq[String]" private val log: Logger = LogManager.getLogger(this.getClass) - private val colRenamePrefix = "df2_" - private val colRenamePrefixLen = colRenamePrefix.length + private val colRenameSuffixLength = 10000 // (5 digits) + private val colRenamePrefix = scala.util.Random.nextInt(colRenameSuffixLength).toString private class InputMatcher[K](ks: K*) { def unapplySeq[V](m: Map[K, V]): Option[Seq[V]] = if (ks.forall(m.contains)) Some(ks.map(m)) else None } @@ -37,6 +36,7 @@ object SelectiveMerge { df1.createOrReplaceTempView(dataFrameName) } else { + logOverrides(df1, df2, joinKeys) merge(df1, df2, joinKeys).createOrReplaceTempView(dataFrameName) } } @@ -44,6 +44,19 @@ object SelectiveMerge { } } + def logOverrides(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): Unit = { + val df1SchemaTitles = df1.schema.map(f => f.name).toList + val df2SchemaTitles = df2.schema.map(f => f.name).toList + + val overridenColumns = df2SchemaTitles.filter(p => df1SchemaTitles.contains(p) && !joinKeys.contains(p)) + val df1OnlyColumns = df1SchemaTitles diff df2SchemaTitles + val df2OnlyColumns = df2SchemaTitles diff df1SchemaTitles + + log.info("DF1 columns which will be overriden: " + overridenColumns) + log.info("DF1 columns which are not found in DF2: " + df1OnlyColumns) + log.info("DF2 columns which are not found in DF1: " + df2OnlyColumns) + } + def merge(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): DataFrame = { val mergedDf = outerJoinWithAliases(df1, df2, joinKeys) overrideConflictingValues(df1, df2, mergedDf, joinKeys) @@ -62,17 +75,6 @@ object SelectiveMerge { ).join(df1, joinKeys,"outer") } - def isMemberOfDf1(schema: StructType, name: String): Boolean = { - val schemaNames = schema.map(f => f.name) - - if (name.startsWith(colRenamePrefix)) { - schemaNames.contains(name.substring(colRenamePrefixLen)) - } - else { - false - } - } - def getMergedSchema(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): Seq[Column] = { val mergedSchemaNames = (df1.schema.map(f => f.name) ++ df2.schema.map(f => f.name)).distinct @@ -88,36 +90,30 @@ object SelectiveMerge { mergedSchema } - def overrideConflictingValues(df1: DataFrame, df2: DataFrame, mergedDf: DataFrame, joinKeys: Seq[String]): DataFrame = { val mergedSchema = getMergedSchema(df1, df2, joinKeys) - val mergedDfBuilder = mergedDf.select( + mergedDf.select( mergedSchema.map{ case (currColumn: Column) => { val colName = currColumn.expr.asInstanceOf[NamedExpression].name - // Column is a part of df1 and doesn't belong to the join keys. - if (isMemberOfDf1(df1.schema, colName) && !joinKeys.contains(colName)) { - val origColName = colName.substring(colRenamePrefixLen) - when(mergedDf(colName).isNotNull, mergedDf(colName).cast(df1.schema(origColName).dataType)) - .otherwise(df1(origColName)) - .alias(origColName) + val colNameArr = colName.split(colRenamePrefix) + val colNameOrig = if (colNameArr.size > 1) colNameArr(1) else colName + + // Belongs to DF2, override. + if (colNameArr.size > 1) { + mergedDf(colName).alias(colNameOrig) + } + // Is the join key(s) + else if (joinKeys.contains(colName)) { + mergedDf(colName) } - // Column doesn't belong to df1 or is join key + // Only exists in DF1. else { - // Column doesn't belong to df1 - if (colName.startsWith(colRenamePrefix)) { - currColumn.alias(colName.substring(colRenamePrefixLen)) - } - // Column is the merge key - else { - currColumn - } + df1(colName) } } }: _* ) - - mergedDfBuilder } } diff --git a/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala b/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala index 3496b2386..99dad6a1b 100644 --- a/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala +++ b/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala @@ -4,6 +4,7 @@ import com.yotpo.metorikku.code.steps.SelectiveMerge import com.yotpo.metorikku.code.steps.SelectiveMerge.merge import com.yotpo.metorikku.exceptions.MetorikkuException import org.apache.log4j.{Level, LogManager, Logger} +import org.apache.spark import org.apache.spark.sql.types.StructField import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.scalatest.{FunSuite, _} @@ -74,7 +75,7 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { val simpleDataExpectedAfterMerge = Seq( ("James", new Integer(1) /* Salary */, new Integer(33) /* age */, new Integer(111) /* fake */, new Integer(1111) /* fake2 */, new Integer(333) /* bonus */), - ("Maria", new Integer(2) /* Salary */, new Integer(22) /* age */, new Integer(222) /* fake */, + ("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, new Integer(222) /* fake */, new Integer(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */), ("Jen", new Integer(4) /* Salary */, new Integer(44) /* age */, null.asInstanceOf[Integer] /* fake */, null.asInstanceOf[Integer] /* fake2 */, new Integer(444) /* bonus */), @@ -122,7 +123,7 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { val simpleDataExpectedAfterMerge = Seq( ("James", "Sharon" /* Last Name */, new Integer(1) /* Salary */, new Integer(33) /* age */, new Integer(111) /* fake */, new Integer(1111) /* fake2 */, new Integer(333) /* bonus */), - ("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, new Integer(22) /* age */, + ("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, new Integer(222) /* fake */, new Integer(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */), ("Jen", null.asInstanceOf[String] /* Last Name */, new Integer(4) /* Salary */, new Integer(44) /* age */, null.asInstanceOf[Integer] /* fake */, null.asInstanceOf[Integer] /* fake2 */, new Integer(444) /* bonus */), @@ -157,11 +158,11 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { val simpleDataExpectedAfterMerge = Seq( ("James", new Integer(10) /* Salary */, new Integer(33) /* age */, new Integer(333) /* Bonus */, new Integer(3333) /* fake */), - ("Maria", new Integer(2) /* Salary */, new Integer(22) /* age */, + ("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */), ("Jen", new Integer(4) /* Salary */, new Integer(44) /* age */, new Integer(444) /* Bonus */, new Integer(4444) /* fake */), - ("Albert",new Integer(3) /* Salary */, new Integer(33) /* age */, + ("Albert", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */) ) val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "salary", "age", "bonus", "fake")