Skip to content

Commit

Permalink
update correlation command
Browse files Browse the repository at this point in the history
add test parts

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 11, 2023
1 parent f998427 commit 4ee2fbf
Show file tree
Hide file tree
Showing 14 changed files with 544 additions and 309 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.JoinHint.NONE
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCorrelationITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable1 = "spark_catalog.default.flint_ppl_test1"
private val testTable2 = "spark_catalog.default.flint_ppl_test2"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test tables
sql(s"""
| CREATE TABLE $testTable1
| (
| name STRING,
| age INT,
| state STRING,
| country STRING
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\t'
| )
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)

sql(s"""
| CREATE TABLE $testTable2
| (
| name STRING,
| occupation STRING,
| salary INT
| )
| USING CSV
| OPTIONS (
| header 'false',
| delimiter '\t'
| )
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)

// Update data insertion
sql(s"""
| INSERT INTO $testTable1
| PARTITION (year=2023, month=4)
| VALUES ('Jake', 70, 'California', 'USA'),
| ('Hello', 30, 'New York', 'USA'),
| ('John', 25, 'Ontario', 'Canada'),
| ('Jane', 20, 'Quebec', 'Canada')
| """.stripMargin)
// Insert data into the new table
sql(s"""
| INSERT INTO $testTable2
| PARTITION (year=2023, month=4)
| VALUES ('Jake', 'Engineer', 100000),
| ('Hello', 'Artist', 70000),
| ('John', 'Doctor', 120000),
| ('Jane', 'Scientist', 90000)
| """.stripMargin)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create ppl correlation query with two tables correlating on a single field test") {
val joinQuery =
s"""
| SELECT a.name, a.age, a.state, a.country, b.occupation, b.salary
| FROM $testTable1 AS a
| JOIN $testTable2 AS b
| ON a.name = b.name
| WHERE a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4
|""".stripMargin

val result = spark.sql(joinQuery)
result.show()

val frame = sql(s"""
| source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", 100000, 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", 70000, 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", 120000, 2023, 4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", 90000, 2023, 4))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
// Compare the results
assert(results.sorted.sameElements(expectedResults.sorted))

// Define unresolved relations
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))

// Define filter expressions
val filter1Expr = And(
EqualTo(UnresolvedAttribute("year"), Literal(2023)),
EqualTo(UnresolvedAttribute("month"), Literal(4)))
val filter2Expr = And(
EqualTo(UnresolvedAttribute("year"), Literal(2023)),
EqualTo(UnresolvedAttribute("month"), Literal(4)))
// Define subquery aliases
val plan1 = Filter(filter1Expr, table1)
val plan2 = Filter(filter2Expr, table2)

// Define join condition
val joinCondition =
EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name"))

// Create Join plan
val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE)

// Add the projection
val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ML: 'ML';

//CORRELATION KEYWORDS
CORRELATE: 'CORRELATE';
SELF: 'SELF';
EXACT: 'EXACT';
APPROXIMATE: 'APPROXIMATE';
SCOPE: 'SCOPE';
Expand Down
21 changes: 11 additions & 10 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,33 @@ describeCommand
;

showDataSourcesCommand
: SHOW DATASOURCES
;
: SHOW DATASOURCES
;

whereCommand
: WHERE logicalExpression
;
: WHERE logicalExpression
;

correlateCommand
: CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause (mappingList)?
;
: CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause mappingList
;

correlationType
: EXACT
: SELF
| EXACT
| APPROXIMATE
;

scopeClause
: SCOPE LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS
;
: SCOPE LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS
;

mappingList
: MAPPING LT_PRTHS ( mappingClause (COMMA mappingClause)* ) RT_PRTHS
;

mappingClause
: qualifiedName EQUAL qualifiedName
: left = qualifiedName comparisonOperator right = qualifiedName # mappingCompareExpr
;

fieldsCommand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
Expand Down Expand Up @@ -99,6 +100,10 @@ public T visitCorrelation(Correlation node, C context) {
return visitChildren(node, context);
}

public T visitCorrelationMapping(FieldsMapping node, C context) {
return visitChildren(node, context);
}

public T visitProject(Project node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

public class FieldsMapping extends UnresolvedExpression {


private final List<UnresolvedExpression> fieldsMappingList;

public <R> FieldsMapping(List<UnresolvedExpression> fieldsMappingList) {
this.fieldsMappingList = fieldsMappingList;
}
public List<UnresolvedExpression> getChild() {
return fieldsMappingList;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visit(this, context);
return nodeVisitor.visitCorrelationMapping(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public FieldsMapping getMappingListContext() {
}

public enum CorrelationType {
self,
exact,
approximate
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Union;
import scala.collection.Seq;
Expand Down Expand Up @@ -37,7 +36,11 @@ public class CatalystPlanContext {
* Grouping NamedExpression contextual parameters
**/
private final Stack<org.apache.spark.sql.catalyst.expressions.Expression> groupingParseExpressions = new Stack<>();


public Stack<LogicalPlan> getPlanBranches() {
return planBranches;
}

public LogicalPlan getPlan() {
if (this.planBranches.size() == 1) {
return planBranches.peek();
Expand All @@ -58,9 +61,10 @@ public Stack<Expression> getGroupingParseExpressions() {
* append context with evolving plan
*
* @param plan
* @return
*/
public void with(LogicalPlan plan) {
this.planBranches.push(plan);
public LogicalPlan with(LogicalPlan plan) {
return this.planBranches.push(plan);
}

public LogicalPlan plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
Expand All @@ -69,12 +73,22 @@ public LogicalPlan plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
}

/**
* retain all logical plans branches
* @return
*/
public <T> Seq<T> retainAllPlans(Function<LogicalPlan, T> transformFunction) {
Seq<T> plans = seq(getPlanBranches().stream().map(transformFunction).collect(Collectors.toList()));
getPlanBranches().retainAll(emptyList());
return plans;
}
/**
*
* retain all expressions and clear expression stack
* @return
*/
public <T> Seq<T> retainAllNamedParseExpressions(Function<Expression, T> transformFunction) {
Seq<T> aggregateExpressions = seq(getNamedParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList()));
.map(transformFunction).collect(Collectors.toList()));
getNamedParseExpressions().retainAll(emptyList());
return aggregateExpressions;
}
Expand All @@ -85,7 +99,7 @@ public <T> Seq<T> retainAllNamedParseExpressions(Function<Expression, T> transfo
*/
public <T> Seq<T> retainAllGroupingNamedParseExpressions(Function<Expression, T> transformFunction) {
Seq<T> aggregateExpressions = seq(getGroupingParseExpressions().stream()
.map(transformFunction::apply).collect(Collectors.toList()));
.map(transformFunction).collect(Collectors.toList()));
getGroupingParseExpressions().retainAll(emptyList());
return aggregateExpressions;
}
Expand Down
Loading

0 comments on commit 4ee2fbf

Please sign in to comment.