diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala new file mode 100644 index 000000000..2815f6031 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLLookupITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + private val lookupTable = "spark_catalog.default.flint_ppl_test_lookup" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + createOccupationTable(lookupTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple query test") { + val frame = sql(s""" + | source = $testTable | where age > 20 | lookup flint_ppl_test_lookup name + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + assert(results.length == 3) + + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", "England", 100000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project( + Seq(UnresolvedStar(None)), + Join( + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")), + JoinType.apply("left"), + Option.empty, + JoinHint.NONE + ) + //UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + ) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } +} + + diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 2c9f9c1ec..6d4f014d1 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -6,7 +6,6 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; -import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.EqualTo; @@ -23,6 +22,7 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; @@ -63,6 +63,7 @@ import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.sparkproject.guava.collect.Iterables; import scala.Option; import scala.collection.Seq; @@ -266,6 +267,14 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { @Override public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { + Node root = node.getChild().get(0); + + while(!root.getChild().isEmpty()) { + root = root.getChild().get(0); + } + + org.opensearch.sql.ast.tree.Relation source = (org.opensearch.sql.ast.tree.Relation) root; + node.getChild().get(0).accept(this, context); //TODO: not sure how to implement appendonly @@ -275,7 +284,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { //TODO: use node.getCopyFieldList() to prefilter the right logical plan //and return only the fields listed there. rename fields when requested - Expression joinExpression = visitFieldMap(node.getMatchFieldList()); + Expression joinCondition = visitFieldMap(node.getMatchFieldList(), source.getTableQualifiedName().toString(), node.getIndexName(), context); return context.apply(p -> new Join( @@ -285,37 +294,37 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { JoinType.apply("left"), //https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-join.html - Option.apply(joinExpression), //which fields to join + Option.apply(joinCondition), //which fields to join JoinHint.NONE() //TODO: check, https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-hints.html#join-hints-types )); } - private Expression visitFieldMap(List fieldMap) { + private org.opensearch.sql.ast.expression.Field prefixField(List prefixParts, UnresolvedExpression field) { + org.opensearch.sql.ast.expression.Field in = (org.opensearch.sql.ast.expression.Field) field; + org.opensearch.sql.ast.expression.QualifiedName inq = (org.opensearch.sql.ast.expression.QualifiedName) in.getField(); + Iterable finalParts = Iterables.concat(prefixParts, inq.getParts()); + return new org.opensearch.sql.ast.expression.Field(new org.opensearch.sql.ast.expression.QualifiedName(finalParts), in.getFieldArgs()); + } + + private Expression visitFieldMap(List fieldMap, String sourceTableName, String lookupTableName, CatalystPlanContext context) { int size = fieldMap.size(); List allEqlExpressions = new ArrayList<>(size); for (Map map : fieldMap) { - Expression eql = new EqualTo(new UnresolvedFieldName(seq(of(((Field) map.getTarget()).getField().toString()))), - new UnresolvedFieldName(seq(of(((Field) map.getOrigin()).getField().toString())))); + + Expression origin = visitExpression(prefixField(of(sourceTableName.split("\\.")),map.getOrigin()), context); + Expression target = visitExpression(prefixField(of(lookupTableName.split("\\.")),map.getTarget()), context); + + //important + context.retainAllNamedParseExpressions(e -> e); + + Expression eql = new EqualTo(origin, target); allEqlExpressions.add(eql); } - if(size == 1) { - return allEqlExpressions.get(0); - } else if(size == 2) { - return new org.apache.spark.sql.catalyst.expressions.And(allEqlExpressions.get(0),allEqlExpressions.get(1)); - } else { - //2 and(1,2) -> 1 * and - //3 -> and(1, and(2,3)) -> 2 * and - //4 -> and(and(1,2), and(3,4)) -> 3 * and - //5 -> and(and(1, and(2,3)),and(4,5)) -> 4* and - //6 -> and(and(and(1,2), and(3,4)), and(5,6)) -> 5* and - - //TODO: implement - throw new RuntimeException("not implemented"); - } + return allEqlExpressions.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); } /** diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala index e0ab05432..5959d2193 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanLookupTranslatorTestSuite.scala @@ -26,7 +26,7 @@ class PPLLogicalPlanLookupTranslatorTestSuite // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source = table | lookup a b,c as d appendonly=true q,w as z ", false), context) + planTransformer.visit(plan(pplParser, "source = table | lookup a b,c as d, e as f,g as b, j appendonly=true q,w as z ", false), context) val star = Seq(UnresolvedStar(None)) val priceField = UnresolvedAttribute("price")