Skip to content

Commit

Permalink
Restrict the maximum size of collect set output
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Dec 27, 2023
1 parent 9061eb9 commit feb67b3
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class FlintSparkSkippingIndex(
// Wrap aggregate function with output column name
val namedAggFuncs =
(outputNames, aggFuncs).zipped.map { case (name, aggFunc) =>
new Column(aggFunc.toAggregateExpression().as(name))
new Column(aggFunc.as(name))
}

df.getOrElse(spark.read.table(tableName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -42,7 +41,7 @@ trait FlintSparkSkippingStrategy {
* @return
* aggregators that generate skipping data structure
*/
def getAggregators: Seq[AggregateFunction]
def getAggregators: Seq[Expression]

/**
* Rewrite a filtering condition on source table into a new predicate on index data based on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.functions.col

Expand All @@ -29,8 +29,11 @@ case class MinMaxSkippingStrategy(
override def outputSchema(): Map[String, String] =
Map(minColName -> columnType, maxColName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(Min(col(columnName).expr), Max(col(columnName).expr))
override def getAggregators: Seq[Expression] = {
Seq(
Min(col(columnName).expr).toAggregateExpression(),
Max(col(columnName).expr).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
predicate match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.functions.col

/**
Expand All @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy(
Map(columnName -> columnType)
}

override def getAggregators: Seq[AggregateFunction] = {
Seq(First(col(columnName).expr, ignoreNulls = true))
override def getAggregators: Seq[Expression] = {
Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._

/**
* Skipping strategy based on unique column value set.
Expand All @@ -24,8 +24,16 @@ case class ValueSetSkippingStrategy(
override def outputSchema(): Map[String, String] =
Map(columnName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(CollectSet(col(columnName).expr))
override def getAggregators: Seq[Expression] = {
val limit = DEFAULT_VALUE_SET_SIZE_LIMIT
val collectSetLimit = collect_set(columnName)

// IF(ARRAY_SIZE(COLLECT_SET(col)) > default_limit, null, COLLECT_SET(col)
val aggregator =
when(size(collectSetLimit) > limit, lit(null))
.otherwise(collectSetLimit)
Seq(aggregator.expr)
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
/*
Expand All @@ -38,3 +46,9 @@ case class ValueSetSkippingStrategy(
case _ => None
}
}

object ValueSetSkippingStrategy {

/** Default limit for value set size collected */
val DEFAULT_VALUE_SET_SIZE_LIMIT = 100
}
Loading

0 comments on commit feb67b3

Please sign in to comment.