Skip to content

Commit

Permalink
Update scala syntax
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Dec 13, 2024
1 parent d34abf1 commit 822ebd5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Equality$;
import org.apache.spark.sql.catalyst.expressions.Explode;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GeneratorOuter;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate$;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Generate;
Expand All @@ -32,6 +38,7 @@
import org.apache.spark.sql.execution.command.ExplainCommand;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.apache.spark.unsafe.types.UTF8String;
import org.opensearch.flint.spark.FlattenGenerator;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
Expand Down Expand Up @@ -266,37 +273,38 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) {
final String APPENDCOL_ID = WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME;
final String TABLE_LHS = "T1";
final String TABLE_RHS = "T2";
scala.collection.mutable.Seq<Expression> fieldsToRemove = seq(
UnresolvedAttribute$.MODULE$.apply(TABLE_LHS + "." + APPENDCOL_ID),
UnresolvedAttribute$.MODULE$.apply(TABLE_RHS + "." + APPENDCOL_ID));
final Compare innerJoinCondition = new Compare("=",
new Field(QualifiedName.of(TABLE_LHS ,APPENDCOL_ID)),
new Field(QualifiedName.of(TABLE_RHS, APPENDCOL_ID)));
final UnresolvedAttribute t1Attr = new UnresolvedAttribute(seq(TABLE_LHS, APPENDCOL_ID));
final UnresolvedAttribute t2Attr = new UnresolvedAttribute(seq(TABLE_RHS, APPENDCOL_ID));
final Seq<Expression> fieldsToRemove = seq(t1Attr, t2Attr);
final Node mainSearchNode = node.getChild().get(0);

// Add a new projection layer with * and ROW_NUMBER (Main-search)
LogicalPlan leftTemp = node.getChild().get(0).accept(this, context);
var mainSearch = getRowNumStarProjection(context, leftTemp, TABLE_LHS);
context.withSubqueryAlias(mainSearch);
LogicalPlan leftTemp = mainSearchNode.accept(this, context);
var mainSearchWithRowNumber = getRowNumStarProjection(context, leftTemp, TABLE_LHS);
context.withSubqueryAlias(mainSearchWithRowNumber);

// Traverse to look for relation clause then append it into the sub-search.
Relation relation = retrieveRelationClause(node.getChild().get(0));
Relation relation = retrieveRelationClause(mainSearchNode);
appendRelationClause(node.getSubSearch(), relation);

context.apply(left -> {
// Add a new projection layer with * and ROW_NUMBER (Sub-search)
LogicalPlan right = node.getSubSearch().accept(this, context);
var subSearch = getRowNumStarProjection(context, right, TABLE_RHS);
context.withSubqueryAlias(subSearch);
LogicalPlan subSearchNode = node.getSubSearch().accept(this, context);
var subSearchWithRowNumber = getRowNumStarProjection(context, subSearchNode, TABLE_RHS);
context.withSubqueryAlias(subSearchWithRowNumber);

Optional<Expression> joinCondition = Optional.of(innerJoinCondition)
.map(c -> expressionAnalyzer.analyzeJoinCondition(c, context));
context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);

LogicalPlan joinedQuery = join(mainSearch, subSearch, Join.JoinType.LEFT, joinCondition, new Join.JoinHint());
// Composite the join clause
LogicalPlan joinedQuery = join(
mainSearchWithRowNumber, subSearchWithRowNumber,
Join.JoinType.LEFT,
Optional.of(new EqualTo(t1Attr, t2Attr)),
new Join.JoinHint());

// Remove the APPEND_ID
return new org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns(fieldsToRemove, joinedQuery);
return new DataFrameDropColumns(fieldsToRemove, joinedQuery);
});
return context.getPlan();
}
Expand Down Expand Up @@ -334,24 +342,19 @@ private static Relation retrieveRelationClause(Node node) {

private org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias getRowNumStarProjection(CatalystPlanContext context, LogicalPlan lp, String alias) {

final String DUMMY_SORT_FIELD = "1";

expressionAnalyzer.visitLiteral(
new Literal(DUMMY_SORT_FIELD, DataType.STRING), context);
Expression strExp = context.popNamedParseExpressions().get();
// Literal("x")
SortOrder sortOrder = SortUtils.sortOrder(strExp, false);
SortOrder sortOrder = SortUtils.sortOrder(
new org.apache.spark.sql.catalyst.expressions.Literal(
UTF8String.fromString("1"), DataTypes.StringType), false);

NamedExpression appendCol = WindowSpecTransformer.buildRowNumber(seq(), seq(sortOrder));

List<NamedExpression> projectList = (context.getNamedParseExpressions().isEmpty())
? List.of(appendCol, UnresolvedStar$.MODULE$.apply(Option.empty()))
? List.of(appendCol, new UnresolvedStar(Option.empty()))
: List.of(appendCol);

LogicalPlan lpWithProjection = new org.apache.spark.sql.catalyst.plans.logical.Project(seq(
projectList), lp);
return SubqueryAlias$.MODULE$.apply(alias, lpWithProjection);

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class PPLLogicalPlanAppendColCommandTranslatorTestSuite
T12_COLUMNS_SEQ,
Join(t1, t2, LeftOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE)))

// scalastyle:off
println(logicalPlan)
println(result)
// scalastyle:on

comparePlans(logicalPlan, result, checkAnalysis = false)
}

Expand Down

0 comments on commit 822ebd5

Please sign in to comment.