diff --git a/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala b/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala index dce7a054..9505b195 100644 --- a/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala +++ b/src/main/scala/com/yotpo/metorikku/code/steps/SelectiveMerge.scala @@ -3,9 +3,11 @@ package com.yotpo.metorikku.code.steps import com.yotpo.metorikku.exceptions.MetorikkuException import org.apache.log4j.{LogManager, Logger} import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.functions._ +import scala.collection.mutable.ListBuffer + object SelectiveMerge { private val message = "You need to send 3 parameters with the names of the dataframes to merge and the key(s) to merge on" + "(merged df1 into df2 favoring values from df2): df1, df2, Seq[String]" @@ -58,8 +60,28 @@ object SelectiveMerge { } def merge(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): DataFrame = { - val mergedDf = outerJoinWithAliases(df1, df2, joinKeys) - overrideConflictingValues(df1, df2, mergedDf, joinKeys) + val df2NoStaleEntries = removeStaleEntries(df1, df2, joinKeys) + val mergedDf = outerJoinWithAliases(df1, df2NoStaleEntries, joinKeys) + overrideConflictingValues(df1, df2NoStaleEntries, mergedDf, joinKeys) + } + + def removeStaleEntries(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): DataFrame = { + var df2New = df2 + for (key <- joinKeys) { + val diff = df2.select(col(key)).except(df1.select(col(key))) + var toRemoveBuff = new ListBuffer[Row]() + + val localIter = diff.toLocalIterator() + while(localIter.hasNext) { + toRemoveBuff += localIter.next() + } + + val toRemove = toRemoveBuff.toList.map(r => r.getAs[String](key)) + + df2New = df2New.filter(!df2New(key).isin(toRemove:_*)) + } + + df2New } def outerJoinWithAliases(df1: DataFrame, df2: DataFrame, joinKeys: Seq[String]): DataFrame = { @@ -91,6 +113,8 @@ object SelectiveMerge { } def overrideConflictingValues(df1: DataFrame, df2: DataFrame, mergedDf: DataFrame, joinKeys: Seq[String]): DataFrame = { + val df1SchemaNames = df1.schema.map(f => f.name) + val mergedSchema = getMergedSchema(df1, df2, joinKeys) mergedDf.select( @@ -100,9 +124,16 @@ object SelectiveMerge { val colNameArr = colName.split(colRenamePrefix) val colNameOrig = if (colNameArr.size > 1) colNameArr(1) else colName - // Belongs to DF2, override. + // Column appears in DF2, override unless the row only belongs to DF1 if (colNameArr.size > 1) { - mergedDf(colName).alias(colNameOrig) + if (df1SchemaNames.contains(colNameOrig)) { + when(mergedDf(colName).isNotNull, mergedDf(colName).cast(df1.schema(colNameOrig).dataType)) + .otherwise(df1(colNameOrig)) + .alias(colNameOrig) + } + else { + mergedDf(colName).alias(colNameOrig) + } } // Is the join key(s) else if (joinKeys.contains(colName)) { diff --git a/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala b/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala index 44ab88c1..e4006b93 100644 --- a/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala +++ b/src/test/scala/com/yotpo/metorikku/code/steps/test/SelectiveMergeTests.scala @@ -53,47 +53,36 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { showString.invoke(df, 10.asInstanceOf[Object], 20.asInstanceOf[Object], false.asInstanceOf[Object]).asInstanceOf[String] } - test("Selective merge") { + test("Equal number of columns") { val sparkSession = SparkSession.builder.appName("test").getOrCreate() - val sqlContext= sparkSession.sqlContext + val sqlContext= new SQLContext(sparkSession.sparkContext) import sqlContext.implicits._ val employeeData1 = Seq( - ("James", 1, 11, 111, 1111), - ("Maria", 2, 22, 222, 2222) + ("James", 20, 10000), + ("Maria", 30, 20000) ) - val df1 = employeeData1.toDF("employee_name", "salary", "age", "fake", "fake2") + val df1 = employeeData1.toDF("employee_name", "age", "salary") val employeeData2 = Seq( - ("James", 1, 33, 333), - ("Jen", 4, 44, 444), - ("Jeff", 5, 55, 555) + ("James", 21, 1), + ("Jen", 40, 2) ) - val df2 = employeeData2.toDF("employee_name", "salary", "age", "bonus") + val df2 = employeeData2.toDF("employee_name", "age", "new_formula") val simpleDataExpectedAfterMerge = Seq( - ("James", Integer.valueOf(1) /* Salary */, Integer.valueOf(33) /* age */, Integer.valueOf(111) /* fake */, - Integer.valueOf(1111) /* fake2 */, Integer.valueOf(333) /* bonus */), - ("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, Integer.valueOf(222) /* fake */, - Integer.valueOf(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */), - ("Jen", Integer.valueOf(4) /* Salary */, Integer.valueOf(44) /* age */, null.asInstanceOf[Integer] /* fake */, - null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(444) /* bonus */), - ("Jeff", Integer.valueOf(5) /* Salary */, Integer.valueOf(55) /* age */, null.asInstanceOf[Integer] /* fake */, - null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(555) /* bonus */) + ("James", new Integer(21) /* age */, new Integer(10000) /* salary */, new Integer(1) /* new formula */), + ("Maria", new Integer(30) /* age */, new Integer(20000) /* salary */, null.asInstanceOf[Integer] /* new_formula */) ) - val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "salary", "age", "fake", "fake2", "bonus") + val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "age", "salary", "new_formula") val simpleDataNotExpectedAfterMerge = Seq( - ("James", Integer.valueOf(10) /* Salary */, Integer.valueOf(33) /* age */, Integer.valueOf(111) /* fake */, - Integer.valueOf(1111) /* fake2 */, Integer.valueOf(333) /* bonus */), - ("Maria", Integer.valueOf(20) /* Salary */, Integer.valueOf(22) /* age */, Integer.valueOf(222) /* fake */, - Integer.valueOf(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */), - ("Jen", Integer.valueOf(40) /* Salary */, Integer.valueOf(44) /* age */, null.asInstanceOf[Integer] /* fake */, - null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(444) /* bonus */), - ("Jeff", Integer.valueOf(50) /* Salary */, Integer.valueOf(55) /* age */, null.asInstanceOf[Integer] /* fake */, - null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(555) /* bonus */) + ("James", new Integer(10) /* age */, new Integer(33) /* salary */, new Integer(111) /* new formula */), + ("Maria", new Integer(20) /* age */, new Integer(22) /* salary */, new Integer(222) /* new formula */), + ("Jen", new Integer(40) /* age */, new Integer(44) /* salary */, null.asInstanceOf[Integer] /* new formula */), + ("Jeff", new Integer(50) /* age */, new Integer(55) /* salary */, null.asInstanceOf[Integer] /* new formula */) ) - val notExpectedDf = simpleDataNotExpectedAfterMerge.toDF("employee_name", "salary", "age", "fake", "fake2", "bonus") + val notExpectedDf = simpleDataNotExpectedAfterMerge.toDF("employee_name", "age", "salary", "new_formula") val mergedDf = merge(df1, df2, Seq("employee_name")) @@ -101,9 +90,9 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { assertSuccess(mergedDf, notExpectedDf, isEqual = false) } - test("String and numbers mixed fields") { + test("Df2 has more columns") { val sparkSession = SparkSession.builder.appName("test").getOrCreate() - val sqlContext= sparkSession.sqlContext + val sqlContext= new SQLContext(sparkSession.sparkContext) import sqlContext.implicits._ val employeeData1 = Seq( @@ -120,14 +109,10 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { val df2 = employeeData2.toDF("employee_name", "salary", "age", "bonus") val simpleDataExpectedAfterMerge = Seq( - ("James", "Sharon" /* Last Name */, Integer.valueOf(1) /* Salary */, Integer.valueOf(33) /* age */, - Integer.valueOf(111) /* fake */, Integer.valueOf(1111) /* fake2 */, Integer.valueOf(333) /* bonus */), - ("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, - Integer.valueOf(222) /* fake */, Integer.valueOf(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */), - ("Jen", null.asInstanceOf[String] /* Last Name */, Integer.valueOf(4) /* Salary */, Integer.valueOf(44) /* age */, - null.asInstanceOf[Integer] /* fake */, null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(444) /* bonus */), - ("Jeff", null.asInstanceOf[String] /* Last Name */, Integer.valueOf(5) /* Salary */, Integer.valueOf(55) /* age */, - null.asInstanceOf[Integer] /* fake */, null.asInstanceOf[Integer] /* fake2 */, Integer.valueOf(555) /* bonus */) + ("James", "Sharon" /* Last Name */, new Integer(1) /* Salary */, new Integer(33) /* age */, + new Integer(111) /* fake */, new Integer(1111) /* fake2 */, new Integer(333) /* bonus */), + ("Maria", "Bob" /* Last Name */, null.asInstanceOf[Integer] /* Salary */, new Integer(22) /* age */, + new Integer(222) /* fake */, new Integer(2222) /* fake2 */, null.asInstanceOf[Integer] /* bonus */) ) val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "last_name", "salary", "age", "fake", "fake2", "bonus") @@ -136,9 +121,9 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { assertSuccess(mergedDf, expectedDf, isEqual = true) } - test("df2 has more columns") { + test("df1 has more columns") { val sparkSession = SparkSession.builder.appName("test").getOrCreate() - val sqlContext= sparkSession.sqlContext + val sqlContext= new SQLContext(sparkSession.sparkContext) import sqlContext.implicits._ val employeeData1 = Seq( @@ -155,13 +140,11 @@ class SelectiveMergeTests extends FunSuite with BeforeAndAfterEach { val df2 = employeeData2.toDF("employee_name", "salary", "age", "bonus", "fake") val simpleDataExpectedAfterMerge = Seq( - ("James", Integer.valueOf(10) /* Salary */, Integer.valueOf(33) /* age */, - Integer.valueOf(333) /* Bonus */, Integer.valueOf(3333) /* fake */), - ("Maria", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, + ("James", new Integer(10) /* Salary */, new Integer(33) /* age */, + new Integer(333) /* Bonus */, new Integer(3333) /* fake */), + ("Maria", new Integer(2) /* Salary */, new Integer(22) /* age */, null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */), - ("Jen", Integer.valueOf(4) /* Salary */, Integer.valueOf(44) /* age */, - Integer.valueOf(444) /* Bonus */, Integer.valueOf(4444) /* fake */), - ("Albert", null.asInstanceOf[Integer] /* Salary */, null.asInstanceOf[Integer] /* age */, + ("Albert", new Integer(3) /* Salary */, new Integer(33) /* age */, null.asInstanceOf[Integer] /* Bonus */, null.asInstanceOf[Integer] /* fake */) ) val expectedDf = simpleDataExpectedAfterMerge.toDF("employee_name", "salary", "age", "bonus", "fake")