Skip to content

Commit

Permalink
set the correlation scope parameter as optional
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 19, 2023
1 parent c6649ad commit a3df76f
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ class FlintSparkPPLCorrelationITSuite
assert(
thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ")
}

test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") {
val thrown = intercept[IllegalStateException] {
val frame = sql(s"""
| source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name)
| """.stripMargin)
}
assert(
thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ")
}

test(
"create failing ppl correlation query - due to mismatch correlation self type and source amount test") {
Expand Down Expand Up @@ -293,6 +303,60 @@ class FlintSparkPPLCorrelationITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl correlation approximate query with two tables correlating on a single field and not scope test") {
val frame = sql(s"""
| source = $testTable1, $testTable2| correlate approximate fields(name) mapping($testTable1.name = $testTable2.name)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4),
Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4),
Row(
"Jake",
70,
"California",
"USA",
2023,
4,
"Jake",
"Engineer",
"England",
100000,
2023,
4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4),
Row("Jim", 27, "B.C", "Canada", 2023, 4, null, null, null, null, null, null),
Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4),
Row("Peter", 57, "B.C", "Canada", 2023, 4, null, null, null, null, null, null),
Row("Rick", 70, "B.C", "Canada", 2023, 4, null, null, null, null, null, null))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
// Compare the results
assert(results.sorted.sameElements(expectedResults.sorted))

// Define unresolved relations
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
// Define join condition
val joinCondition =
EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name"))

// Create Join plan
val joinPlan = Join(table1, table2, FullOuter, Some(joinCondition), JoinHint.NONE)

// Add the projection
val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl correlation query with with filters and two tables correlating on a two fields test") {
Expand Down Expand Up @@ -562,6 +626,64 @@ class FlintSparkPPLCorrelationITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
test(
"create ppl correlation (exact) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) with country filter without scope test") {
val frame = sql(s"""
| source = $testTable1, $testTable2 | where country = 'USA' OR country = 'England' |
| correlate exact fields(name) mapping($testTable1.name = $testTable2.name) |
| stats avg(salary) by span(age, 10) as age_span, $testTable2.country
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(120000.0, "USA", 40), Row(100000.0, "England", 70), Row(70000.0, "USA", 30))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
// Compare the results
assert(results.sorted.sameElements(expectedResults.sorted))

// Define unresolved relations
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))

// Define filter expressions
val filter1Expr = Or(
EqualTo(UnresolvedAttribute("country"), Literal("USA")),
EqualTo(UnresolvedAttribute("country"), Literal("England")))
val filter2Expr = Or(
EqualTo(UnresolvedAttribute("country"), Literal("USA")),
EqualTo(UnresolvedAttribute("country"), Literal("England")))
// Define subquery aliases
val plan1 = Filter(filter1Expr, table1)
val plan2 = Filter(filter2Expr, table2)

// Define join condition
val joinCondition =
EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name"))

// Create Join plan
val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE)

val salaryField = UnresolvedAttribute("salary")
val countryField = UnresolvedAttribute(s"$testTable2.country")
val countryAlias = Alias(countryField, s"$testTable2.country")()
val star = Seq(UnresolvedStar(None))
val aggregateExpressions =
Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan =
Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan)
// Add the projection
val expectedPlan = Project(star, aggregatePlan)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl correlation (approximate) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) test") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ whereCommand
;

correlateCommand
: CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause mappingList
: CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS (scopeClause)? mappingList
;

correlationType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ public class Scope extends Span {
public Scope(UnresolvedExpression field, UnresolvedExpression value, SpanUnit unit) {
super(field, value, unit);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex
context.reduce((left,right) -> {
visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context);
Seq<Expression> fields = context.retainAllNamedParseExpressions(e -> e);
expressionAnalyzer.visitSpan(node.getScope(), context);
Expression scope = context.popNamedParseExpressions().get();
if(!Objects.isNull(node.getScope())) {
// scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit)
expressionAnalyzer.visitSpan(node.getScope(), context);
context.popNamedParseExpressions().get();
}
expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context);
Seq<Expression> mapping = context.retainAllNamedParseExpressions(e -> e);
return join(node.getCorrelationType(), fields, scope, mapping, left, right);
return join(node.getCorrelationType(), fields, mapping, left, right);
});
return context.getPlan();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommand
ctx.fieldList().fieldExpression().stream()
.map(this::internalVisitExpression)
.collect(Collectors.toList()),
new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()),
Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()),
expressionBuilder.visit(ctx.scopeClause().value),
SpanUnit.of(ctx.scopeClause().unit.getText())),
SpanUnit.of(Objects.isNull(ctx.scopeClause().unit) ? "" : ctx.scopeClause().unit.getText())),
Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList()
.mappingClause().stream()
.map(this::internalVisitExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public interface JoinSpecTransformer {
/**
* @param correlationType the correlation type which can be exact (inner join) or approximate (outer join)
* @param fields - fields (columns) that needed to be joined by
* @param scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit)
* @param mapping - in case fields in different relations have different name, that can be aliased with the following names
* @return
*/
static LogicalPlan join(Correlation.CorrelationType correlationType, Seq<Expression> fields, Expression scope, Seq<Expression> mapping, LogicalPlan left, LogicalPlan right) {
static LogicalPlan join(Correlation.CorrelationType correlationType, Seq<Expression> fields, Seq<Expression> mapping, LogicalPlan left, LogicalPlan right) {
//create a join statement - which will replace all the different plans with a single plan which contains the joined plans
switch (correlationType) {
case self:
Expand Down

0 comments on commit a3df76f

Please sign in to comment.