diff --git a/.github/workflows/test-and-build-workflow.yml b/.github/workflows/test-and-build-workflow.yml index 7cae33f76..3c06acb61 100644 --- a/.github/workflows/test-and-build-workflow.yml +++ b/.github/workflows/test-and-build-workflow.yml @@ -25,5 +25,8 @@ jobs: - name: Integ Test run: sbt integtest/test + - name: Unit Test + run: sbt test + - name: Style check run: sbt scalafmtCheckAll diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index b4271360c..d50c0002e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -6,6 +6,8 @@ package org.opensearch.flint.core; import java.util.List; + +import org.opensearch.client.RestHighLevelClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.storage.FlintReader; import org.opensearch.flint.core.storage.FlintWriter; @@ -71,4 +73,10 @@ public interface FlintClient { * @return {@link FlintWriter} */ FlintWriter createWriter(String indexName); + + /** + * Create {@link RestHighLevelClient}. + * @return {@link RestHighLevelClient} + */ + public RestHighLevelClient createClient(); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 4badfe8f4..ff2761856 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -163,7 +163,7 @@ public FlintWriter createWriter(String indexName) { return new OpenSearchWriter(createClient(), toLowercase(indexName), options.getRefreshPolicy()); } - private RestHighLevelClient createClient() { + @Override public RestHighLevelClient createClient() { RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(options.getHost(), options.getPort(), options.getScheme())); diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 0a853dc92..aeaa57499 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -12,7 +12,11 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import org.opensearch.ExceptionsHelper -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} +import org.opensearch.client.{RequestOptions, RestHighLevelClient} +import org.opensearch.cluster.metadata.MappingMetadata +import org.opensearch.common.settings.Settings +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} import org.opensearch.flint.core.metadata.FlintMetadata import play.api.libs.json._ @@ -51,17 +55,19 @@ object FlintJob extends Logging { var dataToWrite: Option[DataFrame] = None try { - // flintClient needs spark session to be created first. Otherwise, we will have connection + // osClient needs spark session to be created first. Otherwise, we will have connection // exception from EMR-S to OS. - val flintClient = FlintClientBuilder.build(FlintSparkConf().flintOptions()) + val osClient = new OSClient(FlintSparkConf().flintOptions()) val futureMappingCheck = Future { - checkAndCreateIndex(flintClient, resultIndex) + checkAndCreateIndex(osClient, resultIndex) } val data = executeQuery(spark, query, dataSource) - val (correctMapping, error) = - ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(if (correctMapping) data else getFailedData(spark, dataSource, error)) + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => getFailedData(spark, dataSource, error) + }) } catch { case e: TimeoutException => val error = "Future operations timed out" @@ -238,7 +244,7 @@ object FlintJob extends Logging { compareJson(inputJson, mappingJson) } - def checkAndCreateIndex(flintClient: FlintClient, resultIndex: String): (Boolean, String) = { + def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val mapping = """{ @@ -271,39 +277,26 @@ object FlintJob extends Logging { }""".stripMargin try { - val existingSchema = flintClient.getIndexMetadata(resultIndex).getContent + val existingSchema = osClient.getIndexMetadata(resultIndex) if (!isSuperset(existingSchema, mapping)) { - (false, s"The mapping of $resultIndex is incorrect.") + Left(s"The mapping of $resultIndex is incorrect.") } else { - (true, "") + Right(()) } } catch { case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => - handleIndexNotFoundException(flintClient, resultIndex, mapping) + osClient.createIndex(resultIndex, mapping) match { + case Right(_) => Right(()) + case Left(errorMsg) => Left(errorMsg) + } case e: Exception => val error = "Failed to verify existing mapping" logError(error, e) - (false, error) + Left(error) } } - def handleIndexNotFoundException( - flintClient: FlintClient, - resultIndex: String, - mapping: String): (Boolean, String) = { - try { - logInfo(s"create $resultIndex") - flintClient.createIndex(resultIndex, FlintMetadata.apply(mapping)) - logInfo(s"create $resultIndex successfully") - (true, "") - } catch { - case e: Exception => - val error = s"Failed to create result index $resultIndex" - logError(error, e) - (false, error) - } - } def executeQuery(spark: SparkSession, query: String, dataSource: String): DataFrame = { // Execute SQL query val result: DataFrame = spark.sql(query) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala new file mode 100644 index 000000000..e40e8dc7c --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.client.RequestOptions +import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} +import org.opensearch.client.indices.CreateIndexRequest +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} + +import org.apache.spark.internal.Logging + +class OSClient(val flintOptions: FlintOptions) extends Logging { + + def getIndexMetadata(osIndexName: String): String = { + + using(FlintClientBuilder.build(flintOptions).createClient()) { client => + val request = new GetIndexRequest(osIndexName) + try { + val response = client.indices.get(request, RequestOptions.DEFAULT) + response.getMappings.get(osIndexName).source.string + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to get OpenSearch index mapping for $osIndexName", + e) + } + } + } + + /** + * Create a new index with given mapping. + * + * @param osIndexName + * the name of the index + * @param mapping + * the mapping of the index + * @return + * use Either for representing success or failure. A Right value indicates success, while a + * Left value indicates an error. + */ + def createIndex(osIndexName: String, mapping: String): Either[String, Unit] = { + logInfo(s"create $osIndexName") + + using(FlintClientBuilder.build(flintOptions).createClient()) { client => + val request = new CreateIndexRequest(osIndexName) + request.mapping(mapping, XContentType.JSON) + + try { + client.indices.create(request, RequestOptions.DEFAULT) + logInfo(s"create $osIndexName successfully") + Right(()) + } catch { + case e: Exception => + val error = s"Failed to create result index $osIndexName" + logError(error, e) + Left(error) + } + } + } + + /** + * the loan pattern to manage resource. + * + * @param resource + * the resource to be managed + * @param f + * the function to be applied to the resource + * @tparam A + * the type of the resource + * @tparam B + * the type of the result + * @return + * the result of the function + */ + def using[A <: AutoCloseable, B](resource: A)(f: A => B): B = { + try { + f(resource) + } finally { + // client is guaranteed to be non-null + resource.close() + } + } + +}