Skip to content

Commit

Permalink
Support struct field as indexed column (#213)
Browse files Browse the repository at this point in the history
* Support struct field in index builder

Signed-off-by: Chen Dai <[email protected]>

* Add IT

Signed-off-by: Chen Dai <[email protected]>

* Fix IT after merge main

Signed-off-by: Chen Dai <[email protected]>

* Implement simple query rewrite and update IT

Signed-off-by: Chen Dai <[email protected]>

* Address PR comments

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Feb 7, 2024
1 parent f4744ab commit 851ade3
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.concurrent.ScheduledExecutorService

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.{ShutdownHookManager, ThreadUtils}

/**
Expand Down Expand Up @@ -120,4 +121,17 @@ package object flint {
def loadTable(catalog: CatalogPlugin, ident: Identifier): Option[Table] = {
CatalogV2Util.loadTable(catalog, ident)
}

/**
* Find field with the given name under root field recursively.
*
* @param rootField
* root field struct
* @param fieldName
* field name to search
* @return
*/
def findField(rootField: StructType, fieldName: String): Option[StructField] = {
rootField.findNestedField(fieldName.split('.')).map(_._2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty

import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.flint.{findField, loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* Flint Spark index builder base class.
Expand All @@ -27,15 +27,14 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
protected var indexOptions: FlintSparkIndexOptions = empty

/** All columns of the given source table */
lazy protected val allColumns: Map[String, Column] = {
lazy protected val allColumns: StructType = {
require(qualifiedTableName.nonEmpty, "Source table name is not provided")

val (catalog, ident) = parseTableName(flint.spark, qualifiedTableName)
val table = loadTable(catalog, ident).getOrElse(
throw new IllegalStateException(s"Table $qualifiedTableName is not found"))

val allFields = table.schema().fields
allFields.map { field => field.name -> convertFieldToColumn(field) }.toMap
table.schema()
}

/**
Expand Down Expand Up @@ -83,14 +82,14 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
* Find column with the given name.
*/
protected def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))
findField(allColumns, colName)
.map(field => convertFieldToColumn(colName, field))
.getOrElse(throw new IllegalArgumentException(s"Column $colName does not exist"))

private def convertFieldToColumn(field: StructField): Column = {
private def convertFieldToColumn(colName: String, field: StructField): Column = {
// Ref to CatalogImpl.listColumns(): Varchar/Char is StringType with real type name in metadata
new Column(
name = field.name,
name = colName,
description = field.getComment().orNull,
dataType =
CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType).catalogString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import org.json4s.CustomSerializer
import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField}
import org.apache.spark.sql.functions.col

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -82,4 +84,39 @@ object FlintSparkSkippingStrategy {
{ case kind: SkippingKind =>
JString(kind.toString)
}))

/**
* Extractor that match the given expression with the index expression in skipping index.
*
* @param indexColName
* indexed column name
*/
case class IndexColumnExtractor(indexColName: String) {

def unapply(expr: Expression): Option[Column] = {
val colName = extractColumnName(expr).mkString(".")
if (colName == indexColName) {
Some(col(indexColName))
} else {
None
}
}

/*
* In Spark, after analysis, nested field "a.b.c" becomes:
* GetStructField(name="a",
* child=GetStructField(name="b",
* child=AttributeReference(name="c")))
* TODO: To support any index expression, analyze index expression string
*/
private def extractColumnName(expr: Expression): Seq[String] = {
expr match {
case attr: Attribute =>
Seq(attr.name)
case GetStructField(child, _, Some(name)) =>
extractColumnName(child) :+ name
case _ => Seq.empty
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.flint.spark.skipping.minmax

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.functions.col
Expand All @@ -35,19 +37,20 @@ case class MinMaxSkippingStrategy(
Max(col(columnName).expr).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexColumn = MinMaxIndexColumnExtractor(IndexColumnExtractor(columnName))
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(minColName) <= value && col(maxColName) >= value).expr)
case LessThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(minColName) < value).expr)
case LessThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(minColName) <= value).expr)
case GreaterThan(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(maxColName) > value).expr)
case GreaterThanOrEqual(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(maxColName) >= value).expr)
case In(column @ AttributeReference(`columnName`, _, _, _), AllLiterals(literals)) =>
case EqualTo(IndexColumn(minIndexCol, maxIndexCol), value: Literal) =>
Some((minIndexCol <= value && maxIndexCol >= value).expr)
case LessThan(IndexColumn(minIndexCol, _), value: Literal) =>
Some((minIndexCol < value).expr)
case LessThanOrEqual(IndexColumn(minIndexCol, _), value: Literal) =>
Some((minIndexCol <= value).expr)
case GreaterThan(IndexColumn(_, maxIndexCol), value: Literal) =>
Some((maxIndexCol > value).expr)
case GreaterThanOrEqual(IndexColumn(_, maxIndexCol), value: Literal) =>
Some((maxIndexCol >= value).expr)
case In(column @ IndexColumn(_), AllLiterals(literals)) =>
/*
* First, convert IN to approximate range check: min(in_list) <= col <= max(in_list)
* to avoid long and maybe unnecessary comparison expressions.
Expand All @@ -62,9 +65,21 @@ case class MinMaxSkippingStrategy(
rewritePredicate(LessThanOrEqual(column, Literal(maxVal))).get))
case _ => None
}
}

/** Extractor that returns MinMax index column if the given expression matched */
private case class MinMaxIndexColumnExtractor(IndexColumn: IndexColumnExtractor) {

def unapply(expr: Expression): Option[(Column, Column)] = {
expr match {
case IndexColumn(_) => Some((col(minColName), col(maxColName)))
case _ => None
}
}
}

/** Need this because Scala pattern match doesn't work for generic type like Seq[Literal] */
object AllLiterals {
private object AllLiterals {
def unapply(values: Seq[Expression]): Option[Seq[Literal]] = {
if (values.forall(_.isInstanceOf[Literal])) {
Some(values.asInstanceOf[Seq[Literal]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
package org.opensearch.flint.spark.skipping.partition

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.functions.col

Expand All @@ -29,11 +30,13 @@ case class PartitionSkippingStrategy(
Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexColumn = IndexColumnExtractor(columnName)
predicate match {
// Column has same name in index data, so just rewrite to the same equation
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(columnName) === value).expr)
case EqualTo(IndexColumn(indexCol), value: Literal) =>
Some((indexCol === value).expr)
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DEFAULT_VALUE_SET_MAX_SIZE, VALUE_SET_MAX_SIZE_KEY}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
import org.apache.spark.sql.functions._

/**
Expand Down Expand Up @@ -44,17 +45,19 @@ case class ValueSetSkippingStrategy(
Seq(aggregator.expr)
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
override def rewritePredicate(predicate: Expression): Option[Expression] = {
/*
* This is supposed to be rewritten to ARRAY_CONTAINS(columName, value).
* However, due to push down limitation in Spark, we keep the equation.
*/
val IndexColumn = IndexColumnExtractor(columnName)
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
case EqualTo(IndexColumn(indexCol), value: Literal) =>
// Value set maybe null due to maximum size limit restriction
Some((isnull(col(columnName)) || col(columnName) === value).expr)
Some((isnull(indexCol) || indexCol === value).expr)
case _ => None
}
}
}

object ValueSetSkippingStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.flint.spark

import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
Expand All @@ -17,7 +16,11 @@ class FlintSparkIndexBuilderSuite extends FlintSuite {

sql("""
| CREATE TABLE spark_catalog.default.test
| ( name STRING, age INT )
| (
| name STRING,
| age INT,
| address STRUCT<first: STRING, second: STRUCT<city: STRING, street: STRING>>
| )
| USING JSON
""".stripMargin)
}
Expand All @@ -28,21 +31,31 @@ class FlintSparkIndexBuilderSuite extends FlintSuite {
super.afterAll()
}

test("find column type") {
builder()
.onTable("test")
.expectTableName("spark_catalog.default.test")
.expectColumn("name", "string")
.expectColumn("age", "int")
.expectColumn("address", "struct<first:string,second:struct<city:string,street:string>>")
.expectColumn("address.first", "string")
.expectColumn("address.second", "struct<city:string,street:string>")
.expectColumn("address.second.city", "string")
.expectColumn("address.second.street", "string")
}

test("should qualify table name in default database") {
builder()
.onTable("test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")

builder()
.onTable("default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")

builder()
.onTable("spark_catalog.default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")
}

test("should qualify table name and get columns in other database") {
Expand All @@ -54,23 +67,19 @@ class FlintSparkIndexBuilderSuite extends FlintSuite {
builder()
.onTable("test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

builder()
.onTable("mydb.test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

builder()
.onTable("spark_catalog.mydb.test2")
.expectTableName("spark_catalog.mydb.test2")
.expectAllColumns("address")

// Can parse any specified table name
builder()
.onTable("spark_catalog.default.test")
.expectTableName("spark_catalog.default.test")
.expectAllColumns("name", "age")
} finally {
sql("DROP DATABASE mydb CASCADE")
sql("USE default")
Expand All @@ -96,8 +105,10 @@ class FlintSparkIndexBuilderSuite extends FlintSuite {
this
}

def expectAllColumns(expected: String*): FakeFlintSparkIndexBuilder = {
allColumns.keys should contain theSameElementsAs expected
def expectColumn(expectName: String, expectType: String): FakeFlintSparkIndexBuilder = {
val column = findColumn(expectName)
column.name shouldBe expectName
column.dataType shouldBe expectType
this
}

Expand Down
Loading

0 comments on commit 851ade3

Please sign in to comment.