From 514d51a506bbe910c9e085be54fb7c5117f4e6d9 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Thu, 27 Jun 2024 16:43:10 -0700 Subject: [PATCH] support shard level split on read path Signed-off-by: Peng Huo --- .../opensearch/flint/core/FlintClient.java | 10 +++ .../core/storage/FlintOpenSearchClient.java | 22 ++++++- .../core/storage/OpenSearchScrollReader.java | 39 ++++++++++-- .../catalog/OpenSearchCatalog.scala | 2 - .../opensearch/table/OpenSearchTable.scala | 62 +++++++++++++++++++ .../opensearch/table/PartitionInfo.scala | 49 +++++++++++++++ .../spark/opensearch/table/ShardInfo.scala | 16 +++++ .../flint/FlintPartitionReaderFactory.scala | 13 ++-- .../spark/sql/flint/FlintReadOnlyTable.scala | 19 ++---- .../apache/spark/sql/flint/FlintScan.scala | 16 +++-- .../spark/sql/flint/FlintScanBuilder.scala | 5 +- 11 files changed, 219 insertions(+), 34 deletions(-) create mode 100644 flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/OpenSearchTable.scala create mode 100644 flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/PartitionInfo.scala create mode 100644 flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/ShardInfo.scala diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index e5e18f126..0e9cc57b9 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -74,6 +74,16 @@ public interface FlintClient { */ FlintReader createReader(String indexName, String query); + /** + * Create {@link FlintReader}. + * + * @param indexName index name. + * @param shardId shard id. + * @param query DSL query. DSL query is null means match_all + * @return {@link FlintReader}. + */ + FlintReader createReader(String indexName, String shardId, String query); + /** * Create {@link FlintWriter}. * diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 2a3bf2da8..1867a27ad 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -14,6 +14,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; @@ -62,6 +63,9 @@ public class FlintOpenSearchClient implements FlintClient { private final static Set INVALID_INDEX_NAME_CHARS = Set.of(' ', ',', ':', '"', '+', '/', '\\', '|', '?', '#', '>', '<'); + private final static Function SHARD_ID_PREFERENCE = + shardId -> "_shards:"+shardId; + private final FlintOptions options; public FlintOpenSearchClient(FlintOptions options) { @@ -173,7 +177,20 @@ public void deleteIndex(String indexName) { */ @Override public FlintReader createReader(String indexName, String query) { - LOG.info("Creating Flint index reader for " + indexName + " with query " + query); + return createReader(indexName, query, null); + } + + /** + * Create {@link FlintReader}. + * + * @param indexName index name. + * @param query DSL query. DSL query is null means match_all + * @param shardId shardId + * @return + */ + @Override + public FlintReader createReader(String indexName, String query, String shardId) { + LOG.info("Creating Flint index reader for " + indexName + " with query " + query + " shardId " + shardId); try { QueryBuilder queryBuilder = new MatchAllQueryBuilder(); if (!Strings.isNullOrEmpty(query)) { @@ -185,7 +202,8 @@ public FlintReader createReader(String indexName, String query) { return new OpenSearchScrollReader(createClient(), sanitizeIndexName(indexName), new SearchSourceBuilder().query(queryBuilder), - options); + options, + SHARD_ID_PREFERENCE.apply(shardId)); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchScrollReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchScrollReader.java index 9cba0c97c..32b658c65 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchScrollReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchScrollReader.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; @@ -35,8 +37,23 @@ public class OpenSearchScrollReader extends OpenSearchReader { private String scrollId = null; - public OpenSearchScrollReader(IRestHighLevelClient client, String indexName, SearchSourceBuilder searchSourceBuilder, FlintOptions options) { - super(client, new SearchRequest().indices(indexName).source(searchSourceBuilder.size(options.getScrollSize()))); + public OpenSearchScrollReader( + IRestHighLevelClient client, + String indexName, + SearchSourceBuilder searchSourceBuilder, + FlintOptions options) { + this(client, indexName, searchSourceBuilder, options, null); + } + + public OpenSearchScrollReader( + IRestHighLevelClient client, + String indexName, + SearchSourceBuilder searchSourceBuilder, + FlintOptions options, + String preference) { + super(client, + applyPreference(preference).apply(new SearchRequest().indices(indexName) + .source(searchSourceBuilder.size(options.getScrollSize())))); this.options = options; this.scrollDuration = TimeValue.timeValueMinutes(options.getScrollDuration()); } @@ -52,8 +69,8 @@ Optional search(SearchRequest request) throws IOException { return Optional.of(response); } else { try { - return Optional - .of(client.scroll(new SearchScrollRequest().scroll(scrollDuration).scrollId(scrollId), + return Optional.of(client.scroll(new SearchScrollRequest().scroll(scrollDuration) + .scrollId(scrollId), RequestOptions.DEFAULT)); } catch (OpenSearchStatusException e) { LOG.log(Level.WARNING, "scroll context not exist", e); @@ -75,8 +92,10 @@ void clean() throws IOException { } } catch (OpenSearchStatusException e) { // OpenSearch throw exception if scroll already closed. https://github.com/opensearch-project/OpenSearch/issues/11121 - LOG.log(Level.WARNING, "close scroll exception, it is a known bug https://github" + - ".com/opensearch-project/OpenSearch/issues/11121.", e); + LOG.log( + Level.WARNING, + "close scroll exception, it is a known bug https://github" + ".com/opensearch-project/OpenSearch/issues/11121.", + e); } finally { scrollId = null; } @@ -88,4 +107,12 @@ void clean() throws IOException { public String getScrollId() { return scrollId; } + + static private Function applyPreference(String preference) { + if (Strings.isNullOrEmpty(preference)) { + return searchRequest -> searchRequest; + } else { + return searchRequest -> searchRequest.preference(preference); + } + } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalog.scala b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalog.scala index 3594f41de..c295eb339 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalog.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/catalog/OpenSearchCatalog.scala @@ -39,12 +39,10 @@ class OpenSearchCatalog extends CatalogPlugin with TableCatalog with Logging { override def name(): String = catalogName - @throws[NoSuchNamespaceException] override def listTables(namespace: Array[String]): Array[Identifier] = { throw new UnsupportedOperationException("OpenSearchCatalog does not support listTables") } - @throws[NoSuchTableException] override def loadTable(ident: Identifier): Table = { logInfo(s"Loading table ${ident.name()}") if (!ident.namespace().exists(n => OpenSearchCatalog.isDefaultNamespace(n))) { diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/OpenSearchTable.scala b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/OpenSearchTable.scala new file mode 100644 index 000000000..80eab850f --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/OpenSearchTable.scala @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.opensearch.table + +import scala.collection.JavaConverters._ + +import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} +import org.opensearch.flint.core.metadata.FlintMetadata + +import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.types.StructType + +/** + * Represents an OpenSearch table. + * + * @param tableName + * The name of the table. + * @param metadata + * Metadata of the table. + */ +case class OpenSearchTable(tableName: String, metadata: Map[String, FlintMetadata]) { + /* + * FIXME. we use first index schema in multiple indices. we should merge StructType to widen type + */ + lazy val schema: StructType = { + metadata.values.headOption + .map(m => FlintDataType.deserialize(m.getContent)) + .getOrElse(StructType(Nil)) + } + + lazy val partitions: Array[PartitionInfo] = { + metadata.map { case (partitionName, metadata) => + PartitionInfo.apply(partitionName, metadata.indexSettings.get) + }.toArray + } +} + +object OpenSearchTable { + + /** + * Creates an OpenSearchTable instance. + * + * @param tableName + * tableName support (1) single index name. (2) wildcard index name. (3) comma sep index name. + * @param options + * The options for Flint. + * @return + * An instance of OpenSearchTable. + */ + def apply(tableName: String, options: FlintOptions): OpenSearchTable = { + OpenSearchTable( + tableName, + FlintClientBuilder + .build(options) + .getAllIndexMetadata(tableName.split(","): _*) + .asScala + .toMap) + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/PartitionInfo.scala b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/PartitionInfo.scala new file mode 100644 index 000000000..ec3453618 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/PartitionInfo.scala @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.opensearch.table + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.jackson.JsonMethods +import org.json4s.native.Serialization + +/** + * Represents information about a partition in OpenSearch. Partition is backed by OpenSearch + * Index. Each partition contain a list of Shards + * + * @param partitionName + * partition name. + * @param shards + * shards. + */ +case class PartitionInfo(partitionName: String, shards: Array[ShardInfo]) {} + +object PartitionInfo { + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + /** + * Creates a PartitionInfo instance. + * + * @param partitionName The name of the partition. + * @param settings The settings of the partition. + * @return An instance of PartitionInfo. + */ + def apply(partitionName: String, settings: String): PartitionInfo = { + val shards = + Range.apply(0, numberOfShards(settings)).map(id => ShardInfo(partitionName, id)).toArray + PartitionInfo(partitionName, shards) + } + + /** + * Extracts the number of shards from the settings string. + * + * @param settingStr The settings string. + * @return The number of shards. + */ + def numberOfShards(settingStr: String): Int = { + val setting = JsonMethods.parse(settingStr) + (setting \ "index.number_of_shards").extract[String].toInt + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/ShardInfo.scala b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/ShardInfo.scala new file mode 100644 index 000000000..7946bf1cb --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/apache/spark/opensearch/table/ShardInfo.scala @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.opensearch.table + +/** + * Represents information about a shard in OpenSearch. + * + * @param indexName + * The name of the index. + * @param id + * The ID of the shard. + */ +case class ShardInfo(indexName: String, id: Int) diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala index ebd46d625..eebad81c8 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala @@ -15,14 +15,19 @@ import org.apache.spark.sql.flint.storage.FlintQueryCompiler import org.apache.spark.sql.types.StructType case class FlintPartitionReaderFactory( - tableName: String, schema: StructType, options: FlintSparkConf, pushedPredicates: Array[Predicate]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val query = FlintQueryCompiler(schema).compile(pushedPredicates) - val flintClient = FlintClientBuilder.build(options.flintOptions()) - new FlintPartitionReader(flintClient.createReader(tableName, query), schema, options) + partition match { + case OpenSearchSplit(shardInfo) => + val query = FlintQueryCompiler(schema).compile(pushedPredicates) + val flintClient = FlintClientBuilder.build(options.flintOptions()) + new FlintPartitionReader( + flintClient.createReader(shardInfo.indexName, query, shardInfo.id.toString), + schema, + options) + } } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala index ac83d2ef6..b1ec83cae 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintReadOnlyTable.scala @@ -12,6 +12,7 @@ import scala.collection.JavaConverters._ import org.opensearch.flint.core.FlintClientBuilder import org.apache.spark.opensearch.catalog.OpenSearchCatalog +import org.apache.spark.opensearch.table.OpenSearchTable import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE} @@ -41,25 +42,17 @@ class FlintReadOnlyTable( lazy val name: String = flintSparkConf.tableName() - // todo. currently, we use first index schema in multiple indices. we should merge StructType - // to widen type + lazy val openSearchTable: OpenSearchTable = + OpenSearchTable.apply(name, flintSparkConf.flintOptions()) + lazy val schema: StructType = { - userSpecifiedSchema.getOrElse { - FlintClientBuilder - .build(flintSparkConf.flintOptions()) - .getAllIndexMetadata(OpenSearchCatalog.indexNames(name): _*) - .values() - .asScala - .headOption - .map(m => FlintDataType.deserialize(m.getContent)) - .getOrElse(StructType(Nil)) - } + userSpecifiedSchema.getOrElse { openSearchTable.schema } } override def capabilities(): util.Set[TableCapability] = util.EnumSet.of(BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - FlintScanBuilder(name, schema, flintSparkConf) + FlintScanBuilder(openSearchTable, schema, flintSparkConf) } } diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala index 154e95476..c6e03e858 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala @@ -5,13 +5,14 @@ package org.apache.spark.sql.flint +import org.apache.spark.opensearch.table.{OpenSearchTable, ShardInfo} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.StructType case class FlintScan( - tableName: String, + table: OpenSearchTable, schema: StructType, options: FlintSparkConf, pushedPredicates: Array[Predicate]) @@ -21,11 +22,11 @@ case class FlintScan( override def readSchema(): StructType = schema override def planInputPartitions(): Array[InputPartition] = { - Array(OpenSearchInputPartition()) + table.partitions.flatMap(p => p.shards.map(s => OpenSearchSplit(s))).toArray } override def createReaderFactory(): PartitionReaderFactory = { - FlintPartitionReaderFactory(tableName, schema, options, pushedPredicates) + FlintPartitionReaderFactory(schema, options, pushedPredicates) } override def toBatch: Batch = this @@ -40,5 +41,10 @@ case class FlintScan( private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") } -// todo. add partition support. -private[spark] case class OpenSearchInputPartition() extends InputPartition {} +/** + * Each OpenSearchSplit is backed by an OpenSearch shard. + * + * @param shardInfo + * shardInfo + */ +private[spark] case class OpenSearchSplit(shardInfo: ShardInfo) extends InputPartition {} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala index 71bfe36e8..8d8d02c02 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala @@ -6,13 +6,14 @@ package org.apache.spark.sql.flint import org.apache.spark.internal.Logging +import org.apache.spark.opensearch.table.OpenSearchTable import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters} import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.storage.FlintQueryCompiler import org.apache.spark.sql.types.StructType -case class FlintScanBuilder(tableName: String, schema: StructType, options: FlintSparkConf) +case class FlintScanBuilder(table: OpenSearchTable, schema: StructType, options: FlintSparkConf) extends ScanBuilder with SupportsPushDownV2Filters with Logging { @@ -20,7 +21,7 @@ case class FlintScanBuilder(tableName: String, schema: StructType, options: Flin private var pushedPredicate = Array.empty[Predicate] override def build(): Scan = { - FlintScan(tableName, schema, options, pushedPredicate) + FlintScan(table, schema, options, pushedPredicate) } override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {