Skip to content

Commit

Permalink
Add queryId Spark parameter to batch query (#2952)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored Sep 4, 2024
1 parent c13f770 commit a83ab20
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class SparkConstants {
"com.amazonaws.emr.AssumeRoleAWSCredentialsProvider";
public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/";
public static final String FLINT_JOB_QUERY = "spark.flint.job.query";
public static final String FLINT_JOB_QUERY_ID = "spark.flint.job.queryId";
public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex";
public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public DispatchQueryResponse submit(
sparkSubmitParametersBuilderProvider
.getSparkSubmitParametersBuilder()
.clusterName(clusterName)
.queryId(context.getQueryId())
.query(dispatchQueryRequest.getQuery())
.dataSource(
context.getDataSourceMetadata(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public DispatchQueryResponse submit(
sparkSubmitParametersBuilderProvider
.getSparkSubmitParametersBuilder()
.clusterName(clusterName)
.queryId(context.getQueryId())
.query(dispatchQueryRequest.getQuery())
.structuredStreaming(true)
.dataSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY_ID;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_PPL_EXTENSION;
Expand Down Expand Up @@ -108,6 +109,11 @@ public SparkSubmitParametersBuilder query(String query) {
return this;
}

public SparkSubmitParametersBuilder queryId(String queryId) {
setConfigItem(FLINT_JOB_QUERY_ID, queryId);
return this;
}

public SparkSubmitParametersBuilder dataSource(
DataSourceMetadata metadata,
DispatchQueryRequest dispatchQueryRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,14 @@ void setUp() {
@Test
void testDispatchSelectQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "select * from my_glue.default.http_logs";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -206,12 +208,14 @@ void testDispatchSelectQuery() {
@Test
void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "select * from my_glue.default.http_logs";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -310,6 +314,7 @@ void testDispatchSelectQueryFailedCreateSession() {
@Test
void testDispatchCreateAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand All @@ -318,7 +323,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
String query =
"CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH"
+ " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index",
Expand Down Expand Up @@ -347,14 +353,16 @@ void testDispatchCreateAutoRefreshIndexQuery() {
@Test
void testDispatchCreateManualRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query =
"CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH"
+ " (auto_refresh = false)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -383,12 +391,14 @@ void testDispatchCreateManualRefreshIndexQuery() {
@Test
void testDispatchWithPPLQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "source = my_glue.default.http_logs";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -448,12 +458,14 @@ void testDispatchWithSparkUDFQuery() {
@Test
void testInvalidSQLQueryDispatchToSpark() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "myselect 1";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -492,12 +504,14 @@ void testInvalidSQLQueryDispatchToSpark() {
@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "show tables";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -526,6 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() {
@Test
void testDispatchIndexQueryWithoutADatasourceName() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand All @@ -534,7 +549,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
String query =
"CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH"
+ " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index",
Expand Down Expand Up @@ -563,14 +579,16 @@ void testDispatchIndexQueryWithoutADatasourceName() {
@Test
void testDispatchMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_mv_1");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:streaming:flint_mv_1",
Expand Down Expand Up @@ -599,12 +617,14 @@ void testDispatchMaterializedViewQuery() {
@Test
void testDispatchShowMVQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "SHOW MATERIALIZED VIEW IN mys3.default";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -633,12 +653,14 @@ void testDispatchShowMVQuery() {
@Test
void testRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -667,12 +689,14 @@ void testRefreshIndexQuery() {
@Test
void testDispatchDescribeIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, null, QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
Expand Down Expand Up @@ -701,6 +725,7 @@ void testDispatchDescribeIndexQuery() {
@Test
void testDispatchAlterToAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand All @@ -709,7 +734,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
String query =
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index",
Expand Down Expand Up @@ -1048,10 +1074,11 @@ void testDispatchQueryWithExtraSparkSubmitParameters() {
}

private String constructExpectedSparkSubmitParameterString(String query) {
return constructExpectedSparkSubmitParameterString(query, null);
return constructExpectedSparkSubmitParameterString(query, null, null);
}

private String constructExpectedSparkSubmitParameterString(String query, String jobType) {
private String constructExpectedSparkSubmitParameterString(
String query, String jobType, String queryId) {
query = "\"" + query + "\"";
return " --class org.apache.spark.sql.FlintJob "
+ getConfParam(
Expand All @@ -1070,6 +1097,7 @@ private String constructExpectedSparkSubmitParameterString(String query, String
"spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider",
"spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions,org.opensearch.flint.spark.FlintPPLSparkExtensions",
"spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")
+ (queryId != null ? getConfParam("spark.flint.job.queryId=" + queryId) : "")
+ getConfParam("spark.flint.job.query=" + query)
+ (jobType != null ? getConfParam("spark.flint.job.type=" + jobType) : "")
+ getConfParam(
Expand Down

0 comments on commit a83ab20

Please sign in to comment.