From 94e1809bf0c533392d3ca70ad9e82873dce4184d Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Thu, 7 Mar 2024 12:04:59 -0800 Subject: [PATCH] Change emr job names based on the query type Signed-off-by: Vamsi Manohar --- spark/src/main/antlr/SqlBaseParser.g4 | 4 +- .../spark/client/EmrServerlessClientImpl.java | 5 +- .../spark/dispatcher/BatchQueryHandler.java | 3 +- .../sql/spark/dispatcher/IndexDMLHandler.java | 6 - .../dispatcher/InteractiveQueryHandler.java | 3 +- .../dispatcher/SparkQueryDispatcher.java | 2 - .../dispatcher/StreamingQueryHandler.java | 7 +- .../session/CreateSessionRequest.java | 7 +- .../execution/session/InteractiveSession.java | 7 +- .../spark/dispatcher/IndexDMLHandlerTest.java | 2 +- .../dispatcher/SparkQueryDispatcherTest.java | 352 ++++++------------ .../session/InteractiveSessionTest.java | 16 +- .../model/CreateAsyncQueryRequestTest.java | 9 +- 13 files changed, 163 insertions(+), 260 deletions(-) diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 07fa56786b..801cc62491 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -989,7 +989,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast - | primaryExpression collateClause #collate + | primaryExpression collateClause #collate | primaryExpression DOUBLE_COLON dataType #castByColon | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first @@ -1096,7 +1096,7 @@ colPosition ; collateClause - : COLLATE collationName=stringLit + : COLLATE collationName=identifier ; type diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 913e1ac378..82644a2fb2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -19,6 +19,7 @@ import com.amazonaws.services.emrserverless.model.StartJobRunResult; import java.security.AccessController; import java.security.PrivilegedAction; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.legacy.metrics.MetricName; @@ -29,6 +30,8 @@ public class EmrServerlessClientImpl implements EMRServerlessClient { private final AWSEMRServerless emrServerless; private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class); + private static final int MAX_JOB_NAME_LENGTH = 255; + private static final String GENERIC_INTERNAL_SERVER_ERROR_MESSAGE = "Internal Server Error."; public EmrServerlessClientImpl(AWSEMRServerless emrServerless) { @@ -43,7 +46,7 @@ public String startJobRun(StartJobRequest startJobRequest) { : startJobRequest.getResultIndex(); StartJobRunRequest request = new StartJobRunRequest() - .withName(startJobRequest.getJobName()) + .withName(StringUtils.truncate(startJobRequest.getJobName(), MAX_JOB_NAME_LENGTH)) .withApplicationId(startJobRequest.getApplicationId()) .withExecutionRoleArn(startJobRequest.getExecutionRoleArn()) .withTags(startJobRequest.getTags()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 46dec38038..ecab31ebc9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -68,7 +68,6 @@ public DispatchQueryResponse submit( leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); String clusterName = dispatchQueryRequest.getClusterName(); - String jobName = clusterName + ":" + "non-index-query"; Map tags = context.getTags(); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -76,7 +75,7 @@ public DispatchQueryResponse submit( StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), - jobName, + clusterName + ":" + JobType.BATCH.getText(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index a03cd64986..f153e94713 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -15,9 +15,7 @@ import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.client.Client; -import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -44,10 +42,6 @@ public class IndexDMLHandler extends AsyncQueryHandler { private final EMRServerlessClient emrServerlessClient; - private final DataSourceService dataSourceService; - - private final DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; - private final JobExecutionResponseReader jobExecutionResponseReader; private final FlintIndexMetadataReader flintIndexMetadataReader; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 1afba22db7..7602988d26 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -71,7 +71,6 @@ public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { Session session = null; String clusterName = dispatchQueryRequest.getClusterName(); - String jobName = clusterName + ":" + "non-index-query"; Map tags = context.getTags(); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -94,7 +93,7 @@ public DispatchQueryResponse submit( session = sessionManager.createSession( new CreateSessionRequest( - jobName, + clusterName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 498a3b9af5..5b5745d438 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -127,8 +127,6 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) { return new IndexDMLHandler( emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, client, diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 75337a3dad..b64c4ffc8d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -44,12 +44,17 @@ public DispatchQueryResponse submit( leaseManager.borrow(new LeaseRequest(JobType.STREAMING, dispatchQueryRequest.getDatasource())); String clusterName = dispatchQueryRequest.getClusterName(); - String jobName = clusterName + ":" + "index-query"; IndexQueryDetails indexQueryDetails = context.getIndexQueryDetails(); Map tags = context.getTags(); tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName()); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); + String jobName = + clusterName + + ":" + + JobType.STREAMING.getText() + + ":" + + indexQueryDetails.openSearchIndexName(); StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index b2201fbd01..855e1ce5b2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -9,10 +9,11 @@ import lombok.Data; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.JobType; @Data public class CreateSessionRequest { - private final String jobName; + private final String clusterName; private final String applicationId; private final String executionRoleArn; private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; @@ -20,10 +21,10 @@ public class CreateSessionRequest { private final String resultIndex; private final String datasourceName; - public StartJobRequest getStartJobRequest() { + public StartJobRequest getStartJobRequest(String sessionId) { return new InteractiveSessionStartJobRequest( "select 1", - jobName, + clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId, applicationId, executionRoleArn, sparkSubmitParametersBuilder.build().toString(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index dd413674a1..254c5a34b4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -20,6 +20,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; @@ -55,8 +56,10 @@ public void open(CreateSessionRequest createSessionRequest) { .getSparkSubmitParametersBuilder() .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId()); - String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); - String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); + StartJobRequest startJobRequest = + createSessionRequest.getStartJobRequest(sessionId.getSessionId()); + String jobID = serverlessClient.startJobRun(startJobRequest); + String applicationId = startJobRequest.getApplicationId(); sessionModel = initInteractiveSession( diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 01c46c3c0b..ec82488749 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -16,7 +16,7 @@ class IndexDMLHandlerTest { @Test public void getResponseFromExecutor() { JSONObject result = - new IndexDMLHandler(null, null, null, null, null, null, null).getResponseFromExecutor(null); + new IndexDMLHandler(null, null, null, null, null).getResponseFromExecutor(null); assertEquals("running", result.getString(STATUS_FIELD)); assertEquals("", result.getString(ERROR_FIELD)); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2a499e7d30..867e1c94c4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -141,17 +141,17 @@ void testDispatchSelectQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -165,16 +165,6 @@ void testDispatchSelectQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -196,17 +186,17 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -220,16 +210,6 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -249,17 +229,17 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { { } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithNoAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -273,16 +253,6 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -368,17 +338,17 @@ void testDispatchIndexQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } })); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -392,16 +362,6 @@ void testDispatchIndexQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -422,17 +382,17 @@ void testDispatchWithPPLQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -446,16 +406,6 @@ void testDispatchWithPPLQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -476,17 +426,17 @@ void testDispatchQueryWithoutATableAndDataSourceName() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -500,16 +450,6 @@ void testDispatchQueryWithoutATableAndDataSourceName() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -534,17 +474,17 @@ void testDispatchIndexQueryWithoutADatasourceName() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } })); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -558,16 +498,6 @@ void testDispatchIndexQueryWithoutADatasourceName() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -592,17 +522,17 @@ void testDispatchMaterializedViewQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } })); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:streaming:flint_mv_1", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -616,16 +546,6 @@ void testDispatchMaterializedViewQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - true, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -646,17 +566,17 @@ void testDispatchShowMVQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -670,16 +590,6 @@ void testDispatchShowMVQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -700,17 +610,17 @@ void testRefreshIndexQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -724,16 +634,6 @@ void testRefreshIndexQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); @@ -754,17 +654,17 @@ void testDispatchDescribeIndexQuery() { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } }); - when(emrServerlessClient.startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - any()))) - .thenReturn(EMR_JOB_ID); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:batch", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -778,16 +678,6 @@ void testDispatchDescribeIndexQuery() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - StartJobRequest expected = - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - null); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); verifyNoInteractions(flintIndexMetadataReader); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 338da431fb..5669716684 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchIntegTestCase; @@ -33,6 +34,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { private static final String DS_NAME = "mys3"; private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -54,9 +56,10 @@ public void clean() { @Test public void openCloseSession() { + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() - .sessionId(SessionId.newSessionId(DS_NAME)) + .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -69,6 +72,8 @@ public void openCloseSession() { .assertAppId("appId") .assertJobId("jobId"); emrsClient.startJobRunCalled(1); + emrsClient.assertJobNameOfLastRequest( + TEST_CLUSTER_NAME + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId.getSessionId()); // close session testSession.close(); @@ -193,7 +198,7 @@ public TestSession close() { public static CreateSessionRequest createSessionRequest() { return new CreateSessionRequest( - "jobName", + TEST_CLUSTER_NAME, "appId", "arn", SparkSubmitParameters.Builder.builder(), @@ -207,8 +212,11 @@ public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; private int cancelJobRunCalled = 0; + private StartJobRequest startJobRequest; + @Override public String startJobRun(StartJobRequest startJobRequest) { + this.startJobRequest = startJobRequest; startJobRunCalled++; return "jobId"; } @@ -231,5 +239,9 @@ public void startJobRunCalled(int expectedTimes) { public void cancelJobRunCalled(int expectedTimes) { assertEquals(expectedTimes, cancelJobRunCalled); } + + public void assertJobNameOfLastRequest(String expectedJobName) { + assertEquals(expectedJobName, startJobRequest.getJobName()); + } } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java index 24f5a9d6fe..de38ca0e3c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -49,11 +49,10 @@ public void fromXContentWithDuplicateFields() throws IOException { Assertions.assertThrows( IllegalArgumentException.class, () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); - Assertions.assertEquals( - "Error while parsing the request body: Duplicate field 'datasource'\n" - + " at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` disabled);" - + " line: 3, column: 15]", - illegalArgumentException.getMessage()); + Assertions.assertTrue( + illegalArgumentException + .getMessage() + .contains("Error while parsing the request body: Duplicate field 'datasource'")); } @Test