Skip to content

Commit

Permalink
generate source tables for old mv without new prop
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Kao <[email protected]>
  • Loading branch information
seankao-az committed Oct 31, 2024
1 parent 372ea26 commit 559b393
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
attachLatestLogEntry(indexName, metadata)
}
.toList
.flatMap(FlintSparkIndexFactory.create)
.flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata))
} else {
Seq.empty
}
Expand All @@ -202,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
if (flintClient.exists(indexName)) {
val metadata = flintIndexMetadataService.getIndexMetadata(indexName)
val metadataWithEntry = attachLatestLogEntry(indexName, metadata)
FlintSparkIndexFactory.create(metadataWithEntry)
FlintSparkIndexFactory.create(spark, metadataWithEntry)
} else {
Option.empty
}
Expand Down Expand Up @@ -327,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
val index = describeIndex(indexName)

if (index.exists(_.options.autoRefresh())) {
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get
FlintSparkIndexRefresh
.create(updatedIndex.name(), updatedIndex)
.validate(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get)
validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Flint Spark index factory that encapsulates specific Flint index instance creation. This is for
Expand All @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index from generic Flint metadata.
*
* @param spark
* Spark session
* @param metadata
* Flint metadata
* @return
* Flint index instance, or None if any error during creation
*/
def create(metadata: FlintMetadata): Option[FlintSparkIndex] = {
def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = {
try {
Some(doCreate(metadata))
Some(doCreate(spark, metadata))
} catch {
case e: Exception =>
logWarning(s"Failed to create Flint index from metadata $metadata", e)
Expand All @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index with default options.
*
* @param spark
* Spark session
* @param index
* Flint index
* @param metadata
* Flint metadata
* @return
* Flint index with default options
*/
def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = {
def createWithDefaultOptions(
spark: SparkSession,
index: FlintSparkIndex): Option[FlintSparkIndex] = {
val originalOptions = index.options
val updatedOptions =
FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions)
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
this.create(updatedMetadata)
this.create(spark, updatedMetadata)
}

private def doCreate(metadata: FlintMetadata): FlintSparkIndex = {
private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)
val latestLogEntry = metadata.latestLogEntry
Expand Down Expand Up @@ -118,7 +123,7 @@ object FlintSparkIndexFactory extends Logging {
FlintSparkMaterializedView(
metadata.name,
metadata.source,
getArrayString(metadata.properties, "sourceTables"),
getMvSourceTables(spark, metadata),
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
Expand All @@ -135,6 +140,15 @@ object FlintSparkIndexFactory extends Logging {
.toMap
}

private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = {
val sourceTables = getArrayString(metadata.properties, "sourceTables")
if (sourceTables.isEmpty) {
FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source)
} else {
sourceTables
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* MV name
* @param query
* source query that generates MV data
* @param sourceTables
* source table names
* @param outputSchema
* output schema
* @param options
Expand Down Expand Up @@ -167,6 +169,25 @@ object FlintSparkMaterializedView {
flintIndexNamePrefix(mvName)
}

/**
* Extract source table names (possibly more than one) from the query.
*
* @param spark
* Spark session
* @param query
* source query that generates MV data
* @return
* source table names
*/
def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = {
spark.sessionState.sqlParser
.parsePlan(query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
.toArray
}

/** Builder class for MV build */
class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) {
private var mvName: String = ""
Expand Down Expand Up @@ -196,8 +217,7 @@ object FlintSparkMaterializedView {
*/
def query(query: String): Builder = {
this.query = query
// Extract source table names (possibly more than one)
this.sourceTables = extractSourceTableNames(query)
this.sourceTables = extractSourceTableNames(flint.spark, query)
this
}

Expand Down Expand Up @@ -228,14 +248,5 @@ object FlintSparkMaterializedView {
.toMap
FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions)
}

private def extractSourceTableNames(query: String): Array[String] = {
flint.spark.sessionState.sqlParser
.parsePlan(query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(flint.spark, relation.tableName)
}
.toArray
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.scalatest.matchers.should.Matchers._

import org.apache.spark.FlintSuite

class FlintSparkIndexFactorySuite extends FlintSuite {

/** Test table, MV name and query */
val testTable = "spark_catalog.default.mv_build_test"
val testMvName = "spark_catalog.default.mv"
val testQuery = s"SELECT * FROM $testTable"

test("create mv should generate source tables if missing in metadata") {
val content =
s""" {
| "_meta": {
| "kind": "$MV_INDEX_TYPE",
| "indexedColumns": [
| {
| "columnType": "int",
| "columnName": "age"
| }
| ],
| "name": "$testMvName",
| "source": "SELECT age FROM $testTable"
| },
| "properties": {
| "age": {
| "type": "integer"
| }
| }
| }
|""".stripMargin

val metadata = FlintOpenSearchIndexMetadataService.deserialize(content)
val index = FlintSparkIndexFactory.create(spark, metadata)
index shouldBe defined
index.get
.asInstanceOf[FlintSparkMaterializedView]
.sourceTables should contain theSameElementsAs Array(testTable)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.storage.{FlintOpenSearchIndexMetadataService, OpenSearchClientUtils}
import org.opensearch.flint.spark.FlintSparkIndex.quotedTableName
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{extractSourceTableNames, getFlintIndexName}
import org.opensearch.flint.spark.scheduler.OpenSearchAsyncQueryScheduler
import org.scalatest.matchers.must.Matchers.{contain, defined}
import org.scalatest.matchers.must.Matchers._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -52,6 +52,29 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
deleteTestIndex(testFlintIndex)
}

test("extract source table names from materialized view source query successfully") {
val testComplexQuery = s"""
| SELECT *
| FROM (
| SELECT 1
| FROM table1
| LEFT JOIN `table2`
| )
| UNION ALL
| SELECT 1
| FROM spark_catalog.default.`table/3`
| INNER JOIN spark_catalog.default.`table.4`
|""".stripMargin
extractSourceTableNames(flint.spark, testComplexQuery) should contain theSameElementsAs
Array(
"spark_catalog.default.table1",
"spark_catalog.default.table2",
"spark_catalog.default.`table/3`",
"spark_catalog.default.`table.4`")

extractSourceTableNames(flint.spark, "SELECT 1") should have size 0
}

test("create materialized view with metadata successfully") {
withTempDir { checkpointDir =>
val indexOptions =
Expand Down Expand Up @@ -111,27 +134,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
}

test("create materialized view should parse source tables successfully") {
val testTable1 = "table1"
val testTable2 = "`table2`"
val testTable3 = "spark_catalog.default.table3"
val testTable4 = "spark_catalog.default.`table4`"
createTimeSeriesTable(testTable1)
createTimeSeriesTable(testTable2)
createTimeSeriesTable(testTable3)
createTimeSeriesTable(testTable4)
val indexOptions = FlintSparkIndexOptions(Map.empty)
val testQuery = s"""
| SELECT *
| FROM (
| SELECT 1
| FROM $testTable1 t1
| LEFT JOIN $testTable2 t2
| )
| UNION ALL
| SELECT 1
| FROM $testTable3 t3
| INNER JOIN $testTable4 t4
|""".stripMargin
flint
.materializedView()
.name(testMvName)
Expand All @@ -143,12 +146,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite {
index shouldBe defined
index.get
.asInstanceOf[FlintSparkMaterializedView]
.sourceTables should contain theSameElementsAs
Array(
"spark_catalog.default.table1",
"spark_catalog.default.table2",
"spark_catalog.default.table3",
"spark_catalog.default.table4")
.sourceTables should contain theSameElementsAs Array(testTable)
}

test("create materialized view with default checkpoint location successfully") {
Expand Down

0 comments on commit 559b393

Please sign in to comment.