Skip to content

Commit

Permalink
Add ast builder and SQL IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 2, 2023
1 parent 7b39800 commit 7bafbc7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,14 @@ case class FlintSparkSkippingIndex(
new Column(aggFunc.toAggregateExpression().as(name))
}

df.getOrElse(spark.read.table(tableName))
var job = df.getOrElse(spark.read.table(tableName))

// Add optional filtering condition
if (filterCondition.isDefined) {
job = job.where(filterCondition.get)
}

job
.groupBy(input_file_name().as(FILE_PATH_COLUMN))
.agg(namedAggFuncs.head, namedAggFuncs.tail: _*)
.withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
}
}

if (ctx.whereClause() != null) {
indexBuilder.filterBy(getSqlText(ctx.whereClause().filterCondition()))
}

val ignoreIfExists = ctx.EXISTS() != null
val indexOptions = visitPropertyList(ctx.propertyList())
indexBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.json4s.native.Serialization
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.storage.FlintOpenSearchClient
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.{defined, have}
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -50,18 +50,28 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite {
| WITH (auto_refresh = true)
| """.stripMargin)

// Wait for streaming job complete current micro batch
val job = spark.streams.active.find(_.name == testIndex)
job shouldBe defined
failAfter(streamingTimeout) {
job.get.processAllAvailable()
}

val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex)
val indexData = awaitStreamingDataComplete(testIndex)
flint.describeIndex(testIndex) shouldBe defined
indexData.count() shouldBe 2
}

test("create skipping index with filtering condition") {
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
| year PARTITION,
| name VALUE_SET,
| age MIN_MAX
| )
| WHERE address = 'Portland'
| WITH (auto_refresh = true)
| """.stripMargin)

val indexData = awaitStreamingDataComplete(testIndex)
flint.describeIndex(testIndex) shouldBe defined
indexData.count() shouldBe 1
}

test("create skipping index with streaming job options") {
withTempDir { checkpointDir =>
sql(s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.flint.spark

import org.opensearch.flint.OpenSearchSuite

import org.apache.spark.FlintSuite
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY}
import org.apache.spark.sql.streaming.StreamTest
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

/**
* Flint Spark suite trait that initializes [[FlintSpark]] API instance.
Expand All @@ -31,6 +33,16 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
setFlintSparkConf(CHECKPOINT_MANDATORY, "false")
}

protected def awaitStreamingDataComplete(flintIndexName: String): DataFrame = {
val job = spark.streams.active.find(_.name == flintIndexName)
job shouldBe defined

failAfter(streamingTimeout) {
job.get.processAllAvailable()
}
spark.read.format(FLINT_DATASOURCE).load(flintIndexName)
}

protected def awaitStreamingComplete(jobId: String): Unit = {
val job = spark.streams.get(jobId)
failAfter(streamingTimeout) {
Expand Down

0 comments on commit 7bafbc7

Please sign in to comment.