Skip to content

Commit

Permalink
Refactor query param (#2519)
Browse files Browse the repository at this point in the history
* Refactor query param

Signed-off-by: Louis Chu <[email protected]>

* Reduce scope of changes

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Mar 13, 2024
1 parent 353b0d7 commit ee2dbd5
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ public Builder clusterName(String clusterName) {
return this;
}

public Builder query(String query) {
config.put(FLINT_JOB_QUERY, query);
return this;
}

public Builder dataSource(DataSourceMetadata metadata) {
if (DataSourceType.S3GLUE.equals(metadata.getConnector())) {
String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public String startJobRun(StartJobRequest startJobRequest) {
.withSparkSubmit(
new SparkSubmit()
.withEntryPoint(SPARK_SQL_APPLICATION_JAR)
.withEntryPointArguments(startJobRequest.getQuery(), resultIndex)
.withEntryPointArguments(resultIndex)
.withSparkSubmitParameters(startJobRequest.getSparkSubmitParams())));

StartJobRunResult startJobRunResult =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public class StartJobRequest {

public static final Long DEFAULT_JOB_TIMEOUT = 120L;

private final String query;
private final String jobName;
private final String applicationId;
private final String executionRoleArn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ public class SparkConstants {
public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER =
"com.amazonaws.emr.AssumeRoleAWSCredentialsProvider";
public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/";

public static final String FLINT_JOB_QUERY = "spark.flint.job.query";
public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex";
public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId";

public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL";
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ public DispatchQueryResponse submit(
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
clusterName + ":" + JobType.BATCH.getText(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.clusterName(clusterName)
.dataSource(context.getDataSourceMetadata())
.query(dispatchQueryRequest.getQuery())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ public DispatchQueryResponse submit(
+ indexQueryDetails.openSearchIndexName();
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.clusterName(clusterName)
.dataSource(dataSourceMetadata)
.query(dispatchQueryRequest.getQuery())
.structuredStreaming(true)
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public class CreateSessionRequest {

public StartJobRequest getStartJobRequest(String sessionId) {
return new InteractiveSessionStartJobRequest(
"select 1",
clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId,
applicationId,
executionRoleArn,
Expand All @@ -34,22 +33,13 @@ public StartJobRequest getStartJobRequest(String sessionId) {

static class InteractiveSessionStartJobRequest extends StartJobRequest {
public InteractiveSessionStartJobRequest(
String query,
String jobName,
String applicationId,
String executionRoleArn,
String sparkSubmitParams,
Map<String, String> tags,
String resultIndex) {
super(
query,
jobName,
applicationId,
executionRoleArn,
sparkSubmitParams,
tags,
false,
resultIndex);
super(jobName, applicationId, executionRoleArn, sparkSubmitParams, tags, false, resultIndex);
}

/** Interactive query keep running. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ public void testBuildWithExtraParameters() {
// Assert the conf is included with a space
assertTrue(params.endsWith(" --conf A=1"));
}

@Test
public void testBuildQueryString() {
String query = "SHOW tables LIKE \"%\";";
String params = SparkSubmitParameters.Builder.builder().query(query).build().toString();
assertTrue(params.contains(query));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;

@ExtendWith(MockitoExtension.class)
public class EmrServerlessClientImplTest {
Expand All @@ -66,13 +67,14 @@ void testStartJobRun() {
when(emrServerless.startJobRun(any())).thenReturn(response);

EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
String parameters = SparkSubmitParameters.Builder.builder().query(QUERY).build().toString();

emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
EMRS_JOB_NAME,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
SPARK_SUBMIT_PARAMETERS,
parameters,
new HashMap<>(),
false,
DEFAULT_RESULT_INDEX));
Expand All @@ -83,8 +85,14 @@ void testStartJobRun() {
Assertions.assertEquals(
ENTRY_POINT_START_JAR, startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPoint());
Assertions.assertEquals(
List.of(QUERY, DEFAULT_RESULT_INDEX),
List.of(DEFAULT_RESULT_INDEX),
startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPointArguments());
Assertions.assertTrue(
startJobRunRequest
.getJobDriver()
.getSparkSubmit()
.getSparkSubmitParameters()
.contains(QUERY));
}

@Test
Expand All @@ -97,7 +105,6 @@ void testStartJobRunWithErrorMetric() {
() ->
emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
EMRS_JOB_NAME,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand All @@ -116,7 +123,6 @@ void testStartJobRunResultIndex() {
EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
EMRS_JOB_NAME,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -185,7 +191,6 @@ void testStartJobRunWithLongJobName() {
EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
emrServerlessClient.startJobRun(
new StartJobRequest(
QUERY,
RandomStringUtils.random(300),
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ void executionTimeout() {
}

private StartJobRequest onDemandJob() {
return new StartJobRequest("", "", "", "", "", Map.of(), false, null);
return new StartJobRequest("", "", "", "", Map.of(), false, null);
}

private StartJobRequest streamingJob() {
return new StartJobRequest("", "", "", "", "", Map.of(), true, null);
return new StartJobRequest("", "", "", "", Map.of(), true, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ void testDispatchSelectQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -186,10 +186,10 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
put(FLINT_INDEX_STORE_AUTH_USERNAME, "username");
put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -229,10 +229,10 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() {
new HashMap<>() {
{
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -342,10 +342,10 @@ void testDispatchIndexQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
}));
},
query));
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -388,10 +388,10 @@ void testDispatchWithPPLQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -432,10 +432,10 @@ void testDispatchQueryWithoutATableAndDataSourceName() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -481,10 +481,10 @@ void testDispatchIndexQueryWithoutADatasourceName() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
}));
},
query));
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -530,10 +530,10 @@ void testDispatchMaterializedViewQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
}));
},
query));
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:streaming:flint_mv_1",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -575,10 +575,10 @@ void testDispatchShowMVQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -620,10 +620,10 @@ void testRefreshIndexQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -665,10 +665,10 @@ void testDispatchDescribeIndexQuery() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
});
},
query);
StartJobRequest expected =
new StartJobRequest(
query,
"TEST_CLUSTER:batch",
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
Expand Down Expand Up @@ -938,7 +938,7 @@ void testDispatchQueryWithExtraSparkSubmitParameters() {
}

private String constructExpectedSparkSubmitParameterString(
String auth, Map<String, String> authParams) {
String auth, Map<String, String> authParams, String query) {
StringBuilder authParamConfigBuilder = new StringBuilder();
for (String key : authParams.keySet()) {
authParamConfigBuilder.append(" --conf ");
Expand Down Expand Up @@ -978,7 +978,10 @@ private String constructExpectedSparkSubmitParameterString(
+ " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"
+ " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog "
+ " --conf spark.flint.datasource.name=my_glue "
+ authParamConfigBuilder;
+ authParamConfigBuilder
+ " --conf spark.flint.job.query="
+ query
+ " ";
}

private String withStructuredStreaming(String parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase {
@Before
public void setup() {
emrsClient = new TestEMRServerlessClient();
startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, "");
startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, "");
stateStore = new StateStore(client(), clusterService());
}

Expand Down

0 comments on commit ee2dbd5

Please sign in to comment.