Skip to content

Commit

Permalink
Fix unit test failures in SparkQueryDispatcherTest
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Apr 29, 2024
1 parent bbed24a commit 02137dc
Showing 1 changed file with 33 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ void setUp() {
emrServerlessClientFactory);
sparkQueryDispatcher =
new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory);
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
}

@Test
void testDispatchSelectQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -163,6 +163,7 @@ void testDispatchSelectQuery() {
LangType.SQL,
EMRS_EXECUTION_ROLE,
TEST_CLUSTER_NAME));

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Expand All @@ -171,6 +172,7 @@ void testDispatchSelectQuery() {

@Test
void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -199,6 +201,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
Expand All @@ -208,6 +211,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
LangType.SQL,
EMRS_EXECUTION_ROLE,
TEST_CLUSTER_NAME));

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Expand All @@ -216,6 +220,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {

@Test
void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -324,6 +329,7 @@ void testDispatchSelectQueryFailedCreateSession() {

@Test
void testDispatchIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -374,6 +380,7 @@ void testDispatchIndexQuery() {

@Test
void testDispatchWithPPLQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -401,6 +408,7 @@ void testDispatchWithPPLQuery() {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
Expand All @@ -410,6 +418,7 @@ void testDispatchWithPPLQuery() {
LangType.PPL,
EMRS_EXECUTION_ROLE,
TEST_CLUSTER_NAME));

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Expand All @@ -418,6 +427,7 @@ void testDispatchWithPPLQuery() {

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -463,6 +473,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() {

@Test
void testDispatchIndexQueryWithoutADatasourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -512,6 +523,7 @@ void testDispatchIndexQueryWithoutADatasourceName() {

@Test
void testDispatchMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_mv_1");
Expand Down Expand Up @@ -561,6 +573,7 @@ void testDispatchMaterializedViewQuery() {

@Test
void testDispatchShowMVQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -606,6 +619,7 @@ void testDispatchShowMVQuery() {

@Test
void testRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -651,6 +665,7 @@ void testRefreshIndexQuery() {

@Test
void testDispatchDescribeIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -699,6 +714,7 @@ void testDispatchWithWrongURI() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
.thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax());
String query = "select * from my_glue.default.http_logs";

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class,
Expand All @@ -711,6 +727,7 @@ void testDispatchWithWrongURI() {
LangType.SQL,
EMRS_EXECUTION_ROLE,
TEST_CLUSTER_NAME)));

Assertions.assertEquals(
"Bad URI in indexstore configuration of the : my_glue datasoure.",
illegalArgumentException.getMessage());
Expand All @@ -721,6 +738,7 @@ void testDispatchWithUnSupportedDataSourceType() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus"))
.thenReturn(constructPrometheusDataSourceType());
String query = "select * from my_prometheus.default.http_logs";

UnsupportedOperationException unsupportedOperationException =
Assertions.assertThrows(
UnsupportedOperationException.class,
Expand All @@ -733,19 +751,23 @@ void testDispatchWithUnSupportedDataSourceType() {
LangType.SQL,
EMRS_EXECUTION_ROLE,
TEST_CLUSTER_NAME)));

Assertions.assertEquals(
"UnSupported datasource type for async queries:: PROMETHEUS",
unsupportedOperationException.getMessage());
}

@Test
void testCancelJob() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));

String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata());

Assertions.assertEquals(QUERY_ID.getId(), queryId);
}

Expand Down Expand Up @@ -800,24 +822,29 @@ void testCancelQueryWithInvalidStatementId() {

@Test
void testCancelQueryWithNoSessionId() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));

String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata());

Assertions.assertEquals(QUERY_ID.getId(), queryId);
}

@Test
void testGetQueryResponse() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));

// simulate result index is not created yet
when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null))
.thenReturn(new JSONObject());

JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata());

Assertions.assertEquals("PENDING", result.get("status"));
}

Expand All @@ -827,10 +854,10 @@ void testGetQueryResponseWithSession() {
doReturn(Optional.of(statement)).when(session).get(any());
when(statement.getStatementModel().getError()).thenReturn("mock error");
doReturn(StatementState.WAITING).when(statement).getStatementState();

doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());

JSONObject result =
sparkQueryDispatcher.getQueryResponse(
asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID));
Expand All @@ -845,6 +872,7 @@ void testGetQueryResponseWithInvalidSession() {
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());

IllegalArgumentException exception =
Assertions.assertThrows(
IllegalArgumentException.class,
Expand All @@ -871,6 +899,7 @@ void testGetQueryResponseWithStatementNotExist() {
() ->
sparkQueryDispatcher.getQueryResponse(
asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)));

verifyNoInteractions(emrServerlessClient);
Assertions.assertEquals(
"no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage());
Expand Down Expand Up @@ -904,6 +933,7 @@ void testGetQueryResponseWithSuccess() {

@Test
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
.thenReturn(dataSourceMetadata);
Expand Down

0 comments on commit 02137dc

Please sign in to comment.