diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 47ade0f87..175436fbf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -12,7 +12,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode} -import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex @@ -339,7 +339,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { logInfo("Start refreshing index in foreach streaming style") val job = spark.readStream .options(options.extraSourceOptions(tableName)) - .table(tableName) + .table(quotedTableName(tableName)) .writeStream .queryName(indexName) .addSinkOptions(options) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index af1e9fa74..248d105a2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -95,6 +95,21 @@ object FlintSparkIndex { s"flint_${parts(0)}_${parts(1)}_${parts.drop(2).mkString(".")}" } + /** + * Add backticks to table name to escape special character + * + * @param fullTableName + * source full table name + * @return + * quoted table name + */ + def quotedTableName(fullTableName: String): String = { + require(fullTableName.split('.').length >= 3, s"Table name $fullTableName is not qualified") + + val parts = fullTableName.split('.') + s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`" + } + /** * Populate environment variables to persist in Flint metadata. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index cdb3a3462..e23126c68 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -60,7 +60,7 @@ case class FlintSparkCoveringIndex( override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val colNames = indexedColumns.keys.toSeq - val job = df.getOrElse(spark.read.table(tableName)) + val job = df.getOrElse(spark.read.table(quotedTableName(tableName))) // Add optional filtering condition filterCondition diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index d83af5df5..d6ed6b07e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -77,7 +77,7 @@ case class FlintSparkSkippingIndex( new Column(aggFunc.toAggregateExpression().as(name)) } - df.getOrElse(spark.read.table(tableName)) + df.getOrElse(spark.read.table(quotedTableName(tableName))) .groupBy(input_file_name().as(FILE_PATH_COLUMN)) .agg(namedAggFuncs.head, namedAggFuncs.tail: _*) .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index f52e6ef85..1cce47d1a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.spark.covering +import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite @@ -30,6 +31,24 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { } } + test("can build index building job with unique ID column") { + val index = + new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) + + val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") + val indexDf = index.build(spark, Some(df)) + indexDf.schema.fieldNames should contain only ("name") + } + + test("can build index on table name with special characters") { + val testTableSpecial = "spark_catalog.default.test/2023/10" + val index = new FlintSparkCoveringIndex("ci", testTableSpecial, Map("name" -> "string")) + + val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") + val indexDf = index.build(spark, Some(df)) + indexDf.schema.fieldNames should contain only ("name") + } + test("should fail if no indexed column given") { assertThrows[IllegalArgumentException] { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 491b7811a..4f0d084c9 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -66,6 +66,18 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } + test("can build index on table name with special characters") { + val testTableSpecial = "spark_catalog.default.test/2023/10" + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) + when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr))) + val index = new FlintSparkSkippingIndex(testTableSpecial, Seq(indexCol)) + + val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") + val indexDf = index.build(spark, Some(df)) + indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) + } + test("can build index for boolean column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.kind).thenReturn(SkippingKind.PARTITION)