Skip to content

Commit

Permalink
add basic span support for aggregate based queries
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Sep 13, 2023
1 parent fe11134 commit 7e5e0d1
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.flint.spark

import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, EqualTo, GreaterThan, LessThan, LessThanOrEqual, Literal, Not, Or}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}
Expand Down Expand Up @@ -631,4 +631,140 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* +--------+-------+-----------+
* |age_span| count_age|
* +--------+-------+-----------+
* | 20| 2 |
* | 30| 1 |
* | 70| 1 |
* +--------+-------+-----------+
*/
test("create ppl simple count age by span of interval of 10 years query test ") {
val frame = sql(
s"""
| source = $testTable| stats count(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(1, 70L),
Row(1, 30L),
Row(2, 20L),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")()
val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* +--------+-------+-----------+
* |age_span| average_age|
* +--------+-------+-----------+
* | 20| 22.5 |
* | 30| 30 |
* | 70| 70 |
* +--------+-------+-----------+
*/
test("create ppl simple avg age by span of interval of 10 years query test ") {
val frame = sql(
s"""
| source = $testTable| stats avg(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(70D, 70L),
Row(30D, 30L),
Row(22.5D, 20L),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* +--------+-------+-----------+
* |age_span|country|average_age|
* +--------+-------+-----------+
* | 20| Canada| 22.5|
* | 30| USA| 30|
* | 70| USA| 70|
* +--------+-------+-----------+
*/
ignore("create ppl average age by span of interval of 10 years group by country query test ") {
val frame = sql(
s"""
| source = $testTable | stats avg(age) by span(age, 10) as age_span, country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(1, 70L),
Row(1, 30L),
Row(2, 20L),
)

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")()
val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ public Span(UnresolvedExpression field, UnresolvedExpression value, SpanUnit uni
this.unit = unit;
}

public UnresolvedExpression getField() {
return field;
}

public UnresolvedExpression getValue() {
return value;
}

public SpanUnit getUnit() {
return unit;
}

@Override
public List<UnresolvedExpression> getChild() {
return ImmutableList.of(field, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Divide;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Floor;
import org.apache.spark.sql.catalyst.expressions.Multiply;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
Expand All @@ -34,6 +37,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.statement.Explain;
import org.opensearch.sql.ast.statement.Query;
Expand All @@ -55,6 +59,7 @@
import scala.collection.Seq;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static com.google.common.base.Strings.isNullOrEmpty;
Expand Down Expand Up @@ -147,25 +152,33 @@ public String visitAggregation(Aggregation node, CatalystPlanContext context) {
final String visitExpressionList = visitExpressionList(node.getAggExprList(), context);
final String group = visitExpressionList(node.getGroupExprList(), context);


if(!isNullOrEmpty(group)) {
NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek();
Seq<NamedExpression> namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream()
.map(v->(NamedExpression)v).collect(Collectors.toList())).toSeq();
//now remove all context.getNamedParseExpressions()
context.getNamedParseExpressions().retainAll(emptyList());
context.plan(p->new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)),namedExpressionSeq,p));
if (!isNullOrEmpty(group)) {
extractedAggregation(context);
}
UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
span.accept(this, context);
extractedAggregation(context);
}
return format(
"%s | stats %s",
child, String.join(" ", visitExpressionList, groupBy(group)).trim());
}

@Override
public String visitSpan(Span node, CatalystPlanContext context) {
return super.visitSpan(node, context);
private static void extractedAggregation(CatalystPlanContext context) {
NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek();
Seq<NamedExpression> namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream()
.map(v -> (NamedExpression) v).collect(Collectors.toList())).toSeq();
//now remove all context.getNamedParseExpressions()
context.getNamedParseExpressions().retainAll(emptyList());
context.plan(p -> new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)), namedExpressionSeq, p));
}

@Override
public String visitAlias(Alias node, CatalystPlanContext context) {
return expressionAnalyzer.visitAlias(node, context);
}

@Override
public String visitRareTopN(RareTopN node, CatalystPlanContext context) {
final String child = node.getChild().get(0).accept(this, context);
Expand All @@ -190,7 +203,7 @@ public String visitProject(Project node, CatalystPlanContext context) {

// Create a projection list from the existing expressions
Seq<?> projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq();
if(!projectList.isEmpty()) {
if (!projectList.isEmpty()) {
// build the plan with the projection step
context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq<NamedExpression>) projectList, p));
}
Expand Down Expand Up @@ -296,7 +309,7 @@ public String visitAnd(And node, CatalystPlanContext context) {
String left = node.getLeft().accept(this, context);
String right = node.getRight().accept(this, context);
context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.And(
(Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop()));
(Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop()));
return format("%s and %s", left, right);
}

Expand All @@ -305,7 +318,7 @@ public String visitOr(Or node, CatalystPlanContext context) {
String left = node.getLeft().accept(this, context);
String right = node.getRight().accept(this, context);
context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Or(
(Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop()));
(Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop()));
return format("%s or %s", left, right);
}

Expand All @@ -314,7 +327,7 @@ public String visitXor(Xor node, CatalystPlanContext context) {
String left = node.getLeft().accept(this, context);
String right = node.getRight().accept(this, context);
context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(
(Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop()));
(Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop()));
return format("%s xor %s", left, right);
}

Expand All @@ -328,7 +341,14 @@ public String visitNot(Not node, CatalystPlanContext context) {

@Override
public String visitSpan(Span node, CatalystPlanContext context) {
return super.visitSpan(node, context);
String field = node.getField().accept(this, context);
String value = node.getValue().accept(this, context);
String unit = node.getUnit().name();

Expression valueExpression = context.getNamedParseExpressions().pop();
Expression fieldExpression = context.getNamedParseExpressions().pop();
context.getNamedParseExpressions().push(new Multiply(new Floor(new Divide(fieldExpression, valueExpression)), valueExpression));
return format("span (%s,%s,%s)", field, value, unit);
}

@Override
Expand Down Expand Up @@ -366,7 +386,7 @@ public String visitField(Field node, CatalystPlanContext context) {
@Override
public String visitAllFields(AllFields node, CatalystPlanContext context) {
// Case of aggregation step - no start projection can be added
if(!context.getNamedParseExpressions().isEmpty()) {
if (!context.getNamedParseExpressions().isEmpty()) {
// if named expression exist - just return their names
return context.getNamedParseExpressions().peek().toString();
} else {
Expand All @@ -376,6 +396,11 @@ public String visitAllFields(AllFields node, CatalystPlanContext context) {
}
}

@Override
public String visitWindowFunction(WindowFunction node, CatalystPlanContext context) {
return super.visitWindowFunction(node, context);
}

@Override
public String visitAlias(Alias node, CatalystPlanContext context) {
String expr = node.getDelegated().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

class PPLLogicalPlanSimpleTranslatorTestSuite
class PPLLogicalAdvancedTranslatorTestSuite
extends SparkFunSuite
with Matchers {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, EqualTo, Floor, Literal, Multiply}
import org.apache.spark.sql.catalyst.plans.logical._
import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
Expand Down Expand Up @@ -104,5 +104,39 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
assertEquals(compareByString(expectedPlan), compareByString(context.getPlan))
}

test("create ppl simple avg age by span of interval of 10 years query test ") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source = table | stats avg(age) by span(age, 10) as age_span", false), context)
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *")
assert(compareByString(expectedPlan) === compareByString(context.getPlan))
}

ignore("create ppl simple avg age by span of interval of 10 years by country query test ") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source = table | stats avg(age) by span(age, 10) as age_span, country", false), context)
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))

val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")()
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *")
assert(compareByString(expectedPlan) === compareByString(context.getPlan))
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

class PPLLogicalPlanComplexQueriesTranslatorTestSuite
class PPLLogicalPlanBasicQueriesTranslatorTestSuite
extends SparkFunSuite
with Matchers {

Expand Down

0 comments on commit 7e5e0d1

Please sign in to comment.