diff --git a/build.sbt b/build.sbt index 4cf923fc2..c8c94ad1c 100644 --- a/build.sbt +++ b/build.sbt @@ -208,6 +208,8 @@ lazy val sparkSqlApplication = (project in file("spark-sql-application")) libraryDependencies ++= deps(sparkVersion), libraryDependencies ++= Seq( "com.typesafe.play" %% "play-json" % "2.9.2", + "com.amazonaws" % "aws-java-sdk-glue" % "1.12.568" % "provided" + exclude ("com.fasterxml.jackson.core", "jackson-databind"), // handle AmazonS3Exception "com.amazonaws" % "aws-java-sdk-s3" % "1.12.568" % "provided" // the transitive jackson.core dependency conflicts with existing scala diff --git a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java index 12a5646f3..5c1080f8c 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/IRestHighLevelClient.java @@ -28,6 +28,8 @@ import org.opensearch.client.indices.PutMappingRequest; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.client.RequestOptions; +import org.opensearch.flint.core.logging.CustomLogging; +import org.opensearch.flint.core.logging.OperationMessage; import org.opensearch.flint.core.metrics.MetricsUtil; import java.io.Closeable; @@ -90,7 +92,11 @@ static void recordOperationSuccess(String metricNamePrefix) { static void recordOperationFailure(String metricNamePrefix, Exception e) { OpenSearchException openSearchException = extractOpenSearchException(e); int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500; - + if (openSearchException != null) { + CustomLogging.logError(new OperationMessage("OpenSearch Operation failed.", statusCode), openSearchException); + } else { + CustomLogging.logError("OpenSearch Operation failed with an exception.", e); + } if (statusCode == 403) { String forbiddenErrorMetricName = metricNamePrefix + ".403.count"; MetricsUtil.incrementCounter(forbiddenErrorMetricName); @@ -104,15 +110,16 @@ static void recordOperationFailure(String metricNamePrefix, Exception e) { * Extracts an OpenSearchException from the given Throwable. * Checks if the Throwable is an instance of OpenSearchException or caused by one. * - * @param ex the exception to be checked + * @param e the exception to be checked * @return the extracted OpenSearchException, or null if not found */ - private static OpenSearchException extractOpenSearchException(Throwable ex) { - if (ex instanceof OpenSearchException) { - return (OpenSearchException) ex; - } else if (ex.getCause() instanceof OpenSearchException) { - return (OpenSearchException) ex.getCause(); + static OpenSearchException extractOpenSearchException(Throwable e) { + if (e instanceof OpenSearchException) { + return (OpenSearchException) e; + } else if (e.getCause() == null) { + return null; + } else { + return extractOpenSearchException(e.getCause()); } - return null; } } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/logging/CustomLogging.java b/flint-core/src/main/java/org/opensearch/flint/core/logging/CustomLogging.java index 8908e763b..d79147ae5 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/logging/CustomLogging.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/logging/CustomLogging.java @@ -44,9 +44,9 @@ public class CustomLogging { private static final Map> logLevelActions = new HashMap<>(); static { - String[] parts = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN).split(":"); + DOMAIN_NAME = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN); + String[] parts = DOMAIN_NAME.split(":"); CLIENT_ID = parts.length == 2 ? parts[0] : UNKNOWN; - DOMAIN_NAME = parts.length == 2 ? parts[1] : UNKNOWN; logLevelActions.put("DEBUG", logger::debug); logLevelActions.put("INFO", logger::info); @@ -78,10 +78,6 @@ private static String convertToJson(Map logEventMap) { * @return A map representation of the log event. */ protected static Map constructLogEventMap(String level, Object content, Throwable throwable) { - if (content == null) { - throw new IllegalArgumentException("Log message must not be null"); - } - Map logEventMap = new LinkedHashMap<>(); Map body = new LinkedHashMap<>(); constructMessageBody(content, body); @@ -105,6 +101,11 @@ protected static Map constructLogEventMap(String level, Object c } private static void constructMessageBody(Object content, Map body) { + if (content == null) { + body.put("message", ""); + return; + } + if (content instanceof Message) { Message message = (Message) content; body.put("message", message.getFormattedMessage()); @@ -151,6 +152,10 @@ public static void logError(Object message) { log("ERROR", message, null); } + public static void logError(Throwable throwable) { + log("ERROR", "", throwable); + } + public static void logError(Object message, Throwable throwable) { log("ERROR", message, throwable); } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 6a081a740..4cdfcee01 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -27,6 +27,11 @@ public final class MetricConstants { */ public static final String S3_ERR_CNT_METRIC = "s3.error.count"; + /** + * Metric name for counting the errors encountered with Amazon Glue operations. + */ + public static final String GLUE_ERR_CNT_METRIC = "glue.error.count"; + /** * Metric name for counting the number of sessions currently running. */ diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index f582f9f45..bba999110 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,6 +8,7 @@ package org.apache.spark.sql import java.util.concurrent.atomic.AtomicInteger +import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import play.api.libs.json._ @@ -28,28 +29,17 @@ import org.apache.spark.sql.types._ */ object FlintJob extends Logging with FlintJobExecutor { def main(args: Array[String]): Unit = { - val (queryOption, resultIndex) = args.length match { - case 1 => - (None, args(0)) // Starting from OS 2.13, resultIndex is the only argument - case 2 => - ( - Some(args(0)), - args(1) - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => - throw new IllegalArgumentException( - "Unsupported number of arguments. Expected 1 or 2 arguments.") - } + val (queryOption, resultIndex) = parseArgs(args) val conf = createSparkConf() val jobType = conf.get("spark.flint.job.type", "batch") - logInfo(s"""Job type is: ${jobType}""") + CustomLogging.logInfo(s"""Job type is: ${jobType}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val dataSource = conf.get("spark.flint.datasource.name", "") val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) if (query.isEmpty) { - throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.") + logAndThrow(s"Query undefined for the ${jobType} job.") } // https://github.com/opensearch-project/opensearch-spark/issues/138 /* diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 1e5df21e1..665ec5a27 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -7,9 +7,13 @@ package org.apache.spark.sql import java.util.Locale +import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException} import com.amazonaws.services.s3.model.AmazonS3Exception +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.flint.core.IRestHighLevelClient +import org.opensearch.flint.core.logging.{CustomLogging, OperationMessage} import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import play.api.libs.json._ @@ -17,6 +21,7 @@ import play.api.libs.json._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY import org.apache.spark.sql.types._ import org.apache.spark.sql.util._ @@ -24,6 +29,9 @@ import org.apache.spark.sql.util._ trait FlintJobExecutor { this: Logging => + val mapper = new ObjectMapper() + mapper.registerModule(DefaultScalaModule) + var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() var envinromentProvider: EnvironmentProvider = new RealEnvironment() @@ -65,6 +73,9 @@ trait FlintJobExecutor { "sessionId": { "type": "keyword" }, + "jobType": { + "type": "keyword" + }, "updateTime": { "type": "date", "format": "strict_date_time||epoch_millis" @@ -190,6 +201,7 @@ trait FlintJobExecutor { StructField("queryId", StringType, nullable = true), StructField("queryText", StringType, nullable = true), StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), // number is not nullable StructField("updateTime", LongType, nullable = false), StructField("queryRunTime", LongType, nullable = true))) @@ -218,6 +230,7 @@ trait FlintJobExecutor { queryId, query, sessionId, + spark.conf.get(FlintSparkConf.JOB_TYPE.key), endTime, endTime - startTime)) @@ -248,6 +261,7 @@ trait FlintJobExecutor { StructField("queryId", StringType, nullable = true), StructField("queryText", StringType, nullable = true), StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), // number is not nullable StructField("updateTime", LongType, nullable = false), StructField("queryRunTime", LongType, nullable = true))) @@ -267,6 +281,7 @@ trait FlintJobExecutor { queryId, query, sessionId, + spark.conf.get(FlintSparkConf.JOB_TYPE.key), endTime, endTime - startTime)) @@ -330,7 +345,7 @@ trait FlintJobExecutor { val inputJson = Json.parse(input) val mappingJson = Json.parse(mapping) - compareJson(inputJson, mappingJson) + compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson) } def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { @@ -411,14 +426,20 @@ trait FlintJobExecutor { private def handleQueryException( e: Exception, message: String, - spark: SparkSession, - dataSource: String, - query: String, - queryId: String, - sessionId: String): String = { - val error = s"$message: ${e.getMessage}" - logError(error, e) - error + errorSource: Option[String] = None, + statusCode: Option[Int] = None): String = { + + val errorDetails = Map("Message" -> s"$message: ${e.getMessage}") ++ + errorSource.map("ErrorSource" -> _) ++ + statusCode.map(code => "StatusCode" -> code.toString) + + val errorJson = mapper.writeValueAsString(errorDetails) + + statusCode.foreach { code => + CustomLogging.logError(new OperationMessage("", code), e) + } + + errorJson } def getRootCause(e: Throwable): Throwable = { @@ -426,53 +447,60 @@ trait FlintJobExecutor { else getRootCause(e.getCause) } - def processQueryException( - ex: Exception, - spark: SparkSession, - dataSource: String, - query: String, - queryId: String, - sessionId: String): String = { + /** + * This method converts query exception into error string, which then persist to query result + * metadata + */ + def processQueryException(ex: Exception): String = { getRootCause(ex) match { case r: ParseException => - handleQueryException(r, "Syntax error", spark, dataSource, query, queryId, sessionId) + handleQueryException(r, "Syntax error") case r: AmazonS3Exception => incrementCounter(MetricConstants.S3_ERR_CNT_METRIC) handleQueryException( r, "Fail to read data from S3. Cause", - spark, - dataSource, - query, - queryId, - sessionId) - case r: AnalysisException => + Some(r.getServiceName), + Some(r.getStatusCode)) + case r: AWSGlueException => + incrementCounter(MetricConstants.GLUE_ERR_CNT_METRIC) + // Redact Access denied in AWS Glue service + r match { + case accessDenied: AccessDeniedException => + accessDenied.setErrorMessage( + "Access denied in AWS Glue service. Please check permissions.") + case _ => // No additional action for other types of AWSGlueException + } handleQueryException( r, - "Fail to analyze query. Cause", - spark, - dataSource, - query, - queryId, - sessionId) + "Fail to read data from Glue. Cause", + Some(r.getServiceName), + Some(r.getStatusCode)) + case r: AnalysisException => + handleQueryException(r, "Fail to analyze query. Cause") case r: SparkException => - handleQueryException( - r, - "Spark exception. Cause", - spark, - dataSource, - query, - queryId, - sessionId) + handleQueryException(r, "Spark exception. Cause") case r: Exception => - handleQueryException( - r, - "Fail to run query, cause", - spark, - dataSource, - query, - queryId, - sessionId) + handleQueryException(r, "Fail to run query. Cause") } } + + def parseArgs(args: Array[String]): (Option[String], String) = { + args match { + case Array(resultIndex) => + (None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument + case Array(query, resultIndex) => + ( + Some(query), + resultIndex + ) // Before OS 2.13, there are two arguments, the second one is resultIndex + case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.") + } + } + + def logAndThrow(errorMessage: String): Nothing = { + val t = new IllegalArgumentException(errorMessage) + CustomLogging.logError(t) + throw t + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index b96163693..36432f016 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -21,6 +21,7 @@ import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} @@ -67,7 +68,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val (queryOption, resultIndex) = parseArgs(args) if (Strings.isNullOrEmpty(resultIndex)) { - throw new IllegalArgumentException("resultIndex is not set") + logAndThrow("resultIndex is not set") } // init SparkContext @@ -84,7 +85,7 @@ object FlintREPL extends Logging with FlintJobExecutor { conf.set("spark.sql.defaultCatalog", dataSource) val jobType = conf.get(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) - logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") + CustomLogging.logInfo(s"""Job type is: ${FlintSparkConf.JOB_TYPE.defaultValue.get}""") conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val query = getQuery(queryOption, jobType, conf) @@ -109,10 +110,10 @@ object FlintREPL extends Logging with FlintJobExecutor { val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) if (sessionIndex.isEmpty) { - throw new IllegalArgumentException(FlintSparkConf.REQUEST_INDEX.key + " is not set") + logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set") } if (sessionId.isEmpty) { - throw new IllegalArgumentException(FlintSparkConf.SESSION_ID.key + " is not set") + logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) @@ -238,27 +239,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - def parseArgs(args: Array[String]): (Option[String], String) = { - args.length match { - case 1 => - (None, args(0)) // Starting from OS 2.13, resultIndex is the only argument - case 2 => - ( - Some(args(0)), - args(1) - ) // Before OS 2.13, there are two arguments, the second one is resultIndex - case _ => - throw new IllegalArgumentException( - "Unsupported number of arguments. Expected 1 or 2 arguments.") - } - } - def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = { queryOption.getOrElse { if (jobType.equalsIgnoreCase("streaming")) { val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "") if (defaultQuery.isEmpty) { - throw new IllegalArgumentException("Query undefined for the streaming job.") + logAndThrow("Query undefined for the streaming job.") } unescapeQuery(defaultQuery) } else "" @@ -456,7 +442,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionIndex: String, sessionTimerContext: Timer.Context): Unit = { val error = s"Session error: ${e.getMessage}" - logError(error, e) + CustomLogging.logError(error, e) val flintInstance = getExistingFlintInstance(osClient, sessionIndex, sessionId) .getOrElse(createFailedFlintInstance(applicationId, jobId, sessionId, jobStartTime, error)) @@ -476,7 +462,9 @@ object FlintREPL extends Logging with FlintJobExecutor { Option(getResponse.getSourceAsMap) .map(FlintInstance.deserializeFromMap) case Failure(exception) => - logError(s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", exception) + CustomLogging.logError( + s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", + exception) None case _ => None } @@ -545,19 +533,8 @@ object FlintREPL extends Logging with FlintJobExecutor { currentTimeProvider) } - def processQueryException( - ex: Exception, - spark: SparkSession, - dataSource: String, - flintCommand: FlintCommand, - sessionId: String): String = { - val error = super.processQueryException( - ex, - spark, - dataSource, - flintCommand.query, - flintCommand.queryId, - sessionId) + def processQueryException(ex: Exception, flintCommand: FlintCommand): String = { + val error = super.processQueryException(ex) flintCommand.fail() flintCommand.error = Some(error) error @@ -656,7 +633,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // or invalid catalog (e.g., we are operating on data not defined in provided data source) case e: Exception => val error = s"""Fail to write result of ${flintCommand}, cause: ${e.getMessage}""" - logError(error, e) + CustomLogging.logError(error, e) flintCommand.fail() updateSessionIndex(flintCommand, flintSessionIndexUpdater) recordStatementStateChange(flintCommand, statementTimerContext) @@ -683,7 +660,6 @@ object FlintREPL extends Logging with FlintJobExecutor { * actions that require the computation of results that need to be collected or stored. */ spark.sparkContext.cancelJobGroup(flintCommand.queryId) - logError(error) Some( handleCommandFailureAndGetFailedData( spark, @@ -716,15 +692,11 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis)) } catch { case e: TimeoutException => - handleCommandTimeout( - spark, - dataSource, - s"Executing ${flintCommand.query} timed out", - flintCommand, - sessionId, - startTime) + val error = s"Executing ${flintCommand.query} timed out" + CustomLogging.logError(error, e) + handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) case e: Exception => - val error = processQueryException(e, spark, dataSource, flintCommand.query, "", "") + val error = processQueryException(e, flintCommand) Some( handleCommandFailureAndGetFailedData( spark, @@ -780,10 +752,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } catch { case e: TimeoutException => val error = s"Getting the mapping of index $resultIndex timed out" + CustomLogging.logError(error, e) dataToWrite = handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" + CustomLogging.logError(error, e) dataToWrite = Some( handleCommandFailureAndGetFailedData( spark, @@ -1014,13 +988,13 @@ object FlintREPL extends Logging with FlintJobExecutor { case ie: InterruptedException => // Preserve the interrupt status Thread.currentThread().interrupt() - logError("HeartBeatUpdater task was interrupted", ie) + CustomLogging.logError("HeartBeatUpdater task was interrupted", ie) incrementCounter( MetricConstants.REQUEST_METADATA_HEARTBEAT_FAILED_METRIC ) // Record heartbeat failure metric // maybe due to invalid sequence number or primary term case e: Exception => - logWarning( + CustomLogging.logWarning( s"""Fail to update the last update time of the flint instance ${sessionId}""", e) incrementCounter( @@ -1080,7 +1054,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } catch { // still proceed since we are not sure what happened (e.g., OpenSearch cluster may be unresponsive) case e: Exception => - logError(s"""Fail to find id ${sessionId} from session index.""", e) + CustomLogging.logError(s"""Fail to find id ${sessionId} from session index.""", e) true } } @@ -1125,10 +1099,13 @@ object FlintREPL extends Logging with FlintJobExecutor { if e.getCause != null && e.getCause.isInstanceOf[ConnectException] => retries += 1 val delay = initialDelay * math.pow(2, retries - 1).toLong - logError(s"Fail to connect to OpenSearch cluster. Retrying in $delay...", e) + CustomLogging.logError( + s"Fail to connect to OpenSearch cluster. Retrying in $delay...", + e) Thread.sleep(delay.toMillis) case e: Exception => + CustomLogging.logError(e) throw e } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 5969f0573..6421c7d57 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -66,7 +66,7 @@ case class JobOperator( dataToWrite = Some( getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) case e: Exception => - val error = processQueryException(e, spark, dataSource, query, "", "") + val error = processQueryException(e) dataToWrite = Some( getFailedData(spark, dataSource, error, "", query, "", startTime, currentTimeProvider)) } finally { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index aceb9468f..19f596e31 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -6,6 +6,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} @@ -13,7 +14,7 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() - + spark.conf.set(FlintSparkConf.JOB_TYPE.key, "streaming") // Define input dataframe val inputSchema = StructType( Seq( @@ -38,6 +39,7 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { StructField("queryId", StringType, nullable = true), StructField("queryText", StringType, nullable = true), StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), StructField("updateTime", LongType, nullable = false), StructField("queryRunTime", LongType, nullable = false))) @@ -61,6 +63,7 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { "10", "select 1", "20", + "streaming", currentTime, queryRunTime)) val expected: DataFrame = @@ -82,7 +85,7 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { } test("test isSuperset") { - // note in input false has enclosed double quotes, while mapping just has false + // Note in input false has enclosed double quotes, while mapping just has false val input = """{"dynamic":"false","properties":{"result":{"type":"object"},"schema":{"type":"object"}, |"applicationId":{"type":"keyword"},"jobRunId":{ @@ -90,12 +93,17 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { |"error":{"type":"text"}}} |""".stripMargin val mapping = - """{"dynamic":false,"properties":{"result":{"type":"object"},"schema":{"type":"object"}, - |"jobRunId":{"type":"keyword"},"applicationId":{ - |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}}} + """{"dynamic":"false","properties":{"result":{"type":"object"},"schema":{"type":"object"}, "jobType":{"type": "keyword"}, + |"applicationId":{"type":"keyword"},"jobRunId":{ + |"type":"keyword"},"dataSourceName":{"type":"keyword"},"status":{"type":"keyword"}, |"error":{"type":"text"}}} |""".stripMargin + + // Assert that input is a superset of mapping assert(FlintJob.isSuperset(input, mapping)) + + // Assert that mapping is a superset of input + assert(FlintJob.isSuperset(mapping, input)) } test("default streaming query maxExecutors is 10") { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index ea789c161..1a6aea4f4 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -15,6 +15,7 @@ import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag +import com.amazonaws.services.glue.model.AccessDeniedException import com.codahale.metrics.Timer import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ @@ -216,6 +217,7 @@ class FlintREPLTest StructField("queryId", StringType, nullable = true), StructField("queryText", StringType, nullable = true), StructField("sessionId", StringType, nullable = true), + StructField("jobType", StringType, nullable = true), StructField("updateTime", LongType, nullable = false), StructField("queryRunTime", LongType, nullable = false))) @@ -235,10 +237,11 @@ class FlintREPLTest "10", "select 1", "20", + "interactive", currentTime, queryRunTime)) val spark = SparkSession.builder().appName("Test").master("local").getOrCreate() - + spark.conf.set(FlintSparkConf.JOB_TYPE.key, FlintSparkConf.JOB_TYPE.defaultValue.get) val expected = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) @@ -436,6 +439,31 @@ class FlintREPLTest assert(result) } + test("processQueryException should handle exceptions, fail the command, and set the error") { + val exception = new AccessDeniedException( + "Unable to verify existence of default database: com.amazonaws.services.glue.model.AccessDeniedException: " + + "User: ****** is not authorized to perform: glue:GetDatabase on resource: ****** " + + "because no identity-based policy allows the glue:GetDatabase action") + exception.setStatusCode(400) + exception.setErrorCode("AccessDeniedException") + exception.setServiceName("AWSGlue") + + val mockFlintCommand = mock[FlintCommand] + val expectedError = ( + """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + + """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + + """"ErrorSource":"AWSGlue","StatusCode":"400"}""" + ) + + val result = FlintREPL.processQueryException(exception, mockFlintCommand) + + result shouldEqual expectedError + verify(mockFlintCommand).fail() + verify(mockFlintCommand).error = Some(expectedError) + + assert(result == expectedError) + } + test("Doc Exists and excludeJobIds is an ArrayList Containing JobId") { val sessionId = "session123" val jobId = "jobABC" @@ -547,10 +575,13 @@ class FlintREPLTest test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] val mockFlintCommand = mock[FlintCommand] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) + .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) // val mockExecutionContextExecutor: ExecutionContextExecutor = mock[ExecutionContextExecutor] val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - try { val dataSource = "someDataSource" val sessionId = "someSessionId" @@ -596,6 +627,10 @@ class FlintREPLTest test("executeAndHandle should handle ParseException properly") { val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) + .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) val flintCommand = new FlintCommand( "Running", @@ -606,7 +641,6 @@ class FlintREPLTest None) val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor("flint-repl", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - try { val dataSource = "someDataSource" val sessionId = "someSessionId" @@ -1020,6 +1054,11 @@ class FlintREPLTest val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) + val mockConf = mock[RuntimeConfig] + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) + .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) + when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) val flintSessionIndexUpdater = mock[OpenSearchUpdater]