From a83ab20d90009a9fd30580f2356f28e415d55752 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 3 Sep 2024 22:21:07 -0700 Subject: [PATCH] Add queryId Spark parameter to batch query (#2952) Signed-off-by: Tomoyuki Morita --- .../spark/data/constants/SparkConstants.java | 1 + .../spark/dispatcher/BatchQueryHandler.java | 1 + .../dispatcher/StreamingQueryHandler.java | 1 + .../SparkSubmitParametersBuilder.java | 6 ++ .../dispatcher/SparkQueryDispatcherTest.java | 58 ++++++++++++++----- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index e87dbba03e..9b82022d8f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -86,6 +86,7 @@ public class SparkConstants { "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; public static final String FLINT_JOB_QUERY = "spark.flint.job.query"; + public static final String FLINT_JOB_QUERY_ID = "spark.flint.job.queryId"; public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 36e4c227b8..c693656150 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -91,6 +91,7 @@ public DispatchQueryResponse submit( sparkSubmitParametersBuilderProvider .getSparkSubmitParametersBuilder() .clusterName(clusterName) + .queryId(context.getQueryId()) .query(dispatchQueryRequest.getQuery()) .dataSource( context.getDataSourceMetadata(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 80d4be27cf..51e245b57c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -82,6 +82,7 @@ public DispatchQueryResponse submit( sparkSubmitParametersBuilderProvider .getSparkSubmitParametersBuilder() .clusterName(clusterName) + .queryId(context.getQueryId()) .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) .dataSource( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java index d9d5859f64..db74d0a5a7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java @@ -20,6 +20,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY_ID; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_PPL_EXTENSION; @@ -108,6 +109,11 @@ public SparkSubmitParametersBuilder query(String query) { return this; } + public SparkSubmitParametersBuilder queryId(String queryId) { + setConfigItem(FLINT_JOB_QUERY_ID, queryId); + return this; + } + public SparkSubmitParametersBuilder dataSource( DataSourceMetadata metadata, DispatchQueryRequest dispatchQueryRequest, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b6369292a6..1587ce6638 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -162,12 +162,14 @@ void setUp() { @Test void testDispatchSelectQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -206,12 +208,14 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -310,6 +314,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchCreateAutoRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -318,7 +323,8 @@ void testDispatchCreateAutoRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -347,6 +353,7 @@ void testDispatchCreateAutoRefreshIndexQuery() { @Test void testDispatchCreateManualRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -354,7 +361,8 @@ void testDispatchCreateManualRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = false)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -383,12 +391,14 @@ void testDispatchCreateManualRefreshIndexQuery() { @Test void testDispatchWithPPLQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -448,12 +458,14 @@ void testDispatchWithSparkUDFQuery() { @Test void testInvalidSQLQueryDispatchToSpark() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "myselect 1"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -492,12 +504,14 @@ void testInvalidSQLQueryDispatchToSpark() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "show tables"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -526,6 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -534,7 +549,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -563,6 +579,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -570,7 +587,8 @@ void testDispatchMaterializedViewQuery() { tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); String query = "CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_mv_1", @@ -599,12 +617,14 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "SHOW MATERIALIZED VIEW IN mys3.default"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -633,12 +653,14 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -667,12 +689,14 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -701,6 +725,7 @@ void testDispatchDescribeIndexQuery() { @Test void testDispatchAlterToAutoRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -709,7 +734,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() { String query = "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -1048,10 +1074,11 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } private String constructExpectedSparkSubmitParameterString(String query) { - return constructExpectedSparkSubmitParameterString(query, null); + return constructExpectedSparkSubmitParameterString(query, null, null); } - private String constructExpectedSparkSubmitParameterString(String query, String jobType) { + private String constructExpectedSparkSubmitParameterString( + String query, String jobType, String queryId) { query = "\"" + query + "\""; return " --class org.apache.spark.sql.FlintJob " + getConfParam( @@ -1070,6 +1097,7 @@ private String constructExpectedSparkSubmitParameterString(String query, String "spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider", "spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions,org.opensearch.flint.spark.FlintPPLSparkExtensions", "spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory") + + (queryId != null ? getConfParam("spark.flint.job.queryId=" + queryId) : "") + getConfParam("spark.flint.job.query=" + query) + (jobType != null ? getConfParam("spark.flint.job.type=" + jobType) : "") + getConfParam(