diff --git a/.github/workflows/test-and-build-workflow.yml b/.github/workflows/test-and-build-workflow.yml index 7cae33f76..3c06acb61 100644 --- a/.github/workflows/test-and-build-workflow.yml +++ b/.github/workflows/test-and-build-workflow.yml @@ -25,5 +25,8 @@ jobs: - name: Integ Test run: sbt integtest/test + - name: Unit Test + run: sbt test + - name: Style check run: sbt scalafmtCheckAll diff --git a/build.sbt b/build.sbt index bc018c265..9389384fd 100644 --- a/build.sbt +++ b/build.sbt @@ -58,7 +58,12 @@ lazy val flintCore = (project in file("flint-core")) "org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion exclude ("org.apache.logging.log4j", "log4j-api"), "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" - exclude ("com.fasterxml.jackson.core", "jackson-databind")), + exclude ("com.fasterxml.jackson.core", "jackson-databind"), + "org.scalactic" %% "scalactic" % "3.2.15" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", + "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test"), publish / skip := true) lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) diff --git a/docs/index.md b/docs/index.md index 44b0052b0..8afdc1fbc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,20 +32,17 @@ Currently, Flint metadata is only static configuration without version control a ```json { - "version": "0.1", - "indexConfig": { - "kind": "skipping", - "properties": { - "indexedColumns": [{ - "kind": "...", - "columnName": "...", - "columnType": "..." - }] - } - }, - "source": "alb_logs", - "state": "active", - "enabled": true + "version": "0.1.0", + "name": "...", + "kind": "skipping", + "source": "...", + "indexedColumns": [{ + "kind": "...", + "columnName": "...", + "columnType": "..." + }], + "options": { }, + "properties": { } } ``` @@ -199,6 +196,8 @@ User can provide the following options in `WITH` clause of create statement: + `checkpoint_location`: a string as the location path for incremental refresh job checkpoint. The location has to be a path in an HDFS compatible file system and only applicable when auto refresh enabled. If unspecified, temporary checkpoint directory will be used and may result in checkpoint data lost upon restart. + `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied. +Note that the index option name is case-sensitive. + ```sql WITH ( auto_refresh = [true|false], diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index b4271360c..d50c0002e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -6,6 +6,8 @@ package org.opensearch.flint.core; import java.util.List; + +import org.opensearch.client.RestHighLevelClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.storage.FlintReader; import org.opensearch.flint.core.storage.FlintWriter; @@ -71,4 +73,10 @@ public interface FlintClient { * @return {@link FlintWriter} */ FlintWriter createWriter(String indexName); + + /** + * Create {@link RestHighLevelClient}. + * @return {@link RestHighLevelClient} + */ + public RestHighLevelClient createClient(); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintJsonHelper.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintJsonHelper.scala new file mode 100644 index 000000000..4c1991edc --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintJsonHelper.scala @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata + +import java.nio.charset.StandardCharsets.UTF_8 + +import org.opensearch.common.bytes.BytesReference +import org.opensearch.common.xcontent._ +import org.opensearch.common.xcontent.json.JsonXContent + +/** + * JSON parsing and building helper. + */ +object FlintJsonHelper { + + /** + * Build JSON by creating JSON builder and pass it to the given function. + * + * @param block + * building logic with JSON builder + * @return + * JSON string + */ + def buildJson(block: XContentBuilder => Unit): String = { + val builder: XContentBuilder = XContentFactory.jsonBuilder + builder.startObject + block(builder) + builder.endObject() + BytesReference.bytes(builder).utf8ToString + } + + /** + * Add an object field of the name to the JSON builder and continue building it with the given + * function. + * + * @param builder + * JSON builder + * @param name + * field name + * @param block + * building logic on the JSON field + */ + def objectField(builder: XContentBuilder, name: String)(block: => Unit): Unit = { + builder.startObject(name) + block + builder.endObject() + } + + /** + * Add an optional object field of the name to the JSON builder. Add an empty object field if + * the value is null. + * + * @param builder + * JSON builder + * @param name + * field name + * @param value + * field value + */ + def optionalObjectField(builder: XContentBuilder, name: String, value: AnyRef): Unit = { + if (value == null) { + builder.startObject(name).endObject() + } else { + builder.field(name, value) + } + } + + /** + * Create a XContent JSON parser on the given JSON string. + * + * @param json + * JSON string + * @return + * JSON parser + */ + def createJsonParser(json: String): XContentParser = { + JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + json.getBytes(UTF_8)) + } + + /** + * Parse the given JSON string by creating JSON parser and pass it to the parsing function. + * + * @param json + * JSON string + * @param block + * parsing logic with the parser + */ + def parseJson(json: String)(block: (XContentParser, String) => Unit): Unit = { + val parser = createJsonParser(json) + + // Read first root object token and start parsing + parser.nextToken() + parseObjectField(parser)(block) + } + + /** + * Parse each inner field in the object field with the given parsing function. + * + * @param parser + * JSON parser + * @param block + * parsing logic on each inner field + */ + def parseObjectField(parser: XContentParser)(block: (XContentParser, String) => Unit): Unit = { + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + val fieldName: String = parser.currentName() + parser.nextToken() // Move to the field value + + block(parser, fieldName) + } + } + + /** + * Parse each inner field in the array field. + * + * @param parser + * JSON parser + * @param block + * parsing logic on each inner field + */ + def parseArrayField(parser: XContentParser)(block: => Unit): Unit = { + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + block + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.java b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.java deleted file mode 100644 index 6773c3897..000000000 --- a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.core.metadata; - -/** - * Flint metadata follows Flint index specification and defines metadata - * for a Flint index regardless of query engine integration and storage. - */ -public class FlintMetadata { - - // TODO: define metadata format and create strong-typed class - private final String content; - - // TODO: piggyback optional index settings and will refactor as above - private String indexSettings; - - public FlintMetadata(String content) { - this.content = content; - } - - public FlintMetadata(String content, String indexSettings) { - this.content = content; - this.indexSettings = indexSettings; - } - - public String getContent() { - return content; - } - - public String getIndexSettings() { - return indexSettings; - } - - public void setIndexSettings(String indexSettings) { - this.indexSettings = indexSettings; - } -} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala new file mode 100644 index 000000000..ea0fb0f98 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/metadata/FlintMetadata.scala @@ -0,0 +1,232 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata + +import java.util + +import org.opensearch.flint.core.FlintVersion +import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.core.metadata.FlintJsonHelper._ + +/** + * Flint metadata follows Flint index specification and defines metadata for a Flint index + * regardless of query engine integration and storage. + */ +case class FlintMetadata( + /** Flint spec version */ + version: FlintVersion, + /** Flint index name */ + name: String, + /** Flint index kind */ + kind: String, + /** Flint index source that index data derived from */ + source: String, + /** Flint indexed column list */ + indexedColumns: Array[util.Map[String, AnyRef]] = Array(), + /** Flint indexed options. TODO: move to properties? */ + options: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef], + /** Flint index properties for any custom fields */ + properties: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef], + /** Flint index schema */ + schema: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef], + /** Optional Flint index settings. TODO: move elsewhere? */ + indexSettings: Option[String]) { + + require(version != null, "version is required") + require(name != null, "name is required") + require(kind != null, "kind is required") + require(source != null, "source is required") + + /** + * Generate JSON content as index metadata. + * + * @return + * JSON content + */ + def getContent: String = { + try { + buildJson(builder => { + // Add _meta field + objectField(builder, "_meta") { + builder + .field("version", version.version) + .field("name", name) + .field("kind", kind) + .field("source", source) + .field("indexedColumns", indexedColumns) + + optionalObjectField(builder, "options", options) + optionalObjectField(builder, "properties", properties) + } + + // Add properties (schema) field + builder.field("properties", schema) + }) + } catch { + case e: Exception => + throw new IllegalStateException("Failed to jsonify Flint metadata", e) + } + } +} + +object FlintMetadata { + + /** + * Construct Flint metadata with JSON content and index settings. + * + * @param content + * JSON content + * @param settings + * index settings + * @return + * Flint metadata + */ + def apply(content: String, settings: String): FlintMetadata = { + val metadata = FlintMetadata(content) + metadata.copy(indexSettings = Option(settings)) + } + + /** + * Parse the given JSON content and construct Flint metadata class. + * + * @param content + * JSON content + * @return + * Flint metadata + */ + def apply(content: String): FlintMetadata = { + try { + val builder = new FlintMetadata.Builder() + parseJson(content) { (parser, fieldName) => + { + fieldName match { + case "_meta" => + parseObjectField(parser) { (parser, innerFieldName) => + { + innerFieldName match { + case "version" => builder.version(FlintVersion.apply(parser.text())) + case "name" => builder.name(parser.text()) + case "kind" => builder.kind(parser.text()) + case "source" => builder.source(parser.text()) + case "indexedColumns" => + parseArrayField(parser) { + builder.addIndexedColumn(parser.map()) + } + case "options" => builder.options(parser.map()) + case "properties" => builder.properties(parser.map()) + case _ => // Handle other fields as needed + } + } + } + case "properties" => + builder.schema(parser.map()) + } + } + } + builder.build() + } catch { + case e: Exception => + throw new IllegalStateException("Failed to parse metadata JSON", e) + } + } + + def builder(): FlintMetadata.Builder = new Builder + + /** + * Flint index metadata builder that can be extended by subclass to provide more custom build + * method. + */ + class Builder { + private var version: FlintVersion = FlintVersion.current() + private var name: String = "" + private var kind: String = "" + private var source: String = "" + private var options: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef]() + private var indexedColumns: Array[util.Map[String, AnyRef]] = Array() + private var properties: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef]() + private var schema: util.Map[String, AnyRef] = new util.HashMap[String, AnyRef]() + private var indexSettings: Option[String] = None + + def version(version: FlintVersion): this.type = { + this.version = version + this + } + + def name(name: String): this.type = { + this.name = name + this + } + + def kind(kind: String): this.type = { + this.kind = kind + this + } + + def source(source: String): this.type = { + this.source = source + this + } + + def options(options: util.Map[String, AnyRef]): this.type = { + this.options = options + this + } + + def indexedColumns(indexedColumns: Array[util.Map[String, AnyRef]]): this.type = { + this.indexedColumns = indexedColumns + this + } + + def addIndexedColumn(indexCol: util.Map[String, AnyRef]): this.type = { + indexedColumns = indexedColumns :+ indexCol + this + } + + def properties(properties: util.Map[String, AnyRef]): this.type = { + this.properties = properties + this + } + + def addProperty(key: String, value: AnyRef): this.type = { + properties.put(key, value) + this + } + + def schema(schema: util.Map[String, AnyRef]): this.type = { + this.schema = schema + this + } + + def schema(schema: String): this.type = { + parseJson(schema) { (parser, fieldName) => + fieldName match { + case "properties" => this.schema = parser.map() + case _ => // do nothing + } + } + this + } + + def indexSettings(indexSettings: String): this.type = { + this.indexSettings = Option(indexSettings) + this + } + + // Build method to create the FlintMetadata instance + def build(): FlintMetadata = { + FlintMetadata( + if (version == null) current() else version, + name, + kind, + source, + indexedColumns, + options, + properties, + schema, + indexSettings) + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index b973385d8..ff2761856 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -47,6 +47,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; +import scala.Option; /** * Flint client implementation for OpenSearch storage. @@ -73,8 +74,9 @@ public FlintOpenSearchClient(FlintOptions options) { CreateIndexRequest request = new CreateIndexRequest(osIndexName); request.mapping(metadata.getContent(), XContentType.JSON); - if (metadata.getIndexSettings() != null) { - request.settings(metadata.getIndexSettings(), XContentType.JSON); + Option settings = metadata.indexSettings(); + if (settings.isDefined()) { + request.settings(settings.get(), XContentType.JSON); } client.indices().create(request, RequestOptions.DEFAULT); } catch (Exception e) { @@ -98,7 +100,7 @@ public FlintOpenSearchClient(FlintOptions options) { GetIndexResponse response = client.indices().get(request, RequestOptions.DEFAULT); return Arrays.stream(response.getIndices()) - .map(index -> new FlintMetadata( + .map(index -> FlintMetadata.apply( response.getMappings().get(index).source().toString(), response.getSettings().get(index).toString())) .collect(Collectors.toList()); @@ -115,7 +117,7 @@ public FlintOpenSearchClient(FlintOptions options) { MappingMetadata mapping = response.getMappings().get(osIndexName); Settings settings = response.getSettings().get(osIndexName); - return new FlintMetadata(mapping.source().string(), settings.toString()); + return FlintMetadata.apply(mapping.source().string(), settings.toString()); } catch (Exception e) { throw new IllegalStateException("Failed to get Flint index metadata for " + osIndexName, e); } @@ -161,7 +163,7 @@ public FlintWriter createWriter(String indexName) { return new OpenSearchWriter(createClient(), toLowercase(indexName), options.getRefreshPolicy()); } - private RestHighLevelClient createClient() { + @Override public RestHighLevelClient createClient() { RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(options.getHost(), options.getPort(), options.getScheme())); diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala new file mode 100644 index 000000000..dc2f5fe6a --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/metadata/FlintMetadataSuite.scala @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metadata + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.opensearch.flint.core.FlintVersion.current +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class FlintMetadataSuite extends AnyFlatSpec with Matchers { + + /** Test Flint index meta JSON string */ + val testMetadataJson: String = s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "test_index", + | "kind": "test_kind", + | "source": "test_source_table", + | "indexedColumns": [ + | { + | "test_field": "spark_type" + | }], + | "options": {}, + | "properties": {} + | }, + | "properties": { + | "test_field": { + | "type": "os_type" + | } + | } + | } + |""".stripMargin + + val testIndexSettingsJson: String = + """ + | { "number_of_shards": 3 } + |""".stripMargin + + "constructor" should "deserialize the given JSON and assign parsed value to field" in { + val metadata = FlintMetadata(testMetadataJson, testIndexSettingsJson) + + metadata.version shouldBe current() + metadata.name shouldBe "test_index" + metadata.kind shouldBe "test_kind" + metadata.source shouldBe "test_source_table" + metadata.indexedColumns shouldBe Array(Map("test_field" -> "spark_type").asJava) + metadata.schema shouldBe Map("test_field" -> Map("type" -> "os_type").asJava).asJava + } + + "getContent" should "serialize all fields to JSON" in { + val builder = new FlintMetadata.Builder + builder.name("test_index") + builder.kind("test_kind") + builder.source("test_source_table") + builder.addIndexedColumn(Map[String, AnyRef]("test_field" -> "spark_type").asJava); + builder.schema("""{"properties": {"test_field": {"type": "os_type"}}}""") + + val metadata = builder.build() + metadata.getContent should matchJson(testMetadataJson) + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index b848f47b4..0bac6ac73 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -5,6 +5,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ /** @@ -12,6 +13,20 @@ import org.apache.spark.sql.connector.catalog._ */ package object flint { + /** + * Convert the given logical plan to Spark data frame. + * + * @param spark + * Spark session + * @param logicalPlan + * logical plan + * @return + * data frame + */ + def logicalPlanToDataFrame(spark: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, logicalPlan) + } + /** * Qualify a given table name. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 4a4885ecb..9c78a07f8 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -7,31 +7,22 @@ package org.opensearch.flint.spark import scala.collection.JavaConverters._ -import org.json4s.{Formats, JArray, NoTypeHints} -import org.json4s.JsonAST.{JField, JObject} -import org.json4s.native.JsonMethods.parse +import org.json4s.{Formats, NoTypeHints} import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex -import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} -import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy -import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy -import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.SaveMode._ import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN} -import org.apache.spark.sql.streaming.OutputMode.Append -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger} /** * Flint Spark integration API entrypoint. @@ -71,6 +62,16 @@ class FlintSpark(val spark: SparkSession) { new FlintSparkCoveringIndex.Builder(this) } + /** + * Create materialized view builder for creating mv with fluent API. + * + * @return + * mv builder + */ + def materializedView(): FlintSparkMaterializedView.Builder = { + new FlintSparkMaterializedView.Builder(this) + } + /** * Create the given index with metadata. * @@ -87,7 +88,6 @@ class FlintSpark(val spark: SparkSession) { } } else { val metadata = index.metadata() - index.options.indexSettings().foreach(metadata.setIndexSettings) flintClient.createIndex(indexName, metadata) } } @@ -105,12 +105,13 @@ class FlintSpark(val spark: SparkSession) { def refreshIndex(indexName: String, mode: RefreshMode): Option[String] = { val index = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) - val tableName = getSourceTableName(index) + val options = index.options + val tableName = index.metadata().source - // Write Flint index data to Flint data source (shared by both refresh modes for now) - def writeFlintIndex(df: DataFrame): Unit = { + // Batch refresh Flint index from the given source data frame + def batchRefresh(df: Option[DataFrame] = None): Unit = { index - .build(df) + .build(spark, df) .write .format(FLINT_DATASOURCE) .options(flintSparkConf.properties) @@ -122,36 +123,37 @@ class FlintSpark(val spark: SparkSession) { case FULL if isIncrementalRefreshing(indexName) => throw new IllegalStateException( s"Index $indexName is incremental refreshing and cannot be manual refreshed") + case FULL => - writeFlintIndex( - spark.read - .table(tableName)) + batchRefresh() None + // Flint index has specialized logic and capability for incremental refresh + case INCREMENTAL if index.isInstanceOf[StreamingRefresh] => + val job = + index + .asInstanceOf[StreamingRefresh] + .buildStream(spark) + .writeStream + .queryName(indexName) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .addIndexOptions(options) + .start(indexName) + Some(job.id.toString) + + // Otherwise, fall back to foreachBatch + batch refresh case INCREMENTAL => - // TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex - // once finalized. Otherwise, covering index/MV may have different logic. val job = spark.readStream .table(tableName) .writeStream .queryName(indexName) - .outputMode(Append()) - - index.options - .checkpointLocation() - .foreach(location => job.option("checkpointLocation", location)) - index.options - .refreshInterval() - .foreach(interval => job.trigger(Trigger.ProcessingTime(interval))) - - val jobId = - job - .foreachBatch { (batchDF: DataFrame, _: Long) => - writeFlintIndex(batchDF) - } - .start() - .id - Some(jobId.toString) + .addIndexOptions(options) + .foreachBatch { (batchDF: DataFrame, _: Long) => + batchRefresh(Some(batchDF)) + } + .start() + Some(job.id.toString) } } @@ -164,7 +166,10 @@ class FlintSpark(val spark: SparkSession) { * Flint index list */ def describeIndexes(indexNamePattern: String): Seq[FlintSparkIndex] = { - flintClient.getAllIndexMetadata(indexNamePattern).asScala.map(deserialize) + flintClient + .getAllIndexMetadata(indexNamePattern) + .asScala + .map(FlintSparkIndexFactory.create) } /** @@ -178,7 +183,8 @@ class FlintSpark(val spark: SparkSession) { def describeIndex(indexName: String): Option[FlintSparkIndex] = { if (flintClient.exists(indexName)) { val metadata = flintClient.getIndexMetadata(indexName) - Some(deserialize(metadata)) + val index = FlintSparkIndexFactory.create(metadata) + Some(index) } else { Option.empty } @@ -224,60 +230,29 @@ class FlintSpark(val spark: SparkSession) { } } - // TODO: Remove all parsing logic below once Flint spec finalized and FlintMetadata strong typed - private def getSourceTableName(index: FlintSparkIndex): String = { - val json = parse(index.metadata().getContent) - (json \ "_meta" \ "source").extract[String] - } + // Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API + private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) { - /* - * For now, deserialize skipping strategies out of Flint metadata json - * ex. extract Seq(Partition("year", "int"), ValueList("name")) from - * { "_meta": { "indexedColumns": [ {...partition...}, {...value list...} ] } } - * - */ - private def deserialize(metadata: FlintMetadata): FlintSparkIndex = { - val meta = parse(metadata.getContent) \ "_meta" - val indexName = (meta \ "name").extract[String] - val tableName = (meta \ "source").extract[String] - val indexType = (meta \ "kind").extract[String] - val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray] - val indexOptions = FlintSparkIndexOptions( - (meta \ "options") - .asInstanceOf[JObject] - .obj - .map { case JField(key, value) => - key -> value.values.toString - } - .toMap) + def addIndexOptions(options: FlintSparkIndexOptions): DataStreamWriter[Row] = { + dataStream + .addCheckpointLocation(options.checkpointLocation()) + .addRefreshInterval(options.refreshInterval()) + } - indexType match { - case SKIPPING_INDEX_TYPE => - val strategies = indexedColumns.arr.map { colInfo => - val skippingKind = SkippingKind.withName((colInfo \ "kind").extract[String]) - val columnName = (colInfo \ "columnName").extract[String] - val columnType = (colInfo \ "columnType").extract[String] + def addCheckpointLocation(checkpointLocation: Option[String]): DataStreamWriter[Row] = { + if (checkpointLocation.isDefined) { + dataStream.option("checkpointLocation", checkpointLocation.get) + } else { + dataStream + } + } - skippingKind match { - case PARTITION => - PartitionSkippingStrategy(columnName = columnName, columnType = columnType) - case VALUE_SET => - ValueSetSkippingStrategy(columnName = columnName, columnType = columnType) - case MIN_MAX => - MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) - case other => - throw new IllegalStateException(s"Unknown skipping strategy: $other") - } - } - new FlintSparkSkippingIndex(tableName, strategies, indexOptions) - case COVERING_INDEX_TYPE => - new FlintSparkCoveringIndex( - indexName, - tableName, - indexedColumns.arr.map { obj => - ((obj \ "columnName").extract[String], (obj \ "columnType").extract[String]) - }.toMap, - indexOptions) + def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = { + if (refreshInterval.isDefined) { + dataStream.trigger(Trigger.ProcessingTime(refreshInterval.get)) + } else { + dataStream + } } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index a19e603dc..0586bfc49 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -5,9 +5,13 @@ package org.opensearch.flint.spark +import scala.collection.JavaConverters.mapAsJavaMapConverter + import org.opensearch.flint.core.metadata.FlintMetadata -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.types.StructType /** * Flint index interface in Spark. @@ -40,16 +44,36 @@ trait FlintSparkIndex { * Build a data frame to represent index data computation logic. Upper level code decides how to * use this, ex. batch or streaming, fully or incremental refresh. * + * @param spark + * Spark session for implementation class to use as needed * @param df - * data frame to append building logic + * data frame to append building logic. If none, implementation class create source data frame + * on its own * @return * index building data frame */ - def build(df: DataFrame): DataFrame + def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } object FlintSparkIndex { + /** + * Interface indicates a Flint index has custom streaming refresh capability other than foreach + * batch streaming. + */ + trait StreamingRefresh { + + /** + * Build streaming refresh data frame. + * + * @param spark + * Spark session + * @return + * data frame represents streaming logic + */ + def buildStream(spark: SparkSession): DataFrame + } + /** * ID column name. */ @@ -81,4 +105,42 @@ object FlintSparkIndex { .map(value => key -> value)) .toMap } + + /** + * Create Flint metadata builder with common fields. + * + * @param index + * Flint index + * @return + * Flint metadata builder + */ + def metadataBuilder(index: FlintSparkIndex): FlintMetadata.Builder = { + val builder = new FlintMetadata.Builder() + // Common fields + builder.kind(index.kind) + builder.options(index.options.optionsWithDefault.mapValues(_.asInstanceOf[AnyRef]).asJava) + + // Index properties + val envs = populateEnvToMetadata + if (envs.nonEmpty) { + builder.addProperty("env", envs.asJava) + } + + // Optional index settings + val settings = index.options.indexSettings() + if (settings.isDefined) { + builder.indexSettings(settings.get) + } + builder + } + + def generateSchemaJSON(allFieldTypes: Map[String, String]): String = { + val catalogDDL = + allFieldTypes + .map { case (colName, colType) => s"$colName $colType not null" } + .mkString(",") + + val structType = StructType.fromDDL(catalogDDL) + FlintDataType.serialize(structType) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala new file mode 100644 index 000000000..cda11405c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy +import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy + +/** + * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for + * internal code use instead of user facing API. + */ +object FlintSparkIndexFactory { + + /** + * Creates Flint index from generic Flint metadata. + * + * @param metadata + * Flint metadata + * @return + * Flint index + */ + def create(metadata: FlintMetadata): FlintSparkIndex = { + val indexOptions = FlintSparkIndexOptions( + metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) + + // Convert generic Map[String,AnyRef] in metadata to specific data structure in Flint index + metadata.kind match { + case SKIPPING_INDEX_TYPE => + val strategies = metadata.indexedColumns.map { colInfo => + val skippingKind = SkippingKind.withName(getString(colInfo, "kind")) + val columnName = getString(colInfo, "columnName") + val columnType = getString(colInfo, "columnType") + + skippingKind match { + case PARTITION => + PartitionSkippingStrategy(columnName = columnName, columnType = columnType) + case VALUE_SET => + ValueSetSkippingStrategy(columnName = columnName, columnType = columnType) + case MIN_MAX => + MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) + case other => + throw new IllegalStateException(s"Unknown skipping strategy: $other") + } + } + FlintSparkSkippingIndex(metadata.source, strategies, indexOptions) + case COVERING_INDEX_TYPE => + FlintSparkCoveringIndex( + metadata.name, + metadata.source, + metadata.indexedColumns.map { colInfo => + getString(colInfo, "columnName") -> getString(colInfo, "columnType") + }.toMap, + indexOptions) + case MV_INDEX_TYPE => + FlintSparkMaterializedView( + metadata.name, + metadata.source, + metadata.indexedColumns.map { colInfo => + getString(colInfo, "columnName") -> getString(colInfo, "columnType") + }.toMap, + indexOptions) + } + } + + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { + map.get(key).asInstanceOf[String] + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index c6f546605..b3e7535c3 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -5,6 +5,9 @@ package org.opensearch.flint.spark +import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, INDEX_SETTINGS, OptionName, REFRESH_INTERVAL} +import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames + /** * Flint Spark index configurable options. * @@ -13,13 +16,15 @@ package org.opensearch.flint.spark */ case class FlintSparkIndexOptions(options: Map[String, String]) { + validateOptionNames(options) + /** * Is Flint index auto refreshed or manual refreshed. * * @return * auto refresh option value */ - def autoRefresh(): Boolean = options.getOrElse("auto_refresh", "false").toBoolean + def autoRefresh(): Boolean = getOptionValue(AUTO_REFRESH).getOrElse("false").toBoolean /** * The refresh interval (only valid if auto refresh enabled). @@ -27,7 +32,7 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { * @return * refresh interval expression */ - def refreshInterval(): Option[String] = options.get("refresh_interval") + def refreshInterval(): Option[String] = getOptionValue(REFRESH_INTERVAL) /** * The checkpoint location which maybe required by Flint index's refresh. @@ -35,7 +40,7 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { * @return * checkpoint location path */ - def checkpointLocation(): Option[String] = options.get("checkpoint_location") + def checkpointLocation(): Option[String] = getOptionValue(CHECKPOINT_LOCATION) /** * The index settings for OpenSearch index created. @@ -43,7 +48,25 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { * @return * index setting JSON */ - def indexSettings(): Option[String] = options.get("index_settings") + def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS) + + /** + * @return + * all option values and fill default value if unspecified + */ + def optionsWithDefault: Map[String, String] = { + val map = Map.newBuilder[String, String] + map ++= options + + if (!options.contains(AUTO_REFRESH.toString)) { + map += (AUTO_REFRESH.toString -> autoRefresh().toString) + } + map.result() + } + + private def getOptionValue(name: OptionName): Option[String] = { + options.get(name.toString) + } } object FlintSparkIndexOptions { @@ -52,4 +75,28 @@ object FlintSparkIndexOptions { * Empty options */ val empty: FlintSparkIndexOptions = FlintSparkIndexOptions(Map.empty) + + /** + * Option name Enum. + */ + object OptionName extends Enumeration { + type OptionName = Value + val AUTO_REFRESH: OptionName.Value = Value("auto_refresh") + val REFRESH_INTERVAL: OptionName.Value = Value("refresh_interval") + val CHECKPOINT_LOCATION: OptionName.Value = Value("checkpoint_location") + val INDEX_SETTINGS: OptionName.Value = Value("index_settings") + } + + /** + * Validate option names and throw exception if any unknown found. + * + * @param options + * options given + */ + def validateOptionNames(options: Map[String, String]): Unit = { + val allOptions = OptionName.values.map(_.toString) + val invalidOptions = options.keys.filterNot(allOptions.contains) + + require(invalidOptions.isEmpty, s"option name ${invalidOptions.mkString(",")} is invalid") + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 3db325c3e..e9c2b5be5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -5,19 +5,15 @@ package org.opensearch.flint.spark.covering -import org.json4s.{Formats, NoTypeHints} -import org.json4s.JsonAST.{JArray, JObject, JString} -import org.json4s.native.JsonMethods.{compact, parse, render} -import org.json4s.native.Serialization +import scala.collection.JavaConverters.mapAsJavaMapConverter + import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, populateEnvToMetadata} +import org.opensearch.flint.spark._ +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql._ /** * Flint covering index in Spark. @@ -38,61 +34,30 @@ case class FlintSparkCoveringIndex( require(indexedColumns.nonEmpty, "indexed columns must not be empty") - /** Required by json4s write function */ - implicit val formats: Formats = Serialization.formats(NoTypeHints) - override val kind: String = COVERING_INDEX_TYPE override def name(): String = getFlintIndexName(indexName, tableName) override def metadata(): FlintMetadata = { - new FlintMetadata(s"""{ - | "_meta": { - | "name": "$indexName", - | "kind": "$kind", - | "indexedColumns": $getMetaInfo, - | "source": "$tableName", - | "options": $getIndexOptions, - | "properties": $getIndexProperties - | }, - | "properties": $getSchema - | } - |""".stripMargin) - } - - override def build(df: DataFrame): DataFrame = { - val colNames = indexedColumns.keys.toSeq - df.select(colNames.head, colNames.tail: _*) - } - - // TODO: refactor all these once Flint metadata spec finalized - private def getMetaInfo: String = { - val objects = indexedColumns.map { case (colName, colType) => - JObject("columnName" -> JString(colName), "columnType" -> JString(colType)) - }.toList - Serialization.write(JArray(objects)) - } - - private def getIndexOptions: String = { - Serialization.write(options.options) - } - - private def getIndexProperties: String = { - val envMap = populateEnvToMetadata - if (envMap.isEmpty) { - "{}" - } else { - s"""{ "env": ${Serialization.write(envMap)} }""" + val indexColumnMaps = { + indexedColumns.map { case (colName, colType) => + Map[String, AnyRef]("columnName" -> colName, "columnType" -> colType).asJava + }.toArray } + val schemaJson = generateSchemaJSON(indexedColumns) + + metadataBuilder(this) + .name(indexName) + .source(tableName) + .indexedColumns(indexColumnMaps) + .schema(schemaJson) + .build() } - private def getSchema: String = { - val catalogDDL = - indexedColumns - .map { case (colName, colType) => s"$colName $colType not null" } - .mkString(",") - val properties = FlintDataType.serialize(StructType.fromDDL(catalogDDL)) - compact(render(parse(properties) \ "properties")) + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { + val colNames = indexedColumns.keys.toSeq + df.getOrElse(spark.read.table(tableName)) + .select(colNames.head, colNames.tail: _*) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala new file mode 100644 index 000000000..ee58ec7f5 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.mv + +import java.util.Locale + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} +import org.opensearch.flint.spark.FlintSparkIndex.{generateSchemaJSON, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.function.TumbleFunction +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.flint.{logicalPlanToDataFrame, qualifyTableName} + +/** + * Flint materialized view in Spark. + * + * @param mvName + * MV name + * @param query + * source query that generates MV data + * @param outputSchema + * output schema + * @param options + * index options + */ +case class FlintSparkMaterializedView( + mvName: String, + query: String, + outputSchema: Map[String, String], + override val options: FlintSparkIndexOptions = empty) + extends FlintSparkIndex + with StreamingRefresh { + + /** TODO: add it to index option */ + private val watermarkDelay = "0 Minute" + + override val kind: String = MV_INDEX_TYPE + + override def name(): String = getFlintIndexName(mvName) + + override def metadata(): FlintMetadata = { + val indexColumnMaps = + outputSchema.map { case (colName, colType) => + Map[String, AnyRef]("columnName" -> colName, "columnType" -> colType).asJava + }.toArray + val schemaJson = generateSchemaJSON(outputSchema) + + metadataBuilder(this) + .name(mvName) + .source(query) + .indexedColumns(indexColumnMaps) + .schema(schemaJson) + .build() + } + + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { + require(df.isEmpty, "materialized view doesn't support reading from other data frame") + + spark.sql(query) + } + + override def buildStream(spark: SparkSession): DataFrame = { + val batchPlan = spark.sql(query).queryExecution.logical + + /* + * Convert unresolved batch plan to streaming plan by: + * 1.Insert Watermark operator below Aggregate (required by Spark streaming) + * 2.Set isStreaming flag to true in Relation operator + */ + val streamingPlan = batchPlan transform { + case WindowingAggregate(agg, timeCol) => + agg.copy(child = watermark(timeCol, watermarkDelay, agg.child)) + + case relation: UnresolvedRelation if !relation.isStreaming => + relation.copy(isStreaming = true) + } + logicalPlanToDataFrame(spark, streamingPlan) + } + + private def watermark(timeCol: Attribute, delay: String, child: LogicalPlan) = { + EventTimeWatermark(timeCol, IntervalUtils.fromIntervalString(delay), child) + } + + /** + * Extractor that extract event time column out of Aggregate operator. + */ + private object WindowingAggregate { + + def unapply(agg: Aggregate): Option[(Aggregate, Attribute)] = { + val winFuncs = agg.groupingExpressions.collect { + case func: UnresolvedFunction if isWindowingFunction(func) => + func + } + + if (winFuncs.size != 1) { + throw new IllegalStateException( + "A windowing function is required for streaming aggregation") + } + + // Assume first aggregate item must be time column + val winFunc = winFuncs.head + val timeCol = winFunc.arguments.head.asInstanceOf[Attribute] + Some(agg, timeCol) + } + + private def isWindowingFunction(func: UnresolvedFunction): Boolean = { + val funcName = func.nameParts.mkString(".").toLowerCase(Locale.ROOT) + val funcIdent = FunctionIdentifier(funcName) + + // TODO: support other window functions + funcIdent == TumbleFunction.identifier + } + } +} + +object FlintSparkMaterializedView { + + /** MV index type name */ + val MV_INDEX_TYPE = "mv" + + /** + * Get index name following the convention "flint_" + qualified MV name (replace dot with + * underscore). + * + * @param mvName + * MV name + * @return + * Flint index name + */ + def getFlintIndexName(mvName: String): String = { + require( + mvName.split("\\.").length >= 3, + "Qualified materialized view name catalog.database.mv is required") + + s"flint_${mvName.replace(".", "_")}" + } + + /** Builder class for MV build */ + class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { + private var mvName: String = "" + private var query: String = "" + + /** + * Set MV name. + * + * @param mvName + * MV name + * @return + * builder + */ + def name(mvName: String): Builder = { + this.mvName = qualifyTableName(flint.spark, mvName) + this + } + + /** + * Set MV query. + * + * @param query + * MV query + * @return + * builder + */ + def query(query: String): Builder = { + this.query = query + this + } + + override protected def buildIndex(): FlintSparkIndex = { + // TODO: change here and FlintDS class to support complex field type in future + val outputSchema = flint.spark + .sql(query) + .schema + .map { field => + field.name -> field.dataType.typeName + } + .toMap + FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index dd9cb6bdf..eb2075b63 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -5,25 +5,20 @@ package org.opensearch.flint.spark.skipping -import org.json4s._ -import org.json4s.native.JsonMethods._ -import org.json4s.native.Serialization -import org.opensearch.flint.core.FlintVersion +import scala.collection.JavaConverters.mapAsJavaMapConverter + import org.opensearch.flint.core.metadata.FlintMetadata -import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, populateEnvToMetadata, ID_COLUMN} +import org.opensearch.flint.spark._ +import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression -import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, input_file_name, sha1} -import org.apache.spark.sql.types.StructType /** * Flint skipping index in Spark. @@ -33,17 +28,14 @@ import org.apache.spark.sql.types.StructType * @param indexedColumns * indexed column list */ -class FlintSparkSkippingIndex( +case class FlintSparkSkippingIndex( tableName: String, - val indexedColumns: Seq[FlintSparkSkippingStrategy], + indexedColumns: Seq[FlintSparkSkippingStrategy], override val options: FlintSparkIndexOptions = empty) extends FlintSparkIndex { require(indexedColumns.nonEmpty, "indexed columns must not be empty") - /** Required by json4s write function */ - implicit val formats: Formats = Serialization.formats(NoTypeHints) + SkippingKindSerializer - /** Skipping index type */ override val kind: String = SKIPPING_INDEX_TYPE @@ -52,22 +44,30 @@ class FlintSparkSkippingIndex( } override def metadata(): FlintMetadata = { - new FlintMetadata(s"""{ - | "_meta": { - | "name": "${name()}", - | "version": "${FlintVersion.current()}", - | "kind": "$SKIPPING_INDEX_TYPE", - | "indexedColumns": $getMetaInfo, - | "source": "$tableName", - | "options": $getIndexOptions, - | "properties": $getIndexProperties - | }, - | "properties": $getSchema - | } - |""".stripMargin) + val indexColumnMaps = + indexedColumns + .map(col => + Map[String, AnyRef]( + "kind" -> col.kind.toString, + "columnName" -> col.columnName, + "columnType" -> col.columnType).asJava) + .toArray + + val fieldTypes = + indexedColumns + .flatMap(_.outputSchema()) + .toMap + (FILE_PATH_COLUMN -> "string") + val schemaJson = generateSchemaJSON(fieldTypes) + + metadataBuilder(this) + .name(name()) + .source(tableName) + .indexedColumns(indexColumnMaps) + .schema(schemaJson) + .build() } - override def build(df: DataFrame): DataFrame = { + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val outputNames = indexedColumns.flatMap(_.outputSchema().keys) val aggFuncs = indexedColumns.flatMap(_.getAggregators) @@ -77,40 +77,11 @@ class FlintSparkSkippingIndex( new Column(aggFunc.toAggregateExpression().as(name)) } - df.groupBy(input_file_name().as(FILE_PATH_COLUMN)) + df.getOrElse(spark.read.table(tableName)) + .groupBy(input_file_name().as(FILE_PATH_COLUMN)) .agg(namedAggFuncs.head, namedAggFuncs.tail: _*) .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) } - - private def getMetaInfo: String = { - Serialization.write(indexedColumns) - } - - private def getIndexOptions: String = { - Serialization.write(options.options) - } - - private def getIndexProperties: String = { - val envMap = populateEnvToMetadata - if (envMap.isEmpty) { - "{}" - } else { - s"""{ "env": ${Serialization.write(envMap)} }""" - } - } - - private def getSchema: String = { - val allFieldTypes = - indexedColumns.flatMap(_.outputSchema()).toMap + (FILE_PATH_COLUMN -> "string") - val catalogDDL = - allFieldTypes - .map { case (colName, colType) => s"$colName $colType not null" } - .mkString(",") - val allFieldSparkTypes = StructType.fromDDL(catalogDDL) - // Convert StructType to {"properties": ...} and only need the properties value - val properties = FlintDataType.serialize(allFieldSparkTypes) - compact(render(parse(properties) \ "properties")) - } } object FlintSparkSkippingIndex { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala new file mode 100644 index 000000000..160a4c9d3 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, INDEX_SETTINGS, REFRESH_INTERVAL} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite + +class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { + + test("should return lowercase name as option name") { + AUTO_REFRESH.toString shouldBe "auto_refresh" + REFRESH_INTERVAL.toString shouldBe "refresh_interval" + CHECKPOINT_LOCATION.toString shouldBe "checkpoint_location" + INDEX_SETTINGS.toString shouldBe "index_settings" + } + + test("should return specified option value") { + val options = FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "refresh_interval" -> "1 Minute", + "checkpoint_location" -> "s3://test/", + "index_settings" -> """{"number_of_shards": 3}""")) + + options.autoRefresh() shouldBe true + options.refreshInterval() shouldBe Some("1 Minute") + options.checkpointLocation() shouldBe Some("s3://test/") + options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""") + } + + test("should return default option value if unspecified") { + val options = FlintSparkIndexOptions(Map.empty) + + options.autoRefresh() shouldBe false + options.refreshInterval() shouldBe empty + options.checkpointLocation() shouldBe empty + options.indexSettings() shouldBe empty + options.optionsWithDefault should contain("auto_refresh" -> "false") + } + + test("should return default option value if unspecified with specified value") { + val options = FlintSparkIndexOptions(Map("refresh_interval" -> "1 Minute")) + + options.optionsWithDefault shouldBe Map( + "auto_refresh" -> "false", + "refresh_interval" -> "1 Minute") + } + + test("should report error if any unknown option name") { + the[IllegalArgumentException] thrownBy + FlintSparkIndexOptions(Map("autoRefresh" -> "true")) + + the[IllegalArgumentException] thrownBy + FlintSparkIndexOptions(Map("AUTO_REFRESH" -> "true")) + + the[IllegalArgumentException] thrownBy { + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "indexSetting" -> "test")) + } should have message "requirement failed: option name indexSetting is invalid" + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala new file mode 100644 index 000000000..c28495c69 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.mv + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.dsl.expressions.{count, intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String + +/** + * This UT include test cases for building API which make use of real SparkSession. This is + * because SparkSession.sessionState is private val and hard to mock but it's required in + * logicalPlanToDataFrame() -> DataRows.of(). + */ +class FlintSparkMaterializedViewSuite extends FlintSuite { + + val testMvName = "spark_catalog.default.mv" + val testQuery = "SELECT 1" + + test("get name") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + mv.name() shouldBe "flint_spark_catalog_default_mv" + } + + test("should fail if get name with unqualified MV name") { + the[IllegalArgumentException] thrownBy + FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + + the[IllegalArgumentException] thrownBy + FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + } + + test("get metadata") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + + val metadata = mv.metadata() + metadata.name shouldBe mv.mvName + metadata.kind shouldBe MV_INDEX_TYPE + metadata.source shouldBe "SELECT 1" + metadata.indexedColumns shouldBe Array( + Map("columnName" -> "test_col", "columnType" -> "integer").asJava) + metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava + } + + test("get metadata with index options") { + val indexSettings = """{"number_of_shards": 2}""" + val indexOptions = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "index_settings" -> indexSettings)) + val mv = FlintSparkMaterializedView( + testMvName, + testQuery, + Map("test_col" -> "integer"), + indexOptions) + + mv.metadata().options shouldBe Map( + "auto_refresh" -> "true", + "index_settings" -> indexSettings).asJava + mv.metadata().indexSettings shouldBe Some(indexSettings) + } + + test("build batch data frame") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + mv.build(spark, None).collect() shouldBe Array(Row(1)) + } + + test("should fail if build given other source data frame") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) + } + + test("build stream should insert watermark operator and replace batch relation") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) + } + } + + test("build stream with filtering query") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | WHERE age > 30 + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) + } + } + + test("build stream with non-aggregate query") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, age FROM $testTable WHERE age > 30", + Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .select($"name", $"age"))) + } + } + + test("build stream should fail if there is aggregation but no windowing function") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Map.empty) + + the[IllegalStateException] thrownBy + mv.buildStream(spark) + } + } +} + +/** + * Helper method that extends LogicalPlan with more methods by Scala implicit class. + */ +object FlintSparkMaterializedViewSuite { + + def streamingRelation(tableName: String): UnresolvedRelation = { + UnresolvedRelation( + TableIdentifier(tableName), + CaseInsensitiveStringMap.empty(), + isStreaming = true) + } + + implicit class StreamingDslLogicalPlan(val logicalPlan: LogicalPlan) { + + def watermark(colName: Attribute, interval: String): DslLogicalPlan = { + EventTimeWatermark( + colName, + IntervalUtils.stringToInterval(UTF8String.fromString(interval)), + logicalPlan) + } + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index b31e18480..d52c43842 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -5,11 +5,14 @@ package org.opensearch.flint.spark.skipping +import scala.collection.JavaConverters.mapAsJavaMapConverter + import org.json4s.native.JsonMethods.parse import org.mockito.Mockito.when import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COLUMN +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.mockito.MockitoSugar.mock @@ -27,6 +30,25 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { index.name() shouldBe "flint_spark_catalog_default_test_skipping_index" } + test("get index metadata") { + val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) + when(indexCol.columnName).thenReturn("test_field") + when(indexCol.columnType).thenReturn("integer") + when(indexCol.outputSchema()).thenReturn(Map("test_field" -> "integer")) + val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) + + val metadata = index.metadata() + metadata.kind shouldBe SKIPPING_INDEX_TYPE + metadata.name shouldBe index.name() + metadata.source shouldBe testTable + metadata.indexedColumns shouldBe Array( + Map( + "kind" -> SkippingKind.PARTITION.toString, + "columnName" -> "test_field", + "columnType" -> "integer").asJava) + } + test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string")) @@ -34,12 +56,13 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") - val indexDf = index.build(df) + val indexDf = index.build(spark, Some(df)) indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } test("can build index for boolean column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("boolean_col" -> "boolean")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("boolean_col").expr))) @@ -59,6 +82,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for string column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("string_col" -> "string")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("string_col").expr))) @@ -80,6 +104,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for varchar column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("varchar_col" -> "varchar(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("varchar_col").expr))) @@ -99,6 +124,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for char column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("char_col" -> "char(20)")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("char_col").expr))) @@ -118,6 +144,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for long column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("long_col" -> "bigint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("long_col").expr))) @@ -137,6 +164,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for int column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("int_col" -> "int")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("int_col").expr))) @@ -156,6 +184,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for short column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("short_col" -> "smallint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("short_col").expr))) @@ -175,6 +204,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for byte column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("byte_col" -> "tinyint")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("byte_col").expr))) @@ -194,6 +224,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for double column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("double_col" -> "double")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("double_col").expr))) @@ -213,6 +244,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for float column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("float_col" -> "float")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("float_col").expr))) @@ -232,6 +264,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for timestamp column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("timestamp_col" -> "timestamp")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("timestamp_col").expr))) @@ -252,6 +285,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for date column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()).thenReturn(Map("date_col" -> "date")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("date_col").expr))) @@ -272,6 +306,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { test("can build index for struct column") { val indexCol = mock[FlintSparkSkippingStrategy] + when(indexCol.kind).thenReturn(SkippingKind.PARTITION) when(indexCol.outputSchema()) .thenReturn(Map("struct_col" -> "struct")) when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("struct_col").expr))) diff --git a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index 9d34b6f2a..5c799128c 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -7,10 +7,10 @@ package org.opensearch.flint.core import scala.collection.JavaConverters._ -import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods.parse import org.json4s.native.Serialization +import org.mockito.Mockito.when import org.opensearch.client.json.jackson.JacksonJsonpMapper import org.opensearch.client.opensearch.OpenSearchClient import org.opensearch.client.transport.rest_client.RestClientTransport @@ -19,6 +19,7 @@ import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY @@ -34,7 +35,7 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M val content = """ { | "_meta": { - | "kind": "SkippingIndex" + | "kind": "test_kind" | }, | "properties": { | "age": { @@ -43,41 +44,51 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M | } | } |""".stripMargin - flintClient.createIndex(indexName, new FlintMetadata(content)) + + val metadata = mock[FlintMetadata] + when(metadata.getContent).thenReturn(content) + when(metadata.indexSettings).thenReturn(None) + flintClient.createIndex(indexName, metadata) flintClient.exists(indexName) shouldBe true - flintClient.getIndexMetadata(indexName).getContent should matchJson(content) + flintClient.getIndexMetadata(indexName).kind shouldBe "test_kind" } it should "create index with settings" in { val indexName = "flint_test_with_settings" val indexSettings = "{\"number_of_shards\": 3,\"number_of_replicas\": 2}" - flintClient.createIndex(indexName, new FlintMetadata("{}", indexSettings)) + val metadata = mock[FlintMetadata] + when(metadata.getContent).thenReturn("{}") + when(metadata.indexSettings).thenReturn(Some(indexSettings)) + flintClient.createIndex(indexName, metadata) flintClient.exists(indexName) shouldBe true // OS uses full setting name ("index" prefix) and store as string implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(indexName).getIndexSettings) + val settings = parse(flintClient.getIndexMetadata(indexName).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "3" (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } it should "get all index metadata with the given index name pattern" in { - flintClient.createIndex("flint_test_1_index", new FlintMetadata("{}")) - flintClient.createIndex("flint_test_2_index", new FlintMetadata("{}")) + val metadata = mock[FlintMetadata] + when(metadata.getContent).thenReturn("{}") + when(metadata.indexSettings).thenReturn(None) + flintClient.createIndex("flint_test_1_index", metadata) + flintClient.createIndex("flint_test_2_index", metadata) val allMetadata = flintClient.getAllIndexMetadata("flint_*_index") allMetadata should have size 2 - allMetadata.forEach(metadata => metadata.getContent shouldBe "{}") - allMetadata.forEach(metadata => metadata.getIndexSettings should not be empty) + allMetadata.forEach(metadata => metadata.getContent should not be empty) + allMetadata.forEach(metadata => metadata.indexSettings should not be empty) } it should "convert index name to all lowercase" in { val indexName = "flint_ELB_logs_index" flintClient.createIndex( indexName, - new FlintMetadata("""{"properties": {"test": { "type": "integer" } } }""")) + FlintMetadata("""{"properties": {"test": { "type": "integer" } } }""")) flintClient.exists(indexName) shouldBe true flintClient.getIndexMetadata(indexName) should not be null diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index ac0b33746..a4b0069dd 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.opensearch.flint.core.FlintVersion.current import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.scalatest.matchers.must.Matchers.defined @@ -45,6 +46,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { index shouldBe defined index.get.metadata().getContent should matchJson(s"""{ | "_meta": { + | "version": "${current()}", | "name": "name_and_age", | "kind": "covering", | "indexedColumns": [ @@ -57,7 +59,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { | "columnType": "int" | }], | "source": "spark_catalog.default.ci_test", - | "options": {}, + | "options": { "auto_refresh": "false" }, | "properties": {} | }, | "properties": { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 892a8faa4..627e11f52 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -16,7 +16,7 @@ import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.defined -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.apache.spark.sql.Row @@ -58,7 +58,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } - test("create skipping index with streaming job options") { + test("create covering index with streaming job options") { withTempDir { checkpointDir => sql(s""" | CREATE INDEX $testIndex ON $testTable ( name ) @@ -77,7 +77,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { } } - test("create skipping index with index settings") { + test("create covering index with index settings") { sql(s""" | CREATE INDEX $testIndex ON $testTable ( name ) | WITH ( @@ -89,11 +89,20 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(testFlintIndex).getIndexSettings) + val settings = parse(flintClient.getIndexMetadata(testFlintIndex).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "2" (settings \ "index.number_of_replicas").extract[String] shouldBe "3" } + test("create covering index with invalid option") { + the[IllegalArgumentException] thrownBy + sql(s""" + | CREATE INDEX $testIndex ON $testTable + | (name, age) + | WITH (autoRefresh = true) + | """.stripMargin) + } + test("create covering index with manual refresh") { sql(s""" | CREATE INDEX $testIndex ON $testTable diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala new file mode 100644 index 000000000..29ab433c6 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.sql.Timestamp + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.sql.{DataFrame, Row} + +class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { + + /** Test table, MV, index name and query */ + private val testTable = "spark_catalog.default.mv_test" + private val testMvName = "spark_catalog.default.mv_test_metrics" + private val testFlintIndex = getFlintIndexName(testMvName) + private val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '10 Minutes') + |""".stripMargin + + override def beforeAll(): Unit = { + super.beforeAll() + createTimeSeriesTable(testTable) + } + + override def afterEach(): Unit = { + super.afterEach() + flint.deleteIndex(testFlintIndex) + } + + test("create materialized view with metadata successfully") { + val indexOptions = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "checkpoint_location" -> "s3://test/")) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get.metadata().getContent should matchJson(s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "spark_catalog.default.mv_test_metrics", + | "kind": "mv", + | "source": "$testQuery", + | "indexedColumns": [ + | { + | "columnName": "startTime", + | "columnType": "timestamp" + | },{ + | "columnName": "count", + | "columnType": "long" + | }], + | "options": { + | "auto_refresh": "true", + | "checkpoint_location": "s3://test/" + | }, + | "properties": {} + | }, + | "properties": { + | "startTime": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "count": { + | "type": "long" + | } + | } + | } + |""".stripMargin) + } + + // TODO: fix this windowing function unable to be used in GROUP BY + ignore("full refresh materialized view") { + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .create() + + flint.refreshIndex(testFlintIndex, FULL) + + val indexData = flint.queryIndex(testFlintIndex) + checkAnswer( + indexData, + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1), + Row(timestamp("2023-10-01 02:00:00"), 1))) + } + + test("incremental refresh materialized view") { + withIncrementalMaterializedView(testQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 02:00:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with larger window") { + val largeWindowQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Hour') + |""".stripMargin + + withIncrementalMaterializedView(largeWindowQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 3), + Row(timestamp("2023-10-01 01:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 02:00:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with filtering query") { + val filterQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | WHERE address = 'Seattle' + | GROUP BY TUMBLE(time, '10 Minutes') + |""".stripMargin + + withIncrementalMaterializedView(filterQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 00:10:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with non-aggregate query") { + val nonAggQuery = + s""" + | SELECT name, age + | FROM $testTable + | WHERE age <= 30 + |""".stripMargin + + withIncrementalMaterializedView(nonAggQuery) { indexData => + checkAnswer(indexData.select("name", "age"), Seq(Row("A", 30), Row("B", 20), Row("E", 15))) + } + } + + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) + + private def withIncrementalMaterializedView(query: String)( + codeBlock: DataFrame => Unit): Unit = { + withTempDir { checkpointDir => + val indexOptions = FlintSparkIndexOptions( + Map("auto_refresh" -> "true", "checkpoint_location" -> checkpointDir.getAbsolutePath)) + + flint + .materializedView() + .name(testMvName) + .query(query) + .options(indexOptions) + .create() + + flint + .refreshIndex(testFlintIndex, INCREMENTAL) + .map(awaitStreamingComplete) + .orElse(throw new RuntimeException) + + val indexData = flint.queryIndex(testFlintIndex) + + // Execute the code block + codeBlock(indexData) + } + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index da61feebc..40cb5c201 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -79,7 +79,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "columnType": "int" | }], | "source": "spark_catalog.default.test", - | "options": {}, + | "options": { "auto_refresh": "false" }, | "properties": {} | }, | "properties": { @@ -105,7 +105,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | } |""".stripMargin) - index.get.options shouldBe FlintSparkIndexOptions.empty + index.get.options shouldBe FlintSparkIndexOptions(Map("auto_refresh" -> "false")) } test("create skipping index with index options successfully") { @@ -517,7 +517,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "columnType": "struct" | }], | "source": "$testTable", - | "options": {}, + | "options": { "auto_refresh": "false" }, | "properties": {} | }, | "properties": { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index a688b1370..bfbeba9c3 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -15,7 +15,7 @@ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.scalatest.matchers.must.Matchers.defined -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.apache.spark.sql.Row import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE @@ -94,11 +94,20 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { val flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)) implicit val formats: Formats = Serialization.formats(NoTypeHints) - val settings = parse(flintClient.getIndexMetadata(testIndex).getIndexSettings) + val settings = parse(flintClient.getIndexMetadata(testIndex).indexSettings.get) (settings \ "index.number_of_shards").extract[String] shouldBe "3" (settings \ "index.number_of_replicas").extract[String] shouldBe "2" } + test("create skipping index with invalid option") { + the[IllegalArgumentException] thrownBy + sql(s""" + | CREATE SKIPPING INDEX ON $testTable + | ( year PARTITION ) + | WITH (autoRefresh = true) + | """.stripMargin) + } + test("create skipping index with manual refresh") { sql(s""" | CREATE SKIPPING INDEX ON $testTable diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index d1f01caca..2b93ca12a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -28,6 +28,13 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit setFlintSparkConf(REFRESH_POLICY, "true") } + protected def awaitStreamingComplete(jobId: String): Unit = { + val job = spark.streams.get(jobId) + failAfter(streamingTimeout) { + job.processAllAvailable() + } + } + protected def createPartitionedTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable @@ -59,4 +66,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | VALUES ('World', 25, 'Portland') | """.stripMargin) } + + protected def createTimeSeriesTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | time TIMESTAMP, + | name STRING, + | age INT, + | address STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + |""".stripMargin) + + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:01:00', 'A', 30, 'Seattle')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:10:00', 'B', 20, 'Seattle')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:15:00', 'C', 35, 'Portland')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 01:00:00', 'D', 40, 'Portland')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 03:00:00', 'E', 15, 'Vancouver')") + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index d12d03565..51bf4d734 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -12,7 +12,11 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import org.opensearch.ExceptionsHelper -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} +import org.opensearch.client.{RequestOptions, RestHighLevelClient} +import org.opensearch.cluster.metadata.MappingMetadata +import org.opensearch.common.settings.Settings +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} import org.opensearch.flint.core.metadata.FlintMetadata import play.api.libs.json._ @@ -51,17 +55,19 @@ object FlintJob extends Logging { var dataToWrite: Option[DataFrame] = None try { - // flintClient needs spark session to be created first. Otherwise, we will have connection + // osClient needs spark session to be created first. Otherwise, we will have connection // exception from EMR-S to OS. - val flintClient = FlintClientBuilder.build(FlintSparkConf().flintOptions()) + val osClient = new OSClient(FlintSparkConf().flintOptions()) val futureMappingCheck = Future { - checkAndCreateIndex(flintClient, resultIndex) + checkAndCreateIndex(osClient, resultIndex) } val data = executeQuery(spark, query, dataSource) - val (correctMapping, error) = - ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) - dataToWrite = Some(if (correctMapping) data else getFailedData(spark, dataSource, error)) + val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) + dataToWrite = Some(mappingCheckResult match { + case Right(_) => data + case Left(error) => getFailedData(spark, dataSource, error) + }) } catch { case e: TimeoutException => val error = "Future operations timed out" @@ -238,7 +244,7 @@ object FlintJob extends Logging { compareJson(inputJson, mappingJson) } - def checkAndCreateIndex(flintClient: FlintClient, resultIndex: String): (Boolean, String) = { + def checkAndCreateIndex(osClient: OSClient, resultIndex: String): Either[String, Unit] = { // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val mapping = """{ @@ -271,39 +277,31 @@ object FlintJob extends Logging { }""".stripMargin try { - val existingSchema = flintClient.getIndexMetadata(resultIndex).getContent + val existingSchema = osClient.getIndexMetadata(resultIndex) if (!isSuperset(existingSchema, mapping)) { - (false, s"The mapping of $resultIndex is incorrect.") + Left(s"The mapping of $resultIndex is incorrect.") } else { - (true, "") + Right(()) } } catch { case e: IllegalStateException if e.getCause().getMessage().contains("index_not_found_exception") => - handleIndexNotFoundException(flintClient, resultIndex, mapping) + try { + osClient.createIndex(resultIndex, mapping) + Right(()) + } catch { + case e: Exception => + val error = s"Failed to create result index $resultIndex" + logError(error, e) + Left(error) + } case e: Exception => val error = "Failed to verify existing mapping" logError(error, e) - (false, error) + Left(error) } } - def handleIndexNotFoundException( - flintClient: FlintClient, - resultIndex: String, - mapping: String): (Boolean, String) = { - try { - logInfo(s"create $resultIndex") - flintClient.createIndex(resultIndex, new FlintMetadata(mapping)) - logInfo(s"create $resultIndex successfully") - (true, "") - } catch { - case e: Exception => - val error = s"Failed to create result index $resultIndex" - logError(error, e) - (false, error) - } - } def executeQuery(spark: SparkSession, query: String, dataSource: String): DataFrame = { // Execute SQL query val result: DataFrame = spark.sql(query) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala new file mode 100644 index 000000000..cf2a5860d --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.opensearch.client.RequestOptions +import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} +import org.opensearch.client.indices.CreateIndexRequest +import org.opensearch.common.xcontent.XContentType +import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} + +import org.apache.spark.internal.Logging + +class OSClient(val flintOptions: FlintOptions) extends Logging { + + def getIndexMetadata(osIndexName: String): String = { + + using(FlintClientBuilder.build(flintOptions).createClient()) { client => + val request = new GetIndexRequest(osIndexName) + try { + val response = client.indices.get(request, RequestOptions.DEFAULT) + response.getMappings.get(osIndexName).source.string + } catch { + case e: Exception => + throw new IllegalStateException( + s"Failed to get OpenSearch index mapping for $osIndexName", + e) + } + } + } + + /** + * Create a new index with given mapping. + * + * @param osIndexName + * the name of the index + * @param mapping + * the mapping of the index + * @return + * use Either for representing success or failure. A Right value indicates success, while a + * Left value indicates an error. + */ + def createIndex(osIndexName: String, mapping: String): Unit = { + logInfo(s"create $osIndexName") + + using(FlintClientBuilder.build(flintOptions).createClient()) { client => + val request = new CreateIndexRequest(osIndexName) + request.mapping(mapping, XContentType.JSON) + + try { + client.indices.create(request, RequestOptions.DEFAULT) + logInfo(s"create $osIndexName successfully") + } catch { + case e: Exception => + throw new IllegalStateException(s"Failed to create index $osIndexName", e); + } + } + } + + /** + * the loan pattern to manage resource. + * + * @param resource + * the resource to be managed + * @param f + * the function to be applied to the resource + * @tparam A + * the type of the resource + * @tparam B + * the type of the result + * @return + * the result of the function + */ + def using[A <: AutoCloseable, B](resource: A)(f: A => B): B = { + try { + f(resource) + } finally { + // client is guaranteed to be non-null + resource.close() + } + } + +}