Skip to content

Commit

Permalink
Extract util methods from each AST builder
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Sep 18, 2023
1 parent 1637cab commit afe9d3e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

package org.opensearch.flint.spark.sql

import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.PropertyListContext
import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder
import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder

import org.apache.spark.sql.catalyst.plans.logical.Command

/**
* Flint Spark AST builder that builds Spark command for Flint index statement.
* This class mix-in all other AST builders and provides util methods.
*/
class FlintSparkSqlAstBuilder
extends FlintSparkSqlExtensionsBaseVisitor[Command]
Expand All @@ -21,3 +25,46 @@ class FlintSparkSqlAstBuilder
override def aggregateResult(aggregate: Command, nextResult: Command): Command =
if (nextResult != null) nextResult else aggregate
}

object FlintSparkSqlAstBuilder {

/**
* Check if auto_refresh is true in property list.
*
* @param ctx
* property list
*/
def isAutoRefreshEnabled(ctx: PropertyListContext): Boolean = {
if (ctx == null) {
false
} else {
ctx
.property()
.forEach(p => {
if (p.key.getText == "auto_refresh") {
return p.value.getText.toBoolean
}
})
false
}
}

/**
* Get full table name if database not specified.
*
* @param flint
* Flint Spark which has access to Spark Catalog
* @param tableNameCtx
* table name
* @return
*/
def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
val tableName = tableNameCtx.getText
if (tableName.contains(".")) {
tableName
} else {
val db = flint.spark.catalog.currentDatabase
s"$db.$tableName"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

package org.opensearch.flint.spark.sql.covering

import org.antlr.v4.runtime.tree.RuleNode
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateCoveringIndexStatementContext, PropertyListContext}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, isAutoRefreshEnabled}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.CreateCoveringIndexStatementContext

import org.apache.spark.sql.catalyst.plans.logical.Command

Expand Down Expand Up @@ -44,29 +43,4 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[C
Seq.empty
}
}

private def isAutoRefreshEnabled(ctx: PropertyListContext): Boolean = {
if (ctx == null) {
false
} else {
ctx
.property()
.forEach(p => {
if (p.key.getText == "auto_refresh") {
return p.value.getText.toBoolean
}
})
false
}
}

private def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
val tableName = tableNameCtx.getText
if (tableName.contains(".")) {
tableName
} else {
val db = flint.spark.catalog.currentDatabase
s"$db.$tableName"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, isAutoRefreshEnabled}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -85,31 +86,6 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[C
Seq.empty
}

private def isAutoRefreshEnabled(ctx: PropertyListContext): Boolean = {
if (ctx == null) {
false
} else {
ctx
.property()
.forEach(p => {
if (p.key.getText == "auto_refresh") {
return p.value.getText.toBoolean
}
})
false
}
}

private def getSkippingIndexName(flint: FlintSpark, tableNameCtx: RuleNode): String =
FlintSparkSkippingIndex.getSkippingIndexName(getFullTableName(flint, tableNameCtx))

private def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
val tableName = tableNameCtx.getText
if (tableName.contains(".")) {
tableName
} else {
val db = flint.spark.catalog.currentDatabase
s"$db.$tableName"
}
}
}

0 comments on commit afe9d3e

Please sign in to comment.