Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPL Parse command #595

Merged
merged 13 commits into from
Aug 23, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,42 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
}
}

protected def createPartitionedGrokEmailTable(testTable: String): Unit = {
spark.sql(s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT,
| email STRING,
| street_address STRING
| )
| USING $tableType $tableOptions
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)

val data = Seq(
("Alice", 30, "[email protected]", "123 Main St, Seattle", 2023, 4),
("Bob", 55, "[email protected]", "456 Elm St, Portland", 2023, 5),
("Charlie", 65, "[email protected]", "789 Pine St, San Francisco", 2023, 4),
("David", 19, "[email protected]", "101 Maple St, New York", 2023, 5),
("Eve", 21, "[email protected]", "202 Oak St, Boston", 2023, 4),
("Frank", 76, "[email protected]", "303 Cedar St, Austin", 2023, 5),
("Grace", 41, "[email protected]", "404 Birch St, Chicago", 2023, 4),
("Hank", 32, "[email protected]", "505 Spruce St, Miami", 2023, 5),
("Ivy", 9, "[email protected]", "606 Fir St, Denver", 2023, 4),
("Jack", 12, "[email protected]", "707 Ash St, Seattle", 2023, 5))

data.foreach { case (name, age, email, street_address, year, month) =>
spark.sql(s"""
| INSERT INTO $testTable
| PARTITION (year=$year, month=$month)
| VALUES ('$name', $age, '$email', '$street_address')
| """.stripMargin)
}
}
protected def createPartitionedAddressTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
Expand Down Expand Up @@ -241,6 +277,39 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| """.stripMargin)
}

protected def createOccupationTopRareTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
| (
| name STRING,
| occupation STRING,
| country STRING,
| salary INT
| )
| USING $tableType $tableOptions
| PARTITIONED BY (
| year INT,
| month INT
| )
|""".stripMargin)

// Insert data into the new table
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=4)
| VALUES ('Jake', 'Engineer', 'England' , 100000),
| ('Hello', 'Artist', 'USA', 70000),
| ('John', 'Doctor', 'Canada', 120000),
| ('Rachel', 'Doctor', 'Canada', 220000),
| ('Henry', 'Doctor', 'Canada', 220000),
| ('David', 'Engineer', 'USA', 320000),
| ('Barty', 'Engineer', 'USA', 120000),
| ('David', 'Unemployed', 'Canada', 0),
| ('Jane', 'Scientist', 'Canada', 90000),
| ('Philip', 'Scientist', 'Canada', 190000)
| """.stripMargin)
}

protected def createHobbiesTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import scala.reflect.internal.Reporter.Count

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLParseITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedGrokEmailTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test parse email expressions parsing") {
val frame = sql(s"""
| source = $testTable| parse email '.+@(?<host>.+)' | fields email, host ;
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("[email protected]", "domain.net"),
Row("[email protected]", "anotherdomain.com"),
Row("[email protected]", "demonstration.com"),
Row("[email protected]", "example.com"),
Row("[email protected]", "sample.org"),
Row("[email protected]", "demo.net"),
Row("[email protected]", "sample.net"),
Row("[email protected]", "examples.com"),
Row("[email protected]", "examples.com"),
Row("[email protected]", "test.org"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val emailAttribute = UnresolvedAttribute("email")
val hostAttribute = UnresolvedAttribute("host")
val hostExpression = Alias(
Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))),
"host")()
val expectedPlan = Project(
Seq(emailAttribute, hostAttribute),
Project(
Seq(emailAttribute, hostExpression, UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("test parse email expressions parsing filter & sort by age") {
val frame = sql(s"""
| source = $testTable| parse email '.+@(?<host>.+)' | where age > 45 | sort - age | fields age, email, host ;
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(76, "[email protected]", "sample.org"),
Row(65, "[email protected]", "domain.net"),
Row(55, "[email protected]", "test.org"))

// Compare the results
assert(results.sameElements(expectedResults))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val emailAttribute = UnresolvedAttribute("email")
val ageAttribute = UnresolvedAttribute("age")
val hostExpression = Alias(
Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))),
"host")()

// Define the corrected expected plan
val expectedPlan = Project(
Seq(ageAttribute, emailAttribute, UnresolvedAttribute("host")),
Sort(
Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)),
global = true,
Filter(
GreaterThan(ageAttribute, Literal(45)),
Project(
Seq(emailAttribute, hostExpression, UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))))
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test parse email expressions and group by count host ") {
val frame = sql(s"""
| source = $testTable| parse email '.+@(?<host>.+)' | stats count() by host
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(1L, "demonstration.com"),
Row(1L, "example.com"),
Row(1L, "domain.net"),
Row(1L, "anotherdomain.com"),
Row(1L, "sample.org"),
Row(1L, "demo.net"),
Row(1L, "sample.net"),
Row(2L, "examples.com"),
Row(1L, "test.org"))

// Sort both the results and the expected results
implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1)))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val emailAttribute = UnresolvedAttribute("email")
val hostAttribute = UnresolvedAttribute("host")
val hostExpression = Alias(
Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))),
"host")()

// Define the corrected expected plan
val expectedPlan = Project(
Seq(UnresolvedStar(None)), // Matches the '*' in the Project
Aggregate(
Seq(Alias(hostAttribute, "host")()), // Group by 'host'
Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(hostAttribute, "host")()),
Project(
Seq(emailAttribute, hostExpression, UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))
// Compare the logical plans
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test parse email expressions and top count_host ") {
val frame = sql(s"""
| source = $testTable| parse email '.+@(?<host>.+)' | top 1 host
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2L, "examples.com"))

// Sort both the results and the expected results
implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1)))
assert(results.sorted.sameElements(expectedResults.sorted))
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val emailAttribute = UnresolvedAttribute("email")
val hostAttribute = UnresolvedAttribute("host")
val hostExpression = Alias(
Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))),
"host")()

val sortedPlan = Sort(
Seq(
SortOrder(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false),
"count_host")(),
Descending,
NullsLast,
Seq.empty)),
global = true,
Aggregate(
Seq(hostAttribute),
Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false),
"count_host")(),
hostAttribute),
Project(
Seq(emailAttribute, hostExpression, UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))
// Define the corrected expected plan
val expectedPlan = Project(
Seq(UnresolvedStar(None)), // Matches the '*' in the Project
GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan)))
// Compare the logical plans
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}
}
Loading
Loading