Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict value set size using new CollectSetLimit function #207

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ indexColTypeList
;

indexColType
: identifier skipType=(PARTITION | VALUE_SET | MIN_MAX)
: identifier skipType=(PARTITION | MIN_MAX)
| identifier valueSetType
;

valueSetType
: VALUE_SET (LEFT_PAREN limit=INTEGER_VALUE RIGHT_PAREN)?
;

indexName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ object FlintSparkIndexFactory {
case PARTITION =>
PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case VALUE_SET =>
ValueSetSkippingStrategy(columnName = columnName, columnType = columnType)
ValueSetSkippingStrategy(
columnName = columnName,
columnType = columnType,
properties = getSkippingProperties(colInfo))
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
Expand Down Expand Up @@ -90,4 +93,13 @@ object FlintSparkIndexFactory {
Some(value.asInstanceOf[String])
}
}

private def getSkippingProperties(
colInfo: java.util.Map[String, AnyRef]): Map[String, String] = {
colInfo
.get("properties")
.asInstanceOf[java.util.Map[String, String]]
.asScala
.toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.function

import scala.collection.mutable

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, CollectSet, ImperativeAggregate}
import org.apache.spark.sql.types.DataType

/**
* Collect set of unique values with maximum limit.
*/
@ExpressionDescription(
usage =
"_FUNC_(expr, limit) - Collects and returns a set of unique elements up to maximum limit.",
examples = """
Examples:
> SELECT _FUNC_(col, 2) FROM VALUES (1), (2), (1) AS tab(col);
[1,2]
> SELECT _FUNC_(col, 1) FROM VALUES (1), (2), (1) AS tab(col);
[1]
""",
note = """
The function is non-deterministic because the order of collected results depends
on the order of the rows which may be non-deterministic after a shuffle.
""",
group = "agg_funcs",
since = "2.0.0")
case class CollectSetLimit(
child: Expression,
limit: Int,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends Collect[mutable.HashSet[Any]] {

/** Delegate to collect set (because Scala prohibit case-to-case inheritance) */
private val collectSet = CollectSet(child, mutableAggBufferOffset, inputAggBufferOffset)

override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
if (buffer.size < limit) {
super.update(buffer, input)
} else {
buffer
}
}

override protected def convertToBufferElement(value: Any): Any =
collectSet.convertToBufferElement(value)

override protected val bufferElementType: DataType = collectSet.bufferElementType

override def createAggregationBuffer(): mutable.HashSet[Any] =
collectSet.createAggregationBuffer()

override def eval(buffer: mutable.HashSet[Any]): Any = collectSet.eval(buffer)

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

override def withNewMutableAggBufferOffset(
newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "collect_set_limit"
}

object CollectSetLimit {

/** Function DSL */
def collect_set_limit(columnName: String, limit: Int): Column =
new Column(
CollectSetLimit(new Column(columnName).expr, limit)
.toAggregateExpression())
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ case class FlintSparkSkippingIndex(
Map[String, AnyRef](
"kind" -> col.kind.toString,
"columnName" -> col.columnName,
"columnType" -> col.columnType).asJava)
"columnType" -> col.columnType,
"properties" -> col.properties.asJava).asJava)
.toArray

val fieldTypes =
Expand All @@ -74,7 +75,7 @@ case class FlintSparkSkippingIndex(
// Wrap aggregate function with output column name
val namedAggFuncs =
(outputNames, aggFuncs).zipped.map { case (name, aggFunc) =>
new Column(aggFunc.toAggregateExpression().as(name))
new Column(aggFunc.as(name))
}

df.getOrElse(spark.read.table(tableName))
Expand Down Expand Up @@ -155,14 +156,20 @@ object FlintSparkSkippingIndex {
*
* @param colName
* indexed column name
* @param properties
* value set skipping properties
* @return
* index builder
*/
def addValueSet(colName: String): Builder = {
def addValueSet(colName: String, properties: Map[String, String] = Map.empty): Builder = {
require(tableName.nonEmpty, "table name cannot be empty")

val col = findColumn(colName)
addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType))
addIndexedColumn(
ValueSetSkippingStrategy(
columnName = col.name,
columnType = col.dataType,
properties = properties))
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ trait FlintSparkSkippingStrategy {
*/
val columnType: String

/**
* Skipping algorithm properties.
*/
val properties: Map[String, String] = Map.empty

/**
* @return
* output schema mapping from Flint field name to Flint field type
Expand All @@ -42,7 +47,7 @@ trait FlintSparkSkippingStrategy {
* @return
* aggregators that generate skipping data structure
*/
def getAggregators: Seq[AggregateFunction]
def getAggregators: Seq[Expression]

/**
* Rewrite a filtering condition on source table into a new predicate on index data based on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.functions.col

Expand All @@ -29,8 +29,10 @@ case class MinMaxSkippingStrategy(
override def outputSchema(): Map[String, String] =
Map(minColName -> columnType, maxColName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(Min(col(columnName).expr), Max(col(columnName).expr))
override def getAggregators: Seq[Expression] =
Seq(
Min(col(columnName).expr).toAggregateExpression(),
Max(col(columnName).expr).toAggregateExpression())

override def rewritePredicate(predicate: Expression): Option[Expression] =
predicate match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.functions.col

/**
Expand All @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy(
Map(columnName -> columnType)
}

override def getAggregators: Seq[AggregateFunction] = {
Seq(First(col(columnName).expr, ignoreNulls = true))
override def getAggregators: Seq[Expression] = {
Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,40 @@

package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.function.CollectSetLimit.collect_set_limit
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._

/**
* Skipping strategy based on unique column value set.
*/
case class ValueSetSkippingStrategy(
override val kind: SkippingKind = VALUE_SET,
override val columnName: String,
override val columnType: String)
override val columnType: String,
override val properties: Map[String, String] = Map.empty)
extends FlintSparkSkippingStrategy {

override def outputSchema(): Map[String, String] =
Map(columnName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(CollectSet(col(columnName).expr))
override def getAggregators: Seq[Expression] = {
val limit = getValueSetSizeLimit()
val aggregator =
if (limit == 0) {
collect_set(columnName)
} else {
// val limitPlusOne = limit + 1
val collectSetLimit = collect_set(columnName) // collect_set_limit(columnName, limitPlusOne)
when(size(collectSetLimit) > limit, lit(null))
.otherwise(collectSetLimit)
}
Seq(aggregator.expr)
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
/*
Expand All @@ -37,4 +50,13 @@ case class ValueSetSkippingStrategy(
Some((col(columnName) === value).expr)
case _ => None
}

private def getValueSetSizeLimit(): Int =
properties.get("limit").map(_.toInt).getOrElse(DEFAULT_VALUE_SET_SIZE_LIMIT)
}

object ValueSetSkippingStrategy {

/** Default limit for value set size collected */
val DEFAULT_VALUE_SET_SIZE_LIMIT = 100
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION}
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._
Expand Down Expand Up @@ -42,11 +42,19 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A

ctx.indexColTypeList().indexColType().forEach { colTypeCtx =>
val colName = colTypeCtx.identifier().getText
val skipType = SkippingKind.withName(colTypeCtx.skipType.getText)
skipType match {
case PARTITION => indexBuilder.addPartitions(colName)
case VALUE_SET => indexBuilder.addValueSet(colName)
case MIN_MAX => indexBuilder.addMinMax(colName)
if (colTypeCtx.skipType == null) {
if (colTypeCtx.valueSetType().limit == null) {
indexBuilder.addValueSet(colName)
} else {
indexBuilder
.addValueSet(colName, Map("limit" -> colTypeCtx.valueSetType().limit.getText))
}
} else {
val skipType = SkippingKind.withName(colTypeCtx.skipType.getText)
skipType match {
case PARTITION => indexBuilder.addPartitions(colName)
case MIN_MAX => indexBuilder.addMinMax(colName)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.function

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.when
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}

class CollectSetLimitSuite extends FlintSuite with Matchers {

var expression: Expression = _

override def beforeEach(): Unit = {
super.beforeEach()

expression = mock[Expression]
when(expression.eval(any[InternalRow])).thenAnswer { invocation =>
val row = invocation.getArgument[InternalRow](0)
val firstValue = row.getInt(0)
Literal(firstValue)
}
}

test("should collect unique elements") {
val collectSetLimit = CollectSetLimit(expression, limit = 2)
var buffer = collectSetLimit.createAggregationBuffer()

Seq(InternalRow(1), InternalRow(1)).foreach(row =>
buffer = collectSetLimit.update(buffer, row))
assert(buffer.size == 1)
}

test("should collect unique elements up to the limit") {
val collectSetLimit = CollectSetLimit(expression, limit = 2)
var buffer = collectSetLimit.createAggregationBuffer()

Seq(InternalRow(1), InternalRow(2), InternalRow(3)).foreach(row =>
buffer = collectSetLimit.update(buffer, row))
assert(buffer.size == 2)
}
}
Loading
Loading