Skip to content

Commit

Permalink
Merge branch 'main' into pr/issues/476
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoJin authored Jul 26, 2024
2 parents 9e2f85f + 98bd79a commit b343aad
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 51 deletions.
7 changes: 6 additions & 1 deletion DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ If you get integration test failures with error message "Previous attempts to fi
### AWS Integration Test
The integration folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain.
```
export AWS_OPENSEARCH_HOST=search-xxx.aos.us-west-2.on.aws
export AWS_OPENSEARCH_HOST=search-xxx.us-west-2.on.aws
export AWS_REGION=us-west-2
export AWS_EMRS_APPID=xxx
export AWS_EMRS_EXECUTION_ROLE=xxx
export AWS_S3_CODE_BUCKET=xxx
export AWS_S3_CODE_PREFIX=xxx
export AWS_OPENSEARCH_RESULT_INDEX=query_execution_result_glue
```
And run the
```
Expand Down
27 changes: 19 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import Dependencies._

lazy val scala212 = "2.12.14"
lazy val sparkVersion = "3.3.2"
// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version hould compatbile
// https://github.com/FasterXML/jackson-module-scala/blob/2.18/src/main/scala/com/fasterxml/jackson/module/scala/JacksonModule.scala#L59
lazy val jacksonVersion = "2.13.4"

// The transitive opensearch jackson-databind dependency version should align with Spark jackson databind dependency version.
// Issue: https://github.com/opensearch-project/opensearch-spark/issues/442
Expand Down Expand Up @@ -49,7 +52,11 @@ lazy val commonSettings = Seq(
compileScalastyle := (Compile / scalastyle).toTask("").value,
Compile / compile := ((Compile / compile) dependsOn compileScalastyle).value,
testScalastyle := (Test / scalastyle).toTask("").value,
Test / test := ((Test / test) dependsOn testScalastyle).value)
Test / test := ((Test / test) dependsOn testScalastyle).value,
dependencyOverrides ++= Seq(
"com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion,
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion
))

// running `scalafmtAll` includes all subprojects under root
lazy val root = (project in file("."))
Expand Down Expand Up @@ -218,9 +225,16 @@ lazy val integtest = (project in file("integ-test"))
commonSettings,
name := "integ-test",
scalaVersion := scala212,
inConfig(IntegrationTest)(Defaults.testSettings),
IntegrationTest / scalaSource := baseDirectory.value / "src/integration/scala",
IntegrationTest / parallelExecution := false,
javaOptions ++= Seq(
s"-DappJar=${(sparkSqlApplication / assembly).value.getAbsolutePath}",
s"-DextensionJar=${(flintSparkIntegration / assembly).value.getAbsolutePath}",
s"-DpplJar=${(pplSparkIntegration / assembly).value.getAbsolutePath}",
),
inConfig(IntegrationTest)(Defaults.testSettings ++ Seq(
IntegrationTest / scalaSource := baseDirectory.value / "src/integration/scala",
IntegrationTest / parallelExecution := false,
IntegrationTest / fork := true,
)),
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
Expand All @@ -229,10 +243,7 @@ lazy val integtest = (project in file("integ-test"))
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
"org.testcontainers" % "testcontainers" % "1.18.0" % "test",
"org.apache.iceberg" %% s"iceberg-spark-runtime-$sparkMinorVersion" % icebergVersion % "test",
"org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0" % "test",
// add opensearch-java client to get node stats
"org.opensearch.client" % "opensearch-java" % "2.6.0" % "test"
exclude ("com.fasterxml.jackson.core", "jackson-databind")),
"org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0" % "test"),
libraryDependencies ++= deps(sparkVersion),
Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value,
(sparkSqlApplication / assembly).value
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.aws

import java.time.LocalDateTime

import scala.concurrent.duration.DurationInt

import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder
import com.amazonaws.services.emrserverless.model.{GetJobRunRequest, JobDriver, SparkSubmit, StartJobRunRequest}
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import org.apache.spark.internal.Logging

class AWSEmrServerlessAccessTestSuite
extends AnyFlatSpec
with BeforeAndAfter
with Matchers
with Logging {

lazy val testHost: String = System.getenv("AWS_OPENSEARCH_HOST")
lazy val testPort: Int = -1
lazy val testRegion: String = System.getenv("AWS_REGION")
lazy val testScheme: String = "https"
lazy val testAuth: String = "sigv4"

lazy val testAppId: String = System.getenv("AWS_EMRS_APPID")
lazy val testExecutionRole: String = System.getenv("AWS_EMRS_EXECUTION_ROLE")
lazy val testS3CodeBucket: String = System.getenv("AWS_S3_CODE_BUCKET")
lazy val testS3CodePrefix: String = System.getenv("AWS_S3_CODE_PREFIX")
lazy val testResultIndex: String = System.getenv("AWS_OPENSEARCH_RESULT_INDEX")

"EMR Serverless job" should "run successfully" in {
val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build()
val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build()

val appJarPath =
sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set"))
val extensionJarPath = sys.props.getOrElse(
"extensionJar",
throw new IllegalArgumentException("extensionJar not set"))
val pplJarPath =
sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set"))

s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/sql-job.jar",
new java.io.File(appJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/extension.jar",
new java.io.File(extensionJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/ppl.jar",
new java.io.File(pplJarPath))

val jobRunRequest = new StartJobRunRequest()
.withApplicationId(testAppId)
.withExecutionRoleArn(testExecutionRole)
.withName(s"integration-${LocalDateTime.now()}")
.withJobDriver(new JobDriver()
.withSparkSubmit(new SparkSubmit()
.withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar")
.withEntryPointArguments(testResultIndex)
.withSparkSubmitParameters(s"--class org.apache.spark.sql.FlintJob --jars " +
s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar," +
s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar " +
s"--conf spark.datasource.flint.host=$testHost " +
s"--conf spark.datasource.flint.port=-1 " +
s"--conf spark.datasource.flint.scheme=$testScheme " +
s"--conf spark.datasource.flint.auth=$testAuth " +
s"--conf spark.sql.catalog.glue=org.opensearch.sql.FlintDelegatingSessionCatalog " +
s"--conf spark.flint.datasource.name=glue " +
s"""--conf spark.flint.job.query="SELECT 1" """ +
s"--conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")))

val jobRunResponse = emrServerless.startJobRun(jobRunRequest)

val startTime = System.currentTimeMillis()
val timeout = 5.minutes.toMillis
var jobState = "STARTING"

while (System.currentTimeMillis() - startTime < timeout
&& (jobState != "FAILED" && jobState != "SUCCESS")) {
Thread.sleep(30000)
val request = new GetJobRunRequest()
.withApplicationId(testAppId)
.withJobRunId(jobRunResponse.getJobRunId)
jobState = emrServerless.getJobRun(request).getJobRun.getState
logInfo(s"Current job state: $jobState at ${System.currentTimeMillis()}")
}

jobState shouldBe "SUCCESS"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,69 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {
}
}

test("create table with dummy location should fail with excepted error message") {
try {
createSession(jobRunId, "")
threadLocalFuture.set(startREPL())

val dummyLocation = "s3://path/to/dummy/location"
val testQueryId = "110"
val createTableStatement =
s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT
| )
| USING CSV
| LOCATION '$dummyLocation'
| OPTIONS (
| header 'false',
| delimiter '\\t'
| )
|""".stripMargin
val createTableStatementId =
submitQuery(s"${makeJsonCompliant(createTableStatement)}", testQueryId)

val createTableStatementValidation: REPLResult => Boolean = result => {
assert(
result.results.size == 0,
s"expected result size is 0, but got ${result.results.size}")
assert(
result.schemas.size == 0,
s"expected schema size is 0, but got ${result.schemas.size}")
failureValidation(result)
true
}
pollForResultAndAssert(createTableStatementValidation, testQueryId)
assert(
!awaitConditionForStatementOrTimeout(
statement => {
statement.error match {
case Some(error)
if error == """{"Message":"Fail to run query. Cause: No FileSystem for scheme \"s3\""}""" =>
// Assertion passed
case _ =>
fail(s"Statement error is: ${statement.error}")
}
statement.state == "failed"
},
createTableStatementId),
s"Fail to verify for $createTableStatementId.")
// clean up
val dropStatement =
s"""DROP TABLE $testTable""".stripMargin
submitQuery(s"${makeJsonCompliant(dropStatement)}", "999")
} catch {
case e: Exception =>
logError("Unexpected exception", e)
assert(false, "Unexpected exception")
} finally {
waitREPLStop(threadLocalFuture.get())
threadLocalFuture.remove()
}
}

/**
* JSON does not support raw newlines (\n) in string values. All newlines must be escaped or
* removed when inside a JSON string. The same goes for tab characters, which should be
Expand Down
2 changes: 1 addition & 1 deletion ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ The next samples of PPL queries are currently supported:
- `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst)
- `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst)
- `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst)
- `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, MAX, MIN and SUM aggregation functions)
- `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) (supports AVG, COUNT, DISTINCT_COUNT, MAX, MIN and SUM aggregation functions)
- `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst)
- `correlation` - [See details](../docs/PPL-Correlation-command.md)

Expand Down
9 changes: 1 addition & 8 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ commands
: whereCommand
| correlateCommand
| fieldsCommand
| renameCommand
| statsCommand
| sortCommand
| headCommand
Expand Down Expand Up @@ -224,8 +223,6 @@ statsFunction
: statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall
| COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall
| (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall
| percentileAggFunction # percentileAggFunctionCall
| takeAggFunction # takeAggFunctionCall
;

statsFunctionName
Expand Down Expand Up @@ -257,18 +254,14 @@ logicalExpression
| left = logicalExpression OR right = logicalExpression # logicalOr
| left = logicalExpression (AND)? right = logicalExpression # logicalAnd
| left = logicalExpression XOR right = logicalExpression # logicalXor
| booleanExpression # booleanExpr
| relevanceExpression # relevanceExpr
;

comparisonExpression
: left = valueExpression comparisonOperator right = valueExpression # compareExpr
;

valueExpression
: left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic
| left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic
| primaryExpression # valueExpressionDefault
: primaryExpression # valueExpressionDefault
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont
return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right));
}

/**
* Value Expression.
*/
@Override
public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) {
return new Function(
ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right)));
}

@Override
public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) {
return visit(ctx.valueExpression()); // Discard parenthesis around
Expand Down Expand Up @@ -172,20 +163,6 @@ public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.Perce
Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value))));
}

@Override
public UnresolvedExpression visitTakeAggFunctionCall(
OpenSearchPPLParser.TakeAggFunctionCallContext ctx) {
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
builder.add(
new UnresolvedArgument(
"size",
ctx.takeAggFunction().size != null
? visit(ctx.takeAggFunction().size)
: new Literal(DEFAULT_TAKE_FUNCTION_SIZE_VALUE, DataType.INTEGER)));
return new AggregateFunction(
"take", visit(ctx.takeAggFunction().fieldExpression()), builder.build());
}

/**
* Eval function.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,18 +436,22 @@ trait FlintJobExecutor {

private def handleQueryException(
e: Exception,
message: String,
messagePrefix: String,
errorSource: Option[String] = None,
statusCode: Option[Int] = None): String = {

val errorDetails = Map("Message" -> s"$message: ${e.getMessage}") ++
val errorMessage = s"$messagePrefix: ${e.getMessage}"
val errorDetails = Map("Message" -> errorMessage) ++
errorSource.map("ErrorSource" -> _) ++
statusCode.map(code => "StatusCode" -> code.toString)

val errorJson = mapper.writeValueAsString(errorDetails)

statusCode.foreach { code =>
CustomLogging.logError(new OperationMessage("", code), e)
// CustomLogging will call log4j logger.error() underneath
statusCode match {
case Some(code) =>
CustomLogging.logError(new OperationMessage(errorMessage, code), e)
case None =>
CustomLogging.logError(errorMessage, e)
}

errorJson
Expand Down Expand Up @@ -491,16 +495,14 @@ trait FlintJobExecutor {
case r: SparkException =>
handleQueryException(r, ExceptionMessages.SparkExceptionErrorPrefix)
case r: Exception =>
val rootCauseClassName = ex.getClass.getName
val errMsg = ex.getMessage
logDebug(s"Root cause class name: $rootCauseClassName")
logDebug(s"Root cause error message: $errMsg")
val rootCauseClassName = r.getClass.getName
val errMsg = r.getMessage
if (rootCauseClassName == "org.apache.hadoop.hive.metastore.api.MetaException" &&
errMsg.contains("com.amazonaws.services.glue.model.AccessDeniedException")) {
val e = new SecurityException(ExceptionMessages.GlueAccessDeniedMessage)
handleQueryException(e, ExceptionMessages.QueryRunErrorPrefix)
} else {
handleQueryException(ex, ExceptionMessages.QueryRunErrorPrefix)
handleQueryException(r, ExceptionMessages.QueryRunErrorPrefix)
}
}
}
Expand Down

0 comments on commit b343aad

Please sign in to comment.