Skip to content

Commit

Permalink
fix(SelectiveMerge):always override with df2's column (if exists) (#322)
Browse files Browse the repository at this point in the history
* fix(SelectiveMerge):always override with df2's column (if exists)

* fix(SelectiveMerge):Favor not favour
  • Loading branch information
Doe-Ed authored May 18, 2020
1 parent f077829 commit c7aeab9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
64 changes: 30 additions & 34 deletions src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType

object SelectiveMerge {
private val message = "You need to send 3 parameters with the names of the dataframes to merge and the key(s) to merge on" +
"(merged df1 into df2 favouring values from df2): df1, df2, Seq[String]"
"(merged df1 into df2 favoring values from df2): df1, df2, Seq[String]"
private val log: Logger = LogManager.getLogger(this.getClass)
private val colRenamePrefix = "df2_"
private val colRenamePrefixLen = colRenamePrefix.length
private val colRenameSuffixLength = 10000 // (5 digits)
private val colRenamePrefix = scala.util.Random.nextInt(colRenameSuffixLength).toString
private class InputMatcher[K](ks: K*) {
def unapplySeq[V](m: Map[K, V]): Option[Seq[V]] = if (ks.forall(m.contains)) Some(ks.map(m)) else None
}
Expand All @@ -37,13 +36,27 @@ object SelectiveMerge {
df1.createOrReplaceTempView(dataFrameName)
}
else {
logOverrides(df1, df2, joinKeys)
merge(df1, df2, joinKeys).createOrReplaceTempView(dataFrameName)
}
}
case _ => throw MetorikkuException(message)
}
}

def logOverrides(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): Unit = {
val df1SchemaTitles = df1.schema.map(f => f.name).toList
val df2SchemaTitles = df2.schema.map(f => f.name).toList

val overridenColumns = df2SchemaTitles.filter(p => df1SchemaTitles.contains(p) && !joinKeys.contains(p))
val df1OnlyColumns = df1SchemaTitles diff df2SchemaTitles
val df2OnlyColumns = df2SchemaTitles diff df1SchemaTitles

log.info("DF1 columns which will be overriden: " + overridenColumns)
log.info("DF1 columns which are not found in DF2: " + df1OnlyColumns)
log.info("DF2 columns which are not found in DF1: " + df2OnlyColumns)
}

def merge(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): DataFrame = {
val mergedDf = outerJoinWithAliases(df1, df2, joinKeys)
overrideConflictingValues(df1, df2, mergedDf, joinKeys)
Expand All @@ -62,17 +75,6 @@ object SelectiveMerge {
).join(df1, joinKeys,"outer")
}

def isMemberOfDf1(schema: StructType, name: String): Boolean = {
val schemaNames = schema.map(f => f.name)

if (name.startsWith(colRenamePrefix)) {
schemaNames.contains(name.substring(colRenamePrefixLen))
}
else {
false
}
}

def getMergedSchema(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): Seq[Column] = {
val mergedSchemaNames = (df1.schema.map(f => f.name) ++ df2.schema.map(f => f.name)).distinct

Expand All @@ -88,36 +90,30 @@ object SelectiveMerge {
mergedSchema
}


def overrideConflictingValues(df1: DataFrame, df2: DataFrame, mergedDf: DataFrame, joinKeys: Seq[String]): DataFrame = {
val mergedSchema = getMergedSchema(df1, df2, joinKeys)

val mergedDfBuilder = mergedDf.select(
mergedDf.select(
mergedSchema.map{
case (currColumn: Column) => {
val colName = currColumn.expr.asInstanceOf[NamedExpression].name
// Column is a part of df1 and doesn't belong to the join keys.
if (isMemberOfDf1(df1.schema, colName) && !joinKeys.contains(colName)) {
val origColName = colName.substring(colRenamePrefixLen)
when(mergedDf(colName).isNotNull, mergedDf(colName).cast(df1.schema(origColName).dataType))
.otherwise(df1(origColName))
.alias(origColName)
val colNameArr = colName.split(colRenamePrefix)
val colNameOrig = if (colNameArr.size > 1) colNameArr(1) else colName

// Belongs to DF2, override.
if (colNameArr.size > 1) {
mergedDf(colName).alias(colNameOrig)
}
// Is the join key(s)
else if (joinKeys.contains(colName)) {
mergedDf(colName)
}
// Column doesn't belong to df1 or is join key
// Only exists in DF1.
else {
// Column doesn't belong to df1
if (colName.startsWith(colRenamePrefix)) {
currColumn.alias(colName.substring(colRenamePrefixLen))
}
// Column is the merge key
else {
currColumn
}
df1(colName)
}
}
}: _*
)

mergedDfBuilder
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.yotpo.metorikku.code.steps.SelectiveMerge
import com.yotpo.metorikku.code.steps.SelectiveMerge.merge
import com.yotpo.metorikku.exceptions.MetorikkuException
import org.apache.log4j.{Level, LogManager, Logger}
import org.apache.spark
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.scalatest.{FunSuite, _}
Expand Down Expand Up @@ -74,7 +75,7 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach {
val simpleDataExpectedAfterMerge = Seq(
("James", new Integer(1) /* Salary */, new Integer(33) /* age */, new Integer(111) /* fake */,
new Integer(1111) /* fake2 */, new Integer(333) /* bonus */),
("Maria", new Integer(2) /* Salary */, new Integer(22) /* age */, new Integer(222) /* fake */,
("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, new Integer(222) /* fake */,
new Integer(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */),
("Jen", new Integer(4) /* Salary */, new Integer(44) /* age */, null.asInstanceOf[Integer] /* fake */,
null.asInstanceOf[Integer] /* fake2 */, new Integer(444) /* bonus */),
Expand Down Expand Up @@ -122,7 +123,7 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach {
val simpleDataExpectedAfterMerge = Seq(
("James", "Sharon" /* Last Name */, new Integer(1) /* Salary */, new Integer(33) /* age */,
new Integer(111) /* fake */, new Integer(1111) /* fake2 */, new Integer(333) /* bonus */),
("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, new Integer(22) /* age */,
("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */,
new Integer(222) /* fake */, new Integer(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */),
("Jen", null.asInstanceOf[String] /* Last Name */, new Integer(4) /* Salary */, new Integer(44) /* age */,
null.asInstanceOf[Integer] /* fake */, null.asInstanceOf[Integer] /* fake2 */, new Integer(444) /* bonus */),
Expand Down Expand Up @@ -157,11 +158,11 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach {
val simpleDataExpectedAfterMerge = Seq(
("James", new Integer(10) /* Salary */, new Integer(33) /* age */,
new Integer(333) /* Bonus */, new Integer(3333) /* fake */),
("Maria", new Integer(2) /* Salary */, new Integer(22) /* age */,
("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */,
null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */),
("Jen", new Integer(4) /* Salary */, new Integer(44) /* age */,
new Integer(444) /* Bonus */, new Integer(4444) /* fake */),
("Albert",new Integer(3) /* Salary */, new Integer(33) /* age */,
("Albert", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */,
null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */)
)
val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "salary", "age", "bonus", "fake")
Expand Down

0 comments on commit c7aeab9

Please sign in to comment.