diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index 7caa69293a..ae82386c3f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.asyncquery; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -20,7 +21,8 @@ public interface AsyncQueryExecutorService { * @param createAsyncQueryRequest createAsyncQueryRequest. * @return {@link CreateAsyncQueryResponse} */ - CreateAsyncQueryResponse createAsyncQuery(CreateAsyncQueryRequest createAsyncQueryRequest); + CreateAsyncQueryResponse createAsyncQuery( + CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext); /** * Returns async query response for a given queryId. diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index f2d8bdc2c5..e4818d737c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -18,6 +18,7 @@ import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -36,9 +37,9 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService @Override public CreateAsyncQueryResponse createAsyncQuery( - CreateAsyncQueryRequest createAsyncQueryRequest) { + CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) { SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -48,7 +49,7 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameters(), + sparkExecutionEngineConfig.getSparkSubmitParameterModifier(), createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( AsyncQueryJobMetadata.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java new file mode 100644 index 0000000000..e106f57cff --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +/** An implementation of RequestContext for where context is not required */ +public class NullRequestContext implements RequestContext { + @Override + public Object getAttribute(String name) { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java new file mode 100644 index 0000000000..3a0f350701 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +/** Context interface to provide additional request related information */ +public interface RequestContext { + Object getAttribute(String name); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index d54b6c29af..6badea6a74 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -21,11 +21,13 @@ import java.util.function.Supplier; import lombok.AllArgsConstructor; import lombok.RequiredArgsConstructor; +import lombok.Setter; import org.apache.commons.lang3.BooleanUtils; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; /** Define Spark Submit Parameters. */ @@ -40,7 +42,24 @@ public class SparkSubmitParameters { private final Map config; /** Extra parameters to append finally */ - private String extraParameters; + @Setter private String extraParameters; + + public void setConfigItem(String key, String value) { + config.put(key, value); + } + + public void deleteConfigItem(String key) { + config.remove(key); + } + + public static Builder builder() { + return Builder.builder(); + } + + public SparkSubmitParameters acceptModifier(SparkSubmitParameterModifier modifier) { + modifier.modifyParameters(this); + return this; + } public static class Builder { @@ -180,17 +199,16 @@ public Builder extraParameters(String params) { return this; } - public Builder sessionExecution(String sessionId, String datasourceName) { - config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); - config.put(FLINT_JOB_SESSION_ID, sessionId); - return this; - } - public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } } + public void sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); + config.put(FLINT_JOB_SESSION_ID, sessionId); + } + @Override public String toString() { StringBuilder stringBuilder = new StringBuilder(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java index e0cc5ea397..4250d32b0e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -13,6 +13,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -32,7 +33,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor @Override public EMRServerlessClient getClient() { SparkExecutionEngineConfig sparkExecutionEngineConfig = - this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig( + new NullRequestContext()); validateSparkExecutionEngineConfig(sparkExecutionEngineConfig); if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) { region = sparkExecutionEngineConfig.getRegion(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java b/spark/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java new file mode 100644 index 0000000000..f1831c9786 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java @@ -0,0 +1,15 @@ +package org.opensearch.sql.spark.config; + +import lombok.AllArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; + +@AllArgsConstructor +public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParameterModifier { + + private String extraParameters; + + @Override + public void modifyParameters(SparkSubmitParameters parameters) { + parameters.setExtraParameters(this.extraParameters); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java index 537a635150..92636c3cfb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -1,8 +1,8 @@ package org.opensearch.sql.spark.config; import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; /** * POJO for spark Execution Engine Config. Interface between {@link @@ -10,12 +10,12 @@ * SparkExecutionEngineConfigSupplier} */ @Data -@NoArgsConstructor +@Builder @AllArgsConstructor public class SparkExecutionEngineConfig { private String applicationId; private String region; private String executionRoleARN; - private String sparkSubmitParameters; + private SparkSubmitParameterModifier sparkSubmitParameterModifier; private String clusterName; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java index 108cb07daf..b5d061bad3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java @@ -1,5 +1,7 @@ package org.opensearch.sql.spark.config; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; + /** Interface for extracting and providing SparkExecutionEngineConfig */ public interface SparkExecutionEngineConfigSupplier { @@ -8,5 +10,5 @@ public interface SparkExecutionEngineConfigSupplier { * * @return {@link SparkExecutionEngineConfig}. */ - SparkExecutionEngineConfig getSparkExecutionEngineConfig(); + SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java index f4c32f24eb..70d628b958 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java @@ -9,6 +9,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; @AllArgsConstructor public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier { @@ -16,27 +17,26 @@ public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEng private Settings settings; @Override - public SparkExecutionEngineConfig getSparkExecutionEngineConfig() { + public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) { String sparkExecutionEngineConfigSettingString = this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); - SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); + SparkExecutionEngineConfig.SparkExecutionEngineConfigBuilder builder = + SparkExecutionEngineConfig.builder(); if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { - SparkExecutionEngineConfigClusterSetting sparkExecutionEngineConfigClusterSetting = + SparkExecutionEngineConfigClusterSetting setting = AccessController.doPrivileged( (PrivilegedAction) () -> SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( sparkExecutionEngineConfigSettingString)); - sparkExecutionEngineConfig.setApplicationId( - sparkExecutionEngineConfigClusterSetting.getApplicationId()); - sparkExecutionEngineConfig.setExecutionRoleARN( - sparkExecutionEngineConfigClusterSetting.getExecutionRoleARN()); - sparkExecutionEngineConfig.setSparkSubmitParameters( - sparkExecutionEngineConfigClusterSetting.getSparkSubmitParameters()); - sparkExecutionEngineConfig.setRegion(sparkExecutionEngineConfigClusterSetting.getRegion()); + builder.applicationId(setting.getApplicationId()); + builder.executionRoleARN(setting.getExecutionRoleARN()); + builder.sparkSubmitParameterModifier( + new OpenSearchSparkSubmitParameterModifier(setting.getSparkSubmitParameters())); + builder.region(setting.getRegion()); } ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); - sparkExecutionEngineConfig.setClusterName(clusterName.value()); - return sparkExecutionEngineConfig; + builder.clusterName(clusterName.value()); + return builder.build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java new file mode 100644 index 0000000000..e79e0f85e3 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java @@ -0,0 +1,7 @@ +package org.opensearch.sql.spark.config; + +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; + +public interface SparkSubmitParameterModifier { + void modifyParameters(SparkSubmitParameters parameters); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 92feba9941..b9436b0801 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -47,8 +47,10 @@ public class SparkConstants { public static final String SPARK_DRIVER_ENV_JAVA_HOME_KEY = "spark.emr-serverless.driverEnv.JAVA_HOME"; public static final String SPARK_EXECUTOR_ENV_JAVA_HOME_KEY = "spark.executorEnv.JAVA_HOME"; + // Used for logging/metrics in Spark (driver) public static final String SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY = "spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME"; + // Used for logging/metrics in Spark (executor) public static final String SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY = "spark.executorEnv.FLINT_CLUSTER_NAME"; public static final String FLINT_INDEX_STORE_HOST_KEY = "spark.datasource.flint.host"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index d06153bf79..85f7a3d8dd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -81,12 +81,12 @@ public DispatchQueryResponse submit( clusterName + ":" + JobType.BATCH.getText(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() + SparkSubmitParameters.builder() .clusterName(clusterName) .dataSource(context.getDataSourceMetadata()) .query(dispatchQueryRequest.getQuery()) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() + .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) .toString(), tags, false, diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 7475c5a7ae..552ddeb76e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -102,11 +102,12 @@ public DispatchQueryResponse submit( clusterName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() + SparkSubmitParameters.builder() .className(FLINT_SESSION_CLASS_NAME) .clusterName(clusterName) .dataSource(dataSourceMetadata) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + .build() + .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()), tags, dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName())); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 4a9b1ce5d5..886e7d176a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -69,13 +69,13 @@ public DispatchQueryResponse submit( jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() + SparkSubmitParameters.builder() .clusterName(clusterName) .dataSource(dataSourceMetadata) .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() + .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) .toString(), tags, indexQueryDetails.getFlintIndexOptions().autoRefresh(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 6aa28227a1..601103254f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -8,6 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.rest.model.LangType; @AllArgsConstructor @@ -21,8 +22,8 @@ public class DispatchQueryRequest { private final String executionRoleARN; private final String clusterName; - /** Optional extra Spark submit parameters to include in final request */ - private String extraSparkSubmitParams; + /* extension point to modify or add spark submit parameter */ + private final SparkSubmitParameterModifier sparkSubmitParameterModifier; /** Optional sessionId. */ private String sessionId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 419b125ab9..d138e5f05d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -16,7 +16,7 @@ public class CreateSessionRequest { private final String clusterName; private final String applicationId; private final String executionRoleArn; - private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; + private final SparkSubmitParameters sparkSubmitParameters; private final Map tags; private final String resultIndex; private final String datasourceName; @@ -26,7 +26,7 @@ public StartJobRequest getStartJobRequest(String sessionId) { clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId, applicationId, executionRoleArn, - sparkSubmitParametersBuilder.build().toString(), + sparkSubmitParameters.toString(), tags, resultIndex); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 760c898825..6eace80da4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -53,8 +53,12 @@ public void open(CreateSessionRequest createSessionRequest) { try { // append session id; createSessionRequest - .getSparkSubmitParametersBuilder() - .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + .getSparkSubmitParameters() + .acceptModifier( + (parameters) -> { + parameters.sessionExecution( + sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + }); createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId()); StartJobRequest startJobRequest = createSessionRequest.getStartJobRequest(sessionId.getSessionId()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java index 4e2102deed..d669875304 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -18,6 +18,7 @@ import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; @@ -64,7 +65,8 @@ protected void doExecute( CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); CreateAsyncQueryResponse createAsyncQueryResponse = - asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest); + asyncQueryExecutorService.createAsyncQuery( + createAsyncQueryRequest, new NullRequestContext()); String responseContent = new JsonResponseFormatter(JsonResponseFormatter.Style.PRETTY) { @Override diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 74b18d0332..fc02ce1e6f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -31,6 +31,8 @@ import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.exceptions.DatasourceDisabledException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionState; @@ -42,6 +44,7 @@ import org.opensearch.sql.spark.rest.model.LangType; public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorServiceSpec { + RequestContext requestContext = new NullRequestContext(); @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @@ -56,7 +59,8 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); emrsClient.startJobRunCalled(1); @@ -86,12 +90,14 @@ public void sessionLimitNotImpactBatchQuery() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); emrsClient.startJobRunCalled(1); CreateAsyncQueryResponse resp2 = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); emrsClient.startJobRunCalled(2); } @@ -105,7 +111,8 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { enableSession(false); CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertNull(response.getSessionId()); assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); @@ -119,7 +126,8 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { enableSession(true); response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); params = emrsClient.getJobRequest().getSparkSubmitParams(); assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); assertTrue( @@ -139,7 +147,8 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(response.getSessionId()); Optional statementModel = statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); @@ -171,14 +180,16 @@ public void reuseSessionWhenCreateAsyncQuery() { // 1. create async query. CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); // 2. reuse session id CreateAsyncQueryResponse second = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), + requestContext); assertEquals(first.getSessionId(), second.getSessionId()); assertNotEquals(first.getQueryId(), second.getQueryId()); @@ -194,7 +205,9 @@ public void reuseSessionWhenCreateAsyncQuery() { 2, search( QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must( + QueryBuilders.termQuery( + "type", STATEMENT_DOC_TYPE)) .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); Optional firstModel = @@ -220,7 +233,8 @@ public void batchQueryHasTimeout() { enableSession(false); CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout()); } @@ -236,7 +250,8 @@ public void interactiveQueryNoTimeout() { enableSession(true); asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout()); } @@ -269,7 +284,7 @@ public void datasourceWithBasicAuth() { enableSession(true); asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null), requestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic"))); assertTrue( @@ -291,7 +306,8 @@ public void withSessionCreateAsyncQueryFailed() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(response.getSessionId()); Optional statementModel = statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); @@ -341,7 +357,8 @@ public void createSessionMoreThanLimitFailed() { // 1. create async query. CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); setSessionState(first.getSessionId(), SessionState.RUNNING); @@ -351,7 +368,8 @@ public void createSessionMoreThanLimitFailed() { ConcurrencyLimitExceededException.class, () -> asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null))); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext)); assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); } @@ -369,7 +387,8 @@ public void recreateSessionIfNotReady() { // 1. create async query. CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); // set sessionState to FAIL @@ -379,7 +398,8 @@ public void recreateSessionIfNotReady() { CreateAsyncQueryResponse second = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), + requestContext); assertNotEquals(first.getSessionId(), second.getSessionId()); @@ -390,7 +410,8 @@ public void recreateSessionIfNotReady() { CreateAsyncQueryResponse third = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId()), + requestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } @@ -408,7 +429,8 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "SHOW SCHEMAS IN " + MYS3_DATASOURCE, MYS3_DATASOURCE, LangType.SQL, null)); + "SHOW SCHEMAS IN " + MYS3_DATASOURCE, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); // set sessionState to RUNNING @@ -421,7 +443,8 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { "SHOW SCHEMAS IN " + MYS3_DATASOURCE, MYS3_DATASOURCE, LangType.SQL, - first.getSessionId())); + first.getSessionId()), + requestContext); assertEquals(first.getSessionId(), second.getSessionId()); @@ -435,7 +458,8 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { "SHOW SCHEMAS IN " + MYGLUE_DATASOURCE, MYGLUE_DATASOURCE, LangType.SQL, - second.getSessionId())); + second.getSessionId()), + requestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } @@ -452,7 +476,8 @@ public void recreateSessionIfStale() { // 1. create async query. CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); // set sessionState to RUNNING @@ -462,7 +487,8 @@ public void recreateSessionIfStale() { CreateAsyncQueryResponse second = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), + requestContext); assertEquals(first.getSessionId(), second.getSessionId()); @@ -480,7 +506,8 @@ public void recreateSessionIfStale() { CreateAsyncQueryResponse third = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId()), + requestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } finally { // set timeout setting to 0 @@ -509,7 +536,8 @@ public void submitQueryInInvalidSessionWillCreateNewSession() { CreateAsyncQueryResponse asyncQuery = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select 1", MYS3_DATASOURCE, LangType.SQL, invalidSessionId.getSessionId())); + "select 1", MYS3_DATASOURCE, LangType.SQL, invalidSessionId.getSessionId()), + requestContext); assertNotNull(asyncQuery.getSessionId()); assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId()); } @@ -542,7 +570,7 @@ public void datasourceNameIncludeUppercase() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", "TESTS3", LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", "TESTS3", LangType.SQL, null), requestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertNotNull(response.getSessionId()); @@ -564,7 +592,8 @@ public void concurrentSessionLimitIsDomainLevel() { // 1. create async query. CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(first.getSessionId()); setSessionState(first.getSessionId(), SessionState.RUNNING); @@ -574,8 +603,8 @@ public void concurrentSessionLimitIsDomainLevel() { ConcurrencyLimitExceededException.class, () -> asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest( - "select 1", MYGLUE_DATASOURCE, LangType.SQL, null))); + new CreateAsyncQueryRequest("select 1", MYGLUE_DATASOURCE, LangType.SQL, null), + requestContext)); assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); } @@ -595,7 +624,8 @@ public void testDatasourceDisabled() { // 1. create async query. try { asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), + requestContext); fail("It should have thrown DataSourceDisabledException"); } catch (DatasourceDisabledException exception) { Assertions.assertEquals("Datasource mys3 is disabled.", exception.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index a5dee8f4e8..b10d54683d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -33,8 +33,11 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -50,6 +53,8 @@ public class AsyncQueryExecutorServiceImplTest { private AsyncQueryExecutorService jobExecutorService; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private RequestContext requestContext; private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @BeforeEach @@ -66,10 +71,10 @@ void testCreateAsyncQuery() { CreateAsyncQueryRequest createAsyncQueryRequest = new CreateAsyncQueryRequest( "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( new SparkExecutionEngineConfig( - EMRS_APPLICATION_ID, "eu-west-1", EMRS_EXECUTION_ROLE, null, TEST_CLUSTER_NAME)); + EMRS_APPLICATION_ID, "eu-west-1", EMRS_EXECUTION_ROLE, sparkSubmitParameterModifier, TEST_CLUSTER_NAME)); DispatchQueryRequest expectedDispatchQueryRequest = new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -77,54 +82,56 @@ void testCreateAsyncQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier); when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = - jobExecutorService.createAsyncQuery(createAsyncQueryRequest); + jobExecutorService.createAsyncQuery(createAsyncQueryRequest, requestContext); verify(asyncQueryJobMetadataStorageService, times(1)) .storeJobMetadata(getAsyncQueryJobMetadata()); - verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); + verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(requestContext); verify(sparkQueryDispatcher, times(1)).dispatch(expectedDispatchQueryRequest); Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); } @Test void testCreateAsyncQueryWithExtraSparkSubmitParameter() { - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + OpenSearchSparkSubmitParameterModifier modifier = + new OpenSearchSparkSubmitParameterModifier("--conf spark.dynamicAllocation.enabled=false"); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( new SparkExecutionEngineConfig( EMRS_APPLICATION_ID, "eu-west-1", - EMRS_APPLICATION_ID, - "--conf spark.dynamicAllocation.enabled=false", + EMRS_EXECUTION_ROLE, + modifier, TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - "select * from my_glue.default.http_logs", "my_glue", LangType.SQL)); + "select * from my_glue.default.http_logs", "my_glue", LangType.SQL), + requestContext); verify(sparkQueryDispatcher, times(1)) .dispatch( - argThat( - actualReq -> - actualReq - .getExtraSparkSubmitParams() - .equals("--conf spark.dynamicAllocation.enabled=false"))); + argThat(actualReq -> actualReq.getSparkSubmitParameterModifier().equals(modifier))); } @Test void testGetAsyncQueryResultsWithJobNotFoundException() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( AsyncQueryNotFoundException.class, () -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID)); + Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); @@ -173,9 +180,11 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { void testCancelJobWithJobNotFound() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID)); + Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); @@ -187,7 +196,9 @@ void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(getAsyncQueryJobMetadata())); when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata())).thenReturn(EMR_JOB_ID); + String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 85bb92bba2..b05da017d5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -52,9 +52,11 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -97,6 +99,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected StateStore stateStore; protected SessionStorageService sessionStorageService; protected StatementStorageService statementStorageService; + protected RequestContext requestContext; @Override protected Collection> nodePlugins() { @@ -332,8 +335,14 @@ public EMRServerlessClient getClient() { } } - public SparkExecutionEngineConfig sparkExecutionEngineConfig() { - return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); + public SparkExecutionEngineConfig sparkExecutionEngineConfig(RequestContext requestContext) { + return SparkExecutionEngineConfig.builder() + .applicationId("appId") + .region("us-west-2") + .executionRoleARN("roleArn") + .sparkSubmitParameterModifier(new OpenSearchSparkSubmitParameterModifier("")) + .clusterName("myCluster") + .build(); } public void enableSession(boolean enabled) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index f2c3bda026..3ab558616b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -26,6 +26,8 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; +import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; @@ -37,6 +39,7 @@ import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; public class AsyncQueryGetResultSpecTest extends AsyncQueryExecutorServiceSpec { + RequestContext requestContext = new NullRequestContext(); /** Mock Flint index and index state */ private final FlintDatasetMock mockIndex = @@ -435,7 +438,8 @@ public JSONObject getResultWithQueryId(String queryId, String resultIndex) { }); this.createQueryResponse = queryService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); } AssertionHelper withInteraction(Interaction interaction) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index d49e3883da..4786e496e0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -76,7 +76,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -144,7 +145,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -225,7 +227,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -288,7 +291,8 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result assertEquals( @@ -353,7 +357,8 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result assertEquals( @@ -427,7 +432,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -501,7 +507,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -569,7 +576,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -630,7 +638,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -693,7 +702,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -756,7 +766,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -816,7 +827,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -874,7 +886,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -940,7 +953,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -1004,7 +1018,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -1069,7 +1084,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 09addccdbb..486ccf7031 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -135,7 +135,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(response.getQueryId()); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -185,7 +186,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -224,7 +226,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -260,7 +263,8 @@ public CancelJobRunResult cancelJobRun( // 1.drop index CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null), + requestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -302,7 +306,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(response.getQueryId()); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -361,7 +366,8 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -407,7 +413,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -452,7 +459,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result assertEquals( @@ -502,7 +510,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -549,7 +558,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result assertEquals( @@ -595,7 +605,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result assertEquals( @@ -649,7 +660,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); AsyncQueryExecutionResponse asyncQueryExecutionResponse = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); @@ -693,7 +705,8 @@ public CancelJobRunResult cancelJobRun( // 1.drop index CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null), + requestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -740,7 +753,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -769,7 +783,8 @@ public void concurrentRefreshJobLimitNotApplied() { + "l_quantity) WITH (auto_refresh = true)"; CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNull(response.getSessionId()); } @@ -797,7 +812,8 @@ public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { ConcurrencyLimitExceededException.class, () -> asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null))); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext)); assertEquals("domain concurrent refresh job can not exceed 1", exception.getMessage()); } @@ -823,7 +839,8 @@ public void concurrentRefreshJobLimitAppliedToRefresh() { ConcurrencyLimitExceededException.class, () -> asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null))); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext)); assertEquals("domain concurrent refresh job can not exceed 1", exception.getMessage()); } @@ -845,7 +862,8 @@ public void concurrentRefreshJobLimitNotAppliedToDDL() { CreateAsyncQueryResponse asyncQueryResponse = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); assertNotNull(asyncQueryResponse.getSessionId()); } @@ -877,7 +895,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 1. submit create / refresh index query CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // 2. cancel query IllegalArgumentException exception = @@ -920,7 +939,8 @@ public GetJobRunResult getJobRunResult( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // mock index state. flintIndexJob.refreshing(); @@ -964,7 +984,8 @@ public GetJobRunResult getJobRunResult( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( - mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null)); + mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); // mock index state. flintIndexJob.active(); @@ -1010,7 +1031,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { "REFRESH INDEX covering_corrupted ON my_glue.mydb.http_logs", MYS3_DATASOURCE, LangType.SQL, - null)); + null), + requestContext); // mock index state. flintIndexJob.refreshing(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index 14bb225c96..c289bbe53f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -171,7 +171,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // Vacuum index CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(mockDS.query, MYS3_DATASOURCE, LangType.SQL, null)); + new CreateAsyncQueryRequest(mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), + requestContext); return asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java index e732cf698c..10f12251b0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java @@ -5,8 +5,11 @@ package org.opensearch.sql.spark.asyncquery.model; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; import org.junit.jupiter.api.Test; @@ -14,7 +17,7 @@ public class SparkSubmitParametersTest { @Test public void testBuildWithoutExtraParameters() { - String params = SparkSubmitParameters.Builder.builder().build().toString(); + String params = SparkSubmitParameters.builder().build().toString(); assertNotNull(params); } @@ -22,7 +25,7 @@ public void testBuildWithoutExtraParameters() { @Test public void testBuildWithExtraParameters() { String params = - SparkSubmitParameters.Builder.builder().extraParameters("--conf A=1").build().toString(); + SparkSubmitParameters.builder().extraParameters("--conf A=1").build().toString(); // Assert the conf is included with a space assertTrue(params.endsWith(" --conf A=1")); @@ -32,7 +35,7 @@ public void testBuildWithExtraParameters() { public void testBuildQueryString() { String rawQuery = "SHOW tables LIKE \"%\";"; String expectedQueryInParams = "\"SHOW tables LIKE \\\"%\\\";\""; - String params = SparkSubmitParameters.Builder.builder().query(rawQuery).build().toString(); + String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); assertTrue(params.contains(expectedQueryInParams)); } @@ -40,7 +43,7 @@ public void testBuildQueryString() { public void testBuildQueryStringNestedQuote() { String rawQuery = "SELECT '\"1\"'"; String expectedQueryInParams = "\"SELECT '\\\"1\\\"'\""; - String params = SparkSubmitParameters.Builder.builder().query(rawQuery).build().toString(); + String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); assertTrue(params.contains(expectedQueryInParams)); } @@ -48,7 +51,34 @@ public void testBuildQueryStringNestedQuote() { public void testBuildQueryStringSpecialCharacter() { String rawQuery = "SELECT '{\"test ,:+\\\"inner\\\"/\\|?#><\"}'"; String expectedQueryInParams = "SELECT '{\\\"test ,:+\\\\\\\"inner\\\\\\\"/\\\\|?#><\\\"}'"; - String params = SparkSubmitParameters.Builder.builder().query(rawQuery).build().toString(); + String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); assertTrue(params.contains(expectedQueryInParams)); } + + @Test + public void testOverrideConfigItem() { + SparkSubmitParameters params = SparkSubmitParameters.builder().build(); + params.setConfigItem(SPARK_JARS_KEY, "Overridden"); + String result = params.toString(); + + assertTrue(result.contains(String.format("%s=Overridden", SPARK_JARS_KEY))); + } + + @Test + public void testDeleteConfigItem() { + SparkSubmitParameters params = SparkSubmitParameters.builder().build(); + params.deleteConfigItem(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + String result = params.toString(); + + assertFalse(result.contains(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY)); + } + + @Test + public void testAddConfigItem() { + SparkSubmitParameters params = SparkSubmitParameters.builder().build(); + params.setConfigItem("AdditionalKey", "Value"); + String result = params.toString(); + + assertTrue(result.contains("AdditionalKey=Value")); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java index 9bfed9f498..562fc84eca 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.client; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import org.junit.jupiter.api.Assertions; @@ -12,7 +13,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.constants.TestConstants; @@ -24,7 +24,7 @@ public class EMRServerlessClientFactoryImplTest { @Test public void testGetClient() { - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(createSparkExecutionEngineConfig()); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); @@ -35,7 +35,7 @@ public void testGetClient() { @Test public void testGetClientWithChangeInSetting() { SparkExecutionEngineConfig sparkExecutionEngineConfig = createSparkExecutionEngineConfig(); - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); @@ -46,7 +46,7 @@ public void testGetClientWithChangeInSetting() { Assertions.assertEquals(emrServerlessClient1, emrserverlessClient); sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION); - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(sparkExecutionEngineConfig); EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(); Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient); @@ -55,7 +55,7 @@ public void testGetClientWithChangeInSetting() { @Test public void testGetClientWithException() { - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()).thenReturn(null); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); IllegalArgumentException illegalArgumentException = @@ -69,8 +69,9 @@ public void testGetClientWithException() { @Test public void testGetClientWithExceptionWithNullRegion() { - SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); - when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + SparkExecutionEngineConfig sparkExecutionEngineConfig = + SparkExecutionEngineConfig.builder().build(); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); @@ -84,13 +85,12 @@ public void testGetClientWithExceptionWithNullRegion() { } private SparkExecutionEngineConfig createSparkExecutionEngineConfig() { - SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); - sparkExecutionEngineConfig.setRegion(TestConstants.US_EAST_REGION); - sparkExecutionEngineConfig.setExecutionRoleARN(TestConstants.EMRS_EXECUTION_ROLE); - sparkExecutionEngineConfig.setSparkSubmitParameters( - SparkSubmitParameters.Builder.builder().build().toString()); - sparkExecutionEngineConfig.setClusterName(TestConstants.TEST_CLUSTER_NAME); - sparkExecutionEngineConfig.setApplicationId(TestConstants.EMRS_APPLICATION_ID); - return sparkExecutionEngineConfig; + return SparkExecutionEngineConfig.builder() + .region(TestConstants.US_EAST_REGION) + .executionRoleARN(TestConstants.EMRS_EXECUTION_ROLE) + .sparkSubmitParameterModifier((sparkSubmitParameters) -> {}) + .clusterName(TestConstants.TEST_CLUSTER_NAME) + .applicationId(TestConstants.EMRS_APPLICATION_ID) + .build(); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 225a43a526..16c37ad299 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -68,7 +68,7 @@ void testStartJobRun() { when(emrServerless.startJobRun(any())).thenReturn(response); EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); - String parameters = SparkSubmitParameters.Builder.builder().query(QUERY).build().toString(); + String parameters = SparkSubmitParameters.builder().query(QUERY).build().toString(); emrServerlessClient.startJobRun( new StartJobRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java index 298a56b17a..03bfde88a3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java @@ -10,11 +10,14 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class SparkExecutionEngineConfigSupplierImplTest { @Mock private Settings settings; + @Mock private RequestContext requestContext; @Test void testGetSparkExecutionEngineConfig() { @@ -30,17 +33,20 @@ void testGetSparkExecutionEngineConfig() { + "}"); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); + Assertions.assertEquals("00fd775baqpu4g0p", sparkExecutionEngineConfig.getApplicationId()); Assertions.assertEquals( "arn:aws:iam::270824043731:role/emr-job-execution-role", sparkExecutionEngineConfig.getExecutionRoleARN()); Assertions.assertEquals("eu-west-1", sparkExecutionEngineConfig.getRegion()); - Assertions.assertEquals( - "--conf spark.dynamicAllocation.enabled=false", - sparkExecutionEngineConfig.getSparkSubmitParameters()); Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); + SparkSubmitParameters parameters = SparkSubmitParameters.builder().build(); + sparkExecutionEngineConfig.getSparkSubmitParameterModifier().modifyParameters(parameters); + Assertions.assertTrue( + parameters.toString().contains("--conf spark.dynamicAllocation.enabled=false")); } @Test @@ -50,12 +56,14 @@ void testGetSparkExecutionEngineConfigWithNullSetting() { when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(null); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); + Assertions.assertNull(sparkExecutionEngineConfig.getApplicationId()); Assertions.assertNull(sparkExecutionEngineConfig.getExecutionRoleARN()); Assertions.assertNull(sparkExecutionEngineConfig.getRegion()); - Assertions.assertNull(sparkExecutionEngineConfig.getSparkSubmitParameters()); + Assertions.assertNull(sparkExecutionEngineConfig.getSparkSubmitParameterModifier()); Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index aade6ff63b..7d43ccc7e3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -26,6 +26,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -46,6 +47,7 @@ class IndexDMLHandlerTest { @Mock private FlintIndexMetadataService flintIndexMetadataService; @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; + @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Test public void getResponseFromExecutor() { @@ -70,7 +72,8 @@ public void testWhenIndexDetailsAreNotFound() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier); DataSourceMetadata metadata = new DataSourceMetadata.Builder() .setName("mys3") @@ -113,7 +116,8 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier); DataSourceMetadata metadata = new DataSourceMetadata.Builder() .setName("mys3") diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 08aa0e4d0e..cfb340abc3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -62,6 +62,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -90,6 +91,7 @@ public class SparkQueryDispatcherTest { @Mock private LeaseManager leaseManager; @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; + @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -158,7 +160,8 @@ void testDispatchSelectQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -206,7 +209,8 @@ void testDispatchSelectQueryWithLakeFormation() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -253,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -299,7 +304,8 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -414,7 +420,8 @@ void testDispatchIndexQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -461,7 +468,8 @@ void testDispatchWithPPLQuery() { "my_glue", LangType.PPL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -508,7 +516,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -559,7 +568,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -610,7 +620,8 @@ void testDispatchMaterializedViewQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -657,7 +668,8 @@ void testDispatchShowMVQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -704,7 +716,8 @@ void testRefreshIndexQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -751,7 +764,8 @@ void testDispatchDescribeIndexQuery() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME)); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -776,7 +790,8 @@ void testDispatchWithWrongURI() { "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier))); Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", @@ -800,7 +815,8 @@ void testDispatchWithUnSupportedDataSourceType() { "my_prometheus", LangType.SQL, EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier))); Assertions.assertEquals( "UnSupported datasource type for async queries:: PROMETHEUS", @@ -1182,7 +1198,7 @@ private DispatchQueryRequest constructDispatchQueryRequest( langType, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - extraParameters, + (parameters) -> parameters.setExtraParameters(extraParameters), null); } @@ -1194,7 +1210,7 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - null, + sparkSubmitParameterModifier, sessionId); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java index 54451effed..6c1514e6e4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java @@ -18,7 +18,7 @@ public static CreateSessionRequest createSessionRequest() { TEST_CLUSTER_NAME, "appId", "arn", - SparkSubmitParameters.Builder.builder(), + SparkSubmitParameters.builder().build(), new HashMap<>(), "resultIndex", TEST_DATASOURCE_NAME); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 190f62135b..2a4d33726b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -69,9 +71,11 @@ public void testDoExecute() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(pluginSettings.getSettingValue(Settings.Key.ASYNC_QUERY_ENABLED)).thenReturn(true); - when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + when(jobExecutorService.createAsyncQuery(eq(createAsyncQueryRequest), any())) .thenReturn(new CreateAsyncQueryResponse("123", null)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = createJobActionResponseArgumentCaptor.getValue(); @@ -87,9 +91,11 @@ public void testDoExecuteWithSessionId() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(pluginSettings.getSettingValue(Settings.Key.ASYNC_QUERY_ENABLED)).thenReturn(true); - when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + when(jobExecutorService.createAsyncQuery(eq(createAsyncQueryRequest), any())) .thenReturn(new CreateAsyncQueryResponse("123", MOCK_SESSION_ID)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = createJobActionResponseArgumentCaptor.getValue(); @@ -107,9 +113,11 @@ public void testDoExecuteWithException() { when(pluginSettings.getSettingValue(Settings.Key.ASYNC_QUERY_ENABLED)).thenReturn(true); doThrow(new RuntimeException("Error")) .when(jobExecutorService) - .createAsyncQuery(createAsyncQueryRequest); + .createAsyncQuery(eq(createAsyncQueryRequest), any()); + action.doExecute(task, request, actionListener); - verify(jobExecutorService, times(1)).createAsyncQuery(createAsyncQueryRequest); + + verify(jobExecutorService, times(1)).createAsyncQuery(eq(createAsyncQueryRequest), any()); Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof RuntimeException); @@ -123,8 +131,10 @@ public void asyncQueryDisabled() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(pluginSettings.getSettingValue(Settings.Key.ASYNC_QUERY_ENABLED)).thenReturn(false); + action.doExecute(task, request, actionListener); - verify(jobExecutorService, never()).createAsyncQuery(createAsyncQueryRequest); + + verify(jobExecutorService, never()).createAsyncQuery(eq(createAsyncQueryRequest), any()); Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof IllegalAccessException);