diff --git a/README.md b/README.md index 017b4a1c2..f9568838e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # OpenSearch Flint -OpenSearch Flint is ... It consists of two modules: +OpenSearch Flint is ... It consists of four modules: - `flint-core`: a module that contains Flint specification and client. +- `flint-commons`: a module that provides a shared library of utilities and common functionalities, designed to easily extend Flint's capabilities. - `flint-spark-integration`: a module that provides Spark integration for Flint and derived dataset based on it. - `ppl-spark-integration`: a module that provides PPL query execution on top of Spark See [PPL repository](https://github.com/opensearch-project/piped-processing-language). diff --git a/build.sbt b/build.sbt index f12a19647..9d419decb 100644 --- a/build.sbt +++ b/build.sbt @@ -48,7 +48,7 @@ lazy val commonSettings = Seq( // running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) + .aggregate(flintCommons, flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -84,6 +84,37 @@ lazy val flintCore = (project in file("flint-core")) libraryDependencies ++= deps(sparkVersion), publish / skip := true) +lazy val flintCommons = (project in file("flint-commons")) + .settings( + commonSettings, + name := "flint-commons", + scalaVersion := scala212, + libraryDependencies ++= Seq( + "org.scalactic" %% "scalactic" % "3.2.15" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", + ), + libraryDependencies ++= deps(sparkVersion), + publish / skip := true, + assembly / test := (Test / test).value, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps@_*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs@_, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + ) + .enablePlugins(AssemblyPlugin) + + lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( @@ -121,7 +152,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) assembly / test := (Test / test).value) lazy val flintSparkIntegration = (project in file("flint-spark-integration")) - .dependsOn(flintCore) + .dependsOn(flintCore, flintCommons) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, @@ -166,7 +197,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") + .dependsOn(flintCommons % "test->test", flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala new file mode 100644 index 000000000..109bf654a --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/ContextualDataStore.scala @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.data + +/** + * Provides a mutable map to store and retrieve contextual data using key-value pairs. + */ +trait ContextualDataStore { + + /** Holds the contextual data as key-value pairs. */ + var context: Map[String, Any] = Map.empty + + /** + * Adds a key-value pair to the context map. + */ + def setContextValue(key: String, value: Any): Unit = { + context += (key -> value) + } + + /** + * Retrieves the value associated with a key from the context map. + */ + def getContextValue(key: String): Option[Any] = { + context.get(key) + } +} diff --git a/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala new file mode 100644 index 000000000..dbe73e9a5 --- /dev/null +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/FlintStatement.scala @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.data + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.JsonAST.JString +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization + +object StatementStates { + val RUNNING = "running" + val SUCCESS = "success" + val FAILED = "failed" + val WAITING = "waiting" +} + +/** + * Represents a statement processed in the Flint job. + * + * @param state + * The current state of the statement. + * @param query + * SQL-like query string that the statement will execute. + * @param statementId + * Unique identifier for the type of statement. + * @param queryId + * Unique identifier for the query. + * @param submitTime + * Timestamp when the statement was submitted. + * @param error + * Optional error message if the statement fails. + * @param statementContext + * Additional context for the statement as key-value pairs. + */ +class FlintStatement( + var state: String, + val query: String, + // statementId is the statement type doc id + val statementId: String, + val queryId: String, + val submitTime: Long, + var error: Option[String] = None, + statementContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualDataStore { + context = statementContext + + def running(): Unit = state = StatementStates.RUNNING + def complete(): Unit = state = StatementStates.SUCCESS + def fail(): Unit = state = StatementStates.FAILED + def isRunning: Boolean = state == StatementStates.RUNNING + def isComplete: Boolean = state == StatementStates.SUCCESS + def isFailed: Boolean = state == StatementStates.FAILED + def isWaiting: Boolean = state == StatementStates.WAITING + + // Does not include context, which could contain sensitive information. + override def toString: String = + s"FlintStatement(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" +} + +object FlintStatement { + + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + def deserialize(statement: String): FlintStatement = { + val meta = parse(statement) + val state = (meta \ "state").extract[String] + val query = (meta \ "query").extract[String] + val statementId = (meta \ "statementId").extract[String] + val queryId = (meta \ "queryId").extract[String] + val submitTime = (meta \ "submitTime").extract[Long] + val maybeError: Option[String] = (meta \ "error") match { + case JString(str) => Some(str) + case _ => None + } + + new FlintStatement(state, query, statementId, queryId, submitTime, maybeError) + } + + def serialize(flintStatement: FlintStatement): String = { + // we only need to modify state and error + Serialization.write( + Map("state" -> flintStatement.state, "error" -> flintStatement.error.getOrElse(""))) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala similarity index 76% rename from flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala rename to flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala index 9911a3b6c..c5eaee4f1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-commons/src/main/scala/org/opensearch/flint/data/InteractiveSession.scala @@ -3,42 +3,78 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ -import scala.collection.mutable import org.json4s.{Formats, JNothing, JNull, NoTypeHints} import org.json4s.JsonAST.{JArray, JString} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization -// lastUpdateTime is added to FlintInstance to track the last update time of the instance. Its unit is millisecond. -class FlintInstance( +object SessionStates { + val RUNNING = "running" + val COMPLETE = "complete" + val FAILED = "failed" + val WAITING = "waiting" +} + +/** + * Represents an interactive session for job and state management. + * + * @param applicationId + * Unique identifier for the EMR-S application. + * @param jobId + * Identifier for the specific EMR-S job. + * @param sessionId + * Unique session identifier. + * @param state + * Current state of the session. + * @param lastUpdateTime + * Timestamp of the last update. + * @param jobStartTime + * Start time of the job. + * @param excludedJobIds + * List of job IDs that are excluded. + * @param error + * Optional error message. + * @param sessionContext + * Additional context for the session. + */ +class InteractiveSession( val applicationId: String, val jobId: String, - // sessionId is the session type doc id val sessionId: String, var state: String, val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) { + val error: Option[String] = None, + sessionContext: Map[String, Any] = Map.empty[String, Any]) + extends ContextualDataStore { + context = sessionContext // Initialize the context from the constructor + + def isRunning: Boolean = state == SessionStates.RUNNING + def isComplete: Boolean = state == SessionStates.COMPLETE + def isFailed: Boolean = state == SessionStates.FAILED + def isWaiting: Boolean = state == SessionStates.WAITING + override def toString: String = { val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") val errorStr = error.getOrElse("None") + // Does not include context, which could contain sensitive information. s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" } } -object FlintInstance { +object InteractiveSession { implicit val formats: Formats = Serialization.formats(NoTypeHints) - def deserialize(job: String): FlintInstance = { + def deserialize(job: String): InteractiveSession = { val meta = parse(job) val applicationId = (meta \ "applicationId").extract[String] val state = (meta \ "state").extract[String] @@ -64,7 +100,7 @@ object FlintInstance { case _ => None } - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -75,7 +111,7 @@ object FlintInstance { maybeError) } - def deserializeFromMap(source: JavaMap[String, AnyRef]): FlintInstance = { + def deserializeFromMap(source: JavaMap[String, AnyRef]): InteractiveSession = { // Since we are dealing with JavaMap, we convert it to a Scala mutable Map for ease of use. val scalaSource = source.asScala @@ -105,7 +141,7 @@ object FlintInstance { } // Construct a new FlintInstance with the extracted values. - new FlintInstance( + new InteractiveSession( applicationId, jobId, sessionId, @@ -133,7 +169,10 @@ object FlintInstance { * @return * serialized Flint session */ - def serialize(job: FlintInstance, currentTime: Long, includeJobId: Boolean = true): String = { + def serialize( + job: InteractiveSession, + currentTime: Long, + includeJobId: Boolean = true): String = { val baseMap = Map( "type" -> "session", "sessionId" -> job.sessionId, @@ -159,7 +198,7 @@ object FlintInstance { Serialization.write(resultMap) } - def serializeWithoutJobId(job: FlintInstance, currentTime: Long): String = { + def serializeWithoutJobId(job: InteractiveSession, currentTime: Long): String = { serialize(job, currentTime, includeJobId = false) } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala b/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala similarity index 78% rename from flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala rename to flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala index 8ece6ba8a..f69fe70b4 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/app/FlintInstanceTest.scala +++ b/flint-commons/src/test/scala/org/opensearch/flint/data/InteractiveSessionTest.scala @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.flint.app +package org.opensearch.flint.data import java.util.{HashMap => JavaHashMap} @@ -11,12 +11,12 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite -class FlintInstanceTest extends SparkFunSuite with Matchers { +class InteractiveSessionTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance with excludedJobIds from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -30,7 +30,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("serialize should correctly produce JSON from a FlintInstance with excludedJobIds") { val excludedJobIds = Seq("job-101", "job-202") - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -39,7 +39,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { 1620000001000L, excludedJobIds) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""applicationId":"app-123"""") json should not include (""""jobId":"job-456"""") @@ -56,7 +56,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle an empty excludedJobIds field in JSON") { val jsonWithoutExcludedJobIds = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"jobStartTime":1620000001000}""" - val instance = FlintInstance.deserialize(jsonWithoutExcludedJobIds) + val instance = InteractiveSession.deserialize(jsonWithoutExcludedJobIds) instance.excludedJobIds shouldBe empty } @@ -64,13 +64,13 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly handle error field in JSON") { val jsonWithError = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"FAILED","lastUpdateTime":1620000000000,"jobStartTime":1620000001000,"error":"Some error occurred"}""" - val instance = FlintInstance.deserialize(jsonWithError) + val instance = InteractiveSession.deserialize(jsonWithError) instance.error shouldBe Some("Some error occurred") } test("serialize should include error when present in FlintInstance") { - val instance = new FlintInstance( + val instance = new InteractiveSession( "app-123", "job-456", "session-789", @@ -80,7 +80,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { Seq.empty[String], Some("Some error occurred")) val currentTime = System.currentTimeMillis() - val json = FlintInstance.serializeWithoutJobId(instance, currentTime) + val json = InteractiveSession.serializeWithoutJobId(instance, currentTime) json should include(""""error":"Some error occurred"""") } @@ -96,7 +96,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -114,7 +114,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("lastUpdateTime", "1234567890") assertThrows[ClassCastException] { - FlintInstance.deserializeFromMap(sourceMap) + InteractiveSession.deserializeFromMap(sourceMap) } } @@ -129,7 +129,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { sourceMap.put("excludeJobIds", java.util.Arrays.asList("job2", "job3")) sourceMap.put("error", "An error occurred") - val result = FlintInstance.deserializeFromMap(sourceMap) + val result = InteractiveSession.deserializeFromMap(sourceMap) assert(result.applicationId == "app1") assert(result.jobId == "job1") @@ -144,7 +144,7 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { test("deserialize should correctly parse a FlintInstance without jobStartTime from JSON") { val json = """{"applicationId":"app-123","jobId":"job-456","sessionId":"session-789","state":"RUNNING","lastUpdateTime":1620000000000,"excludeJobIds":["job-101","job-202"]}""" - val instance = FlintInstance.deserialize(json) + val instance = InteractiveSession.deserialize(json) instance.applicationId shouldBe "app-123" instance.jobId shouldBe "job-456" @@ -155,4 +155,23 @@ class FlintInstanceTest extends SparkFunSuite with Matchers { instance.excludedJobIds should contain allOf ("job-101", "job-202") instance.error shouldBe None } + + test("sessionContext should add/get key-value pairs to/from the context") { + val session = + new InteractiveSession("app-123", "job-456", "session-789", "RUNNING", 1620000000000L) + session.context shouldBe empty + + session.setContextValue("key1", "value1") + session.setContextValue("key2", 42) + + session.context should contain("key1" -> "value1") + session.context should contain("key2" -> 42) + + session.getContextValue("key1") shouldBe Some("value1") + session.getContextValue("key2") shouldBe Some(42) + session.getContextValue("key3") shouldBe None // Test for a key that does not exist + + session.setContextValue("key1", "updatedValue") + session.getContextValue("key1") shouldBe Some("updatedValue") + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index ee78aa512..b9ef05851 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -5,10 +5,9 @@ package org.opensearch.flint.core; -import java.util.List; +import java.util.Map; import org.opensearch.flint.core.metadata.FlintMetadata; -import org.opensearch.flint.core.metadata.log.OptimisticTransaction; import org.opensearch.flint.core.storage.FlintReader; import org.opensearch.flint.core.storage.FlintWriter; @@ -18,27 +17,6 @@ */ public interface FlintClient { - /** - * Start a new optimistic transaction. - * - * @param indexName index name - * @param dataSourceName TODO: read from elsewhere in future - * @return transaction handle - */ - OptimisticTransaction startTransaction(String indexName, String dataSourceName); - - /** - * - * Start a new optimistic transaction. - * - * @param indexName index name - * @param dataSourceName TODO: read from elsewhere in future - * @param forceInit forceInit create empty translog if not exist. - * @return transaction handle - */ - OptimisticTransaction startTransaction(String indexName, String dataSourceName, - boolean forceInit); - /** * Create a Flint index with the metadata given. * @@ -59,9 +37,10 @@ OptimisticTransaction startTransaction(String indexName, String dataSourc * Retrieve all metadata for Flint index whose name matches the given pattern. * * @param indexNamePattern index name pattern - * @return all matched index metadata + * @return map where the keys are the matched index names, and the values are + * corresponding index metadata */ - List getAllIndexMetadata(String indexNamePattern); + Map getAllIndexMetadata(String indexNamePattern); /** * Retrieve metadata in a Flint index. diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogService.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogService.java new file mode 100644 index 000000000..a356a456f --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogService.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata.log; + +import java.util.Optional; + +/** + * Flint metadata log service provides API for metadata log related operations on a Flint index + * regardless of underlying storage. + */ +public interface FlintMetadataLogService { + + /** + * Start a new optimistic transaction. + * + * @param indexName index name + * @param forceInit force init transaction and create empty metadata log if not exist + * @return transaction handle + */ + OptimisticTransaction startTransaction(String indexName, boolean forceInit); + + /** + * Start a new optimistic transaction. + * + * @param indexName index name + * @return transaction handle + */ + default OptimisticTransaction startTransaction(String indexName) { + return startTransaction(indexName, false); + } + + /** + * Get metadata log for index. + * + * @param indexName index name + * @return optional metadata log + */ + Optional> getIndexMetadataLog(String indexName); + + /** + * Record heartbeat timestamp for index streaming job. + * + * @param indexName index name + */ + void recordHeartbeat(String indexName); +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java new file mode 100644 index 000000000..3e2556f57 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/log/FlintMetadataLogServiceBuilder.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata.log; + +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService; + +/** + * {@link FlintMetadataLogService} builder. + */ +public class FlintMetadataLogServiceBuilder { + public static FlintMetadataLogService build(FlintOptions options) { + return new FlintOpenSearchMetadataLogService(options); + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index c1b884241..36db4a040 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -7,31 +7,17 @@ import static org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import java.io.IOException; -import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; import java.util.stream.Collectors; -import org.apache.http.HttpHost; -import org.apache.http.auth.AuthScope; -import org.apache.http.auth.UsernamePasswordCredentials; -import org.apache.http.client.CredentialsProvider; -import org.apache.http.impl.client.BasicCredentialsProvider; -import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.client.RequestOptions; -import org.opensearch.client.RestClient; -import org.opensearch.client.RestClientBuilder; -import org.opensearch.client.RestHighLevelClient; import org.opensearch.client.indices.CreateIndexRequest; import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.client.indices.GetIndexResponse; @@ -45,20 +31,13 @@ import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; -import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; -import org.opensearch.flint.core.http.RetryableHttpAsyncClient; import org.opensearch.flint.core.metadata.FlintMetadata; -import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; -import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; -import org.opensearch.flint.core.metadata.log.OptimisticTransaction; -import org.opensearch.flint.core.RestHighLevelClientWrapper; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import scala.Option; -import scala.Some; /** * Flint client implementation for OpenSearch storage. @@ -67,8 +46,6 @@ public class FlintOpenSearchClient implements FlintClient { private static final Logger LOG = Logger.getLogger(FlintOpenSearchClient.class.getName()); - private static final String SERVICE_NAME = "es"; - /** * {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link QueryBuilder} from DSL query string. @@ -85,47 +62,12 @@ public class FlintOpenSearchClient implements FlintClient { private final static Set INVALID_INDEX_NAME_CHARS = Set.of(' ', ',', ':', '"', '+', '/', '\\', '|', '?', '#', '>', '<'); - /** - * Metadata log index name prefix - */ - public final static String META_LOG_NAME_PREFIX = ".query_execution_request"; - private final FlintOptions options; public FlintOpenSearchClient(FlintOptions options) { this.options = options; } - @Override - public OptimisticTransaction startTransaction( - String indexName, String dataSourceName, boolean forceInit) { - LOG.info("Starting transaction on index " + indexName + " and data source " + dataSourceName); - String metaLogIndexName = constructMetaLogIndexName(dataSourceName); - try (IRestHighLevelClient client = createClient()) { - if (client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { - LOG.info("Found metadata log index " + metaLogIndexName); - } else { - if (forceInit) { - createIndex(metaLogIndexName, FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_MAPPING(), - Some.apply(FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_SETTINGS())); - } else { - String errorMsg = "Metadata log index not found " + metaLogIndexName; - LOG.warning(errorMsg); - throw new IllegalStateException(errorMsg); - } - } - return new DefaultOptimisticTransaction<>(dataSourceName, - new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName)); - } catch (IOException e) { - throw new IllegalStateException("Failed to check if index metadata log index exists " + metaLogIndexName, e); - } - } - - @Override - public OptimisticTransaction startTransaction(String indexName, String dataSourceName) { - return startTransaction(indexName, dataSourceName, false); - } - @Override public void createIndex(String indexName, FlintMetadata metadata) { LOG.info("Creating Flint index " + indexName + " with metadata " + metadata); @@ -159,7 +101,7 @@ public boolean exists(String indexName) { } @Override - public List getAllIndexMetadata(String indexNamePattern) { + public Map getAllIndexMetadata(String indexNamePattern) { LOG.info("Fetching all Flint index metadata for pattern " + indexNamePattern); String osIndexNamePattern = sanitizeIndexName(indexNamePattern); try (IRestHighLevelClient client = createClient()) { @@ -167,11 +109,13 @@ public List getAllIndexMetadata(String indexNamePattern) { GetIndexResponse response = client.getIndex(request, RequestOptions.DEFAULT); return Arrays.stream(response.getIndices()) - .map(index -> constructFlintMetadata( - index, - response.getMappings().get(index).source().toString(), - response.getSettings().get(index).toString())) - .collect(Collectors.toList()); + .collect(Collectors.toMap( + index -> index, + index -> FlintMetadata.apply( + response.getMappings().get(index).source().toString(), + response.getSettings().get(index).toString() + ) + )); } catch (Exception e) { throw new IllegalStateException("Failed to get Flint index metadata for " + osIndexNamePattern, e); } @@ -187,7 +131,7 @@ public FlintMetadata getIndexMetadata(String indexName) { MappingMetadata mapping = response.getMappings().get(osIndexName); Settings settings = response.getSettings().get(osIndexName); - return constructFlintMetadata(indexName, mapping.source().string(), settings.toString()); + return FlintMetadata.apply(mapping.source().string(), settings.toString()); } catch (Exception e) { throw new IllegalStateException("Failed to get Flint index metadata for " + osIndexName, e); } @@ -254,102 +198,7 @@ public FlintWriter createWriter(String indexName) { @Override public IRestHighLevelClient createClient() { - RestClientBuilder - restClientBuilder = - RestClient.builder(new HttpHost(options.getHost(), options.getPort(), options.getScheme())); - - // SigV4 support - if (options.getAuth().equals(FlintOptions.SIGV4_AUTH)) { - // Use DefaultAWSCredentialsProviderChain by default. - final AtomicReference customAWSCredentialsProvider = - new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); - String customProviderClass = options.getCustomAwsCredentialsProvider(); - if (!Strings.isNullOrEmpty(customProviderClass)) { - instantiateProvider(customProviderClass, customAWSCredentialsProvider); - } - - // Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility - // unless a specific metadata access provider class name is provided - String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider(); - final AtomicReference metadataAccessAWSCredentialsProvider = - new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); - - String metaLogIndexName = constructMetaLogIndexName(options.getDataSourceName()); - String systemIndexName = Strings.isNullOrEmpty(options.getSystemIndexName()) ? metaLogIndexName : options.getSystemIndexName(); - - if (Strings.isNullOrEmpty(metadataAccessProviderClass)) { - metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get()); - } else { - instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider); - } - - restClientBuilder.setHttpClientConfigCallback(builder -> { - HttpAsyncClientBuilder delegate = builder.addInterceptorLast( - new ResourceBasedAWSRequestSigningApacheInterceptor( - SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); - return RetryableHttpAsyncClient.builder(delegate, options); - } - ); - } else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) { - CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials( - AuthScope.ANY, - new UsernamePasswordCredentials(options.getUsername(), options.getPassword())); - restClientBuilder.setHttpClientConfigCallback(builder -> { - HttpAsyncClientBuilder delegate = builder.setDefaultCredentialsProvider(credentialsProvider); - return RetryableHttpAsyncClient.builder(delegate, options); - }); - } else { - restClientBuilder.setHttpClientConfigCallback(delegate -> - RetryableHttpAsyncClient.builder(delegate, options)); - } - - final RequestConfigurator callback = new RequestConfigurator(options); - restClientBuilder.setRequestConfigCallback(callback); - - return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder)); - } - - /** - * Attempts to instantiate the AWS credential provider using reflection. - */ - private void instantiateProvider(String providerClass, AtomicReference provider) { - try { - Class awsCredentialsProviderClass = Class.forName(providerClass); - Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); - ctor.setAccessible(true); - provider.set((AWSCredentialsProvider) ctor.newInstance()); - } catch (Exception e) { - throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e); - } - } - - /* - * Constructs Flint metadata with latest metadata log entry attached if it's available. - * It relies on FlintOptions to provide data source name. - */ - private FlintMetadata constructFlintMetadata(String indexName, String mapping, String settings) { - String dataSourceName = options.getDataSourceName(); - String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX - : META_LOG_NAME_PREFIX + "_" + dataSourceName; - Optional latest = Optional.empty(); - - try (IRestHighLevelClient client = createClient()) { - if (client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { - LOG.info("Found metadata log index " + metaLogIndexName); - FlintOpenSearchMetadataLog metadataLog = - new FlintOpenSearchMetadataLog(this, indexName, metaLogIndexName); - latest = metadataLog.getLatest(); - } - } catch (IOException e) { - throw new IllegalStateException("Failed to check if index metadata log index exists " + metaLogIndexName, e); - } - - if (latest.isEmpty()) { - return FlintMetadata.apply(mapping, settings); - } else { - return FlintMetadata.apply(mapping, settings, latest.get()); - } + return OpenSearchClientUtils.createClient(options); } /* @@ -388,8 +237,4 @@ private String sanitizeIndexName(String indexName) { String encoded = percentEncode(indexName); return toLowercase(encoded); } - - private String constructMetaLogIndexName(String dataSourceName) { - return dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; - } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 7195ae177..6aea13436 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -23,7 +23,7 @@ import org.opensearch.client.RequestOptions; import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.flint.core.FlintClient; +import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; import org.opensearch.flint.core.metadata.log.FlintMetadataLog; import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; @@ -37,23 +37,23 @@ public class FlintOpenSearchMetadataLog implements FlintMetadataLog getLatest() { LOG.info("Fetching latest log entry with id " + latestId); - try (IRestHighLevelClient client = flintClient.createClient()) { + try (IRestHighLevelClient client = createOpenSearchClient()) { GetResponse response = - client.get(new GetRequest(metaLogIndexName, latestId), RequestOptions.DEFAULT); + client.get(new GetRequest(metadataLogIndexName, latestId), RequestOptions.DEFAULT); if (response.isExists()) { FlintMetadataLogEntry latest = new FlintMetadataLogEntry( @@ -102,10 +102,10 @@ public Optional getLatest() { @Override public void purge() { LOG.info("Purging log entry with id " + latestId); - try (IRestHighLevelClient client = flintClient.createClient()) { + try (IRestHighLevelClient client = createOpenSearchClient()) { DeleteResponse response = client.delete( - new DeleteRequest(metaLogIndexName, latestId), RequestOptions.DEFAULT); + new DeleteRequest(metadataLogIndexName, latestId), RequestOptions.DEFAULT); LOG.info("Purged log entry with result " + response.getResult()); } catch (Exception e) { @@ -129,7 +129,7 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { return writeLogEntry(logEntryWithId, client -> client.index( new IndexRequest() - .index(metaLogIndexName) + .index(metadataLogIndexName) .id(logEntryWithId.id()) .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) .source(logEntryWithId.toJson(), XContentType.JSON), @@ -140,7 +140,7 @@ private FlintMetadataLogEntry updateLogEntry(FlintMetadataLogEntry logEntry) { LOG.info("Updating log entry " + logEntry); return writeLogEntry(logEntry, client -> client.update( - new UpdateRequest(metaLogIndexName, logEntry.id()) + new UpdateRequest(metadataLogIndexName, logEntry.id()) .doc(logEntry.toJson(), XContentType.JSON) .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) .setIfSeqNo(logEntry.seqNo()) @@ -151,7 +151,7 @@ private FlintMetadataLogEntry updateLogEntry(FlintMetadataLogEntry logEntry) { private FlintMetadataLogEntry writeLogEntry( FlintMetadataLogEntry logEntry, CheckedFunction write) { - try (IRestHighLevelClient client = flintClient.createClient()) { + try (IRestHighLevelClient client = createOpenSearchClient()) { // Write (create or update) the doc DocWriteResponse response = write.apply(client); @@ -173,14 +173,18 @@ private FlintMetadataLogEntry writeLogEntry( } private boolean exists() { - LOG.info("Checking if Flint index exists " + metaLogIndexName); - try (IRestHighLevelClient client = flintClient.createClient()) { - return client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT); + LOG.info("Checking if Flint index exists " + metadataLogIndexName); + try (IRestHighLevelClient client = createOpenSearchClient()) { + return client.doesIndexExist(new GetIndexRequest(metadataLogIndexName), RequestOptions.DEFAULT); } catch (IOException e) { - throw new IllegalStateException("Failed to check if Flint index exists " + metaLogIndexName, e); + throw new IllegalStateException("Failed to check if Flint index exists " + metadataLogIndexName, e); } } + private IRestHighLevelClient createOpenSearchClient() { + return OpenSearchClientUtils.createClient(options); + } + @FunctionalInterface public interface CheckedFunction { R apply(T t) throws IOException; diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLogService.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLogService.java new file mode 100644 index 000000000..f04a3bc67 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLogService.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import java.io.IOException; +import java.util.Optional; +import java.util.logging.Logger; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.indices.CreateIndexRequest; +import org.opensearch.client.indices.GetIndexRequest; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; +import org.opensearch.flint.core.metadata.log.FlintMetadataLog; +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState$; +import org.opensearch.flint.core.metadata.log.FlintMetadataLogService; +import org.opensearch.flint.core.metadata.log.OptimisticTransaction; + +/** + * Flint metadata log service implementation for OpenSearch storage. + */ +public class FlintOpenSearchMetadataLogService implements FlintMetadataLogService { + + private static final Logger LOG = Logger.getLogger(FlintOpenSearchMetadataLogService.class.getName()); + + public final static String METADATA_LOG_INDEX_NAME_PREFIX = ".query_execution_request"; + + private final FlintOptions options; + private final String dataSourceName; + private final String metadataLogIndexName; + + public FlintOpenSearchMetadataLogService(FlintOptions options) { + this.options = options; + this.dataSourceName = options.getDataSourceName(); + this.metadataLogIndexName = constructMetadataLogIndexName(); + } + + @Override + public OptimisticTransaction startTransaction(String indexName, boolean forceInit) { + LOG.info("Starting transaction on index " + indexName + " and data source " + dataSourceName); + Optional> metadataLog = getIndexMetadataLog(indexName, forceInit); + if (metadataLog.isEmpty()) { + String errorMsg = "Metadata log index not found " + metadataLogIndexName; + throw new IllegalStateException(errorMsg); + } + return new DefaultOptimisticTransaction<>(dataSourceName, metadataLog.get()); + } + + @Override + public Optional> getIndexMetadataLog(String indexName) { + return getIndexMetadataLog(indexName, false); + } + + @Override + public void recordHeartbeat(String indexName) { + startTransaction(indexName) + .initialLog(latest -> latest.state() == IndexState$.MODULE$.REFRESHING()) + .finalLog(latest -> latest) // timestamp will update automatically + .commit(latest -> null); + } + + private Optional> getIndexMetadataLog(String indexName, boolean initIfNotExist) { + LOG.info("Getting metadata log for index " + indexName + " and data source " + dataSourceName); + try (IRestHighLevelClient client = createOpenSearchClient()) { + if (client.doesIndexExist(new GetIndexRequest(metadataLogIndexName), RequestOptions.DEFAULT)) { + LOG.info("Found metadata log index " + metadataLogIndexName); + } else { + if (initIfNotExist) { + initIndexMetadataLog(); + } else { + String errorMsg = "Metadata log index not found " + metadataLogIndexName; + LOG.warning(errorMsg); + return Optional.empty(); + } + } + return Optional.of(new FlintOpenSearchMetadataLog(options, indexName, metadataLogIndexName)); + } catch (IOException e) { + throw new IllegalStateException("Failed to check if index metadata log index exists " + metadataLogIndexName, e); + } + } + + private void initIndexMetadataLog() { + LOG.info("Initializing metadata log index " + metadataLogIndexName); + try (IRestHighLevelClient client = createOpenSearchClient()) { + CreateIndexRequest request = new CreateIndexRequest(metadataLogIndexName); + request.mapping(FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_MAPPING(), XContentType.JSON); + request.settings(FlintMetadataLogEntry.QUERY_EXECUTION_REQUEST_SETTINGS(), XContentType.JSON); + client.createIndex(request, RequestOptions.DEFAULT); + } catch (Exception e) { + throw new IllegalStateException("Failed to initialize metadata log index " + metadataLogIndexName, e); + } + } + + private String constructMetadataLogIndexName() { + return dataSourceName.isEmpty() ? METADATA_LOG_INDEX_NAME_PREFIX : METADATA_LOG_INDEX_NAME_PREFIX + "_" + dataSourceName; + } + + private IRestHighLevelClient createOpenSearchClient() { + return OpenSearchClientUtils.createClient(options); + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java new file mode 100644 index 000000000..c047ced51 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import java.lang.reflect.Constructor; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.Strings; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.flint.core.RestHighLevelClientWrapper; +import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; +import org.opensearch.flint.core.http.RetryableHttpAsyncClient; + +/** + * Utility functions to create {@link IRestHighLevelClient}. + */ +public class OpenSearchClientUtils { + + private static final String SERVICE_NAME = "es"; + + /** + * Metadata log index name prefix + */ + public final static String META_LOG_NAME_PREFIX = ".query_execution_request"; + + public static IRestHighLevelClient createClient(FlintOptions options) { + RestClientBuilder + restClientBuilder = + RestClient.builder(new HttpHost(options.getHost(), options.getPort(), options.getScheme())); + + if (options.getAuth().equals(FlintOptions.SIGV4_AUTH)) { + restClientBuilder = configureSigV4Auth(restClientBuilder, options); + } else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) { + restClientBuilder = configureBasicAuth(restClientBuilder, options); + } else { + restClientBuilder = configureDefaultAuth(restClientBuilder, options); + } + + final RequestConfigurator callback = new RequestConfigurator(options); + restClientBuilder.setRequestConfigCallback(callback); + + return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder)); + } + + private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClientBuilder, FlintOptions options) { + // Use DefaultAWSCredentialsProviderChain by default. + final AtomicReference customAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + String customProviderClass = options.getCustomAwsCredentialsProvider(); + if (!Strings.isNullOrEmpty(customProviderClass)) { + instantiateProvider(customProviderClass, customAWSCredentialsProvider); + } + + // Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility + // unless a specific metadata access provider class name is provided + String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider(); + final AtomicReference metadataAccessAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + + String metaLogIndexName = constructMetaLogIndexName(options.getDataSourceName()); + String systemIndexName = Strings.isNullOrEmpty(options.getSystemIndexName()) ? metaLogIndexName : options.getSystemIndexName(); + + if (Strings.isNullOrEmpty(metadataAccessProviderClass)) { + metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get()); + } else { + instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider); + } + + restClientBuilder.setHttpClientConfigCallback(builder -> { + HttpAsyncClientBuilder delegate = builder.addInterceptorLast( + new ResourceBasedAWSRequestSigningApacheInterceptor( + SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); + return RetryableHttpAsyncClient.builder(delegate, options); + } + ); + + return restClientBuilder; + } + + private static RestClientBuilder configureBasicAuth(RestClientBuilder restClientBuilder, FlintOptions options) { + CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials( + AuthScope.ANY, + new UsernamePasswordCredentials(options.getUsername(), options.getPassword())); + restClientBuilder.setHttpClientConfigCallback(builder -> { + HttpAsyncClientBuilder delegate = builder.setDefaultCredentialsProvider(credentialsProvider); + return RetryableHttpAsyncClient.builder(delegate, options); + }); + + return restClientBuilder; + } + + private static RestClientBuilder configureDefaultAuth(RestClientBuilder restClientBuilder, FlintOptions options) { + // No auth + restClientBuilder.setHttpClientConfigCallback(delegate -> + RetryableHttpAsyncClient.builder(delegate, options)); + return restClientBuilder; + } + + /** + * Attempts to instantiate the AWS credential provider using reflection. + */ + private static void instantiateProvider(String providerClass, AtomicReference provider) { + try { + Class awsCredentialsProviderClass = Class.forName(providerClass); + Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); + ctor.setAccessible(true); + provider.set((AWSCredentialsProvider) ctor.newInstance()); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e); + } + } + + private static String constructMetaLogIndexName(String dataSourceName) { + return dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala deleted file mode 100644 index 7624c2c54..000000000 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintCommand.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.app - -import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.JString -import org.json4s.native.JsonMethods.parse -import org.json4s.native.Serialization - -class FlintCommand( - var state: String, - val query: String, - // statementId is the statement type doc id - val statementId: String, - val queryId: String, - val submitTime: Long, - var error: Option[String] = None) { - def running(): Unit = { - state = "running" - } - - def complete(): Unit = { - state = "success" - } - - def fail(): Unit = { - state = "failed" - } - - def isRunning(): Boolean = { - state == "running" - } - - def isComplete(): Boolean = { - state == "success" - } - - def isFailed(): Boolean = { - state == "failed" - } - - def isWaiting(): Boolean = { - state == "waiting" - } - - override def toString: String = { - s"FlintCommand(state=$state, query=$query, statementId=$statementId, queryId=$queryId, submitTime=$submitTime, error=$error)" - } -} - -object FlintCommand { - - implicit val formats: Formats = Serialization.formats(NoTypeHints) - - def deserialize(command: String): FlintCommand = { - val meta = parse(command) - val state = (meta \ "state").extract[String] - val query = (meta \ "query").extract[String] - val statementId = (meta \ "statementId").extract[String] - val queryId = (meta \ "queryId").extract[String] - val submitTime = (meta \ "submitTime").extract[Long] - val maybeError: Option[String] = (meta \ "error") match { - case JString(str) => Some(str) - case _ => None - } - - new FlintCommand(state, query, statementId, queryId, submitTime, maybeError) - } - - def serialize(flintCommand: FlintCommand): String = { - // we only need to modify state and error - Serialization.write( - Map("state" -> flintCommand.state, "error" -> flintCommand.error.getOrElse(""))) - } -} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 848bbe61f..df7c92636 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -10,6 +10,8 @@ import scala.collection.JavaConverters._ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.metadata.log.{FlintMetadataLogService, FlintMetadataLogServiceBuilder} import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ import org.opensearch.flint.core.metadata.log.OptimisticTransaction.NO_LOG_ENTRY import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN @@ -43,19 +45,15 @@ class FlintSpark(val spark: SparkSession) extends Logging { /** Flint client for low-level index operation */ private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions()) + private val flintMetadataLogService: FlintMetadataLogService = + FlintMetadataLogServiceBuilder.build(flintSparkConf.flintOptions()) + /** Required by json4s parse function */ implicit val formats: Formats = Serialization.formats(NoTypeHints) + SkippingKindSerializer - /** - * Data source name. Assign empty string in case of backward compatibility. TODO: remove this in - * future - */ - private val dataSourceName: String = - spark.conf.getOption("spark.flint.datasource.name").getOrElse("") - /** Flint Spark index monitor */ val flintIndexMonitor: FlintSparkIndexMonitor = - new FlintSparkIndexMonitor(spark, flintClient, dataSourceName) + new FlintSparkIndexMonitor(spark, flintMetadataLogService) /** * Create index builder for creating index with fluent API. @@ -105,8 +103,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { } else { val metadata = index.metadata() try { - flintClient - .startTransaction(indexName, dataSourceName, true) + flintMetadataLogService + .startTransaction(indexName, true) .initialLog(latest => latest.state == EMPTY || latest.state == DELETED) .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) @@ -141,8 +139,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { val indexRefresh = FlintSparkIndexRefresh.create(indexName, index) try { - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => latest.state == ACTIVE) .transientLog(latest => latest.copy(state = REFRESHING, createTime = System.currentTimeMillis())) @@ -179,6 +177,10 @@ class FlintSpark(val spark: SparkSession) extends Logging { flintClient .getAllIndexMetadata(indexNamePattern) .asScala + .map { case (indexName, metadata) => + attachLatestLogEntry(indexName, metadata) + } + .toList .flatMap(FlintSparkIndexFactory.create) } else { Seq.empty @@ -197,7 +199,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { logInfo(s"Describing index name $indexName") if (flintClient.exists(indexName)) { val metadata = flintClient.getIndexMetadata(indexName) - FlintSparkIndexFactory.create(metadata) + val metadataWithEntry = attachLatestLogEntry(indexName, metadata) + FlintSparkIndexFactory.create(metadataWithEntry) } else { Option.empty } @@ -248,8 +251,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { logInfo(s"Deleting Flint index $indexName") if (flintClient.exists(indexName)) { try { - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => latest.state == ACTIVE || latest.state == REFRESHING || latest.state == FAILED) .transientLog(latest => latest.copy(state = DELETING)) @@ -283,8 +286,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { logInfo(s"Vacuuming Flint index $indexName") if (flintClient.exists(indexName)) { try { - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => latest.state == DELETED) .transientLog(latest => latest.copy(state = VACUUMING)) .finalLog(_ => NO_LOG_ENTRY) @@ -314,8 +317,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { val index = describeIndex(indexName) if (index.exists(_.options.autoRefresh())) { try { - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => Set(ACTIVE, REFRESHING, FAILED).contains(latest.state)) .transientLog(latest => latest.copy(state = RECOVERING, createTime = System.currentTimeMillis())) @@ -345,8 +348,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { * interim, but metadata log get deleted by this cleanup process. */ logWarning("Cleaning up metadata log as index data has been deleted") - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(_ => true) .finalLog(_ => NO_LOG_ENTRY) .commit(_ => {}) @@ -389,6 +392,27 @@ class FlintSpark(val spark: SparkSession) extends Logging { } } + /** + * Attaches latest log entry to metadata if available. + * + * @param indexName + * index name + * @param metadata + * base flint metadata + * @return + * flint metadata with latest log entry attached if available + */ + private def attachLatestLogEntry(indexName: String, metadata: FlintMetadata): FlintMetadata = { + val latest = flintMetadataLogService + .getIndexMetadataLog(indexName) + .flatMap(_.getLatest) + if (latest.isPresent) { + metadata.copy(latestLogEntry = Some(latest.get())) + } else { + metadata + } + } + /** * Validate the index update options are allowed. * @param originalOptions @@ -435,8 +459,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { private def updateIndexAutoToManual(index: FlintSparkIndex): Option[String] = { val indexName = index.name val indexLogEntry = index.latestLogEntry.get - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => latest.state == REFRESHING && latest.seqNo == indexLogEntry.seqNo && latest.primaryTerm == indexLogEntry.primaryTerm) .transientLog(latest => latest.copy(state = UPDATING)) @@ -454,8 +478,8 @@ class FlintSpark(val spark: SparkSession) extends Logging { val indexName = index.name val indexLogEntry = index.latestLogEntry.get val indexRefresh = FlintSparkIndexRefresh.create(indexName, index) - flintClient - .startTransaction(indexName, dataSourceName) + flintMetadataLogService + .startTransaction(indexName) .initialLog(latest => latest.state == ACTIVE && latest.seqNo == indexLogEntry.seqNo && latest.primaryTerm == indexLogEntry.primaryTerm) .transientLog(latest => diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala index d3f3ff0ee..815dfa71a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexMonitor.scala @@ -16,8 +16,8 @@ import scala.sys.addShutdownHook import dev.failsafe.{Failsafe, RetryPolicy} import dev.failsafe.event.ExecutionAttemptedEvent import dev.failsafe.function.CheckedRunnable -import org.opensearch.flint.core.FlintClient import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{FAILED, REFRESHING} +import org.opensearch.flint.core.metadata.log.FlintMetadataLogService import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil} import org.apache.spark.internal.Logging @@ -30,15 +30,12 @@ import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor * * @param spark * Spark session - * @param flintClient - * Flint client - * @param dataSourceName - * data source name + * @param flintMetadataLogService + * Flint metadata log service */ class FlintSparkIndexMonitor( spark: SparkSession, - flintClient: FlintClient, - dataSourceName: String) + flintMetadataLogService: FlintMetadataLogService) extends Logging { /** Task execution initial delay in seconds */ @@ -80,7 +77,8 @@ class FlintSparkIndexMonitor( */ def stopMonitor(indexName: String): Unit = { logInfo(s"Cancelling scheduled task for index $indexName") - val task = FlintSparkIndexMonitor.indexMonitorTracker.remove(indexName) + // Hack: Don't remove because awaitMonitor API requires Flint index name. + val task = FlintSparkIndexMonitor.indexMonitorTracker.get(indexName) if (task.isDefined) { task.get.cancel(true) } else { @@ -119,26 +117,25 @@ class FlintSparkIndexMonitor( logInfo(s"Streaming job $name terminated without exception") } catch { case e: Throwable => - /** - * Transition the index state to FAILED upon encountering an exception. Retry in case - * conflicts with final transaction in scheduled task. - * ``` - * TODO: - * 1) Determine the appropriate state code based on the type of exception encountered - * 2) Record and persist the error message of the root cause for further diagnostics. - * ``` - */ - logError(s"Streaming job $name terminated with exception", e) - retry { - flintClient - .startTransaction(name, dataSourceName) - .initialLog(latest => latest.state == REFRESHING) - .finalLog(latest => latest.copy(state = FAILED)) - .commit(_ => {}) - } + logError(s"Streaming job $name terminated with exception: ${e.getMessage}") + retryUpdateIndexStateToFailed(name) } } else { - logInfo(s"Index monitor for [$indexName] not found") + logInfo(s"Index monitor for [$indexName] not found.") + + /* + * Streaming job exits early. Try to find Flint index name in monitor list. + * Assuming: 1) there are at most 1 entry in the list, otherwise index name + * must be given upon this method call; 2) this await API must be called for + * auto refresh index, otherwise index state will be updated mistakenly. + */ + val name = FlintSparkIndexMonitor.indexMonitorTracker.keys.headOption + if (name.isDefined) { + logInfo(s"Found index name in index monitor task list: ${name.get}") + retryUpdateIndexStateToFailed(name.get) + } else { + logInfo(s"Index monitor task list is empty") + } } } @@ -159,11 +156,7 @@ class FlintSparkIndexMonitor( try { if (isStreamingJobActive(indexName)) { logInfo("Streaming job is still active") - flintClient - .startTransaction(indexName, dataSourceName) - .initialLog(latest => latest.state == REFRESHING) - .finalLog(latest => latest) // timestamp will update automatically - .commit(_ => {}) + flintMetadataLogService.recordHeartbeat(indexName) } else { logError("Streaming job is not active. Cancelling monitor task") stopMonitor(indexName) @@ -199,6 +192,26 @@ class FlintSparkIndexMonitor( } } + /** + * Transition the index state to FAILED upon encountering an exception. Retry in case conflicts + * with final transaction in scheduled task. + * ``` + * TODO: + * 1) Determine the appropriate state code based on the type of exception encountered + * 2) Record and persist the error message of the root cause for further diagnostics. + * ``` + */ + private def retryUpdateIndexStateToFailed(indexName: String): Unit = { + logInfo(s"Updating index state to failed for $indexName") + retry { + flintMetadataLogService + .startTransaction(indexName) + .initialLog(latest => latest.state == REFRESHING) + .finalLog(latest => latest.copy(state = FAILED)) + .commit(_ => {}) + } + } + private def retry(operation: => Unit): Unit = { // Retry policy for 3 times every 1 second val retryPolicy = RetryPolicy diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 74626d25d..d7c6ddf81 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -196,6 +196,22 @@ object FlintSparkMaterializedView { this } + override protected def validateIndex(index: FlintSparkIndex): FlintSparkIndex = { + /* + * Validate if duplicate column names in the output schema. + * MV query may be empty in the case of ALTER index statement. + */ + if (query.nonEmpty) { + val outputColNames = flint.spark.sql(query).schema.map(_.name) + require( + outputColNames.distinct.length == outputColNames.length, + "Duplicate columns found in materialized view query output") + } + + // Continue to perform any additional index validation + super.validateIndex(index) + } + override protected def buildIndex(): FlintSparkIndex = { // TODO: change here and FlintDS class to support complex field type in future val outputSchema = flint.spark diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index 990b9e449..7318e5c7c 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -17,13 +17,15 @@ import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest import org.opensearch.action.get.GetRequest import org.opensearch.client.RequestOptions import org.opensearch.flint.core.FlintOptions -import org.opensearch.flint.spark.FlintSparkSuite +import org.opensearch.flint.spark.{FlintSparkIndexMonitor, FlintSparkSuite} import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.{contain, defined} import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf._ +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.util.MockEnvironment import org.apache.spark.util.ThreadUtils @@ -46,6 +48,11 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { protected override def beforeEach(): Unit = { super.beforeEach() + + // Clear up because awaitMonitor will assume single name in tracker + FlintSparkIndexMonitor.indexMonitorTracker.values.foreach(_.cancel(true)) + FlintSparkIndexMonitor.indexMonitorTracker.clear() + createPartitionedMultiRowAddressTable(testTable) } @@ -195,6 +202,42 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { } } + test("create skipping index with auto refresh and streaming job early exit") { + // Custom listener to force streaming job to fail at the beginning + val listener = new StreamingQueryListener { + override def onQueryStarted(event: QueryStartedEvent): Unit = { + logInfo("Stopping streaming job intentionally") + spark.streams.active.find(_.name == event.name).get.stop() + } + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {} + } + + try { + spark.streams.addListener(listener) + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | (name VALUE_SET) + | WITH (auto_refresh = true) + | """.stripMargin + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + // Assert streaming job must exit + Thread.sleep(5000) + pollForResultAndAssert(_ => true, jobRunId) + spark.streams.active.exists(_.name == testIndex) shouldBe false + + // Assert Flint index transitioned to FAILED state after waiting seconds + Thread.sleep(2000L) + val latestId = Base64.getEncoder.encodeToString(testIndex.getBytes) + latestLogEntry(latestId) should contain("state" -> "failed") + } finally { + spark.streams.removeListener(listener) + } + } + test("create skipping index with non-existent table") { val query = s""" diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala index f2d4be911..1c0b27674 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -15,9 +15,9 @@ import scala.util.control.Breaks.{break, breakable} import org.opensearch.OpenSearchStatusException import org.opensearch.flint.OpenSearchSuite -import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.core.{FlintClient, FlintOptions} import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkFunSuite @@ -546,28 +546,28 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { } def awaitConditionForStatementOrTimeout( - expected: FlintCommand => Boolean, + expected: FlintStatement => Boolean, statementId: String): Boolean = { - awaitConditionOrTimeout[FlintCommand]( + awaitConditionOrTimeout[FlintStatement]( osClient, expected, statementId, 10000, requestIndex, - FlintCommand.deserialize, + FlintStatement.deserialize, "statement") } def awaitConditionForSessionOrTimeout( - expected: FlintInstance => Boolean, + expected: InteractiveSession => Boolean, sessionId: String): Boolean = { - awaitConditionOrTimeout[FlintInstance]( + awaitConditionOrTimeout[InteractiveSession]( osClient, expected, sessionId, 10000, requestIndex, - FlintInstance.deserialize, + InteractiveSession.deserialize, "session") } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala index 1e2219600..f37bb53f7 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/OpenSearchTransactionSuite.scala @@ -19,7 +19,7 @@ import org.opensearch.common.xcontent.XContentType import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.{QUERY_EXECUTION_REQUEST_MAPPING, QUERY_EXECUTION_REQUEST_SETTINGS} import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.IndexState -import org.opensearch.flint.core.storage.FlintOpenSearchClient._ +import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService.METADATA_LOG_INDEX_NAME_PREFIX import org.opensearch.flint.spark.FlintSparkSuite import org.apache.spark.sql.flint.config.FlintSparkConf.DATA_SOURCE_NAME @@ -31,7 +31,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.DATA_SOURCE_NAME trait OpenSearchTransactionSuite extends FlintSparkSuite { val testDataSourceName = "myglue" - lazy val testMetaLogIndex: String = META_LOG_NAME_PREFIX + "_" + testDataSourceName + lazy val testMetaLogIndex: String = METADATA_LOG_INDEX_NAME_PREFIX + "_" + testDataSourceName override def beforeAll(): Unit = { super.beforeAll() diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala new file mode 100644 index 000000000..f8a8c2164 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintMetadataLogITSuite.scala @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core + +import java.util.Base64 + +import scala.collection.JavaConverters._ + +import org.opensearch.flint.OpenSearchTransactionSuite +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry +import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ +import org.opensearch.flint.core.metadata.log.FlintMetadataLogService +import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService +import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.flint.config.FlintSparkConf.DATA_SOURCE_NAME + +class FlintMetadataLogITSuite extends OpenSearchTransactionSuite with Matchers { + + val testFlintIndex = "flint_test_index" + val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) + val testCreateTime = 1234567890123L + val flintMetadataLogEntry = FlintMetadataLogEntry( + id = testLatestId, + seqNo = UNASSIGNED_SEQ_NO, + primaryTerm = UNASSIGNED_PRIMARY_TERM, + createTime = testCreateTime, + state = ACTIVE, + dataSource = testDataSourceName, + error = "") + + var flintMetadataLogService: FlintMetadataLogService = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val options = openSearchOptions + (DATA_SOURCE_NAME.key -> testDataSourceName) + val flintOptions = new FlintOptions(options.asJava) + flintMetadataLogService = new FlintOpenSearchMetadataLogService(flintOptions) + } + + test("should fail if metadata log index doesn't exists") { + val options = openSearchOptions + (DATA_SOURCE_NAME.key -> "non-exist-datasource") + val flintMetadataLogService = + new FlintOpenSearchMetadataLogService(new FlintOptions(options.asJava)) + + the[IllegalStateException] thrownBy { + flintMetadataLogService.startTransaction(testFlintIndex) + } + } + + test("should get index metadata log without log entry") { + val metadataLog = flintMetadataLogService.getIndexMetadataLog(testFlintIndex) + metadataLog.isPresent shouldBe true + metadataLog.get.getLatest shouldBe empty + } + + test("should get index metadata log with log entry") { + createLatestLogEntry(flintMetadataLogEntry) + val metadataLog = flintMetadataLogService.getIndexMetadataLog(testFlintIndex) + metadataLog.isPresent shouldBe true + + val latest = metadataLog.get.getLatest + latest.isPresent shouldBe true + latest.get.id shouldBe testLatestId + latest.get.createTime shouldBe testCreateTime + latest.get.dataSource shouldBe testDataSourceName + latest.get.error shouldBe "" + } + + test("should not get index metadata log if not exist") { + val options = openSearchOptions + (DATA_SOURCE_NAME.key -> "non-exist-datasource") + val flintMetadataLogService = + new FlintOpenSearchMetadataLogService(new FlintOptions(options.asJava)) + val metadataLog = flintMetadataLogService.getIndexMetadataLog(testFlintIndex) + metadataLog.isPresent shouldBe false + } + + test("should update timestamp when record heartbeat") { + val refreshingLogEntry = flintMetadataLogEntry.copy(state = REFRESHING) + createLatestLogEntry(refreshingLogEntry) + val updateTimeBeforeHeartbeat = + latestLogEntry(testLatestId).get("lastUpdateTime").get.asInstanceOf[Long] + flintMetadataLogService.recordHeartbeat(testFlintIndex) + latestLogEntry(testLatestId) + .get("lastUpdateTime") + .get + .asInstanceOf[Long] should be > updateTimeBeforeHeartbeat + } + + test("should fail when record heartbeat if index not refreshing") { + createLatestLogEntry(flintMetadataLogEntry) + the[IllegalStateException] thrownBy { + flintMetadataLogService.recordHeartbeat(testFlintIndex) + } + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index 6eab292e2..1373654aa 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.sql.flint.config.FlintSparkConf.{REFRESH_POLICY, SCROLL_DURATION, SCROLL_SIZE} +import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, REFRESH_POLICY, SCROLL_DURATION, SCROLL_SIZE} class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with Matchers { @@ -30,12 +30,6 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M behavior of "Flint OpenSearch client" - it should "throw IllegalStateException if metadata log index doesn't exists" in { - the[IllegalStateException] thrownBy { - flintClient.startTransaction("test", "non-exist-index") - } - } - it should "create index successfully" in { val indexName = "test" val content = @@ -130,8 +124,8 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M val allMetadata = flintClient.getAllIndexMetadata("flint_*_index") allMetadata should have size 2 - allMetadata.forEach(metadata => metadata.getContent should not be empty) - allMetadata.forEach(metadata => metadata.indexSettings should not be empty) + allMetadata.values.forEach(metadata => metadata.getContent should not be empty) + allMetadata.values.forEach(metadata => metadata.indexSettings should not be empty) } it should "convert index name to all lowercase" in { diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala index 7dc5c695c..6da232389 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintTransactionITSuite.scala @@ -11,15 +11,13 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.json4s.{Formats, NoTypeHints} import org.json4s.native.{JsonMethods, Serialization} -import org.mockito.Mockito.when import org.opensearch.flint.OpenSearchTransactionSuite -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState._ -import org.opensearch.flint.core.storage.FlintOpenSearchClient +import org.opensearch.flint.core.metadata.log.FlintMetadataLogService +import org.opensearch.flint.core.storage.FlintOpenSearchMetadataLogService import org.opensearch.index.seqno.SequenceNumbers.{UNASSIGNED_PRIMARY_TERM, UNASSIGNED_SEQ_NO} import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.sql.flint.config.FlintSparkConf.DATA_SOURCE_NAME @@ -27,17 +25,18 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { val testFlintIndex = "flint_test_index" val testLatestId: String = Base64.getEncoder.encodeToString(testFlintIndex.getBytes) - var flintClient: FlintClient = _ + var flintMetadataLogService: FlintMetadataLogService = _ override def beforeAll(): Unit = { super.beforeAll() val options = openSearchOptions + (DATA_SOURCE_NAME.key -> testDataSourceName) - flintClient = new FlintOpenSearchClient(new FlintOptions(options.asJava)) + flintMetadataLogService = new FlintOpenSearchMetadataLogService( + new FlintOptions(options.asJava)) } test("empty metadata log entry content") { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => { latest.id shouldBe testLatestId latest.state shouldBe EMPTY @@ -50,45 +49,6 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { .commit(_ => {}) } - test("get index metadata with latest log entry") { - val testCreateTime = 1234567890123L - val flintMetadataLogEntry = FlintMetadataLogEntry( - id = testLatestId, - seqNo = UNASSIGNED_SEQ_NO, - primaryTerm = UNASSIGNED_PRIMARY_TERM, - createTime = testCreateTime, - state = ACTIVE, - dataSource = testDataSourceName, - error = "") - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn("{}") - when(metadata.indexSettings).thenReturn(None) - when(metadata.latestLogEntry).thenReturn(Some(flintMetadataLogEntry)) - - flintClient.createIndex(testFlintIndex, metadata) - createLatestLogEntry(flintMetadataLogEntry) - - val latest = flintClient.getIndexMetadata(testFlintIndex).latestLogEntry - latest.isDefined shouldBe true - latest.get.id shouldBe testLatestId - latest.get.createTime shouldBe testCreateTime - latest.get.dataSource shouldBe testDataSourceName - latest.get.error shouldBe "" - - deleteTestIndex(testFlintIndex) - } - - test("should get empty metadata log entry") { - val metadata = mock[FlintMetadata] - when(metadata.getContent).thenReturn("{}") - when(metadata.indexSettings).thenReturn(None) - flintClient.createIndex(testFlintIndex, metadata) - - flintClient.getIndexMetadata(testFlintIndex).latestLogEntry shouldBe empty - - deleteTestIndex(testFlintIndex) - } - test("should preserve original values when transition") { val testCreateTime = 1234567890123L createLatestLogEntry( @@ -101,8 +61,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { dataSource = testDataSourceName, error = "")) - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => { latest.id shouldBe testLatestId latest.createTime shouldBe testCreateTime @@ -125,8 +85,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { } test("should transit from initial to final log if initial log is empty") { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => { latest.state shouldBe EMPTY true @@ -139,8 +99,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { } test("should transit from initial to final log directly if no transient log") { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .finalLog(latest => latest.copy(state = ACTIVE)) .commit(_ => latestLogEntry(testLatestId) should contain("state" -> "empty")) @@ -161,8 +121,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { dataSource = testDataSourceName, error = "")) - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => { latest.state shouldBe ACTIVE true @@ -176,8 +136,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should exit if initial log entry doesn't meet precondition") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => false) .transientLog(latest => latest.copy(state = ACTIVE)) .finalLog(latest => latest) @@ -190,8 +150,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should fail if initial log entry updated by others when updating transient log entry") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .transientLog(latest => { // This update will happen first and thus cause version conflict as expected @@ -206,8 +166,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should fail if transient log entry updated by others when updating final log entry") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .transientLog(latest => { @@ -224,8 +184,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should rollback to initial log if transaction operation failed") { // Use create index scenario in this test case the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .transientLog(latest => latest.copy(state = CREATING)) .finalLog(latest => latest.copy(state = ACTIVE)) @@ -249,8 +209,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { error = "")) the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .transientLog(latest => latest.copy(state = REFRESHING)) .finalLog(_ => throw new RuntimeException("Mock final log error")) @@ -265,8 +225,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { "should not necessarily rollback if transaction operation failed but no transient action") { // Use create index scenario in this test case the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(_ => true) .finalLog(latest => latest.copy(state = ACTIVE)) .commit(_ => throw new RuntimeException("Mock operation error")) @@ -278,8 +238,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("forceInit translog, even index is deleted before startTransaction") { deleteIndex(testMetaLogIndex) - flintClient - .startTransaction(testFlintIndex, testDataSourceName, true) + flintMetadataLogService + .startTransaction(testFlintIndex, true) .initialLog(latest => { latest.id shouldBe testLatestId latest.state shouldBe EMPTY @@ -298,8 +258,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should fail if index is deleted before initial operation") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => { deleteIndex(testMetaLogIndex) true @@ -312,8 +272,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should fail if index is deleted before transient operation") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => true) .transientLog(latest => { deleteIndex(testMetaLogIndex) @@ -326,8 +286,8 @@ class FlintTransactionITSuite extends OpenSearchTransactionSuite with Matchers { test("should fail if index is deleted before final operation") { the[IllegalStateException] thrownBy { - flintClient - .startTransaction(testFlintIndex, testDataSourceName) + flintMetadataLogService + .startTransaction(testFlintIndex) .initialLog(latest => true) .transientLog(latest => { latest.copy(state = CREATING) }) .finalLog(latest => { diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index 3b317a0fe..fa7f75b81 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -10,15 +10,15 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.action.get.{GetRequest, GetResponse} import org.opensearch.client.RequestOptions import org.opensearch.flint.OpenSearchTransactionSuite -import org.opensearch.flint.app.FlintInstance import org.opensearch.flint.core.storage.{FlintOpenSearchClient, OpenSearchUpdater} +import org.opensearch.flint.data.InteractiveSession import org.scalatest.matchers.should.Matchers class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { val sessionId = "sessionId" val timestamp = 1700090926955L val flintJob = - new FlintInstance( + new InteractiveSession( "applicationId", "jobId", sessionId, @@ -38,7 +38,7 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { } test("upsert flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe timestamp } @@ -46,15 +46,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("update flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val newTimestamp = 1700090926956L - updater.update(sessionId, FlintInstance.serialize(flintJob, newTimestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, newTimestamp)) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } @@ -62,25 +62,25 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { deleteIndex(testMetaLogIndex) the[IllegalStateException] thrownBy { - updater.update(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.update(sessionId, InteractiveSession.serialize(flintJob, timestamp)) } } test("updateIf flintJob should success") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) val newTimestamp = 1700090926956L updater.updateIf( sessionId, - FlintInstance.serialize(latest, newTimestamp), + InteractiveSession.serialize(latest, newTimestamp), resp.getSeqNo, resp.getPrimaryTerm) getFlintInstance(sessionId)._2.lastUpdateTime shouldBe newTimestamp } test("index is deleted when updateIf flintJob should throw IllegalStateException") { - updater.upsert(sessionId, FlintInstance.serialize(flintJob, timestamp)) + updater.upsert(sessionId, InteractiveSession.serialize(flintJob, timestamp)) val (resp, latest) = getFlintInstance(sessionId) deleteIndex(testMetaLogIndex) @@ -88,15 +88,15 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { the[IllegalStateException] thrownBy { updater.updateIf( sessionId, - FlintInstance.serialize(latest, timestamp), + InteractiveSession.serialize(latest, timestamp), resp.getSeqNo, resp.getPrimaryTerm) } } - def getFlintInstance(docId: String): (GetResponse, FlintInstance) = { + def getFlintInstance(docId: String): (GetResponse, InteractiveSession) = { val response = openSearchClient.get(new GetRequest(testMetaLogIndex, docId), RequestOptions.DEFAULT) - (response, FlintInstance.deserializeFromMap(response.getSourceAsMap)) + (response, InteractiveSession.deserializeFromMap(response.getSourceAsMap)) } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 2627ed964..1e2d68b8e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -188,6 +188,19 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc latestLog should contain("state" -> "failed") } + test( + "await monitor terminated with streaming job exit early should update index state to failed") { + // Terminate streaming job intentionally before await + spark.streams.active.find(_.name == testFlintIndex).get.stop() + + // Await until streaming job terminated + flint.flintIndexMonitor.awaitMonitor() + + // Assert index state is active now + val latestLog = latestLogEntry(testLatestId) + latestLog should contain("state" -> "failed") + } + private def getLatestTimestamp: (Long, Long) = { val latest = latestLogEntry(testLatestId) (latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long]) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 8dfde3439..3a17cb8b1 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -16,7 +16,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName -import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.must.Matchers.{defined, have} import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.apache.spark.sql.Row @@ -251,6 +251,26 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { metadata.indexedColumns.map(_.asScala("columnName")) shouldBe Seq("start.time", "count") } + Seq( + s"SELECT name, name FROM $testTable", + s"SELECT name AS dup_col, age AS dup_col FROM $testTable") + .foreach { query => + test(s"should fail to create materialized view if duplicate columns in $query") { + the[IllegalArgumentException] thrownBy { + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $query + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}' + | ) + |""".stripMargin) + } + } should have message "requirement failed: Duplicate columns found in materialized view query output" + } + } + test("show all materialized views in catalog and database") { // Show in catalog flint.materializedView().name("spark_catalog.default.mv1").query(testQuery).create() diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 36432f016..8cad8844b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,13 +18,13 @@ import com.codahale.metrics.Timer import org.json4s.native.Serialization import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings -import org.opensearch.flint.app.{FlintCommand, FlintInstance} -import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.flint.data.{FlintStatement, InteractiveSession} +import org.opensearch.flint.data.InteractiveSession.formats import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf @@ -57,8 +57,8 @@ object FlintREPL extends Logging with FlintJobExecutor { @volatile var earlyExitFlag: Boolean = false - def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { - updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) + def updateSessionIndex(flintStatement: FlintStatement, updater: OpenSearchUpdater): Unit = { + updater.update(flintStatement.statementId, FlintStatement.serialize(flintStatement)) } private val sessionRunningCount = new AtomicInteger(0) @@ -411,7 +411,7 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds: Seq[String] = Seq.empty[String]): Unit = { val includeJobId = !excludeJobIds.isEmpty && !excludeJobIds.contains(jobId) val currentTime = currentTimeProvider.currentEpochMillis() - val flintJob = new FlintInstance( + val flintJob = new InteractiveSession( applicationId, jobId, sessionId, @@ -421,9 +421,9 @@ object FlintREPL extends Logging with FlintJobExecutor { excludeJobIds) val serializedFlintInstance = if (includeJobId) { - FlintInstance.serialize(flintJob, currentTime, true) + InteractiveSession.serialize(flintJob, currentTime, true) } else { - FlintInstance.serializeWithoutJobId(flintJob, currentTime) + InteractiveSession.serializeWithoutJobId(flintJob, currentTime) } flintSessionIndexUpdater.upsert(sessionId, serializedFlintInstance) logInfo( @@ -456,11 +456,11 @@ object FlintREPL extends Logging with FlintJobExecutor { private def getExistingFlintInstance( osClient: OSClient, sessionIndex: String, - sessionId: String): Option[FlintInstance] = Try( + sessionId: String): Option[InteractiveSession] = Try( osClient.getDoc(sessionIndex, sessionId)) match { case Success(getResponse) if getResponse.isExists() => Option(getResponse.getSourceAsMap) - .map(FlintInstance.deserializeFromMap) + .map(InteractiveSession.deserializeFromMap) case Failure(exception) => CustomLogging.logError( s"Failed to retrieve existing FlintInstance: ${exception.getMessage}", @@ -474,7 +474,7 @@ object FlintREPL extends Logging with FlintJobExecutor { jobId: String, sessionId: String, jobStartTime: Long, - errorMessage: String): FlintInstance = new FlintInstance( + errorMessage: String): InteractiveSession = new InteractiveSession( applicationId, jobId, sessionId, @@ -484,19 +484,19 @@ object FlintREPL extends Logging with FlintJobExecutor { error = Some(errorMessage)) private def updateFlintInstance( - flintInstance: FlintInstance, + flintInstance: InteractiveSession, flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String): Unit = { val currentTime = currentTimeProvider.currentEpochMillis() flintSessionIndexUpdater.upsert( sessionId, - FlintInstance.serializeWithoutJobId(flintInstance, currentTime)) + InteractiveSession.serializeWithoutJobId(flintInstance, currentTime)) } /** - * handling the case where a command's execution fails, updates the flintCommand with the error - * and failure status, and then write the result to result index. Thus, an error is written to - * both result index or statement model in request index + * handling the case where a command's execution fails, updates the flintStatement with the + * error and failure status, and then write the result to result index. Thus, an error is + * written to both result index or statement model in request index * * @param spark * spark session @@ -504,7 +504,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * data source * @param error * error message - * @param flintCommand + * @param flintStatement * flint command * @param sessionId * session id @@ -517,26 +517,26 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, dataSource: String, error: String, - flintCommand: FlintCommand, + flintStatement: FlintStatement, sessionId: String, startTime: Long): DataFrame = { - flintCommand.fail() - flintCommand.error = Some(error) + flintStatement.fail() + flintStatement.error = Some(error) super.getFailedData( spark, dataSource, error, - flintCommand.queryId, - flintCommand.query, + flintStatement.queryId, + flintStatement.query, sessionId, startTime, currentTimeProvider) } - def processQueryException(ex: Exception, flintCommand: FlintCommand): String = { + def processQueryException(ex: Exception, flintStatement: FlintStatement): String = { val error = super.processQueryException(ex) - flintCommand.fail() - flintCommand.error = Some(error) + flintStatement.fail() + flintStatement.error = Some(error) error } @@ -570,12 +570,12 @@ object FlintREPL extends Logging with FlintJobExecutor { } else { val statementTimerContext = getTimerContext( MetricConstants.STATEMENT_PROCESSING_TIME_METRIC) - val flintCommand = processCommandInitiation(flintReader, flintSessionIndexUpdater) + val flintStatement = processCommandInitiation(flintReader, flintSessionIndexUpdater) val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( recordedVerificationResult, spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -587,7 +587,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = returnedVerificationResult finalizeCommand( dataToWrite, - flintCommand, + flintStatement, resultIndex, flintSessionIndexUpdater, osClient, @@ -606,7 +606,7 @@ object FlintREPL extends Logging with FlintJobExecutor { * * @param dataToWrite * data to write - * @param flintCommand + * @param flintStatement * flint command * @param resultIndex * result index @@ -615,28 +615,28 @@ object FlintREPL extends Logging with FlintJobExecutor { */ private def finalizeCommand( dataToWrite: Option[DataFrame], - flintCommand: FlintCommand, + flintStatement: FlintStatement, resultIndex: String, flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, statementTimerContext: Timer.Context): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - if (flintCommand.isRunning() || flintCommand.isWaiting()) { + if (flintStatement.isRunning || flintStatement.isWaiting) { // we have set failed state in exception handling - flintCommand.complete() + flintStatement.complete() } - updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementTimerContext) + updateSessionIndex(flintStatement, flintSessionIndexUpdater) + recordStatementStateChange(flintStatement, statementTimerContext) } catch { // e.g., maybe due to authentication service connection issue // or invalid catalog (e.g., we are operating on data not defined in provided data source) case e: Exception => - val error = s"""Fail to write result of ${flintCommand}, cause: ${e.getMessage}""" + val error = s"""Fail to write result of ${flintStatement}, cause: ${e.getMessage}""" CustomLogging.logError(error, e) - flintCommand.fail() - updateSessionIndex(flintCommand, flintSessionIndexUpdater) - recordStatementStateChange(flintCommand, statementTimerContext) + flintStatement.fail() + updateSessionIndex(flintStatement, flintSessionIndexUpdater) + recordStatementStateChange(flintStatement, statementTimerContext) } } @@ -644,7 +644,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, dataSource: String, error: String, - flintCommand: FlintCommand, + flintStatement: FlintStatement, sessionId: String, startTime: Long): Option[DataFrame] = { /* @@ -659,20 +659,20 @@ object FlintREPL extends Logging with FlintJobExecutor { * of Spark jobs. In the context of Spark SQL, this typically happens when we perform * actions that require the computation of results that need to be collected or stored. */ - spark.sparkContext.cancelJobGroup(flintCommand.queryId) + spark.sparkContext.cancelJobGroup(flintStatement.queryId) Some( handleCommandFailureAndGetFailedData( spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } def executeAndHandle( spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -683,7 +683,7 @@ object FlintREPL extends Logging with FlintJobExecutor { Some( executeQueryAsync( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -692,17 +692,17 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis)) } catch { case e: TimeoutException => - val error = s"Executing ${flintCommand.query} timed out" + val error = s"Executing ${flintStatement.query} timed out" CustomLogging.logError(error, e) - handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) + handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) case e: Exception => - val error = processQueryException(e, flintCommand) + val error = processQueryException(e, flintStatement) Some( handleCommandFailureAndGetFailedData( spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -711,14 +711,14 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processStatementOnVerification( recordedVerificationResult: VerificationResult, spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, futureMappingCheck: Future[Either[String, Unit]], resultIndex: String, queryExecutionTimeout: Duration, - queryWaitTimeMillis: Long): (Option[DataFrame], VerificationResult) = { + queryWaitTimeMillis: Long) = { val startTime: Long = currentTimeProvider.currentEpochMillis() var verificationResult = recordedVerificationResult var dataToWrite: Option[DataFrame] = None @@ -730,7 +730,7 @@ object FlintREPL extends Logging with FlintJobExecutor { case Right(_) => dataToWrite = executeAndHandle( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -745,7 +745,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -754,7 +754,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val error = s"Getting the mapping of index $resultIndex timed out" CustomLogging.logError(error, e) dataToWrite = - handleCommandTimeout(spark, dataSource, error, flintCommand, sessionId, startTime) + handleCommandTimeout(spark, dataSource, error, flintStatement, sessionId, startTime) case NonFatal(e) => val error = s"An unexpected error occurred: ${e.getMessage}" CustomLogging.logError(error, e) @@ -763,7 +763,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, error, - flintCommand, + flintStatement, sessionId, startTime)) } @@ -773,13 +773,13 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, dataSource, err, - flintCommand, + flintStatement, sessionId, startTime)) case VerifiedWithoutError => dataToWrite = executeAndHandle( spark, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -788,13 +788,13 @@ object FlintREPL extends Logging with FlintJobExecutor { queryWaitTimeMillis) } - logInfo(s"command complete: $flintCommand") + logInfo(s"command complete: $flintStatement") (dataToWrite, verificationResult) } def executeQueryAsync( spark: SparkSession, - flintCommand: FlintCommand, + flintStatement: FlintStatement, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -802,21 +802,21 @@ object FlintREPL extends Logging with FlintJobExecutor { queryExecutionTimeOut: Duration, queryWaitTimeMillis: Long): DataFrame = { if (currentTimeProvider - .currentEpochMillis() - flintCommand.submitTime > queryWaitTimeMillis) { + .currentEpochMillis() - flintStatement.submitTime > queryWaitTimeMillis) { handleCommandFailureAndGetFailedData( spark, dataSource, "wait timeout", - flintCommand, + flintStatement, sessionId, startTime) } else { val futureQueryExecution = Future { executeQuery( spark, - flintCommand.query, + flintStatement.query, dataSource, - flintCommand.queryId, + flintStatement.queryId, sessionId, false) }(executionContext) @@ -827,16 +827,16 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def processCommandInitiation( flintReader: FlintReader, - flintSessionIndexUpdater: OpenSearchUpdater): FlintCommand = { + flintSessionIndexUpdater: OpenSearchUpdater): FlintStatement = { val command = flintReader.next() logDebug(s"raw command: $command") - val flintCommand = FlintCommand.deserialize(command) - logDebug(s"command: $flintCommand") - flintCommand.running() - logDebug(s"command running: $flintCommand") - updateSessionIndex(flintCommand, flintSessionIndexUpdater) + val flintStatement = FlintStatement.deserialize(command) + logDebug(s"command: $flintStatement") + flintStatement.running() + logDebug(s"command running: $flintStatement") + updateSessionIndex(flintStatement, flintSessionIndexUpdater) statementRunningCount.incrementAndGet() - flintCommand + flintStatement } private def createQueryReader( @@ -932,11 +932,11 @@ object FlintREPL extends Logging with FlintJobExecutor { flintSessionIndexUpdater: OpenSearchUpdater, sessionId: String, sessionTimerContext: Timer.Context): Unit = { - val flintInstance = FlintInstance.deserializeFromMap(source) + val flintInstance = InteractiveSession.deserializeFromMap(source) flintInstance.state = "dead" flintSessionIndexUpdater.updateIf( sessionId, - FlintInstance.serializeWithoutJobId( + InteractiveSession.serializeWithoutJobId( flintInstance, currentTimeProvider.currentEpochMillis()), getResponse.getSeqNo, @@ -1131,15 +1131,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def recordStatementStateChange( - flintCommand: FlintCommand, + flintStatement: FlintStatement, statementTimerContext: Timer.Context): Unit = { stopTimer(statementTimerContext) if (statementRunningCount.get() > 0) { statementRunningCount.decrementAndGet() } - if (flintCommand.isComplete()) { + if (flintStatement.isComplete) { incrementCounter(MetricConstants.STATEMENT_SUCCESS_METRIC) - } else if (flintCommand.isFailed()) { + } else if (flintStatement.isFailed) { incrementCounter(MetricConstants.STATEMENT_FAILED_METRIC) } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 3582bcf09..f315dc836 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -88,11 +88,11 @@ case class JobOperator( } try { - // Wait for streaming job complete if no error and there is streaming job running - if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { + // Wait for streaming job complete if no error + if (!exceptionThrown && streaming) { // Clean Spark shuffle data after each microBatch. spark.streams.addListener(new ShuffleCleaner(spark)) - // Await streaming job thread to finish before the main thread terminates + // Await index monitor before the main thread terminates new FlintSpark(spark).flintIndexMonitor.awaitMonitor() } else { logInfo(s""" diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 1a6aea4f4..546cd8e97 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse -import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.flint.data.FlintStatement import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar @@ -245,7 +245,7 @@ class FlintREPLTest val expected = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) - val flintCommand = new FlintCommand("failed", "select 1", "30", "10", currentTime, None) + val flintStatement = new FlintStatement("failed", "select 1", "30", "10", currentTime, None) try { FlintREPL.currentTimeProvider = new MockTimeProvider(currentTime) @@ -256,12 +256,12 @@ class FlintREPLTest spark, dataSourceName, error, - flintCommand, + flintStatement, "20", currentTime - queryRunTime) assertEqualDataframe(expected, result) - assert("failed" == flintCommand.state) - assert(error == flintCommand.error.get) + assert("failed" == flintStatement.state) + assert(error == flintStatement.error.get) } finally { spark.close() FlintREPL.currentTimeProvider = new RealTimeProvider() @@ -448,18 +448,18 @@ class FlintREPLTest exception.setErrorCode("AccessDeniedException") exception.setServiceName("AWSGlue") - val mockFlintCommand = mock[FlintCommand] + val mockFlintStatement = mock[FlintStatement] val expectedError = ( """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + """"ErrorSource":"AWSGlue","StatusCode":"400"}""" ) - val result = FlintREPL.processQueryException(exception, mockFlintCommand) + val result = FlintREPL.processQueryException(exception, mockFlintStatement) result shouldEqual expectedError - verify(mockFlintCommand).fail() - verify(mockFlintCommand).error = Some(expectedError) + verify(mockFlintStatement).fail() + verify(mockFlintStatement).error = Some(expectedError) assert(result == expectedError) } @@ -574,7 +574,7 @@ class FlintREPLTest test("executeAndHandle should handle TimeoutException properly") { val mockSparkSession = mock[SparkSession] - val mockFlintCommand = mock[FlintCommand] + val mockFlintStatement = mock[FlintStatement] val mockConf = mock[RuntimeConfig] when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) @@ -588,8 +588,8 @@ class FlintREPLTest val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] - when(mockFlintCommand.query).thenReturn("SELECT 1") - when(mockFlintCommand.submitTime).thenReturn(Instant.now().toEpochMilli()) + when(mockFlintStatement.query).thenReturn("SELECT 1") + when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) // When the `sql` method is called, execute the custom Answer that introduces a delay when(mockSparkSession.sql(any[String])).thenAnswer(new Answer[DataFrame] { override def answer(invocation: InvocationOnMock): DataFrame = { @@ -610,7 +610,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - mockFlintCommand, + mockFlintStatement, dataSource, sessionId, executionContext, @@ -631,8 +631,8 @@ class FlintREPLTest when(mockSparkSession.conf).thenReturn(mockConf) when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key)) .thenReturn(FlintSparkConf.JOB_TYPE.defaultValue.get) - val flintCommand = - new FlintCommand( + val flintStatement = + new FlintStatement( "Running", "select * from default.http_logs limit1 1", "10", @@ -661,7 +661,7 @@ class FlintREPLTest val result = FlintREPL.executeAndHandle( mockSparkSession, - flintCommand, + flintStatement, dataSource, sessionId, executionContext, @@ -671,8 +671,8 @@ class FlintREPLTest // Verify that ParseException was caught and handled result should not be None // or result.isDefined shouldBe true - flintCommand.error should not be None - flintCommand.error.get should include("Syntax error:") + flintStatement.error should not be None + flintStatement.error.get should include("Syntax error:") } finally threadPool.shutdown() }