Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.5-nexus] Extract source table names from mv query #860

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
attachLatestLogEntry(indexName, metadata)
}
.toList
.flatMap(FlintSparkIndexFactory.create)
.flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata))
} else {
Seq.empty
}
Expand All @@ -201,7 +201,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
if (flintClient.exists(indexName)) {
val metadata = flintIndexMetadataService.getIndexMetadata(indexName)
val metadataWithEntry = attachLatestLogEntry(indexName, metadata)
FlintSparkIndexFactory.create(metadataWithEntry)
FlintSparkIndexFactory.create(spark, metadataWithEntry)
} else {
Option.empty
}
Expand Down Expand Up @@ -326,7 +326,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
val index = describeIndex(indexName)

if (index.exists(_.options.autoRefresh())) {
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get
FlintSparkIndexRefresh
.create(updatedIndex.name(), updatedIndex)
.validate(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get)
validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Flint Spark index factory that encapsulates specific Flint index instance creation. This is for
Expand All @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index from generic Flint metadata.
*
* @param spark
* Spark session
* @param metadata
* Flint metadata
* @return
* Flint index instance, or None if any error during creation
*/
def create(metadata: FlintMetadata): Option[FlintSparkIndex] = {
def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = {
try {
Some(doCreate(metadata))
Some(doCreate(spark, metadata))
} catch {
case e: Exception =>
logWarning(s"Failed to create Flint index from metadata $metadata", e)
Expand All @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index with default options.
*
* @param spark
* Spark session
* @param index
* Flint index
* @param metadata
* Flint metadata
* @return
* Flint index with default options
*/
def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = {
def createWithDefaultOptions(
spark: SparkSession,
index: FlintSparkIndex): Option[FlintSparkIndex] = {
val originalOptions = index.options
val updatedOptions =
FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions)
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
this.create(updatedMetadata)
this.create(spark, updatedMetadata)
}

private def doCreate(metadata: FlintMetadata): FlintSparkIndex = {
private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)
val latestLogEntry = metadata.latestLogEntry
Expand Down Expand Up @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging {
FlintSparkMaterializedView(
metadata.name,
metadata.source,
getMvSourceTables(spark, metadata),
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
Expand All @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging {
.toMap
}

private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = {
val sourceTables = getArrayString(metadata.properties, "sourceTables")
if (sourceTables.isEmpty) {
FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source)
} else {
sourceTables
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}
Expand All @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging {
Some(value.asInstanceOf[String])
}
}

private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = {
map.get(key) match {
case list: java.util.ArrayList[_] =>
list.toArray.map(_.toString)
case _ => Array.empty[String]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.flint.{loadTable, parseTableName}

/**
* Flint Spark validation helper.
Expand All @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging {
* true if all non Hive, otherwise false
*/
def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than one for MV query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
case covering: FlintSparkCoveringIndex => Seq(covering.tableName)
case mv: FlintSparkMaterializedView =>
spark.sessionState.sqlParser
.parsePlan(mv.query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq
}

// Validate if any source table is not supported (currently Hive only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser

/**
Expand Down Expand Up @@ -46,9 +47,7 @@ case class FlintMetadataCache(

object FlintMetadataCache {

// TODO: constant for version
val mockTableName =
"dataSourceName.default.logGroups(logGroupIdentifier:['arn:aws:logs:us-east-1:123456:test-llt-xa', 'arn:aws:logs:us-east-1:123456:sample-lg-1'])"
val metadataCacheVersion = "1.0"

def apply(metadata: FlintMetadata): FlintMetadataCache = {
val indexOptions = FlintSparkIndexOptions(
Expand All @@ -61,14 +60,22 @@ object FlintMetadataCache {
} else {
None
}
val sourceTables = metadata.kind match {
case MV_INDEX_TYPE =>
metadata.properties.get("sourceTables") match {
case list: java.util.ArrayList[_] =>
list.toArray.map(_.toString)
case _ => Array.empty[String]
}
case _ => Array(metadata.source)
}
val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry =>
entry.lastRefreshCompleteTime match {
case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None
case timestamp => Some(timestamp)
}
}

// TODO: get source tables from metadata
FlintMetadataCache("1.0", refreshInterval, Array(mockTableName), lastRefreshTime)
FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* MV name
* @param query
* source query that generates MV data
* @param sourceTables
* source table names
* @param outputSchema
* output schema
* @param options
Expand All @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class FlintSparkMaterializedView(
mvName: String,
query: String,
sourceTables: Array[String],
outputSchema: Map[String, String],
override val options: FlintSparkIndexOptions = empty,
override val latestLogEntry: Option[FlintMetadataLogEntry] = None)
Expand All @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView(
metadataBuilder(this)
.name(mvName)
.source(query)
.addProperty("sourceTables", sourceTables)
.indexedColumns(indexColumnMaps)
.schema(schema)
.build()
Expand Down Expand Up @@ -165,10 +169,30 @@ object FlintSparkMaterializedView {
flintIndexNamePrefix(mvName)
}

/**
* Extract source table names (possibly more than one) from the query.
*
* @param spark
* Spark session
* @param query
* source query that generates MV data
* @return
* source table names
*/
def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = {
spark.sessionState.sqlParser
.parsePlan(query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
.toArray
}

/** Builder class for MV build */
class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) {
private var mvName: String = ""
private var query: String = ""
private var sourceTables: Array[String] = Array.empty[String]

/**
* Set MV name.
Expand All @@ -193,6 +217,7 @@ object FlintSparkMaterializedView {
*/
def query(query: String): Builder = {
this.query = query
this.sourceTables = extractSourceTableNames(flint.spark, query)
this
}

Expand Down Expand Up @@ -221,7 +246,7 @@ object FlintSparkMaterializedView {
field.name -> field.dataType.simpleString
}
.toMap
FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions)
FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.scalatest.matchers.should.Matchers._

import org.apache.spark.FlintSuite

class FlintSparkIndexFactorySuite extends FlintSuite {

test("create mv should generate source tables if missing in metadata") {
val testTable = "spark_catalog.default.mv_build_test"
val testMvName = "spark_catalog.default.mv"
val testQuery = s"SELECT * FROM $testTable"

val content =
s""" {
| "_meta": {
| "kind": "$MV_INDEX_TYPE",
| "indexedColumns": [
| {
| "columnType": "int",
| "columnName": "age"
| }
| ],
| "name": "$testMvName",
| "source": "$testQuery"
| },
| "properties": {
| "age": {
| "type": "integer"
| }
| }
| }
|""".stripMargin

val metadata = FlintOpenSearchIndexMetadataService.deserialize(content)
val index = FlintSparkIndexFactory.create(spark, metadata)
index shouldBe defined
index.get
.asInstanceOf[FlintSparkMaterializedView]
.sourceTables should contain theSameElementsAs Array(testTable)
}
}
Loading
Loading