diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 0f2193ce0..83fd25a0c 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -16,3 +16,4 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Sean Kao | [seankao-az](https://github.com/seankao-az) | Amazon | | Anirudha Jadhav | [anirudha](https://github.com/anirudha) | Amazon | | Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon | +| Louis Chu | [noCharger](https://github.com/noCharger) | Amazon | diff --git a/build.sbt b/build.sbt index 95324fc99..bcf21e444 100644 --- a/build.sbt +++ b/build.sbt @@ -7,6 +7,10 @@ import Dependencies._ lazy val scala212 = "2.12.14" lazy val sparkVersion = "3.3.2" lazy val opensearchVersion = "2.6.0" +lazy val icebergVersion = "1.5.0" + +val scalaMinorVersion = scala212.split("\\.").take(2).mkString(".") +val sparkMinorVersion = sparkVersion.split("\\.").take(2).mkString(".") ThisBuild / organization := "org.opensearch" @@ -172,6 +176,8 @@ lazy val integtest = (project in file("integ-test")) "org.scalatest" %% "scalatest" % "3.2.15" % "test", "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "org.testcontainers" % "testcontainers" % "1.18.0" % "test", + "org.apache.iceberg" %% s"iceberg-spark-runtime-$sparkMinorVersion" % icebergVersion % "test", + "org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0" % "test", // add opensearch-java client to get node stats "org.opensearch.client" % "opensearch-java" % "2.6.0" % "test" exclude ("com.fasterxml.jackson.core", "jackson-databind")), diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index ee38bbb9c..ee78aa512 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -102,9 +102,10 @@ OptimisticTransaction startTransaction(String indexName, String dataSourc * @return {@link FlintWriter} */ FlintWriter createWriter(String indexName); + /** * Create {@link IRestHighLevelClient}. * @return {@link IRestHighLevelClient} */ - public IRestHighLevelClient createClient(); + IRestHighLevelClient createClient(); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 1282e1c94..9858ffd1e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -5,7 +5,6 @@ package org.opensearch.flint.core; -import dev.failsafe.RetryPolicy; import java.io.Serializable; import java.util.Map; import org.opensearch.flint.core.http.FlintRetryOptions; @@ -46,10 +45,14 @@ public class FlintOptions implements Serializable { public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider"; + public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider"; + /** - * By default, customAWSCredentialsProvider is empty. use DefaultAWSCredentialsProviderChain. + * By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain. */ - public static final String DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER = ""; + public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = ""; + + public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex"; /** * Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader} @@ -121,7 +124,11 @@ public String getAuth() { } public String getCustomAwsCredentialsProvider() { - return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, ""); + return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); + } + + public String getMetadataAccessAwsCredentialsProvider() { + return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } public String getUsername() { @@ -139,4 +146,8 @@ public int getSocketTimeoutMillis() { public String getDataSourceName() { return options.getOrDefault(DATA_SOURCE_NAME, ""); } + + public String getSystemIndexName() { + return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, ""); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java new file mode 100644 index 000000000..c3e65fef3 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -0,0 +1,99 @@ +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import org.apache.http.HttpException; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.protocol.HttpContext; + +import java.io.IOException; +import java.net.URISyntaxException; + +/** + * Intercepts HTTP requests to sign them for AWS authentication, adjusting the signing process + * based on whether the request accesses metadata or not. + */ +public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequestInterceptor { + + private final String service; + private final String metadataAccessIdentifier; + final AWSRequestSigningApacheInterceptor primaryInterceptor; + final AWSRequestSigningApacheInterceptor metadataAccessInterceptor; + + /** + * Constructs an interceptor for AWS request signing with optional metadata access. + * + * @param service The AWS service name. + * @param signer The AWS request signer. + * @param primaryCredentialsProvider The credentials provider for general access. + * @param metadataAccessCredentialsProvider The credentials provider for metadata access. + * @param metadataAccessIdentifier Identifier for operations requiring metadata access. + */ + public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, + final Signer signer, + final AWSCredentialsProvider primaryCredentialsProvider, + final AWSCredentialsProvider metadataAccessCredentialsProvider, + final String metadataAccessIdentifier) { + this(service, + new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider), + new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider), + metadataAccessIdentifier); + } + + // Test constructor allowing injection of mock interceptors + ResourceBasedAWSRequestSigningApacheInterceptor(final String service, + final AWSRequestSigningApacheInterceptor primaryInterceptor, + final AWSRequestSigningApacheInterceptor metadataAccessInterceptor, + final String metadataAccessIdentifier) { + this.service = service == null ? "unknown" : service; + this.primaryInterceptor = primaryInterceptor; + this.metadataAccessInterceptor = metadataAccessInterceptor; + this.metadataAccessIdentifier = metadataAccessIdentifier; + } + + /** + * Processes an HTTP request, signing it according to whether it requires metadata access. + * + * @param request The HTTP request to process. + * @param context The context in which the HTTP request is being processed. + * @throws HttpException If processing the HTTP request results in an exception. + * @throws IOException If an I/O error occurs. + */ + @Override + public void process(HttpRequest request, HttpContext context) throws HttpException, IOException { + String resourcePath = parseUriToPath(request); + if ("es".equals(this.service) && isMetadataAccess(resourcePath)) { + metadataAccessInterceptor.process(request, context); + } else { + primaryInterceptor.process(request, context); + } + } + + /** + * Extracts and returns the path component of a URI from an HTTP request. + * + * @param request The HTTP request from which to extract the URI path. + * @return The path component of the URI. + * @throws IOException If an error occurs parsing the URI. + */ + private String parseUriToPath(HttpRequest request) throws IOException { + try { + URIBuilder uriBuilder = new URIBuilder(request.getRequestLine().getUri()); + return uriBuilder.build().getRawPath(); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + } + + /** + * Determines whether the accessed resource requires metadata credentials. + * + * @param resourcePath The path of the resource being accessed. + * @return true if the operation requires metadata access credentials, false otherwise. + */ + private boolean isMetadataAccess(String resourcePath) { + return resourcePath.contains(metadataAccessIdentifier); + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 1c15af357..b03ac0c6f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -46,7 +46,7 @@ import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; -import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor; +import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; import org.opensearch.flint.core.http.RetryableHttpAsyncClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; @@ -262,26 +262,30 @@ public IRestHighLevelClient createClient() { signer.setRegionName(options.getRegion()); // Use DefaultAWSCredentialsProviderChain by default. - final AtomicReference awsCredentialsProvider = - new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); - String providerClass = options.getCustomAwsCredentialsProvider(); - if (!Strings.isNullOrEmpty(providerClass)) { - try { - Class awsCredentialsProviderClass = Class.forName(providerClass); - Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); - ctor.setAccessible(true); - awsCredentialsProvider.set((AWSCredentialsProvider) ctor.newInstance()); - } catch (Exception e) { - throw new RuntimeException(e); - } + final AtomicReference customAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + String customProviderClass = options.getCustomAwsCredentialsProvider(); + if (!Strings.isNullOrEmpty(customProviderClass)) { + instantiateProvider(customProviderClass, customAWSCredentialsProvider); + } + + // Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility + // unless a specific metadata access provider class name is provided + String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider(); + final AtomicReference metadataAccessAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + if (Strings.isNullOrEmpty(metadataAccessProviderClass)) { + metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get()); + } else { + instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider); } + restClientBuilder.setHttpClientConfigCallback(builder -> { - HttpAsyncClientBuilder delegate = - builder.addInterceptorLast( - new AWSRequestSigningApacheInterceptor( - signer.getServiceName(), signer, awsCredentialsProvider.get())); - return RetryableHttpAsyncClient.builder(delegate, options); - } + HttpAsyncClientBuilder delegate = builder.addInterceptorLast( + new ResourceBasedAWSRequestSigningApacheInterceptor( + signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName())); + return RetryableHttpAsyncClient.builder(delegate, options); + } ); } else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) { CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); @@ -303,6 +307,20 @@ public IRestHighLevelClient createClient() { return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder)); } + /** + * Attempts to instantiate the AWS credential provider using reflection. + */ + private void instantiateProvider(String providerClass, AtomicReference provider) { + try { + Class awsCredentialsProviderClass = Class.forName(providerClass); + Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); + ctor.setAccessible(true); + provider.set((AWSCredentialsProvider) ctor.newInstance()); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e); + } + } + /* * Constructs Flint metadata with latest metadata log entry attached if it's available. * It relies on FlintOptions to provide data source name. diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java new file mode 100644 index 000000000..0ef021b53 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java @@ -0,0 +1,66 @@ +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import org.apache.http.HttpRequest; +import org.apache.http.message.BasicHttpRequest; +import org.apache.http.protocol.HttpContext; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.Mockito.*; + +public class ResourceBasedAWSRequestSigningApacheInterceptorTest { + + @Mock + private Signer mockSigner; + @Mock + private AWSCredentialsProvider mockPrimaryCredentialsProvider; + @Mock + private AWSCredentialsProvider mockMetadataAccessCredentialsProvider; + @Mock + private HttpContext mockContext; + @Captor + private ArgumentCaptor httpRequestCaptor; + + private ResourceBasedAWSRequestSigningApacheInterceptor interceptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider)); + AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider)); + + interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor( + "es", + primaryInterceptorSpy, + metadataInterceptorSpy, + "/metadata"); + } + + @Test + public void testProcessWithMetadataAccess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource"); + + interceptor.process(request, mockContext); + + verify(interceptor.metadataAccessInterceptor).process(httpRequestCaptor.capture(), eq(mockContext)); + verify(interceptor.primaryInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class)); + assert httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata"); + } + + @Test + public void testProcessWithoutMetadataAccess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/es/regular/resource"); + + interceptor.process(request, mockContext); + + verify(interceptor.primaryInterceptor).process(httpRequestCaptor.capture(), eq(mockContext)); + verify(interceptor.metadataAccessInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class)); + assert !httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata"); + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index fbbea9176..eb3a29adc 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -75,7 +75,7 @@ object FlintSparkConf { FlintConfig("spark.datasource.flint.customAWSCredentialsProvider") .datasourceOption() .doc("AWS customAWSCredentialsProvider") - .createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER) + .createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER) val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name") .datasourceOption() @@ -174,6 +174,10 @@ object FlintSparkConf { FlintConfig(s"spark.flint.job.inactivityLimitMillis") .doc("inactivity timeout") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS)) + val METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = + FlintConfig("spark.metadata.accessAWSCredentialsProvider") + .doc("AWS credentials provider for metadata access permission") + .createOptional() } /** @@ -234,6 +238,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable DATA_SOURCE_NAME, SESSION_ID, REQUEST_INDEX, + METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index 7b875f63b..9cd5f60a7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -407,9 +407,6 @@ class FlintSpark(val spark: SparkSession) extends Logging { case (true, false) => AUTO case (false, false) => FULL case (false, true) => INCREMENTAL - case (true, true) => - throw new IllegalArgumentException( - "auto_refresh and incremental_refresh options cannot both be true") } // validate allowed options depending on refresh mode diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala index adcb4c45f..106df276d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexBuilder.scala @@ -8,6 +8,7 @@ package org.opensearch.flint.spark import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh import org.apache.spark.sql.catalog.Column import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -59,7 +60,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { * ignore existing index */ def create(ignoreIfExists: Boolean = false): Unit = - flint.createIndex(buildIndex(), ignoreIfExists) + flint.createIndex(validateIndex(buildIndex()), ignoreIfExists) /** * Copy Flint index with updated options. @@ -80,7 +81,24 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) { val updatedMetadata = index .metadata() .copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava) - FlintSparkIndexFactory.create(updatedMetadata).get + validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get) + } + + /** + * Pre-validate index to ensure its validity. By default, this method validates index options by + * delegating to specific index refresh (index options are mostly serving index refresh). + * Subclasses can extend this method to include additional validation logic. + * + * @param index + * Flint index to be validated + * @return + * the index or exception occurred if validation failed + */ + protected def validateIndex(index: FlintSparkIndex): FlintSparkIndex = { + FlintSparkIndexRefresh + .create(index.name(), index) // TODO: remove first argument? + .validate(flint.spark) + index } /** diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala new file mode 100644 index 000000000..f689d9aee --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkValidationHelper.scala @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName} + +/** + * Flint Spark validation helper. + */ +trait FlintSparkValidationHelper extends Logging { + + /** + * Determines whether the source table(s) for a given Flint index are supported. + * + * @param spark + * Spark session + * @param index + * Flint index + * @return + * true if all non Hive, otherwise false + */ + def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = { + // Extract source table name (possibly more than one for MV query) + val tableNames = index match { + case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName) + case covering: FlintSparkCoveringIndex => Seq(covering.tableName) + case mv: FlintSparkMaterializedView => + spark.sessionState.sqlParser + .parsePlan(mv.query) + .collect { case relation: UnresolvedRelation => + qualifyTableName(spark, relation.tableName) + } + } + + // Validate if any source table is not supported (currently Hive only) + tableNames.exists { tableName => + val (catalog, ident) = parseTableName(spark, tableName) + val table = loadTable(catalog, ident).get + + // TODO: add allowed table provider list + DDLUtils.isHiveTable(Option(table.properties().get("provider"))) + } + } + + /** + * Checks whether a specified checkpoint location is accessible. Accessibility, in this context, + * means that the folder exists and the current Spark session has the necessary permissions to + * access it. + * + * @param spark + * Spark session + * @param checkpointLocation + * checkpoint location + * @return + * true if accessible, otherwise false + */ + def isCheckpointLocationAccessible(spark: SparkSession, checkpointLocation: String): Boolean = { + try { + val checkpointManager = + CheckpointFileManager.create( + new Path(checkpointLocation), + spark.sessionState.newHadoopConf()) + + checkpointManager.exists(new Path(checkpointLocation)) + } catch { + case e: IOException => + logWarning(s"Failed to check if checkpoint location $checkpointLocation exists", e) + false + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala index 09428f80d..35902e184 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/AutoIndexRefresh.scala @@ -5,7 +5,9 @@ package org.opensearch.flint.spark.refresh -import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions} +import java.util.Collections + +import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkIndexOptions, FlintSparkValidationHelper} import org.opensearch.flint.spark.FlintSparkIndex.{quotedTableName, StreamingRefresh} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, RefreshMode} @@ -23,10 +25,41 @@ import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger} * @param index * Flint index */ -class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) extends FlintSparkIndexRefresh { +class AutoIndexRefresh(indexName: String, index: FlintSparkIndex) + extends FlintSparkIndexRefresh + with FlintSparkValidationHelper { override def refreshMode: RefreshMode = AUTO + override def validate(spark: SparkSession): Unit = { + // Incremental refresh cannot enabled at the same time + val options = index.options + require( + !options.incrementalRefresh(), + "Incremental refresh cannot be enabled if auto refresh is enabled") + + // Hive table doesn't support auto refresh + require( + !isTableProviderSupported(spark, index), + "Index auto refresh doesn't support Hive table") + + // Checkpoint location is required if mandatory option set + val flintSparkConf = new FlintSparkConf(Collections.emptyMap[String, String]) + val checkpointLocation = options.checkpointLocation() + if (flintSparkConf.isCheckpointMandatory) { + require( + checkpointLocation.isDefined, + s"Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled") + } + + // Checkpoint location must be accessible + if (checkpointLocation.isDefined) { + require( + isCheckpointLocationAccessible(spark, checkpointLocation.get), + s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access") + } + } + override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = { val options = index.options val tableName = index.metadata().source diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala index 3c929d8e3..0c6adb0bd 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FlintSparkIndexRefresh.scala @@ -24,6 +24,20 @@ trait FlintSparkIndexRefresh extends Logging { */ def refreshMode: RefreshMode + /** + * Validates the current index refresh settings before the actual execution begins. This method + * checks for the integrity of the index refresh configurations and ensures that all options set + * for the current refresh mode are valid. This preemptive validation helps in identifying + * configuration issues before the refresh operation is initiated, minimizing runtime errors and + * potential inconsistencies. + * + * @param spark + * Spark session + * @throws IllegalArgumentException + * if any invalid or inapplicable config identified + */ + def validate(spark: SparkSession): Unit + /** * Start refreshing the index. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala index be09c2c36..b2ce2ad34 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/FullIndexRefresh.scala @@ -31,6 +31,11 @@ class FullIndexRefresh( override def refreshMode: RefreshMode = FULL + override def validate(spark: SparkSession): Unit = { + // Full refresh validates nothing for now, including Hive table validation. + // This allows users to continue using their existing Hive table with full refresh only. + } + override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = { logInfo(s"Start refreshing index $indexName in full mode") index diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala index 418ada902..8eb8d6f1f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/refresh/IncrementalIndexRefresh.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.refresh -import org.opensearch.flint.spark.FlintSparkIndex +import org.opensearch.flint.spark.{FlintSparkIndex, FlintSparkValidationHelper} import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{INCREMENTAL, RefreshMode} import org.apache.spark.sql.SparkSession @@ -20,18 +20,31 @@ import org.apache.spark.sql.flint.config.FlintSparkConf * Flint index */ class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex) - extends FlintSparkIndexRefresh { + extends FlintSparkIndexRefresh + with FlintSparkValidationHelper { override def refreshMode: RefreshMode = INCREMENTAL + override def validate(spark: SparkSession): Unit = { + // Non-Hive table is required for incremental refresh + require( + !isTableProviderSupported(spark, index), + "Index incremental refresh doesn't support Hive table") + + // Checkpoint location is required regardless of mandatory option + val options = index.options + val checkpointLocation = options.checkpointLocation() + require( + options.checkpointLocation().nonEmpty, + "Checkpoint location is required by incremental refresh") + require( + isCheckpointLocationAccessible(spark, checkpointLocation.get), + s"Checkpoint location ${checkpointLocation.get} doesn't exist or no permission to access") + } + override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = { logInfo(s"Start refreshing index $indexName in incremental mode") - // TODO: move this to validation method together in future - if (index.options.checkpointLocation().isEmpty) { - throw new IllegalStateException("Checkpoint location is required by incremental refresh") - } - // Reuse auto refresh which uses AvailableNow trigger and will stop once complete val jobId = new AutoIndexRefresh(indexName, index) diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index 149e8128b..3d643dde3 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -9,6 +9,7 @@ import java.util.Optional import scala.collection.JavaConverters._ +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.http.FlintRetryOptions._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -62,6 +63,18 @@ class FlintSparkConfSuite extends FlintSuite { retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException" } + test("test metadata access AWS credentials provider option") { + withSparkConf("spark.metadata.accessAWSCredentialsProvider") { + spark.conf.set( + "spark.metadata.accessAWSCredentialsProvider", + "com.example.MetadataAccessCredentialsProvider") + val flintOptions = FlintSparkConf().flintOptions() + assert(flintOptions.getCustomAwsCredentialsProvider == "") + assert( + flintOptions.getMetadataAccessAwsCredentialsProvider == "com.example.MetadataAccessCredentialsProvider") + } + } + /** * Delete index `indexNames` after calling `f`. */ diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala index 59016d6bc..9697588d4 100644 --- a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -38,7 +38,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { super.beforeAll() // initialized after the container is started osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) - createPartitionedMultiRowTable(testTable) + createPartitionedMultiRowAddressTable(testTable) } protected override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/apache/spark/sql/SparkHiveSupportSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/SparkHiveSupportSuite.scala new file mode 100644 index 000000000..36a0b526d --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/SparkHiveSupportSuite.scala @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.hive.HiveSessionStateBuilder +import org.apache.spark.sql.internal.{SessionState, StaticSQLConf} +import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} + +/** + * Flint Spark base suite with Hive support enabled. Because enabling Hive support in Spark + * configuration alone is not adequate, as [[TestSparkSession]] disregards it and consistently + * creates its own instance of [[org.apache.spark.sql.test.TestSQLSessionStateBuilder]]. We need + * to override its session state with that of Hive in the meanwhile. + * + * Note that we need to extend [[SharedSparkSession]] to call super.sparkConf() method. + */ +trait SparkHiveSupportSuite extends SharedSparkSession { + + override protected def sparkConf: SparkConf = { + super.sparkConf + // Enable Hive support + .set(StaticSQLConf.CATALOG_IMPLEMENTATION.key, "hive") + // Use in-memory Derby as Hive metastore so no need to clean up metastore_db folder after test + .set("javax.jdo.option.ConnectionURL", "jdbc:derby:memory:metastore_db;create=true") + .set("hive.metastore.uris", "") + } + + override protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() + new FlintTestSparkSession(sparkConf) + } + + class FlintTestSparkSession(sparkConf: SparkConf) extends TestSparkSession(sparkConf) { self => + + override lazy val sessionState: SessionState = { + // Override to replace [[TestSQLSessionStateBuilder]] with Hive session state + new HiveSessionStateBuilder(spark, None).build() + } + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index 99a73e627..e5aa7b4d1 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -26,7 +26,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { override def beforeAll(): Unit = { super.beforeAll() - createPartitionedTable(testTable) + createPartitionedAddressTable(testTable) } override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 6991e60d8..dd15624cf 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -31,7 +31,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { override def beforeEach(): Unit = { super.beforeEach() - createPartitionedTable(testTable) + createPartitionedAddressTable(testTable) } override def afterEach(): Unit = { @@ -125,7 +125,7 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { test("create skipping index with auto refresh should fail if mandatory checkpoint enabled") { setFlintSparkConf(CHECKPOINT_MANDATORY, "true") try { - the[IllegalStateException] thrownBy { + the[IllegalArgumentException] thrownBy { sql(s""" | CREATE INDEX $testIndex ON $testTable | (name, age) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala index 98ce7b9b6..7b9624045 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexJobITSuite.scala @@ -25,7 +25,7 @@ class FlintSparkIndexJobITSuite extends OpenSearchTransactionSuite with Matchers override def beforeAll(): Unit = { super.beforeAll() - createPartitionedTable(testTable) + createPartitionedAddressTable(testTable) } override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala index 219e0c900..d6028bcb0 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexMonitorITSuite.scala @@ -30,7 +30,7 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc override def beforeAll(): Unit = { super.beforeAll() - createPartitionedTable(testTable) + createPartitionedAddressTable(testTable) // Replace mock executor with real one and change its delay val realExecutor = newDaemonThreadPoolScheduledExecutor("flint-index-heartbeat", 1) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexValidationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexValidationITSuite.scala new file mode 100644 index 000000000..ee7420d94 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexValidationITSuite.scala @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.util.{Locale, UUID} + +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, INCREMENTAL, RefreshMode} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.scalatest.matchers.must.Matchers.have +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} + +import org.apache.spark.sql.SparkHiveSupportSuite +import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY + +class FlintSparkIndexValidationITSuite extends FlintSparkSuite with SparkHiveSupportSuite { + + // Test Hive table name + private val testTable = "spark_catalog.default.index_validation_test" + + // Test create Flint index name and DDL statement + private val skippingIndexName = FlintSparkSkippingIndex.getSkippingIndexName(testTable) + private val createSkippingIndexStatement = + s"CREATE SKIPPING INDEX ON $testTable (name VALUE_SET)" + + private val coveringIndexName = + FlintSparkCoveringIndex.getFlintIndexName("ci_test", testTable) + private val createCoveringIndexStatement = + s"CREATE INDEX ci_test ON $testTable (name)" + + private val materializedViewName = + FlintSparkMaterializedView.getFlintIndexName("spark_catalog.default.mv_test") + private val createMaterializedViewStatement = + s"CREATE MATERIALIZED VIEW spark_catalog.default.mv_test AS SELECT * FROM $testTable" + + Seq(createSkippingIndexStatement, createCoveringIndexStatement, createMaterializedViewStatement) + .foreach { statement => + test( + s"should fail to create auto refresh Flint index if incremental refresh enabled: $statement") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING) USING JSON") + + the[IllegalArgumentException] thrownBy { + sql(s""" + | $statement + | WITH ( + | auto_refresh = true, + | incremental_refresh = true + | ) + |""".stripMargin) + } should have message + "requirement failed: Incremental refresh cannot be enabled if auto refresh is enabled" + } + } + } + + Seq(createSkippingIndexStatement, createCoveringIndexStatement, createMaterializedViewStatement) + .foreach { statement => + test( + s"should fail to create auto refresh Flint index if checkpoint location mandatory: $statement") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING) USING JSON") + + the[IllegalArgumentException] thrownBy { + try { + setFlintSparkConf(CHECKPOINT_MANDATORY, "true") + sql(s""" + | $statement + | WITH ( + | auto_refresh = true + | ) + |""".stripMargin) + } finally { + setFlintSparkConf(CHECKPOINT_MANDATORY, "false") + } + } should have message + s"requirement failed: Checkpoint location is required if ${CHECKPOINT_MANDATORY.key} option enabled" + } + } + } + + Seq(createSkippingIndexStatement, createCoveringIndexStatement, createMaterializedViewStatement) + .foreach { statement => + test( + s"should fail to create incremental refresh Flint index without checkpoint location: $statement") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING) USING JSON") + + the[IllegalArgumentException] thrownBy { + sql(s""" + | $statement + | WITH ( + | incremental_refresh = true + | ) + |""".stripMargin) + } should have message + "requirement failed: Checkpoint location is required by incremental refresh" + } + } + } + + Seq( + (AUTO, createSkippingIndexStatement), + (AUTO, createCoveringIndexStatement), + (AUTO, createMaterializedViewStatement), + (INCREMENTAL, createSkippingIndexStatement), + (INCREMENTAL, createCoveringIndexStatement), + (INCREMENTAL, createMaterializedViewStatement)) + .foreach { case (refreshMode, statement) => + test( + s"should fail to create $refreshMode refresh Flint index if checkpoint location is inaccessible: $statement") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING) USING JSON") + + // Generate UUID as folder name to ensure the path not exist + val checkpointDir = s"/test/${UUID.randomUUID()}" + the[IllegalArgumentException] thrownBy { + sql(s""" + | $statement + | WITH ( + | ${optionName(refreshMode)} = true, + | checkpoint_location = "$checkpointDir" + | ) + |""".stripMargin) + } should have message + s"requirement failed: Checkpoint location $checkpointDir doesn't exist or no permission to access" + } + } + } + + Seq( + (AUTO, createSkippingIndexStatement), + (AUTO, createCoveringIndexStatement), + (AUTO, createMaterializedViewStatement), + (INCREMENTAL, createSkippingIndexStatement), + (INCREMENTAL, createCoveringIndexStatement), + (INCREMENTAL, createMaterializedViewStatement)) + .foreach { case (refreshMode, statement) => + test(s"should fail to create $refreshMode refresh Flint index on Hive table: $statement") { + withTempDir { checkpointDir => + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING)") + + the[IllegalArgumentException] thrownBy { + sql(s""" + | $statement + | WITH ( + | ${optionName(refreshMode)} = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}' + | ) + |""".stripMargin) + } should have message + s"requirement failed: Index ${lowercase(refreshMode)} refresh doesn't support Hive table" + } + } + } + } + + Seq( + (skippingIndexName, createSkippingIndexStatement), + (coveringIndexName, createCoveringIndexStatement), + (materializedViewName, createMaterializedViewStatement)).foreach { + case (flintIndexName, statement) => + test(s"should succeed to create full refresh Flint index on Hive table: $flintIndexName") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING)") + sql(s"INSERT INTO $testTable VALUES ('test')") + + sql(statement) + flint.refreshIndex(flintIndexName) + flint.queryIndex(flintIndexName).count() shouldBe 1 + } + } + } + + private def lowercase(mode: RefreshMode): String = mode.toString.toLowerCase(Locale.ROOT) + + private def optionName(mode: RefreshMode): String = mode match { + case AUTO => "auto_refresh" + case INCREMENTAL => "incremental_refresh" + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index 16d2b0b07..83fe1546c 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -43,56 +43,58 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { } test("create materialized view with metadata successfully") { - val indexOptions = - FlintSparkIndexOptions( - Map( - "auto_refresh" -> "true", - "checkpoint_location" -> "s3://test/", - "watermark_delay" -> "30 Seconds")) - flint - .materializedView() - .name(testMvName) - .query(testQuery) - .options(indexOptions) - .create() + withTempDir { checkpointDir => + val indexOptions = + FlintSparkIndexOptions( + Map( + "auto_refresh" -> "true", + "checkpoint_location" -> checkpointDir.getAbsolutePath, + "watermark_delay" -> "30 Seconds")) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions) + .create() - val index = flint.describeIndex(testFlintIndex) - index shouldBe defined - index.get.metadata().getContent should matchJson(s""" - | { - | "_meta": { - | "version": "${current()}", - | "name": "spark_catalog.default.mv_test_metrics", - | "kind": "mv", - | "source": "$testQuery", - | "indexedColumns": [ - | { - | "columnName": "startTime", - | "columnType": "timestamp" - | },{ - | "columnName": "count", - | "columnType": "bigint" - | }], - | "options": { - | "auto_refresh": "true", - | "incremental_refresh": "false", - | "checkpoint_location": "s3://test/", - | "watermark_delay": "30 Seconds" - | }, - | "latestId": "$testLatestId", - | "properties": {} - | }, - | "properties": { - | "startTime": { - | "type": "date", - | "format": "strict_date_optional_time_nanos" - | }, - | "count": { - | "type": "long" - | } - | } - | } - |""".stripMargin) + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get.metadata().getContent should matchJson(s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "spark_catalog.default.mv_test_metrics", + | "kind": "mv", + | "source": "$testQuery", + | "indexedColumns": [ + | { + | "columnName": "startTime", + | "columnType": "timestamp" + | },{ + | "columnName": "count", + | "columnType": "bigint" + | }], + | "options": { + | "auto_refresh": "true", + | "incremental_refresh": "false", + | "checkpoint_location": "${checkpointDir.getAbsolutePath}", + | "watermark_delay": "30 Seconds" + | }, + | "latestId": "$testLatestId", + | "properties": {} + | }, + | "properties": { + | "startTime": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "count": { + | "type": "long" + | } + | } + | } + |""".stripMargin) + } } test("full refresh materialized view") { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 8b724fde7..999fb3008 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -32,21 +32,20 @@ import org.apache.spark.sql.internal.SQLConf class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { /** Test table and index name */ - private val testTable = "spark_catalog.default.test" + private val testTable = "spark_catalog.default.skipping_test" private val testIndex = getSkippingIndexName(testTable) private val testLatestId = Base64.getEncoder.encodeToString(testIndex.getBytes) override def beforeEach(): Unit = { super.beforeEach() - createPartitionedMultiRowTable(testTable) + createPartitionedMultiRowAddressTable(testTable) } override def afterEach(): Unit = { - super.afterEach() - // Delete all test indices deleteTestIndex(testIndex) sql(s"DROP TABLE $testTable") + super.afterEach() } test("create skipping index with metadata successfully") { @@ -63,7 +62,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { index shouldBe defined index.get.metadata().getContent should matchJson(s"""{ | "_meta": { - | "name": "flint_spark_catalog_default_test_skipping_index", + | "name": "flint_spark_catalog_default_skipping_test_skipping_index", | "version": "${current()}", | "kind": "skipping", | "indexedColumns": [ @@ -101,7 +100,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "columnName": "name", | "columnType": "string" | }], - | "source": "spark_catalog.default.test", + | "source": "spark_catalog.default.skipping_test", | "options": { | "auto_refresh": "false", | "incremental_refresh": "false" @@ -141,36 +140,39 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { } test("create skipping index with index options successfully") { - flint - .skippingIndex() - .onTable(testTable) - .addValueSet("address") - .options(FlintSparkIndexOptions(Map( - "auto_refresh" -> "true", - "refresh_interval" -> "1 Minute", - "checkpoint_location" -> "s3a://test/", - "index_settings" -> "{\"number_of_shards\": 3,\"number_of_replicas\": 2}"))) - .create() + withTempDir { checkpointDir => + flint + .skippingIndex() + .onTable(testTable) + .addValueSet("address") + .options(FlintSparkIndexOptions(Map( + "auto_refresh" -> "true", + "refresh_interval" -> "1 Minute", + "checkpoint_location" -> checkpointDir.getAbsolutePath, + "index_settings" -> "{\"number_of_shards\": 3,\"number_of_replicas\": 2}"))) + .create() - val index = flint.describeIndex(testIndex) - index shouldBe defined - val optionJson = compact(render(parse(index.get.metadata().getContent) \ "_meta" \ "options")) - optionJson should matchJson(""" - | { - | "auto_refresh": "true", - | "incremental_refresh": "false", - | "refresh_interval": "1 Minute", - | "checkpoint_location": "s3a://test/", - | "index_settings": "{\"number_of_shards\": 3,\"number_of_replicas\": 2}" - | } - |""".stripMargin) + val index = flint.describeIndex(testIndex) + index shouldBe defined + val optionJson = + compact(render(parse(index.get.metadata().getContent) \ "_meta" \ "options")) + optionJson should matchJson(s""" + | { + | "auto_refresh": "true", + | "incremental_refresh": "false", + | "refresh_interval": "1 Minute", + | "checkpoint_location": "${checkpointDir.getAbsolutePath}", + | "index_settings": "{\\"number_of_shards\\": 3,\\"number_of_replicas\\": 2}" + | } + |""".stripMargin) - // Load index options from index mapping (verify OS index setting in SQL IT) - index.get.options.autoRefresh() shouldBe true - index.get.options.refreshInterval() shouldBe Some("1 Minute") - index.get.options.checkpointLocation() shouldBe Some("s3a://test/") - index.get.options.indexSettings() shouldBe - Some("{\"number_of_shards\": 3,\"number_of_replicas\": 2}") + // Load index options from index mapping (verify OS index setting in SQL IT) + index.get.options.autoRefresh() shouldBe true + index.get.options.refreshInterval() shouldBe Some("1 Minute") + index.get.options.checkpointLocation() shouldBe Some(checkpointDir.getAbsolutePath) + index.get.options.indexSettings() shouldBe + Some("{\"number_of_shards\": 3,\"number_of_replicas\": 2}") + } } test("should not have ID column in index data") { @@ -233,16 +235,14 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { } test("should fail if incremental refresh without checkpoint location") { - flint - .skippingIndex() - .onTable(testTable) - .addPartitions("year", "month") - .options(FlintSparkIndexOptions(Map("incremental_refresh" -> "true"))) - .create() - - assertThrows[IllegalStateException] { - flint.refreshIndex(testIndex) - } + the[IllegalArgumentException] thrownBy { + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year", "month") + .options(FlintSparkIndexOptions(Map("incremental_refresh" -> "true"))) + .create() + } should have message "requirement failed: Checkpoint location is required by incremental refresh" } test("auto refresh skipping index successfully") { @@ -479,7 +479,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { // Table name without database name "default" val query = sql(s""" | SELECT name - | FROM test + | FROM skipping_test | WHERE year = 2023 |""".stripMargin) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala index 53d08bda7..cdc599233 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSqlITSuite.scala @@ -31,7 +31,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { override def beforeEach(): Unit = { super.beforeAll() - createPartitionedMultiRowTable(testTable) + createPartitionedMultiRowAddressTable(testTable) } protected override def afterEach(): Unit = { @@ -178,7 +178,7 @@ class FlintSparkSkippingIndexSqlITSuite extends FlintSparkSuite { test("create skipping index with auto refresh should fail if mandatory checkpoint enabled") { setFlintSparkConf(CHECKPOINT_MANDATORY, "true") try { - the[IllegalStateException] thrownBy { + the[IllegalArgumentException] thrownBy { sql(s""" | CREATE SKIPPING INDEX ON $testTable | ( year PARTITION ) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 4ab3a983b..0c6282bb6 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -5,9 +5,13 @@ package org.opensearch.flint.spark +import java.nio.file.{Files, Path, Paths, StandardCopyOption} +import java.util.Comparator import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture} +import scala.collection.immutable.Map import scala.concurrent.duration.TimeUnit +import scala.util.Try import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when @@ -16,9 +20,11 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest import org.opensearch.client.RequestOptions import org.opensearch.client.indices.GetIndexRequest import org.opensearch.flint.OpenSearchSuite +import org.scalatest.prop.TableDrivenPropertyChecks.forAll import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.flint.config.FlintSparkConf.{CHECKPOINT_MANDATORY, HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY} import org.apache.spark.sql.streaming.StreamTest @@ -30,17 +36,22 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit /** Flint Spark high level API being tested */ lazy protected val flint: FlintSpark = new FlintSpark(spark) + lazy protected val tableType: String = "CSV" + lazy protected val tableOptions: String = "OPTIONS (header 'false', delimiter '\t')" + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + .set(HOST_ENDPOINT.key, openSearchHost) + .set(HOST_PORT.key, openSearchPort.toString) + .set(REFRESH_POLICY.key, "true") + // Disable mandatory checkpoint for test convenience + .set(CHECKPOINT_MANDATORY.key, "false") + conf + } override def beforeAll(): Unit = { super.beforeAll() - setFlintSparkConf(HOST_ENDPOINT, openSearchHost) - setFlintSparkConf(HOST_PORT, openSearchPort) - setFlintSparkConf(REFRESH_POLICY, "true") - - // Disable mandatory checkpoint for test convenience - setFlintSparkConf(CHECKPOINT_MANDATORY, "false") - // Replace executor to avoid impact on IT. // TODO: Currently no IT test scheduler so no need to restore it back. val mockExecutor = mock[ScheduledExecutorService] @@ -73,6 +84,16 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit }) } + def deleteDirectory(dirPath: String): Try[Unit] = { + Try { + val directory = Paths.get(dirPath) + Files + .walk(directory) + .sorted(Comparator.reverseOrder()) + .forEach(Files.delete(_)) + } + } + protected def awaitStreamingComplete(jobId: String): Unit = { val job = spark.streams.get(jobId) failAfter(streamingTimeout) { @@ -80,7 +101,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } } - protected def createPartitionedTable(testTable: String): Unit = { + protected def createPartitionedAddressTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable | ( @@ -88,11 +109,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | age INT, | address STRING | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) + | USING $tableType $tableOptions | PARTITIONED BY ( | year INT, | month INT @@ -112,24 +129,20 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } - protected def createPartitionedMultiRowTable(testTable: String): Unit = { - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | address STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) + protected def createPartitionedMultiRowAddressTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | address STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) // Use hint to insert all rows in a single csv file sql(s""" @@ -152,21 +165,103 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createPartitionedStateCountryTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected def createOccupationTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | occupation STRING, + | country STRING, + | salary INT + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'Engineer', 'England' , 100000), + | ('Hello', 'Artist', 'USA', 70000), + | ('John', 'Doctor', 'Canada', 120000), + | ('David', 'Doctor', 'USA', 120000), + | ('David', 'Unemployed', 'Canada', 0), + | ('Jane', 'Scientist', 'Canada', 90000) + | """.stripMargin) + } + + protected def createHobbiesTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | country STRING, + | hobby STRING, + | language STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'USA', 'Fishing', 'English'), + | ('Hello', 'USA', 'Painting', 'English'), + | ('John', 'Canada', 'Reading', 'French'), + | ('Jim', 'Canada', 'Hiking', 'English'), + | ('Peter', 'Canada', 'Gaming', 'English'), + | ('Rick', 'USA', 'Swimming', 'English'), + | ('David', 'USA', 'Gardening', 'English'), + | ('Jane', 'Canada', 'Singing', 'French') + | """.stripMargin) + } + protected def createTimeSeriesTable(testTable: String): Unit = { sql(s""" - | CREATE TABLE $testTable - | ( - | time TIMESTAMP, - | name STRING, - | age INT, - | address STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - |""".stripMargin) + | CREATE TABLE $testTable + | ( + | time TIMESTAMP, + | name STRING, + | age INT, + | address STRING + | ) + | USING $tableType $tableOptions + |""".stripMargin) sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:01:00', 'A', 30, 'Seattle')") sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:10:00', 'B', 20, 'Seattle')") @@ -175,6 +270,48 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 03:00:00', 'E', 15, 'Vancouver')") } + protected def createTimeSeriesTransactionTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | transactionId STRING, + | transactionDate TIMESTAMP, + | productId STRING, + | productsAmount INT, + | customerId STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for April 2023 + sql(s""" + | INSERT INTO $testTable PARTITION (year=2023, month=4) + | VALUES + | ('txn001', CAST('2023-04-01 10:30:00' AS TIMESTAMP), 'prod1', 2, 'cust1'), + | ('txn001', CAST('2023-04-01 14:30:00' AS TIMESTAMP), 'prod1', 4, 'cust1'), + | ('txn002', CAST('2023-04-02 11:45:00' AS TIMESTAMP), 'prod2', 1, 'cust2'), + | ('txn003', CAST('2023-04-03 12:15:00' AS TIMESTAMP), 'prod3', 3, 'cust1'), + | ('txn004', CAST('2023-04-04 09:50:00' AS TIMESTAMP), 'prod1', 1, 'cust3') + | """.stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for May 2023 + sql(s""" + | INSERT INTO $testTable PARTITION (year=2023, month=5) + | VALUES + | ('txn005', CAST('2023-05-01 08:30:00' AS TIMESTAMP), 'prod2', 1, 'cust4'), + | ('txn006', CAST('2023-05-02 07:25:00' AS TIMESTAMP), 'prod4', 5, 'cust2'), + | ('txn007', CAST('2023-05-03 15:40:00' AS TIMESTAMP), 'prod3', 1, 'cust3'), + | ('txn007', CAST('2023-05-03 19:30:00' AS TIMESTAMP), 'prod3', 2, 'cust3'), + | ('txn008', CAST('2023-05-04 14:15:00' AS TIMESTAMP), 'prod1', 4, 'cust1') + | """.stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala index 7ad03f84a..b2d489c81 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionITSuite.scala @@ -26,7 +26,7 @@ class FlintSparkTransactionITSuite extends OpenSearchTransactionSuite with Match override def beforeAll(): Unit = { super.beforeAll() - createPartitionedTable(testTable) + createPartitionedAddressTable(testTable) } override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala index c5ac0ab95..76da7e8c3 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkUpdateIndexITSuite.scala @@ -22,7 +22,7 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { override def beforeEach(): Unit = { super.beforeEach() - createPartitionedMultiRowTable(testTable) + createPartitionedMultiRowAddressTable(testTable) } override def afterEach(): Unit = { @@ -199,12 +199,13 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { Map( "auto_refresh" -> "false", "incremental_refresh" -> "true", - "refresh_interval" -> "1 Minute"), + "refresh_interval" -> "1 Minute", + "checkpoint_location" -> "s3a://test/"), Map( "auto_refresh" -> false, "incremental_refresh" -> true, "refresh_interval" -> Some("1 Minute"), - "checkpoint_location" -> None, + "checkpoint_location" -> Some("s3a://test/"), "watermark_delay" -> None)), ( Map("auto_refresh" -> "true"), @@ -223,12 +224,13 @@ class FlintSparkUpdateIndexITSuite extends FlintSparkSuite { Map( "auto_refresh" -> "false", "incremental_refresh" -> "true", + "checkpoint_location" -> "s3a://test/", "watermark_delay" -> "1 Minute"), Map( "auto_refresh" -> false, "incremental_refresh" -> true, "refresh_interval" -> None, - "checkpoint_location" -> None, + "checkpoint_location" -> Some("s3a://test/"), "watermark_delay" -> Some("1 Minute"))))), ( "convert to auto refresh with allowed options", diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergCoveringIndexITSuite.scala new file mode 100644 index 000000000..2675ef0cd --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergCoveringIndexITSuite.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.iceberg + +import org.opensearch.flint.spark.FlintSparkCoveringIndexSqlITSuite + +class FlintSparkIcebergCoveringIndexITSuite + extends FlintSparkCoveringIndexSqlITSuite + with FlintSparkIcebergSuite {} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergMaterializedViewITSuite.scala new file mode 100644 index 000000000..ffb8a7d1b --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergMaterializedViewITSuite.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.iceberg + +import org.opensearch.flint.spark.FlintSparkMaterializedViewSqlITSuite + +class FlintSparkIcebergMaterializedViewITSuite + extends FlintSparkMaterializedViewSqlITSuite + with FlintSparkIcebergSuite {} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSkippingIndexITSuite.scala new file mode 100644 index 000000000..ba24e3b2b --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSkippingIndexITSuite.scala @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.iceberg + +import org.opensearch.flint.spark.FlintSparkSkippingIndexSqlITSuite + +class FlintSparkIcebergSkippingIndexITSuite + extends FlintSparkSkippingIndexSqlITSuite + with FlintSparkIcebergSuite {} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala new file mode 100644 index 000000000..2ae0d157a --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/iceberg/FlintSparkIcebergSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.iceberg + +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions +import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintSparkSuite + +import org.apache.spark.SparkConf + +/** + * Flint Spark suite tailored for Iceberg. + */ +trait FlintSparkIcebergSuite extends FlintSparkSuite { + + // Override table type to Iceberg for this suite + override lazy protected val tableType: String = "iceberg" + + // You can also override tableOptions if Iceberg requires different options + override lazy protected val tableOptions: String = "" + + // Override the sparkConf method to include Iceberg-specific configurations + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + // Set Iceberg-specific Spark configurations + .set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .set("spark.sql.catalog.spark_catalog.type", "hadoop") + .set("spark.sql.catalog.spark_catalog.warehouse", s"spark-warehouse/${suiteName}") + .set( + "spark.sql.extensions", + List( + classOf[IcebergSparkSessionExtensions].getName, + classOf[FlintSparkExtensions].getName).mkString(", ")) + conf + } + + override def afterAll(): Unit = { + deleteDirectory(s"spark-warehouse/${suiteName}") + super.afterAll() + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala index be43447fe..1ece33ce1 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintPPLSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.ppl -import org.opensearch.flint.spark.{FlintPPLSparkExtensions, FlintSparkExtensions} +import org.opensearch.flint.spark.{FlintPPLSparkExtensions, FlintSparkExtensions, FlintSparkSuite} import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode @@ -14,17 +14,9 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.OPTIMIZER_RULE_ENABLED import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -trait FlintPPLSuite extends SharedSparkSession { +trait FlintPPLSuite extends FlintSparkSuite { override protected def sparkConf: SparkConf = { - val conf = new SparkConf() - .set("spark.ui.enabled", "false") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) + val conf = super.sparkConf .set( "spark.sql.extensions", List(classOf[FlintPPLSparkExtensions].getName, classOf[FlintSparkExtensions].getName) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 8dfde6c94..b3abf8438 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -24,35 +24,7 @@ class FlintSparkPPLAggregationWithSpanITSuite super.beforeAll() // Create test table - // Update table creation - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | state STRING, - | country STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Update data insertion - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70, 'California', 'USA'), - | ('Hello', 30, 'New York', 'USA'), - | ('John', 25, 'Ontario', 'Canada'), - | ('Jane', 20, 'Quebec', 'Canada') - | """.stripMargin) + createPartitionedStateCountryTable(testTable) } protected override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index e8533d831..745c354eb 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -24,35 +24,7 @@ class FlintSparkPPLAggregationsITSuite super.beforeAll() // Create test table - // Update table creation - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | state STRING, - | country STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Update data insertion - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70, 'California', 'USA'), - | ('Hello', 30, 'New York', 'USA'), - | ('John', 25, 'Ontario', 'Canada'), - | ('Jane', 20, 'Quebec', 'Canada') - | """.stripMargin) + createPartitionedStateCountryTable(testTable) } protected override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index 8f1d1bd1f..ba925339e 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -22,36 +22,9 @@ class FlintSparkPPLBasicITSuite override def beforeAll(): Unit = { super.beforeAll() + // Create test table - // Update table creation - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | state STRING, - | country STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Update data insertion - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70, 'California', 'USA'), - | ('Hello', 30, 'New York', 'USA'), - | ('John', 25, 'Ontario', 'Canada'), - | ('Jane', 20, 'Quebec', 'Canada') - | """.stripMargin) + createPartitionedStateCountryTable(testTable) } protected override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index 575f09362..38fdcdbb9 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -28,100 +28,19 @@ class FlintSparkPPLCorrelationITSuite override def beforeAll(): Unit = { super.beforeAll() // Create test tables - sql(s""" - | CREATE TABLE $testTable1 - | ( - | name STRING, - | age INT, - | state STRING, - | country STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - sql(s""" - | CREATE TABLE $testTable2 - | ( - | name STRING, - | occupation STRING, - | country STRING, - | salary INT - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - + createPartitionedStateCountryTable(testTable1) // Update data insertion sql(s""" | INSERT INTO $testTable1 | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70, 'California', 'USA'), - | ('Hello', 30, 'New York', 'USA'), - | ('John', 25, 'Ontario', 'Canada'), - | ('Jim', 27, 'B.C', 'Canada'), + | VALUES ('Jim', 27, 'B.C', 'Canada'), | ('Peter', 57, 'B.C', 'Canada'), | ('Rick', 70, 'B.C', 'Canada'), - | ('David', 40, 'Washington', 'USA'), - | ('Jane', 20, 'Quebec', 'Canada') - | """.stripMargin) - // Insert data into the new table - sql(s""" - | INSERT INTO $testTable2 - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 'Engineer', 'England' , 100000), - | ('Hello', 'Artist', 'USA', 70000), - | ('John', 'Doctor', 'Canada', 120000), - | ('David', 'Doctor', 'USA', 120000), - | ('David', 'Unemployed', 'Canada', 0), - | ('Jane', 'Scientist', 'Canada', 90000) + | ('David', 40, 'Washington', 'USA') | """.stripMargin) - sql(s""" - | CREATE TABLE $testTable3 - | ( - | name STRING, - | country STRING, - | hobby STRING, - | language STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Insert data into the new table - sql(s""" - | INSERT INTO $testTable3 - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 'USA', 'Fishing', 'English'), - | ('Hello', 'USA', 'Painting', 'English'), - | ('John', 'Canada', 'Reading', 'French'), - | ('Jim', 'Canada', 'Hiking', 'English'), - | ('Peter', 'Canada', 'Gaming', 'English'), - | ('Rick', 'USA', 'Swimming', 'English'), - | ('David', 'USA', 'Gardening', 'English'), - | ('Jane', 'Canada', 'Singing', 'French') - | """.stripMargin) + + createOccupationTable(testTable2) + createHobbiesTable(testTable3) } protected override def afterEach(): Unit = { @@ -701,7 +620,15 @@ class FlintSparkPPLCorrelationITSuite Row(70000.0, "Canada", 50L), Row(95000.0, "USA", 40L)) - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](2)) + // Define ordering for rows that first compares by age then by name + implicit val rowOrdering: Ordering[Row] = new Ordering[Row] { + def compare(x: Row, y: Row): Int = { + val ageCompare = x.getAs[Long](2).compareTo(y.getAs[Long](2)) + if (ageCompare != 0) ageCompare + else x.getAs[String](1).compareTo(y.getAs[String](1)) + } + } + // Compare the results assert(results.sorted.sameElements(expectedResults.sorted)) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 32c1baa0a..236c216cf 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -23,35 +23,7 @@ class FlintSparkPPLFiltersITSuite override def beforeAll(): Unit = { super.beforeAll() // Create test table - // Update table creation - sql(s""" - | CREATE TABLE $testTable - | ( - | name STRING, - | age INT, - | state STRING, - | country STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Update data insertion - sql(s""" - | INSERT INTO $testTable - | PARTITION (year=2023, month=4) - | VALUES ('Jake', 70, 'California', 'USA'), - | ('Hello', 30, 'New York', 'USA'), - | ('John', 25, 'Ontario', 'Canada'), - | ('Jane', 20, 'Quebec', 'Canada') - | """.stripMargin) + createPartitionedStateCountryTable(testTable) } protected override def afterEach(): Unit = { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala index df77e0d90..fbae03fff 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala @@ -26,49 +26,7 @@ class FlintSparkPPLTimeWindowITSuite super.beforeAll() // Create test table // Update table creation - sql(s""" - | CREATE TABLE $testTable - | ( - | transactionId STRING, - | transactionDate TIMESTAMP, - | productId STRING, - | productsAmount INT, - | customerId STRING - | ) - | USING CSV - | OPTIONS ( - | header 'false', - | delimiter '\t' - | ) - | PARTITIONED BY ( - | year INT, - | month INT - | ) - |""".stripMargin) - - // Update data insertion - // -- Inserting records into the testTable for April 2023 - sql(s""" - |INSERT INTO $testTable PARTITION (year=2023, month=4) - |VALUES - |('txn001', CAST('2023-04-01 10:30:00' AS TIMESTAMP), 'prod1', 2, 'cust1'), - |('txn001', CAST('2023-04-01 14:30:00' AS TIMESTAMP), 'prod1', 4, 'cust1'), - |('txn002', CAST('2023-04-02 11:45:00' AS TIMESTAMP), 'prod2', 1, 'cust2'), - |('txn003', CAST('2023-04-03 12:15:00' AS TIMESTAMP), 'prod3', 3, 'cust1'), - |('txn004', CAST('2023-04-04 09:50:00' AS TIMESTAMP), 'prod1', 1, 'cust3') - | """.stripMargin) - - // Update data insertion - // -- Inserting records into the testTable for May 2023 - sql(s""" - |INSERT INTO $testTable PARTITION (year=2023, month=5) - |VALUES - |('txn005', CAST('2023-05-01 08:30:00' AS TIMESTAMP), 'prod2', 1, 'cust4'), - |('txn006', CAST('2023-05-02 07:25:00' AS TIMESTAMP), 'prod4', 5, 'cust2'), - |('txn007', CAST('2023-05-03 15:40:00' AS TIMESTAMP), 'prod3', 1, 'cust3'), - |('txn007', CAST('2023-05-03 19:30:00' AS TIMESTAMP), 'prod3', 2, 'cust3'), - |('txn008', CAST('2023-05-04 14:15:00' AS TIMESTAMP), 'prod1', 4, 'cust1') - | """.stripMargin) + createTimeSeriesTransactionTable(testTable) } protected override def afterEach(): Unit = { @@ -274,12 +232,16 @@ class FlintSparkPPLTimeWindowITSuite "prod3", Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-04 17:00:00"))) - // Compare the results - implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + + // Define ordering for rows that first compares by the timestamp and then by the productId + implicit val rowOrdering: Ordering[Row] = new Ordering[Row] { + def compare(x: Row, y: Row): Int = { + val dateCompare = x.getAs[Timestamp](2).compareTo(y.getAs[Timestamp](2)) + if (dateCompare != 0) dateCompare + else x.getAs[String](1).compareTo(y.getAs[String](1)) + } } - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](2)) assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan diff --git a/project/Dependencies.scala b/project/Dependencies.scala index db92cf78f..047afb64c 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -12,6 +12,7 @@ object Dependencies { "org.apache.spark" %% "spark-core" % sparkVersion % "provided" withSources (), "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" withSources (), "org.json4s" %% "json4s-native" % "3.7.0-M5" % "test", + "org.apache.spark" %% "spark-hive" % sparkVersion % "test", "org.apache.spark" %% "spark-catalyst" % sparkVersion % "test" classifier "tests", "org.apache.spark" %% "spark-core" % sparkVersion % "test" classifier "tests", "org.apache.spark" %% "spark-sql" % sparkVersion % "test" classifier "tests") diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 66859c9f4..6ee7cc68e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -6,14 +6,15 @@ // defined in spark package so that I can use ThreadUtils package org.apache.spark.sql -import org.apache.spark.internal.Logging -import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.types._ +import java.util.concurrent.atomic.AtomicInteger + import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import play.api.libs.json._ -import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.internal.Logging +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.types._ /** * Spark SQL Application entrypoint @@ -46,7 +47,7 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, "")) + val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) if (query.isEmpty) { throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.") } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 915de4089..c1d2bf79b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -8,6 +8,7 @@ package org.apache.spark.sql import java.util.Locale import com.amazonaws.services.s3.model.AmazonS3Exception +import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter @@ -359,6 +360,14 @@ trait FlintJobExecutor { } } + /** + * Unescape the query string which is escaped for EMR spark submit parameter parsing. Ref: + * https://github.com/opensearch-project/sql/pull/2587 + */ + def unescapeQuery(query: String): String = { + unescapeJava(query) + } + def executeQuery( spark: SparkSession, query: String, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index b4c27f3c8..5a0918d4a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -249,7 +249,7 @@ object FlintREPL extends Logging with FlintJobExecutor { if (defaultQuery.isEmpty) { throw new IllegalArgumentException("Query undefined for the streaming job.") } - defaultQuery + unescapeQuery(defaultQuery) } else "" } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 421457c4e..288eeb7c5 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -95,6 +95,18 @@ class FlintREPLTest query shouldBe "SELECT * FROM table" } + test( + "getQuery should return unescaped default query for streaming job if queryOption is None") { + val queryOption = None + val jobType = "streaming" + val conf = new SparkConf().set( + FlintSparkConf.QUERY.key, + "SELECT \\\"1\\\" UNION SELECT '\\\"1\\\"' UNION SELECT \\\"\\\\\\\"1\\\\\\\"\\\"") + + val query = FlintREPL.getQuery(queryOption, jobType, conf) + query shouldBe "SELECT \"1\" UNION SELECT '\"1\"' UNION SELECT \"\\\"1\\\"\"" + } + test( "getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") { val queryOption = None