diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 56a928aa5..609fd7b4c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -14,6 +14,7 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.quoteIfNeeded import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} import org.apache.spark.sql.types.{MapType, StructType} @@ -150,9 +151,9 @@ object FlintSparkIndex extends Logging { val allOutputCols = df.schema.fields.map { field => field.dataType match { case _: StructType | _: MapType => - to_json(col(field.name)) + to_json(col(quoteIfNeeded(field.name))) case _ => - col(field.name) + col(quoteIfNeeded(field.name)) } } diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index 1d301087f..a6d771534 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -6,9 +6,12 @@ package org.apache.spark import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{Alias, CodegenObjectFactoryMode, Expression} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf @@ -68,4 +71,15 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(METADATA_CACHE_WRITE, "false") } } + + protected implicit class DataFrameExtensions(val df: DataFrame) { + + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 8415613e8..6d0d972f6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -10,7 +10,11 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { @@ -19,6 +23,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some(Add(UnresolvedAttribute("id"), Literal(10))) checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } @@ -47,7 +52,37 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = addIdColumn(df, options) - resultDf.columns should contain(ID_COLUMN) + resultDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("count")))))) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 + } + + test("should add ID column for aggregated query with quoted alias") { + val df = spark + .createDataFrame( + sparkContext.parallelize( + Seq( + Row(1, "Alice", Row("WA", "Seattle")), + Row(2, "Bob", Row("OR", "Portland")), + Row(3, "Alice", Row("WA", "Seattle")))), + StructType.fromDDL("id INT, name STRING, address STRUCT")) + .toDF("id", "name", "address") + .groupBy(col("name").as("test.name"), col("address").as("test.address")) + .count() + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some( + Sha1(ConcatWs(Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("test.name")), + new StructsToJson(UnresolvedAttribute(Seq("test.address"))), + UnresolvedAttribute(Seq("count")))))) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } @@ -92,7 +127,20 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = addIdColumn(aggregatedDf, options) - resultDf.columns should contain(ID_COLUMN) + resultDf.idColumn() shouldBe Some( + Sha1(ConcatWs(Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + new StructsToJson(UnresolvedAttribute(Seq("struct_col"))), + UnresolvedAttribute(Seq("subfield2")), + UnresolvedAttribute(Seq("count")))))) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 200efbe97..7774eb2fb 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -7,10 +7,9 @@ package org.opensearch.flint.spark.mv import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE -import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, DataFrameIdColumnExtractor, StreamingDslLogicalPlan} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock @@ -19,8 +18,8 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Expression, Literal, Sha1} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -385,15 +384,4 @@ object FlintSparkMaterializedViewSuite { logicalPlan) } } - - implicit class DataFrameIdColumnExtractor(val df: DataFrame) { - - def idColumn(): Option[Expression] = { - df.queryExecution.logical.collectFirst { case Project(projectList, _) => - projectList.collectFirst { case Alias(child, ID_COLUMN) => - child - } - }.flatten - } - } }