diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala index 4788aa23f..ca96c126f 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala @@ -277,6 +277,26 @@ class FlintSparkPPLFillnullITSuite assert(ex.getMessage().contains("Syntax error ")) } + test("test fillnull with null_replacement type mismatch") { + val frame = sql(s""" + | source = $testTable | fillnull with cast(0 as long) in status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", null, 200), + Row(2, "/about", "2023-10-01 10:05:00", 0), + Row(3, "/contact", "2023-10-01 10:10:00", 0), + Row(4, null, "2023-10-01 10:15:00", 301), + Row(5, null, "2023-10-01 10:20:00", 200), + Row(6, "/home", null, 403)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + private def fillNullExpectedPlan( nullReplacements: Seq[(String, Expression)], addDefaultProject: Boolean = true): LogicalPlan = { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index b78471591..d7f59bae3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Generate; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$; import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; @@ -452,10 +453,30 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - LogicalPlan resultWithoutDuplicatedColumns = context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(toDrop), logicalPlan)); + LogicalPlan resultWithoutDuplicatedColumns = context.apply(dropOriginalColumns(p -> p.children().head(), toDrop)); return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + /** + * This method is used to generate DataFrameDropColumns operator for dropping duplicated columns + * in the original plan. Then achieving similar effect like updating columns. + * + * PLAN_ID_TAG is a mechanism inner Spark that explicitly specify a plan to resolve the + * UnresolvedAttributes. Set toDrop expressions' PLAN_ID_TAG to the same value as that of the + * original plan, so Spark will resolve them correctly by that plan instead of the child. + */ + private java.util.function.Function dropOriginalColumns( + java.util.function.Function findOriginalPlan, + List toDrop) { + return logicalPlan -> { + LogicalPlan originalPlan = findOriginalPlan.apply(logicalPlan); + long planId = logicalPlan.hashCode(); + originalPlan.setTagValue(LogicalPlan$.MODULE$.PLAN_ID_TAG(), planId); + toDrop.forEach(e -> e.setTagValue(LogicalPlan$.MODULE$.PLAN_ID_TAG(), planId)); + return DataFrameDropColumns$.MODULE$.apply(seq(toDrop), logicalPlan); + }; + } + @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { visitFirstChild(flatten, context);