diff --git a/docs/index.md b/docs/index.md index 8cf60f24a..84ba54d4b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -301,6 +301,8 @@ WITH ( Currently Flint index job ID is same as internal Flint index name in [OpenSearch](./index.md#OpenSearch) section below. +- **Recover Job**: Initiates a restart of the index refresh job and transition the Flint index to the 'refreshing' state. Additionally, it includes functionality to clean up the metadata log entry in the event that the Flint data index is no longer present in OpenSearch. + ```sql RECOVER INDEX JOB ``` diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index cf2cd2b6e..06c92882b 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -9,6 +9,7 @@ import java.util.concurrent.ScheduledExecutorService import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} /** @@ -120,4 +121,17 @@ package object flint { def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = { CatalogV2Util.loadTable(catalog, ident) } + + /** + * Find field with the given name under root field recursively. + * + * @param rootField + * root field struct + * @param fieldName + * field name to search + * @return + */ + def findField(rootField: StructType, fieldName: String): Option[StructField] = { + rootField.findNestedField(fieldName.split('.')).map(_._2) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index c197a0bd4..dc85affb1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -303,6 +303,20 @@ class FlintSpark(val spark: SparkSession) extends Logging { } } else { logInfo("Index to be recovered either doesn't exist or not auto refreshed") + if (index.isEmpty) { + /* + * If execution reaches this point, it indicates that the Flint index is corrupted. + * In such cases, clean up the metadata log, as the index data no longer exists. + * There is a very small possibility that users may recreate the index in the + * interim, but metadata log get deleted by this cleanup process. + */ + logWarning("Cleaning up metadata log as index data has been deleted") + flintClient + .startTransaction(indexName, dataSourceName) + .initialLog(_ => true) + .finalLog(_ => NO_LOG_ENTRY) + .commit(_ => {}) + } false } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index d429244d1..0b2a84519 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -9,8 +9,8 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.apache.spark.sql.catalog.Column import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.flint.{findField, loadTable, parseTableName, qualifyTableName} +import org.apache.spark.sql.types.{StructField, StructType} /** * Flint Spark index builder base class. @@ -27,15 +27,14 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { protected var indexOptions: FlintSparkIndexOptions = empty /** All columns of the given source table */ - lazy protected val allColumns: Map[String, Column] = { + lazy protected val allColumns: StructType = { require(qualifiedTableName.nonEmpty, "Source table name is not provided") val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName) val table = loadTable(catalog, ident).getOrElse( throw new IllegalStateException(s"Table $qualifiedTableName is not found")) - val allFields = table.schema().fields - allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap + table.schema() } /** @@ -83,14 +82,14 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { * Find column with the given name. */ protected def findColumn(colName: String): Column = - allColumns.getOrElse( - colName, - throw new IllegalArgumentException(s"Column $colName does not exist")) + findField(allColumns, colName) + .map(field => convertFieldToColumn(colName, field)) + .getOrElse(throw new IllegalArgumentException(s"Column $colName does not exist")) - private def convertFieldToColumn(field: StructField): Column = { + private def convertFieldToColumn(colName: String, field: StructField): Column = { // Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata new Column( - name = field.name, + name = colName, description = field.getComment().orNull, dataType = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString, diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 6c87924e7..de2ea772d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -9,7 +9,9 @@ import org.json4s.CustomSerializer import org.json4s.JsonAST.JString import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField} +import org.apache.spark.sql.functions.col /** * Skipping index strategy that defines skipping data structure building and reading logic. @@ -82,4 +84,39 @@ object FlintSparkSkippingStrategy { { case kind: SkippingKind => JString(kind.toString) })) + + /** + * Extractor that match the given expression with the index expression in skipping index. + * + * @param indexColName + * indexed column name + */ + case class IndexColumnExtractor(indexColName: String) { + + def unapply(expr: Expression): Option[Column] = { + val colName = extractColumnName(expr).mkString(".") + if (colName == indexColName) { + Some(col(indexColName)) + } else { + None + } + } + + /* + * In Spark, after analysis, nested field "a.b.c" becomes: + * GetStructField(name="a", + * child=GetStructField(name="b", + * child=AttributeReference(name="c"))) + * TODO: To support any index expression, analyze index expression string + */ + private def extractColumnName(expr: Expression): Seq[String] = { + expr match { + case attr: Attribute => + Seq(attr.name) + case GetStructField(child, _, Some(name)) => + extractColumnName(child) :+ name + case _ => Seq.empty + } + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala index f7745e7a8..edcc24c26 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategy.scala @@ -6,9 +6,11 @@ package org.opensearch.flint.spark.skipping.minmax import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.functions.col @@ -35,19 +37,20 @@ case class MinMaxSkippingStrategy( Max(col(columnName).expr).toAggregateExpression()) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { + val IndexColumn = MinMaxIndexColumnExtractor(IndexColumnExtractor(columnName)) predicate match { - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(minColName) <= value && col(maxColName) >= value).expr) - case LessThan(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(minColName) < value).expr) - case LessThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(minColName) <= value).expr) - case GreaterThan(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(maxColName) > value).expr) - case GreaterThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(maxColName) >= value).expr) - case In(column @ AttributeReference(`columnName`, _, _, _), AllLiterals(literals)) => + case EqualTo(IndexColumn(minIndexCol, maxIndexCol), value: Literal) => + Some((minIndexCol <= value && maxIndexCol >= value).expr) + case LessThan(IndexColumn(minIndexCol, _), value: Literal) => + Some((minIndexCol < value).expr) + case LessThanOrEqual(IndexColumn(minIndexCol, _), value: Literal) => + Some((minIndexCol <= value).expr) + case GreaterThan(IndexColumn(_, maxIndexCol), value: Literal) => + Some((maxIndexCol > value).expr) + case GreaterThanOrEqual(IndexColumn(_, maxIndexCol), value: Literal) => + Some((maxIndexCol >= value).expr) + case In(column @ IndexColumn(_), AllLiterals(literals)) => /* * First, convert IN to approximate range check: min(in_list) <= col <= max(in_list) * to avoid long and maybe unnecessary comparison expressions. @@ -62,9 +65,21 @@ case class MinMaxSkippingStrategy( rewritePredicate(LessThanOrEqual(column, Literal(maxVal))).get)) case _ => None } + } + + /** Extractor that returns MinMax index column if the given expression matched */ + private case class MinMaxIndexColumnExtractor(IndexColumn: IndexColumnExtractor) { + + def unapply(expr: Expression): Option[(Column, Column)] = { + expr match { + case IndexColumn(_) => Some((col(minColName), col(maxColName))) + case _ => None + } + } + } /** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */ - object AllLiterals { + private object AllLiterals { def unapply(values: Seq[Expression]): Option[Seq[Literal]] = { if (values.forall(_.isInstanceOf[Literal])) { Some(values.asInstanceOf[Seq[Literal]]) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index 18fec0642..21d6dc836 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -6,9 +6,10 @@ package org.opensearch.flint.spark.skipping.partition import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.functions.col @@ -29,11 +30,13 @@ case class PartitionSkippingStrategy( Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression()) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { + val IndexColumn = IndexColumnExtractor(columnName) predicate match { // Column has same name in index data, so just rewrite to the same equation - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => - Some((col(columnName) === value).expr) + case EqualTo(IndexColumn(indexCol), value: Literal) => + Some((indexCol === value).expr) case _ => None } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index 1db9e3d32..18f573949 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -6,10 +6,11 @@ package org.opensearch.flint.spark.skipping.valueset import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET} import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} import org.apache.spark.sql.functions._ /** @@ -44,17 +45,19 @@ case class ValueSetSkippingStrategy( Seq(aggregator.expr) } - override def rewritePredicate(predicate: Expression): Option[Expression] = + override def rewritePredicate(predicate: Expression): Option[Expression] = { /* * This is supposed to be rewritten to ARRAY_CONTAINS(columName, value). * However, due to push down limitation in Spark, we keep the equation. */ + val IndexColumn = IndexColumnExtractor(columnName) predicate match { - case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => + case EqualTo(IndexColumn(indexCol), value: Literal) => // Value set maybe null due to maximum size limit restriction - Some((isnull(col(columnName)) || col(columnName) === value).expr) + Some((isnull(indexCol) || indexCol === value).expr) case _ => None } + } } object ValueSetSkippingStrategy { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala index 0cd4a5293..a4ca4430a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexBuilderSuite.scala @@ -5,7 +5,6 @@ package org.opensearch.flint.spark -import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite @@ -17,7 +16,11 @@ class FlintSparkIndexBuilderSuite extends FlintSuite { sql(""" | CREATE TABLE spark_catalog.default.test - | ( name STRING, age INT ) + | ( + | name STRING, + | age INT, + | address STRUCT> + | ) | USING JSON """.stripMargin) } @@ -28,21 +31,31 @@ class FlintSparkIndexBuilderSuite extends FlintSuite { super.afterAll() } + test("find column type") { + builder() + .onTable("test") + .expectTableName("spark_catalog.default.test") + .expectColumn("name", "string") + .expectColumn("age", "int") + .expectColumn("address", "struct>") + .expectColumn("address.first", "string") + .expectColumn("address.second", "struct") + .expectColumn("address.second.city", "string") + .expectColumn("address.second.street", "string") + } + test("should qualify table name in default database") { builder() .onTable("test") .expectTableName("spark_catalog.default.test") - .expectAllColumns("name", "age") builder() .onTable("default.test") .expectTableName("spark_catalog.default.test") - .expectAllColumns("name", "age") builder() .onTable("spark_catalog.default.test") .expectTableName("spark_catalog.default.test") - .expectAllColumns("name", "age") } test("should qualify table name and get columns in other database") { @@ -54,23 +67,19 @@ class FlintSparkIndexBuilderSuite extends FlintSuite { builder() .onTable("test2") .expectTableName("spark_catalog.mydb.test2") - .expectAllColumns("address") builder() .onTable("mydb.test2") .expectTableName("spark_catalog.mydb.test2") - .expectAllColumns("address") builder() .onTable("spark_catalog.mydb.test2") .expectTableName("spark_catalog.mydb.test2") - .expectAllColumns("address") // Can parse any specified table name builder() .onTable("spark_catalog.default.test") .expectTableName("spark_catalog.default.test") - .expectAllColumns("name", "age") } finally { sql("DROP DATABASE mydb CASCADE") sql("USE default") @@ -96,8 +105,10 @@ class FlintSparkIndexBuilderSuite extends FlintSuite { this } - def expectAllColumns(expected: String*): FakeFlintSparkIndexBuilder = { - allColumns.keys should contain theSameElementsAs expected + def expectColumn(expectName: String, expectType: String): FakeFlintSparkIndexBuilder = { + val column = findColumn(expectName) + column.name shouldBe expectName + column.dataType shouldBe expectType this } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 3612d3101..e68efdb7e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -671,6 +671,61 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { deleteTestIndex(testIndex) } + test("build skipping index for nested field and rewrite applicable query") { + val testTable = "spark_catalog.default.nested_field_table" + val testIndex = getSkippingIndexName(testTable) + withTable(testTable) { + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | struct_col STRUCT, field2: INT> + | ) + | USING JSON + |""".stripMargin) + sql(s""" + | INSERT INTO $testTable + | SELECT /*+ COALESCE(1) */ * + | FROM VALUES + | ( 30, STRUCT(STRUCT("value1"),123) ), + | ( 40, STRUCT(STRUCT("value2"),456) ) + |""".stripMargin) + sql(s""" + | INSERT INTO $testTable + | VALUES ( 50, STRUCT(STRUCT("value3"),789) ) + |""".stripMargin) + + flint + .skippingIndex() + .onTable(testTable) + .addMinMax("struct_col.field2") + .addValueSet("struct_col.field1.subfield") + .create() + flint.refreshIndex(testIndex) + + // FIXME: add assertion on index data once https://github.com/opensearch-project/opensearch-spark/issues/233 fixed + // Query rewrite nested field + val query1 = + sql(s"SELECT int_col FROM $testTable WHERE struct_col.field2 = 456".stripMargin) + checkAnswer(query1, Row(40)) + query1.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter( + col("MinMax_struct_col.field2_0") <= 456 && col("MinMax_struct_col.field2_1") >= 456)) + + // Query rewrite deep nested field + val query2 = sql( + s"SELECT int_col FROM $testTable WHERE struct_col.field1.subfield = 'value3'".stripMargin) + checkAnswer(query2, Row(50)) + query2.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter(isnull(col("struct_col.field1.subfield")) || + col("struct_col.field1.subfield") === "value3")) + + deleteTestIndex(testIndex) + } + } + // Custom matcher to check if a SparkPlan uses FlintSparkSkippingFileIndex def useFlintSparkSkippingFileIndex( subMatcher: Matcher[FlintSparkSkippingFileIndex]): Matcher[SparkPlan] = { @@ -703,8 +758,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { MatchResult( hasExpectedFilter, - "FlintSparkSkippingFileIndex does not have expected filter", - "FlintSparkSkippingFileIndex has expected filter") + s"FlintSparkSkippingFileIndex does not have expected filter: ${fileIndex.indexFilter}", + s"FlintSparkSkippingFileIndex has expected filter: ${fileIndex.indexFilter}") } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index fc4e4638d..a2b93648e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -142,4 +142,23 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match .create() } should have message s"Flint index $testFlintIndex already exists" } + + test("should clean up metadata log entry if index data has been deleted") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options(FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + .create() + flint.refreshIndex(testFlintIndex) + + // Simulate the situation that user delete index data directly and then refresh exits + spark.streams.active.find(_.name == testFlintIndex).get.stop() + deleteIndex(testFlintIndex) + + // Index state is refreshing and expect recover API clean it up + latestLogEntry(testLatestId) should contain("state" -> "refreshing") + flint.recoverIndex(testFlintIndex) + latestLogEntry(testLatestId) shouldBe empty + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index c60d250ea..a702d2c64 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -84,8 +84,8 @@ case class JobOperator( } try { - // Stop SparkSession if streaming job succeeds - if (!exceptionThrown && streaming) { + // Wait for streaming job complete if no error and there is streaming job running + if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { // wait if any child thread to finish before the main thread terminates spark.streams.awaitAnyTermination() }