Skip to content

Commit

Permalink
Merge branch 'main' into fix-show-index-wrong-result
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed May 14, 2024
2 parents cb17ec2 + 422dae7 commit 60fd1ee
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 131 deletions.
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ lazy val sparkSqlApplication = (project in file("spark-sql-application"))
libraryDependencies ++= deps(sparkVersion),
libraryDependencies ++= Seq(
"com.typesafe.play" %% "play-json" % "2.9.2",
"com.amazonaws" % "aws-java-sdk-glue" % "1.12.568" % "provided"
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
// handle AmazonS3Exception
"com.amazonaws" % "aws-java-sdk-s3" % "1.12.568" % "provided"
// the transitive jackson.core dependency conflicts with existing scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.opensearch.client.indices.PutMappingRequest;
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest;
import org.opensearch.client.RequestOptions;
import org.opensearch.flint.core.logging.CustomLogging;
import org.opensearch.flint.core.logging.OperationMessage;
import org.opensearch.flint.core.metrics.MetricsUtil;

import java.io.Closeable;
Expand Down Expand Up @@ -90,7 +92,11 @@ static void recordOperationSuccess(String metricNamePrefix) {
static void recordOperationFailure(String metricNamePrefix, Exception e) {
OpenSearchException openSearchException = extractOpenSearchException(e);
int statusCode = openSearchException != null ? openSearchException.status().getStatus() : 500;

if (openSearchException != null) {
CustomLogging.logError(new OperationMessage("OpenSearch Operation failed.", statusCode), openSearchException);
} else {
CustomLogging.logError("OpenSearch Operation failed with an exception.", e);
}
if (statusCode == 403) {
String forbiddenErrorMetricName = metricNamePrefix + ".403.count";
MetricsUtil.incrementCounter(forbiddenErrorMetricName);
Expand All @@ -104,15 +110,16 @@ static void recordOperationFailure(String metricNamePrefix, Exception e) {
* Extracts an OpenSearchException from the given Throwable.
* Checks if the Throwable is an instance of OpenSearchException or caused by one.
*
* @param ex the exception to be checked
* @param e the exception to be checked
* @return the extracted OpenSearchException, or null if not found
*/
private static OpenSearchException extractOpenSearchException(Throwable ex) {
if (ex instanceof OpenSearchException) {
return (OpenSearchException) ex;
} else if (ex.getCause() instanceof OpenSearchException) {
return (OpenSearchException) ex.getCause();
static OpenSearchException extractOpenSearchException(Throwable e) {
if (e instanceof OpenSearchException) {
return (OpenSearchException) e;
} else if (e.getCause() == null) {
return null;
} else {
return extractOpenSearchException(e.getCause());
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ public class CustomLogging {
private static final Map<String, BiConsumer<String, Throwable>> logLevelActions = new HashMap<>();

static {
String[] parts = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN).split(":");
DOMAIN_NAME = System.getenv().getOrDefault("FLINT_CLUSTER_NAME", UNKNOWN + ":" + UNKNOWN);
String[] parts = DOMAIN_NAME.split(":");
CLIENT_ID = parts.length == 2 ? parts[0] : UNKNOWN;
DOMAIN_NAME = parts.length == 2 ? parts[1] : UNKNOWN;

logLevelActions.put("DEBUG", logger::debug);
logLevelActions.put("INFO", logger::info);
Expand Down Expand Up @@ -78,10 +78,6 @@ private static String convertToJson(Map<String, Object> logEventMap) {
* @return A map representation of the log event.
*/
protected static Map<String, Object> constructLogEventMap(String level, Object content, Throwable throwable) {
if (content == null) {
throw new IllegalArgumentException("Log message must not be null");
}

Map<String, Object> logEventMap = new LinkedHashMap<>();
Map<String, Object> body = new LinkedHashMap<>();
constructMessageBody(content, body);
Expand All @@ -105,6 +101,11 @@ protected static Map<String, Object> constructLogEventMap(String level, Object c
}

private static void constructMessageBody(Object content, Map<String, Object> body) {
if (content == null) {
body.put("message", "");
return;
}

if (content instanceof Message) {
Message message = (Message) content;
body.put("message", message.getFormattedMessage());
Expand Down Expand Up @@ -151,6 +152,10 @@ public static void logError(Object message) {
log("ERROR", message, null);
}

public static void logError(Throwable throwable) {
log("ERROR", "", throwable);
}

public static void logError(Object message, Throwable throwable) {
log("ERROR", message, throwable);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public final class MetricConstants {
*/
public static final String S3_ERR_CNT_METRIC = "s3.error.count";

/**
* Metric name for counting the errors encountered with Amazon Glue operations.
*/
public static final String GLUE_ERR_CNT_METRIC = "glue.error.count";

/**
* Metric name for counting the number of sessions currently running.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._
Expand All @@ -28,28 +29,17 @@ import org.apache.spark.sql.types._
*/
object FlintJob extends Logging with FlintJobExecutor {
def main(args: Array[String]): Unit = {
val (queryOption, resultIndex) = args.length match {
case 1 =>
(None, args(0)) // Starting from OS 2.13, resultIndex is the only argument
case 2 =>
(
Some(args(0)),
args(1)
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ =>
throw new IllegalArgumentException(
"Unsupported number of arguments. Expected 1 or 2 arguments.")
}
val (queryOption, resultIndex) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
logInfo(s"""Job type is: ${jobType}""")
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
logAndThrow(s"Query undefined for the ${jobType} job.")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,31 @@ package org.apache.spark.sql

import java.util.Locale

import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException}
import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._

trait FlintJobExecutor {
this: Logging =>

val mapper = new ObjectMapper()
mapper.registerModule(DefaultScalaModule)

var currentTimeProvider: TimeProvider = new RealTimeProvider()
var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory()
var envinromentProvider: EnvironmentProvider = new RealEnvironment()
Expand Down Expand Up @@ -65,6 +73,9 @@ trait FlintJobExecutor {
"sessionId": {
"type": "keyword"
},
"jobType": {
"type": "keyword"
},
"updateTime": {
"type": "date",
"format": "strict_date_time||epoch_millis"
Expand Down Expand Up @@ -190,6 +201,7 @@ trait FlintJobExecutor {
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))
Expand Down Expand Up @@ -218,6 +230,7 @@ trait FlintJobExecutor {
queryId,
query,
sessionId,
spark.conf.get(FlintSparkConf.JOB_TYPE.key),
endTime,
endTime - startTime))

Expand Down Expand Up @@ -248,6 +261,7 @@ trait FlintJobExecutor {
StructField("queryId", StringType, nullable = true),
StructField("queryText", StringType, nullable = true),
StructField("sessionId", StringType, nullable = true),
StructField("jobType", StringType, nullable = true),
// number is not nullable
StructField("updateTime", LongType, nullable = false),
StructField("queryRunTime", LongType, nullable = true)))
Expand All @@ -267,6 +281,7 @@ trait FlintJobExecutor {
queryId,
query,
sessionId,
spark.conf.get(FlintSparkConf.JOB_TYPE.key),
endTime,
endTime - startTime))

Expand Down Expand Up @@ -330,7 +345,7 @@ trait FlintJobExecutor {
val inputJson = Json.parse(input)
val mappingJson = Json.parse(mapping)

compareJson(inputJson, mappingJson)
compareJson(inputJson, mappingJson) || compareJson(mappingJson, inputJson)
}

def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = {
Expand Down Expand Up @@ -411,68 +426,81 @@ trait FlintJobExecutor {
private def handleQueryException(
e: Exception,
message: String,
spark: SparkSession,
dataSource: String,
query: String,
queryId: String,
sessionId: String): String = {
val error = s"$message: ${e.getMessage}"
logError(error, e)
error
errorSource: Option[String] = None,
statusCode: Option[Int] = None): String = {

val errorDetails = Map("Message" -> s"$message: ${e.getMessage}") ++
errorSource.map("ErrorSource" -> _) ++
statusCode.map(code => "StatusCode" -> code.toString)

val errorJson = mapper.writeValueAsString(errorDetails)

statusCode.foreach { code =>
CustomLogging.logError(new OperationMessage("", code), e)
}

errorJson
}

def getRootCause(e: Throwable): Throwable = {
if (e.getCause == null) e
else getRootCause(e.getCause)
}

def processQueryException(
ex: Exception,
spark: SparkSession,
dataSource: String,
query: String,
queryId: String,
sessionId: String): String = {
/**
* This method converts query exception into error string, which then persist to query result
* metadata
*/
def processQueryException(ex: Exception): String = {
getRootCause(ex) match {
case r: ParseException =>
handleQueryException(r, "Syntax error", spark, dataSource, query, queryId, sessionId)
handleQueryException(r, "Syntax error")
case r: AmazonS3Exception =>
incrementCounter(MetricConstants.S3_ERR_CNT_METRIC)
handleQueryException(
r,
"Fail to read data from S3. Cause",
spark,
dataSource,
query,
queryId,
sessionId)
case r: AnalysisException =>
Some(r.getServiceName),
Some(r.getStatusCode))
case r: AWSGlueException =>
incrementCounter(MetricConstants.GLUE_ERR_CNT_METRIC)
// Redact Access denied in AWS Glue service
r match {
case accessDenied: AccessDeniedException =>
accessDenied.setErrorMessage(
"Access denied in AWS Glue service. Please check permissions.")
case _ => // No additional action for other types of AWSGlueException
}
handleQueryException(
r,
"Fail to analyze query. Cause",
spark,
dataSource,
query,
queryId,
sessionId)
"Fail to read data from Glue. Cause",
Some(r.getServiceName),
Some(r.getStatusCode))
case r: AnalysisException =>
handleQueryException(r, "Fail to analyze query. Cause")
case r: SparkException =>
handleQueryException(
r,
"Spark exception. Cause",
spark,
dataSource,
query,
queryId,
sessionId)
handleQueryException(r, "Spark exception. Cause")
case r: Exception =>
handleQueryException(
r,
"Fail to run query, cause",
spark,
dataSource,
query,
queryId,
sessionId)
handleQueryException(r, "Fail to run query. Cause")
}
}

def parseArgs(args: Array[String]): (Option[String], String) = {
args match {
case Array(resultIndex) =>
(None, resultIndex) // Starting from OS 2.13, resultIndex is the only argument
case Array(query, resultIndex) =>
(
Some(query),
resultIndex
) // Before OS 2.13, there are two arguments, the second one is resultIndex
case _ => logAndThrow("Unsupported number of arguments. Expected 1 or 2 arguments.")
}
}

def logAndThrow(errorMessage: String): Nothing = {
val t = new IllegalArgumentException(errorMessage)
CustomLogging.logError(t)
throw t
}
}
Loading

0 comments on commit 60fd1ee

Please sign in to comment.