Skip to content

Commit

Permalink
Make it basically work
Browse files Browse the repository at this point in the history
  • Loading branch information
salyh committed Jul 2, 2024
1 parent 0935fcc commit 0a81105
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLLookupITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val lookupTable = "spark_catalog.default.flint_ppl_test_lookup"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createOccupationTable(lookupTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create ppl simple query test") {
val frame = sql(s"""
| source = $testTable | where age > 20 | lookup flint_ppl_test_lookup name
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()

assert(results.length == 3)

// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", "England", 100000, 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(
Seq(UnresolvedStar(None)),
Join(
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")),
JoinType.apply("left"),
Option.empty,
JoinHint.NONE
)
//UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
Expand All @@ -23,6 +22,7 @@
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
Expand Down Expand Up @@ -63,6 +63,7 @@
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.SortUtils;
import org.sparkproject.guava.collect.Iterables;
import scala.Option;
import scala.collection.Seq;

Expand Down Expand Up @@ -266,6 +267,14 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) {

@Override
public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
Node root = node.getChild().get(0);

while(!root.getChild().isEmpty()) {
root = root.getChild().get(0);
}

org.opensearch.sql.ast.tree.Relation source = (org.opensearch.sql.ast.tree.Relation) root;

node.getChild().get(0).accept(this, context);

//TODO: not sure how to implement appendonly
Expand All @@ -275,7 +284,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
//TODO: use node.getCopyFieldList() to prefilter the right logical plan
//and return only the fields listed there. rename fields when requested

Expression joinExpression = visitFieldMap(node.getMatchFieldList());
Expression joinCondition = visitFieldMap(node.getMatchFieldList(), source.getTableQualifiedName().toString(), node.getIndexName(), context);

return context.apply(p -> new Join(

Expand All @@ -285,37 +294,37 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {

JoinType.apply("left"), //https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-join.html

Option.apply(joinExpression), //which fields to join
Option.apply(joinCondition), //which fields to join

JoinHint.NONE() //TODO: check, https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-hints.html#join-hints-types
));
}

private Expression visitFieldMap(List<Map> fieldMap) {
private org.opensearch.sql.ast.expression.Field prefixField(List<String> prefixParts, UnresolvedExpression field) {
org.opensearch.sql.ast.expression.Field in = (org.opensearch.sql.ast.expression.Field) field;
org.opensearch.sql.ast.expression.QualifiedName inq = (org.opensearch.sql.ast.expression.QualifiedName) in.getField();
Iterable finalParts = Iterables.concat(prefixParts, inq.getParts());
return new org.opensearch.sql.ast.expression.Field(new org.opensearch.sql.ast.expression.QualifiedName(finalParts), in.getFieldArgs());
}

private Expression visitFieldMap(List<Map> fieldMap, String sourceTableName, String lookupTableName, CatalystPlanContext context) {
int size = fieldMap.size();

List<Expression> allEqlExpressions = new ArrayList<>(size);

for (Map map : fieldMap) {
Expression eql = new EqualTo(new UnresolvedFieldName(seq(of(((Field) map.getTarget()).getField().toString()))),
new UnresolvedFieldName(seq(of(((Field) map.getOrigin()).getField().toString()))));

Expression origin = visitExpression(prefixField(of(sourceTableName.split("\\.")),map.getOrigin()), context);
Expression target = visitExpression(prefixField(of(lookupTableName.split("\\.")),map.getTarget()), context);

//important
context.retainAllNamedParseExpressions(e -> e);

Expression eql = new EqualTo(origin, target);
allEqlExpressions.add(eql);
}

if(size == 1) {
return allEqlExpressions.get(0);
} else if(size == 2) {
return new org.apache.spark.sql.catalyst.expressions.And(allEqlExpressions.get(0),allEqlExpressions.get(1));
} else {
//2 and(1,2) -> 1 * and
//3 -> and(1, and(2,3)) -> 2 * and
//4 -> and(and(1,2), and(3,4)) -> 3 * and
//5 -> and(and(1, and(2,3)),and(4,5)) -> 4* and
//6 -> and(and(and(1,2), and(3,4)), and(5,6)) -> 5* and

//TODO: implement
throw new RuntimeException("not implemented");
}
return allEqlExpressions.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PPLLogicalPlanLookupTranslatorTestSuite
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table | lookup a b,c as d appendonly=true q,w as z ", false), context)
planTransformer.visit(plan(pplParser, "source = table | lookup a b,c as d, e as f,g as b, j appendonly=true q,w as z ", false), context)
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
Expand Down

0 comments on commit 0a81105

Please sign in to comment.