From 8cb5d40b197678d44131a861e0996310cf1e5c14 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 20 May 2024 17:29:50 -0700 Subject: [PATCH] Move to new tx support trait Signed-off-by: Chen Dai --- .../opensearch/flint/spark/FlintSpark.scala | 33 +------ .../spark/FlintSparkTransactionSupport.scala | 86 +++++++++++++++++++ 2 files changed, 90 insertions(+), 29 deletions(-) create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkTransactionSupport.scala diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 696a66749..6d9a38ab7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGN /** * Flint Spark integration API entrypoint. */ -class FlintSpark(val spark: SparkSession) extends Logging { +class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport with Logging { /** Flint spark configuration */ private val flintSparkConf: FlintSparkConf = @@ -42,7 +42,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { IGNORE_DOC_ID_COLUMN.optionKey -> "true").asJava) /** Flint client for low-level index operation */ - private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions()) + override protected val flintClient: FlintClient = + FlintClientBuilder.build(flintSparkConf.flintOptions()) /** Required by json4s parse function */ implicit val formats: Formats = Serialization.formats(NoTypeHints) + SkippingKindSerializer @@ -51,7 +52,7 @@ class FlintSpark(val spark: SparkSession) extends Logging { * Data source name. Assign empty string in case of backward compatibility. TODO: remove this in * future */ - private val dataSourceName: String = + override protected val dataSourceName: String = spark.conf.getOption("spark.flint.datasource.name").getOrElse("") /** Flint Spark index monitor */ @@ -432,30 +433,4 @@ class FlintSpark(val spark: SparkSession) extends Logging { indexRefresh.start(spark, flintSparkConf) }) } - - private def withTransaction[T]( - indexName: String, - operationName: String, - forceInit: Boolean = false)(opBlock: OptimisticTransaction[T] => T): T = { - logInfo(s"Starting index operation [$operationName] with forceInit=$forceInit") - try { - val tx: OptimisticTransaction[T] = - flintClient.startTransaction(indexName, dataSourceName, forceInit) - - val result = opBlock(tx) - logInfo(s"Index operation [$operationName] complete") - result - } catch { - case e: Exception => - val detailedMessage = - s"Failed to execute index operation [$operationName] caused by ${e.getMessage}" - logError(detailedMessage, e) - - // Re-throw directly if runtime exception or wrap it - e match { - case re: RuntimeException => throw re - case _ => throw new IllegalStateException(detailedMessage, e) - } - } - } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkTransactionSupport.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkTransactionSupport.scala new file mode 100644 index 000000000..0af04bed3 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkTransactionSupport.scala @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.core.FlintClient +import org.opensearch.flint.core.metadata.log.OptimisticTransaction + +import org.apache.spark.internal.Logging + +/** + * Provides transaction support with proper error handling and logging capabilities. + * + * @note + * This trait requires the mixing class to extend Spark's `Logging` to utilize its logging + * functionalities. Meanwhile it needs to provide `FlintClient` and data source name so this + * trait can help create transaction context. + */ +trait FlintSparkTransactionSupport { self: Logging => + + /** Abstract FlintClient that need to be defined in the mixing class */ + protected def flintClient: FlintClient + + /** Abstract data source name that need to be defined in the mixing class */ + protected def dataSourceName: String + + /** + * Executes a block of code within a transaction context, handling and logging errors + * appropriately. This method logs the start and completion of the transaction and captures any + * exceptions that occur, enriching them with detailed error messages before re-throwing. + * + * @param indexName + * the name of the index on which the operation is performed + * @param opName + * the name of the operation, used for logging + * @param forceInit + * a boolean flag indicating whether to force the initialization of the transaction + * @param opBlock + * the operation block to execute within the transaction context, which takes an + * `OptimisticTransaction` and returns a value of type `T` + * @tparam T + * the type of the result produced by the operation block + * @return + * the result of the operation block + */ + def withTransaction[T](indexName: String, opName: String, forceInit: Boolean = false)( + opBlock: OptimisticTransaction[T] => T): T = { + logInfo(s"Starting index operation [$opName] with forceInit=$forceInit") + try { + // Create transaction (only have side effect if forceInit is true) + val tx: OptimisticTransaction[T] = + flintClient.startTransaction(indexName, dataSourceName, forceInit) + + val result = opBlock(tx) + logInfo(s"Index operation [$opName] complete") + result + } catch { + case e: Exception => + // Extract and add root cause message to final error message + val rootCauseMessage = extractRootCause(e) + val detailedMessage = + s"Failed to execute index operation [$opName] caused by: $rootCauseMessage" + logError(detailedMessage, e) + + // Re-throw with new detailed error message + throw new IllegalStateException(detailedMessage) + } + } + + private def extractRootCause(e: Throwable): String = { + var cause = e + while (cause.getCause != null && cause.getCause != cause) { + cause = cause.getCause + } + + if (cause.getLocalizedMessage != null) { + return cause.getLocalizedMessage + } + if (cause.getMessage != null) { + return cause.getMessage + } + cause.toString + } +}