diff --git a/build.sbt b/build.sbt index 0dcfb8af7..938f19a64 100644 --- a/build.sbt +++ b/build.sbt @@ -139,6 +139,10 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) }, + assembly / assemblyExcludedJars := { + val cp = (assembly / fullClasspath).value + cp filter { file => file.data.getName.contains("LogsConnectorSpark")} + }, assembly / test := (Test / test).value) // Test assembly package with integration test. diff --git a/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar b/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar new file mode 100644 index 000000000..0aa50bbb2 Binary files /dev/null and b/flint-spark-integration/lib/LogsConnectorSpark-1.0.jar differ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 11f8ad304..83f8def4d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -5,14 +5,19 @@ package org.opensearch.flint.spark.skipping +import com.amazon.awslogsdataaccesslayer.connectors.spark.LogsTable import org.opensearch.flint.spark.FlintSpark -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.qualifyTableName /** @@ -57,6 +62,46 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] } else { filter } + case filter @ Filter( + condition: Predicate, + relation @ DataSourceV2Relation(table, _, Some(catalog), Some(identifier), _)) + if hasNoDisjunction(condition) && + // Check if query plan already rewritten + table.isInstanceOf[LogsTable] && !table.asInstanceOf[LogsTable].hasFileIndexScan() => + val index = flint.describeIndex(getIndexName(catalog, identifier)) + if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) { + val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] + val indexFilter = rewriteToIndexFilter(skippingIndex, condition) + /* + * Replace original LogsTable with a new one with file index scan: + * Filter(a=b) + * |- DataSourceV2Relation(A) + * |- LogsTable <== replaced with a new LogsTable with file index scan + */ + if (indexFilter.isDefined) { + val indexScan = flint.queryIndex(skippingIndex.name()) + val selectFileIndexScan = + // Non hybrid scan + // TODO: refactor common logic with file-based skipping index + indexScan + .filter(new Column(indexFilter.get)) + .select(FILE_PATH_COLUMN) + + // Construct LogsTable with file index scan + // It will build scan operator using log file ids collected from file index scan + val logsTable = table.asInstanceOf[LogsTable] + val newTable = new LogsTable( + logsTable.schema(), + logsTable.options(), + selectFileIndexScan, + logsTable.processedFields()) + filter.copy(child = relation.copy(table = newTable)) + } else { + filter + } + } else { + filter + } } private def getIndexName(table: CatalogTable): String = { @@ -67,6 +112,11 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] getSkippingIndexName(qualifiedTableName) } + private def getIndexName(catalog: CatalogPlugin, identifier: Identifier): String = { + val qualifiedTableName = s"${catalog.name}.${identifier}" + getSkippingIndexName(qualifiedTableName) + } + private def hasNoDisjunction(condition: Expression): Boolean = { condition.collectFirst { case Or(_, _) => true