Skip to content

Commit

Permalink
Provide a way to modify spark parameters
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed May 22, 2024
1 parent 3a28d2a commit 92197a8
Show file tree
Hide file tree
Showing 34 changed files with 423 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.asyncquery;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;

Expand All @@ -20,7 +21,8 @@ public interface AsyncQueryExecutorService {
* @param createAsyncQueryRequest createAsyncQueryRequest.
* @return {@link CreateAsyncQueryResponse}
*/
CreateAsyncQueryResponse createAsyncQuery(CreateAsyncQueryRequest createAsyncQueryRequest);
CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext);

/**
* Returns async query response for a given queryId.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
Expand All @@ -36,9 +37,9 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest) {
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext);
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
Expand All @@ -48,7 +49,7 @@ public CreateAsyncQueryResponse createAsyncQuery(
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
sparkExecutionEngineConfig.getClusterName(),
sparkExecutionEngineConfig.getSparkSubmitParameters(),
sparkExecutionEngineConfig.getSparkSubmitParameterModifier(),
createAsyncQueryRequest.getSessionId()));
asyncQueryJobMetadataStorageService.storeJobMetadata(
AsyncQueryJobMetadata.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

/** An implementation of RequestContext for where context is not required */
public class NullRequestContext implements RequestContext {
@Override
public Object getAttribute(String name) {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

/** Context interface to provide additional request related information */
public interface RequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import java.util.function.Supplier;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.datasources.auth.AuthenticationType;
import org.opensearch.sql.spark.config.SparkSubmitParameterModifier;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;

/** Define Spark Submit Parameters. */
Expand All @@ -40,7 +42,24 @@ public class SparkSubmitParameters {
private final Map<String, String> config;

/** Extra parameters to append finally */
private String extraParameters;
@Setter private String extraParameters;

public void setConfigItem(String key, String value) {
config.put(key, value);
}

public void deleteConfigItem(String key) {
config.remove(key);
}

public static Builder builder() {
return Builder.builder();
}

public SparkSubmitParameters acceptModifier(SparkSubmitParameterModifier modifier) {
modifier.modifyParameters(this);
return this;
}

public static class Builder {

Expand Down Expand Up @@ -180,17 +199,16 @@ public Builder extraParameters(String params) {
return this;
}

public Builder sessionExecution(String sessionId, String datasourceName) {
config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName));
config.put(FLINT_JOB_SESSION_ID, sessionId);
return this;
}

public SparkSubmitParameters build() {
return new SparkSubmitParameters(className, config, extraParameters);
}
}

public void sessionExecution(String sessionId, String datasourceName) {
config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName));
config.put(FLINT_JOB_SESSION_ID, sessionId);
}

@Override
public String toString() {
StringBuilder stringBuilder = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.NullRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

Expand All @@ -32,7 +33,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
@Override
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullRequestContext());
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.opensearch.sql.spark.config;

import lombok.AllArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;

@AllArgsConstructor
public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParameterModifier {

private String extraParameters;

@Override
public void modifyParameters(SparkSubmitParameters parameters) {
parameters.setExtraParameters(this.extraParameters);
}
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package org.opensearch.sql.spark.config;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* POJO for spark Execution Engine Config. Interface between {@link
* org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService} and {@link
* SparkExecutionEngineConfigSupplier}
*/
@Data
@NoArgsConstructor
@Builder
@AllArgsConstructor
public class SparkExecutionEngineConfig {
private String applicationId;
private String region;
private String executionRoleARN;
private String sparkSubmitParameters;
private SparkSubmitParameterModifier sparkSubmitParameterModifier;
private String clusterName;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.opensearch.sql.spark.config;

import org.opensearch.sql.spark.asyncquery.model.RequestContext;

/** Interface for extracting and providing SparkExecutionEngineConfig */
public interface SparkExecutionEngineConfigSupplier {

Expand All @@ -8,5 +10,5 @@ public interface SparkExecutionEngineConfigSupplier {
*
* @return {@link SparkExecutionEngineConfig}.
*/
SparkExecutionEngineConfig getSparkExecutionEngineConfig();
SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,34 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.cluster.ClusterName;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;

@AllArgsConstructor
public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier {

private Settings settings;

@Override
public SparkExecutionEngineConfig getSparkExecutionEngineConfig() {
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) {
String sparkExecutionEngineConfigSettingString =
this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG);
SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig();
SparkExecutionEngineConfig.SparkExecutionEngineConfigBuilder builder =
SparkExecutionEngineConfig.builder();
if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) {
SparkExecutionEngineConfigClusterSetting sparkExecutionEngineConfigClusterSetting =
SparkExecutionEngineConfigClusterSetting setting =
AccessController.doPrivileged(
(PrivilegedAction<SparkExecutionEngineConfigClusterSetting>)
() ->
SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(
sparkExecutionEngineConfigSettingString));
sparkExecutionEngineConfig.setApplicationId(
sparkExecutionEngineConfigClusterSetting.getApplicationId());
sparkExecutionEngineConfig.setExecutionRoleARN(
sparkExecutionEngineConfigClusterSetting.getExecutionRoleARN());
sparkExecutionEngineConfig.setSparkSubmitParameters(
sparkExecutionEngineConfigClusterSetting.getSparkSubmitParameters());
sparkExecutionEngineConfig.setRegion(sparkExecutionEngineConfigClusterSetting.getRegion());
builder.applicationId(setting.getApplicationId());
builder.executionRoleARN(setting.getExecutionRoleARN());
builder.sparkSubmitParameterModifier(
new OpenSearchSparkSubmitParameterModifier(setting.getSparkSubmitParameters()));
builder.region(setting.getRegion());
}
ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME);
sparkExecutionEngineConfig.setClusterName(clusterName.value());
return sparkExecutionEngineConfig;
builder.clusterName(clusterName.value());
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.opensearch.sql.spark.config;

import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;

public interface SparkSubmitParameterModifier {
void modifyParameters(SparkSubmitParameters parameters);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ public class SparkConstants {
public static final String SPARK_DRIVER_ENV_JAVA_HOME_KEY =
"spark.emr-serverless.driverEnv.JAVA_HOME";
public static final String SPARK_EXECUTOR_ENV_JAVA_HOME_KEY = "spark.executorEnv.JAVA_HOME";
// Used for logging/metrics in Spark (driver)
public static final String SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY =
"spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME";
// Used for logging/metrics in Spark (executor)
public static final String SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY =
"spark.executorEnv.FLINT_CLUSTER_NAME";
public static final String FLINT_INDEX_STORE_HOST_KEY = "spark.datasource.flint.host";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ public DispatchQueryResponse submit(
clusterName + ":" + JobType.BATCH.getText(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
SparkSubmitParameters.builder()
.clusterName(clusterName)
.dataSource(context.getDataSourceMetadata())
.query(dispatchQueryRequest.getQuery())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier())
.toString(),
tags,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ public DispatchQueryResponse submit(
clusterName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
SparkSubmitParameters.builder()
.className(FLINT_SESSION_CLASS_NAME)
.clusterName(clusterName)
.dataSource(dataSourceMetadata)
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()),
.build()
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ public DispatchQueryResponse submit(
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
SparkSubmitParameters.builder()
.clusterName(clusterName)
.dataSource(dataSourceMetadata)
.query(dispatchQueryRequest.getQuery())
.structuredStreaming(true)
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier())
.toString(),
tags,
indexQueryDetails.getFlintIndexOptions().autoRefresh(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkSubmitParameterModifier;
import org.opensearch.sql.spark.rest.model.LangType;

@AllArgsConstructor
Expand All @@ -21,8 +22,8 @@ public class DispatchQueryRequest {
private final String executionRoleARN;
private final String clusterName;

/** Optional extra Spark submit parameters to include in final request */
private String extraSparkSubmitParams;
/* extension point to modify or add spark submit parameter */
private final SparkSubmitParameterModifier sparkSubmitParameterModifier;

/** Optional sessionId. */
private String sessionId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class CreateSessionRequest {
private final String clusterName;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder;
private final SparkSubmitParameters sparkSubmitParameters;
private final Map<String, String> tags;
private final String resultIndex;
private final String datasourceName;
Expand All @@ -26,7 +26,7 @@ public StartJobRequest getStartJobRequest(String sessionId) {
clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
sparkSubmitParameters.toString(),
tags,
resultIndex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ public void open(CreateSessionRequest createSessionRequest) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParametersBuilder()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
.getSparkSubmitParameters()
.acceptModifier(
(parameters) -> {
parameters.sessionExecution(
sessionId.getSessionId(), createSessionRequest.getDatasourceName());
});
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.model.NullRequestContext;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest;
Expand Down Expand Up @@ -64,7 +65,8 @@ protected void doExecute(

CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest();
CreateAsyncQueryResponse createAsyncQueryResponse =
asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest);
asyncQueryExecutorService.createAsyncQuery(
createAsyncQueryRequest, new NullRequestContext());
String responseContent =
new JsonResponseFormatter<CreateAsyncQueryResponse>(JsonResponseFormatter.Style.PRETTY) {
@Override
Expand Down
Loading

0 comments on commit 92197a8

Please sign in to comment.