Skip to content

Commit

Permalink
Fix ppl describe bug #612 (#656)
Browse files Browse the repository at this point in the history
* fix describe ppl command to be aware of 3 parts FQN table that includes schema.catalog.table

Signed-off-by: YANGDB <[email protected]>

* fix describe ppl command to be aware of 3 parts FQN table that includes schema.catalog.table

Signed-off-by: YANGDB <[email protected]>

---------

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB authored Sep 13, 2024
1 parent c6b388a commit 0be5697
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,48 +39,83 @@ class FlintSparkPPLBasicITSuite
}

test("describe (extended) table query test") {
val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`"
Seq(testTable, testTableQuoted).foreach { table =>
val frame = sql(s"""
val frame = sql(s"""
describe flint_ppl_test
""".stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("name", "string", null),
Row("age", "int", null),
Row("state", "string", null),
Row("country", "string", null),
Row("year", "int", null),
Row("month", "int", null),
Row("# Partition Information", "", ""),
Row("# col_name", "data_type", "comment"),
Row("year", "int", null),
Row("month", "int", null))

// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(
resultsSet.contains(expectedRow),
s"Expected row $expectedRow not found in results")
}
// Retrieve the logical plan
val logicalPlan: LogicalPlan =
frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan
// Define the expected logical plan
val expectedPlan: LogicalPlan =
DescribeTableCommand(
TableIdentifier("flint_ppl_test"),
Map.empty[String, String],
isExtended = true,
output = DescribeRelation.getOutputAttrs)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("name", "string", null),
Row("age", "int", null),
Row("state", "string", null),
Row("country", "string", null),
Row("year", "int", null),
Row("month", "int", null),
Row("# Partition Information", "", ""),
Row("# col_name", "data_type", "comment"),
Row("year", "int", null),
Row("month", "int", null))

// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results")
}
// Retrieve the logical plan
val logicalPlan: LogicalPlan =
frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan
// Define the expected logical plan
val expectedPlan: LogicalPlan =
DescribeTableCommand(
TableIdentifier("flint_ppl_test"),
Map.empty[String, String],
isExtended = true,
output = DescribeRelation.getOutputAttrs)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("describe (extended) FQN (2 parts) table query test") {
val frame = sql(s"""
describe default.flint_ppl_test
""".stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("name", "string", null),
Row("age", "int", null),
Row("state", "string", null),
Row("country", "string", null),
Row("year", "int", null),
Row("month", "int", null),
Row("# Partition Information", "", ""),
Row("# col_name", "data_type", "comment"),
Row("year", "int", null),
Row("month", "int", null))

// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results")
}
// Retrieve the logical plan
val logicalPlan: LogicalPlan =
frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan
// Define the expected logical plan
val expectedPlan: LogicalPlan =
DescribeTableCommand(
TableIdentifier("flint_ppl_test", Option("default")),
Map.empty[String, String],
isExtended = true,
output = DescribeRelation.getOutputAttrs)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("create ppl simple query test") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(1),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)));
} else if (node.getTableQualifiedName().getParts().size() == 3) {
identifier = new TableIdentifier(
node.getTableQualifiedName().getParts().get(2),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0)),
Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(1)));
} else {
throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName()
+ " Syntax: [ database_name. ] table_name");
Expand Down Expand Up @@ -164,10 +169,10 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
@Override
public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
context.reduce((left,right) -> {
context.reduce((left, right) -> {
visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context);
Seq<Expression> fields = context.retainAllNamedParseExpressions(e -> e);
if(!Objects.isNull(node.getScope())) {
if (!Objects.isNull(node.getScope())) {
// scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit)
expressionAnalyzer.visitSpan(node.getScope(), context);
context.popNamedParseExpressions().get();
Expand All @@ -188,7 +193,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
//add group by fields to context
context.getGroupingParseExpressions().addAll(groupExpList);
}

UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
span.accept(this, context);
Expand All @@ -212,8 +217,8 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
if ((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
}
return logicalPlan;
Expand Down Expand Up @@ -296,7 +301,7 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<UnresolvedExpression> aliases = new ArrayList<>();
List<Let> letExpressions = node.getExpressionList();
for(Let let : letExpressions) {
for (Let let : letExpressions) {
Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression());
aliases.add(alias);
}
Expand Down Expand Up @@ -353,7 +358,7 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) {
visitFieldList(node.getFields(), context);
// Columns to deduplicate
Seq<org.apache.spark.sql.catalyst.expressions.Attribute> dedupeFields
= context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e);
= context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e);
// Although we can also use the Window operator to translate this as allowedDuplication > 1 did,
// adding Aggregate operator could achieve better performance.
if (allowedDuplication == 1) {
Expand Down Expand Up @@ -388,6 +393,7 @@ public Expression visitLiteral(Literal node, CatalystPlanContext context) {

/**
* generic binary (And, Or, Xor , ...) arithmetic expression resolver
*
* @param node
* @param transformer
* @param context
Expand All @@ -398,11 +404,11 @@ public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction<Expres
Optional<Expression> left = context.popNamedParseExpressions();
node.getRight().accept(this, context);
Optional<Expression> right = context.popNamedParseExpressions();
if(left.isPresent() && right.isPresent()) {
return transformer.apply(left.get(),right.get());
} else if(left.isPresent()) {
if (left.isPresent() && right.isPresent()) {
return transformer.apply(left.get(), right.get());
} else if (left.isPresent()) {
return context.getNamedParseExpressions().push(left.get());
} else if(right.isPresent()) {
} else if (right.isPresent()) {
return context.getNamedParseExpressions().push(right.get());
}
return null;
Expand All @@ -412,25 +418,25 @@ public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction<Expres
@Override
public Expression visitAnd(And node, CatalystPlanContext context) {
return visitBinaryArithmetic(node,
(left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context);
(left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context);
}

@Override
public Expression visitOr(Or node, CatalystPlanContext context) {
return visitBinaryArithmetic(node,
(left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context);
(left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context);
}

@Override
public Expression visitXor(Xor node, CatalystPlanContext context) {
return visitBinaryArithmetic(node,
(left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context);
(left, right) -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context);
}

@Override
public Expression visitNot(Not node, CatalystPlanContext context) {
node.getExpression().accept(this, context);
Optional<Expression> arg = context.popNamedParseExpressions();
Optional<Expression> arg = context.popNamedParseExpressions();
return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null);
}

Expand Down Expand Up @@ -474,7 +480,7 @@ public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext con
}
return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts())));
}

@Override
public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) {
return node.getChild().stream().map(expression ->
Expand Down Expand Up @@ -513,18 +519,18 @@ public Expression visitEval(Eval node, CatalystPlanContext context) {
@Override
public Expression visitFunction(Function node, CatalystPlanContext context) {
List<Expression> arguments =
node.getFuncArgs().stream()
.map(
unresolvedExpression -> {
var ret = analyze(unresolvedExpression, context);
if (ret == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", unresolvedExpression));
} else {
return context.popNamedParseExpressions().get();
}
})
.collect(Collectors.toList());
node.getFuncArgs().stream()
.map(
unresolvedExpression -> {
var ret = analyze(unresolvedExpression, context);
if (ret == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", unresolvedExpression));
} else {
return context.popNamedParseExpressions().get();
}
})
.collect(Collectors.toList());
Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments);
return context.getNamedParseExpressions().push(function);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
thrown.getMessage === "Invalid table name: t.b.c.d Syntax: [ database_name. ] table_name")
}

test("test describe FQN table clause") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "describe schema.default.http_logs", false), context)

val expectedPlan = DescribeTableCommand(
TableIdentifier("http_logs", Option("schema"), Option("default")),
Map.empty[String, String].empty,
isExtended = true,
output = DescribeRelation.getOutputAttrs)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple describe clause") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "describe t", false), context)
Expand Down

0 comments on commit 0be5697

Please sign in to comment.