diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 14107712f1..ea3f9a1eea 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -42,18 +42,22 @@ public CreateAsyncQueryResponse createAsyncQuery( sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - sparkExecutionEngineConfig.getApplicationId(), - createAsyncQueryRequest.getQuery(), - createAsyncQueryRequest.getDatasource(), - createAsyncQueryRequest.getLang(), - sparkExecutionEngineConfig.getExecutionRoleARN(), - sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameterModifier(), - createAsyncQueryRequest.getSessionId())); + DispatchQueryRequest.builder() + .accountId(sparkExecutionEngineConfig.getAccountId()) + .applicationId(sparkExecutionEngineConfig.getApplicationId()) + .query(createAsyncQueryRequest.getQuery()) + .datasource(createAsyncQueryRequest.getDatasource()) + .langType(createAsyncQueryRequest.getLang()) + .executionRoleARN(sparkExecutionEngineConfig.getExecutionRoleARN()) + .clusterName(sparkExecutionEngineConfig.getClusterName()) + .sparkSubmitParameterModifier( + sparkExecutionEngineConfig.getSparkSubmitParameterModifier()) + .sessionId(createAsyncQueryRequest.getSessionId()) + .build()); asyncQueryJobMetadataStorageService.storeJobMetadata( AsyncQueryJobMetadata.builder() .queryId(dispatchQueryResponse.getQueryId()) + .accountId(sparkExecutionEngineConfig.getAccountId()) .applicationId(sparkExecutionEngineConfig.getApplicationId()) .jobId(dispatchQueryResponse.getJobId()) .resultIndex(dispatchQueryResponse.getResultIndex()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index e1f30edc10..1ffb780ef1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -20,6 +20,8 @@ @EqualsAndHashCode(callSuper = false) public class AsyncQueryJobMetadata extends StateModel { private final String queryId; + // optional: accountId for EMRS cluster + private final String accountId; private final String applicationId; private final String jobId; private final String resultIndex; @@ -44,6 +46,7 @@ public static AsyncQueryJobMetadata copy( AsyncQueryJobMetadata copy, ImmutableMap metadata) { return builder() .queryId(copy.queryId) + .accountId(copy.accountId) .applicationId(copy.getApplicationId()) .jobId(copy.getJobId()) .resultIndex(copy.getResultIndex()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java index 4250d32b0e..2bbbd1f968 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -59,6 +59,7 @@ private void validateSparkExecutionEngineConfig( } private EMRServerlessClient createEMRServerlessClient(String awsRegion) { + // TODO: It does not handle accountId for now. (it creates client for same account) return AccessController.doPrivileged( (PrivilegedAction) () -> { diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index b532c439c0..173b40d453 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -20,6 +20,8 @@ public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; private final String jobName; + // optional + private final String accountId; private final String applicationId; private final String executionRoleArn; private final String sparkSubmitParams; diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java index 92636c3cfb..51407111b6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -1,6 +1,5 @@ package org.opensearch.sql.spark.config; -import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -11,8 +10,8 @@ */ @Data @Builder -@AllArgsConstructor public class SparkExecutionEngineConfig { + private String accountId; private String applicationId; private String region; private String executionRoleARN; diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java index b3f1295faa..338107f8a3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java @@ -16,6 +16,8 @@ @Data @JsonIgnoreProperties(ignoreUnknown = true) public class SparkExecutionEngineConfigClusterSetting { + // optional + private String accountId; private String applicationId; private String region; private String executionRoleARN; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 3bdbd8ca74..a88fe485fe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -79,6 +79,7 @@ public DispatchQueryResponse submit( StartJobRequest startJobRequest = new StartJobRequest( clusterName + ":" + JobType.BATCH.getText(), + dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index e41f4a49fd..bfab3a946b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -100,6 +100,7 @@ public DispatchQueryResponse submit( sessionManager.createSession( new CreateSessionRequest( clusterName, + dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 0649e81418..7b317d2218 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -66,6 +66,7 @@ public DispatchQueryResponse submit( StartJobRequest startJobRequest = new StartJobRequest( jobName, + dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 601103254f..066349873a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -6,15 +6,16 @@ package org.opensearch.sql.spark.dispatcher.model; import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; -import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.rest.model.LangType; @AllArgsConstructor @Data -@RequiredArgsConstructor // required explicitly +@Builder public class DispatchQueryRequest { + private final String accountId; private final String applicationId; private final String query; private final String datasource; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index d138e5f05d..4170f0c2d6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -14,6 +14,7 @@ @Data public class CreateSessionRequest { private final String clusterName; + private final String accountId; private final String applicationId; private final String executionRoleArn; private final SparkSubmitParameters sparkSubmitParameters; @@ -24,6 +25,7 @@ public class CreateSessionRequest { public StartJobRequest getStartJobRequest(String sessionId) { return new InteractiveSessionStartJobRequest( clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId, + accountId, applicationId, executionRoleArn, sparkSubmitParameters.toString(), @@ -34,12 +36,21 @@ public StartJobRequest getStartJobRequest(String sessionId) { static class InteractiveSessionStartJobRequest extends StartJobRequest { public InteractiveSessionStartJobRequest( String jobName, + String accountId, String applicationId, String executionRoleArn, String sparkSubmitParams, Map tags, String resultIndex) { - super(jobName, applicationId, executionRoleArn, sparkSubmitParams, tags, false, resultIndex); + super( + jobName, + accountId, + applicationId, + executionRoleArn, + sparkSubmitParams, + tags, + false, + resultIndex); } /** Interactive query keep running. */ diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 9920fb9aec..eaa69d9386 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -60,10 +60,11 @@ public void open(CreateSessionRequest createSessionRequest) { createSessionRequest.getStartJobRequest(sessionId.getSessionId()); String jobID = serverlessClient.startJobRun(startJobRequest); String applicationId = startJobRequest.getApplicationId(); + String accountId = createSessionRequest.getAccountId(); sessionModel = initInteractiveSession( - applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); sessionStorageService.createSession(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; @@ -99,6 +100,7 @@ public StatementId submit(QueryRequest request) { Statement st = Statement.builder() .sessionId(sessionId) + .accountId(sessionModel.getAccountId()) .applicationId(sessionModel.getApplicationId()) .jobId(sessionModel.getJobId()) .statementStorageService(statementStorageService) @@ -130,6 +132,7 @@ public Optional get(StatementId stID) { model -> Statement.builder() .sessionId(sessionId) + .accountId(model.getAccountId()) .applicationId(model.getApplicationId()) .jobId(model.getJobId()) .statementId(model.getStatementId()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index b79bef7b27..07a011515d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -24,6 +24,8 @@ public class SessionModel extends StateModel { private final SessionType sessionType; private final SessionId sessionId; private final SessionState sessionState; + // optional: accountId for EMRS cluster + private final String accountId; private final String applicationId; private final String jobId; private final String datasourceName; @@ -37,6 +39,7 @@ public static SessionModel of(SessionModel copy, ImmutableMap me .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(copy.sessionState) .datasourceName(copy.datasourceName) + .accountId(copy.accountId) .applicationId(copy.getApplicationId()) .jobId(copy.jobId) .error(UNKNOWN) @@ -53,6 +56,7 @@ public static SessionModel copyWithState( .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(state) .datasourceName(copy.datasourceName) + .accountId(copy.getAccountId()) .applicationId(copy.getApplicationId()) .jobId(copy.jobId) .error(UNKNOWN) @@ -62,13 +66,14 @@ public static SessionModel copyWithState( } public static SessionModel initInteractiveSession( - String applicationId, String jobId, SessionId sid, String datasourceName) { + String accountId, String applicationId, String jobId, SessionId sid, String datasourceName) { return builder() .version("1.0") .sessionType(INTERACTIVE) .sessionId(sid) .sessionState(NOT_STARTED) .datasourceName(datasourceName) + .accountId(accountId) .applicationId(applicationId) .jobId(jobId) .error(UNKNOWN) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index b0205aec64..d87d9fa89f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -25,6 +25,8 @@ public class Statement { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; + // optional + private final String accountId; private final String applicationId; private final String jobId; private final StatementId statementId; @@ -42,6 +44,7 @@ public void open() { statementModel = submitStatement( sessionId, + accountId, applicationId, jobId, statementId, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index 86e8d6e156..451cd8cd15 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -24,6 +24,8 @@ public class StatementModel extends StateModel { private final StatementState statementState; private final StatementId statementId; private final SessionId sessionId; + // optional: accountId for EMRS cluster + private final String accountId; private final String applicationId; private final String jobId; private final LangType langType; @@ -39,6 +41,7 @@ public static StatementModel copy(StatementModel copy, ImmutableMap metadata) { return builder() .indexState(copy.indexState) + .accountId(copy.accountId) .applicationId(copy.applicationId) .jobId(copy.jobId) .latestId(copy.latestId) @@ -42,6 +44,7 @@ public static FlintIndexStateModel copyWithState( FlintIndexStateModel copy, FlintIndexState state, ImmutableMap metadata) { return builder() .indexState(state) + .accountId(copy.accountId) .applicationId(copy.applicationId) .jobId(copy.jobId) .latestId(copy.latestId) diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 96ed18e897..9f258fb2a1 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -72,21 +72,23 @@ void testCreateAsyncQuery() { "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( - new SparkExecutionEngineConfig( - EMRS_APPLICATION_ID, - "eu-west-1", - EMRS_EXECUTION_ROLE, - sparkSubmitParameterModifier, - TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig.builder() + .applicationId(EMRS_APPLICATION_ID) + .region("eu-west-1") + .executionRoleARN(EMRS_EXECUTION_ROLE) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .clusterName(TEST_CLUSTER_NAME) + .build()); DispatchQueryRequest expectedDispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "select * from my_glue.default.http_logs", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query("select * from my_glue.default.http_logs") + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(); when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) .thenReturn( DispatchQueryResponse.builder() @@ -114,12 +116,14 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { new OpenSearchSparkSubmitParameterModifier("--conf spark.dynamicAllocation.enabled=false"); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( - new SparkExecutionEngineConfig( - EMRS_APPLICATION_ID, - "eu-west-1", - EMRS_EXECUTION_ROLE, - modifier, - TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig.builder() + .applicationId(EMRS_APPLICATION_ID) + .region("eu-west-1") + .executionRoleARN(EMRS_EXECUTION_ROLE) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .sparkSubmitParameterModifier(modifier) + .clusterName(TEST_CLUSTER_NAME) + .build()); when(sparkQueryDispatcher.dispatch(any())) .thenReturn( DispatchQueryResponse.builder() diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 16c37ad299..9ea7e91c54 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -73,6 +73,7 @@ void testStartJobRun() { emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, parameters, @@ -109,6 +110,7 @@ void testStartJobRunWithErrorMetric() { emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, @@ -127,6 +129,7 @@ void testStartJobRunResultIndex() { emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, @@ -217,6 +220,7 @@ void testStartJobRunWithLongJobName() { emrServerlessClient.startJobRun( new StartJobRequest( RandomStringUtils.random(300), + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, @@ -240,6 +244,7 @@ void testStartJobRunThrowsValidationException() { emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS, diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java index 3671cfaa42..ac5b0dd750 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java @@ -20,10 +20,10 @@ void executionTimeout() { } private StartJobRequest onDemandJob() { - return new StartJobRequest("", "", "", "", Map.of(), false, null); + return new StartJobRequest("", null, "", "", "", Map.of(), false, null); } private StartJobRequest streamingJob() { - return new StartJobRequest("", "", "", "", Map.of(), true, null); + return new StartJobRequest("", null, "", "", "", Map.of(), true, null); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 2e536ef6b3..877d6ec32b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -117,14 +117,15 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { } private DispatchQueryRequest getDispatchQueryRequest(String query) { - return new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); + return DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 5d04c86cce..bd9a0f2507 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -140,6 +140,7 @@ void testDispatchSelectQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -153,14 +154,15 @@ void testDispatchSelectQuery() { DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build()); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -189,6 +191,7 @@ void testDispatchSelectQueryWithLakeFormation() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -201,15 +204,7 @@ void testDispatchSelectQueryWithLakeFormation() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -237,6 +232,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -249,15 +245,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -284,6 +272,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -296,16 +285,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); - + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -400,6 +380,7 @@ void testDispatchIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -412,15 +393,7 @@ void testDispatchIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -448,6 +421,7 @@ void testDispatchWithPPLQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -461,14 +435,7 @@ void testDispatchWithPPLQuery() { DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.PPL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + getBaseDispatchQueryRequestBuilder(query).langType(LangType.PPL).build()); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -496,6 +463,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -508,15 +476,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -548,6 +508,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -560,15 +521,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -600,6 +553,7 @@ void testDispatchMaterializedViewQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_mv_1", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -612,15 +566,7 @@ void testDispatchMaterializedViewQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -648,6 +594,7 @@ void testDispatchShowMVQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -660,15 +607,7 @@ void testDispatchShowMVQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -696,6 +635,7 @@ void testRefreshIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -708,15 +648,7 @@ void testRefreshIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -744,6 +676,7 @@ void testDispatchDescribeIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -756,15 +689,7 @@ void testDispatchDescribeIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query)); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -781,16 +706,7 @@ void testDispatchWithWrongURI() { IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier))); + () -> sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query))); Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", @@ -808,14 +724,7 @@ void testDispatchWithUnSupportedDataSourceType() { UnsupportedOperationException.class, () -> sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_prometheus", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier))); + getBaseDispatchQueryRequestBuilder(query).datasource("my_prometheus").build())); Assertions.assertEquals( "UnSupported datasource type for async queries:: PROMETHEUS", @@ -1187,29 +1096,33 @@ private DataSourceMetadata constructPrometheusDataSourceType() { .build(); } + private DispatchQueryRequest getBaseDispatchQueryRequest(String query) { + return getBaseDispatchQueryRequestBuilder(query).build(); + } + + private DispatchQueryRequest.DispatchQueryRequestBuilder getBaseDispatchQueryRequestBuilder( + String query) { + return DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier); + } + private DispatchQueryRequest constructDispatchQueryRequest( String query, LangType langType, String extraParameters) { - return new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - langType, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - (parameters) -> parameters.setExtraParameters(extraParameters), - null); + return getBaseDispatchQueryRequestBuilder(query) + .langType(langType) + .sparkSubmitParameterModifier( + (parameters) -> parameters.setExtraParameters(extraParameters)) + .build(); } private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { - return new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier, - sessionId); + return getBaseDispatchQueryRequestBuilder(query).sessionId(sessionId).build(); } private AsyncQueryJobMetadata asyncQueryJobMetadata() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 0c606cc5df..29a3a9cba8 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -47,7 +47,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { @Before public void setup() { emrsClient = new TestEMRServerlessClient(); - startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, ""); + startJobRequest = new StartJobRequest("", null, "appId", "", "", new HashMap<>(), false, ""); StateStore stateStore = new StateStore(client(), clusterService()); sessionStorageService = new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java index 6c1514e6e4..06689a15d0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java @@ -16,6 +16,7 @@ public class SessionTestUtil { public static CreateSessionRequest createSessionRequest() { return new CreateSessionRequest( TEST_CLUSTER_NAME, + null, "appId", "arn", SparkSubmitParameters.builder().build(), diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java index f0cce5405c..c43a6f936e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java @@ -9,10 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import org.json.JSONObject; import org.junit.jupiter.api.Test; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -29,6 +27,7 @@ void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = AsyncQueryJobMetadata.builder() .queryId("query1") + .accountId("account1") .applicationId("app1") .jobId("job1") .resultIndex("result1") @@ -45,6 +44,7 @@ void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { assertEquals(true, json.contains("\"queryId\":\"query1\"")); assertEquals(true, json.contains("\"type\":\"jobmeta\"")); assertEquals(true, json.contains("\"jobId\":\"job1\"")); + assertEquals(true, json.contains("\"accountId\":\"account1\"")); assertEquals(true, json.contains("\"applicationId\":\"app1\"")); assertEquals(true, json.contains("\"resultIndex\":\"result1\"")); assertEquals(true, json.contains("\"sessionId\":\"session1\"")); @@ -55,24 +55,14 @@ void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { @Test void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { - XContentParser parser = - prepareParserForJson( - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"applicationId\": \"app1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"interactive\",\n" - + " \"indexName\": \"index1\"\n" - + "}"); + String json = getBaseJson().toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); + assertEquals("account1", jobMetadata.getAccountId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex()); assertEquals("session1", jobMetadata.getSessionId()); @@ -82,67 +72,39 @@ void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { } @Test - void fromXContentShouldThrowExceptionWhenMissingRequiredFields() throws Exception { - XContentParser parser = - prepareParserForJson( - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"asyncqueryjobmeta\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"async_query\",\n" - + " \"indexName\": \"index1\"\n" - + "}"); + void fromXContentShouldThrowExceptionWhenMissingJobId() throws Exception { + String json = getJsonWithout("jobId").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test - void fromXContentShouldDeserializeWithMissingApplicationId() throws Exception { - XContentParser parser = - prepareParserForJson( - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"interactive\",\n" - + " \"indexName\": \"index1\"\n" - + "}"); + void fromXContentShouldThrowExceptionWhenMissingApplicationId() throws Exception { + String json = getJsonWithout("applicationId").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test void fromXContentShouldThrowExceptionWhenUnknownFields() throws Exception { - XContentParser parser = prepareParserForJson("{\"unknownAttr\": \"index1\"}"); + String json = getBaseJson().put("unknownAttr", "index1").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { - XContentParser parser = - prepareParserForJson( - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"applicationId\": \"app1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"\",\n" - + " \"indexName\": \"index1\"\n" - + "}"); + String json = getBaseJson().put("jobType", "").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); + assertEquals("account1", jobMetadata.getAccountId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex()); assertEquals("session1", jobMetadata.getSessionId()); @@ -152,26 +114,49 @@ void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { } @Test - void fromXContentShouldDeserializeAsyncQueryWithoutJobId() throws Exception { - XContentParser parser = - prepareParserForJson("{\"queryId\": \"query1\", \"applicationId\": \"app1\"}"); + void fromXContentShouldDeserializeAsyncQueryWithAccountIdNUll() throws Exception { + String json = getJsonWithout("accountId").put("jobType", "").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); - assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); + AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); + + assertEquals("query1", jobMetadata.getQueryId()); + assertEquals("job1", jobMetadata.getJobId()); + assertEquals("app1", jobMetadata.getApplicationId()); + assertEquals("result1", jobMetadata.getResultIndex()); + assertEquals("session1", jobMetadata.getSessionId()); + assertEquals("datasource1", jobMetadata.getDatasourceName()); + assertNull(jobMetadata.getJobType()); + assertEquals("index1", jobMetadata.getIndexName()); } @Test - void fromXContentShouldDeserializeAsyncQueryWithoutApplicationId() throws Exception { - XContentParser parser = prepareParserForJson("{\"queryId\": \"query1\", \"jobId\": \"job1\"}"); + void fromXContentShouldDeserializeAsyncQueryWithoutJobId() throws Exception { + String json = getJsonWithout("jobId").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } - private XContentParser prepareParserForJson(String json) throws Exception { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); - return parser; + private JSONObject getJsonWithout(String... attrs) { + JSONObject result = getBaseJson(); + for (String attr : attrs) { + result.remove(attr); + } + return result; + } + + private JSONObject getBaseJson() { + return new JSONObject() + .put("queryId", "query1") + .put("type", "jobmeta") + .put("jobId", "job1") + .put("accountId", "account1") + .put("applicationId", "app1") + .put("resultIndex", "result1") + .put("sessionId", "session1") + .put("dataSourceName", "datasource1") + .put("jobType", "interactive") + .put("indexName", "index1"); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializerTest.java index be8875d694..0d6d5f3119 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializerTest.java @@ -6,15 +6,14 @@ package org.opensearch.sql.spark.execution.xcontent; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -32,6 +31,7 @@ void toXContentShouldSerializeFlintIndexStateModel() throws Exception { FlintIndexStateModel flintIndexStateModel = FlintIndexStateModel.builder() .indexState(FlintIndexState.ACTIVE) + .accountId("account1") .applicationId("app1") .jobId("job1") .latestId("latest1") @@ -47,6 +47,7 @@ void toXContentShouldSerializeFlintIndexStateModel() throws Exception { assertEquals(true, json.contains("\"version\":\"1.0\"")); assertEquals(true, json.contains("\"type\":\"flintindexstate\"")); assertEquals(true, json.contains("\"state\":\"active\"")); + assertEquals(true, json.contains("\"accountId\":\"account1\"")); assertEquals(true, json.contains("\"applicationId\":\"app1\"")); assertEquals(true, json.contains("\"jobId\":\"job1\"")); assertEquals(true, json.contains("\"latestId\":\"latest1\"")); @@ -55,23 +56,56 @@ void toXContentShouldSerializeFlintIndexStateModel() throws Exception { @Test void fromXContentShouldDeserializeFlintIndexStateModel() throws Exception { - String json = - "{\"version\":\"1.0\",\"type\":\"flintindexstate\",\"state\":\"active\",\"applicationId\":\"app1\",\"jobId\":\"job1\",\"latestId\":\"latest1\",\"dataSourceName\":\"datasource1\",\"lastUpdateTime\":1623456789,\"error\":\"\"}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + String json = getBaseJson().toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); FlintIndexStateModel flintIndexStateModel = serializer.fromXContent(parser, 1L, 1L); assertEquals(FlintIndexState.ACTIVE, flintIndexStateModel.getIndexState()); + assertEquals("account1", flintIndexStateModel.getAccountId()); assertEquals("app1", flintIndexStateModel.getApplicationId()); assertEquals("job1", flintIndexStateModel.getJobId()); assertEquals("latest1", flintIndexStateModel.getLatestId()); assertEquals("datasource1", flintIndexStateModel.getDatasourceName()); } + @Test + void fromXContentShouldDeserializeFlintIndexStateModelWithoutAccountId() throws Exception { + String json = getJsonWithout("accountId").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); + + FlintIndexStateModel flintIndexStateModel = serializer.fromXContent(parser, 1L, 1L); + + assertEquals(FlintIndexState.ACTIVE, flintIndexStateModel.getIndexState()); + assertNull(flintIndexStateModel.getAccountId()); + assertEquals("app1", flintIndexStateModel.getApplicationId()); + assertEquals("job1", flintIndexStateModel.getJobId()); + assertEquals("latest1", flintIndexStateModel.getLatestId()); + assertEquals("datasource1", flintIndexStateModel.getDatasourceName()); + } + + private JSONObject getJsonWithout(String attr) { + JSONObject result = getBaseJson(); + result.remove(attr); + return result; + } + + private JSONObject getBaseJson() { + return new JSONObject() + .put("version", "1.0") + .put("type", "flintindexstate") + .put("state", "active") + .put("statementId", "statement1") + .put("sessionId", "session1") + .put("accountId", "account1") + .put("applicationId", "app1") + .put("jobId", "job1") + .put("latestId", "latest1") + .put("dataSourceName", "datasource1") + .put("lastUpdateTime", 1623456789) + .put("error", ""); + } + @Test void fromXContentThrowsExceptionWhenParsingInvalidContent() { XContentParser parser = mock(XContentParser.class); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializerTest.java index a5e8696465..36c019485f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializerTest.java @@ -5,14 +5,13 @@ package org.opensearch.sql.spark.execution.xcontent; +import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import org.json.JSONObject; import org.junit.jupiter.api.Test; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -34,6 +33,7 @@ void toXContentShouldSerializeSessionModel() throws Exception { .sessionId(new SessionId("session1")) .sessionState(SessionState.FAIL) .datasourceName("datasource1") + .accountId("account1") .applicationId("app1") .jobId("job1") .lastUpdateTime(System.currentTimeMillis()) @@ -49,30 +49,15 @@ void toXContentShouldSerializeSessionModel() throws Exception { assertEquals(true, json.contains("\"sessionId\":\"session1\"")); assertEquals(true, json.contains("\"state\":\"fail\"")); assertEquals(true, json.contains("\"dataSourceName\":\"datasource1\"")); + assertEquals(true, json.contains("\"accountId\":\"account1\"")); assertEquals(true, json.contains("\"applicationId\":\"app1\"")); assertEquals(true, json.contains("\"jobId\":\"job1\"")); } @Test void fromXContentShouldDeserializeSessionModel() throws Exception { - String json = - "{\n" - + " \"version\": \"1.0\",\n" - + " \"type\": \"session\",\n" - + " \"sessionType\": \"interactive\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"state\": \"fail\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"applicationId\": \"app1\",\n" - + " \"jobId\": \"job1\",\n" - + " \"lastUpdateTime\": 1623456789,\n" - + " \"error\": \"\"\n" - + "}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + String json = getBaseJson().toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); SessionModel sessionModel = serializer.fromXContent(parser, 1L, 1L); @@ -81,10 +66,49 @@ void fromXContentShouldDeserializeSessionModel() throws Exception { assertEquals("session1", sessionModel.getSessionId().getSessionId()); assertEquals(SessionState.FAIL, sessionModel.getSessionState()); assertEquals("datasource1", sessionModel.getDatasourceName()); + assertEquals("account1", sessionModel.getAccountId()); assertEquals("app1", sessionModel.getApplicationId()); assertEquals("job1", sessionModel.getJobId()); } + @Test + void fromXContentShouldDeserializeSessionModelWithoutAccountId() throws Exception { + String json = getJsonWithout("accountId").toString(); + XContentParser parser = XContentSerializerTestUtil.prepareParser(json); + + SessionModel sessionModel = serializer.fromXContent(parser, 1L, 1L); + + assertEquals("1.0", sessionModel.getVersion()); + assertEquals(SessionType.INTERACTIVE, sessionModel.getSessionType()); + assertEquals("session1", sessionModel.getSessionId().getSessionId()); + assertEquals(SessionState.FAIL, sessionModel.getSessionState()); + assertEquals("datasource1", sessionModel.getDatasourceName()); + assertNull(sessionModel.getAccountId()); + assertEquals("app1", sessionModel.getApplicationId()); + assertEquals("job1", sessionModel.getJobId()); + } + + private JSONObject getJsonWithout(String attr) { + JSONObject result = getBaseJson(); + result.remove(attr); + return result; + } + + private JSONObject getBaseJson() { + return new JSONObject() + .put("version", "1.0") + .put("type", "session") + .put("sessionType", "interactive") + .put("sessionId", "session1") + .put("state", "fail") + .put("dataSourceName", "datasource1") + .put("accountId", "account1") + .put("applicationId", "app1") + .put("jobId", "job1") + .put("lastUpdateTime", 1623456789) + .put("error", ""); + } + @Test void fromXContentThrowsExceptionWhenParsingInvalidContent() { XContentParser parser = mock(XContentParser.class); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializerTest.java index 40e5873ce2..cdca39d051 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializerTest.java @@ -6,15 +6,14 @@ package org.opensearch.sql.spark.execution.xcontent; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -38,6 +37,7 @@ void toXContentShouldSerializeStatementModel() throws Exception { .statementState(StatementState.RUNNING) .statementId(new StatementId("statement1")) .sessionId(new SessionId("session1")) + .accountId("account1") .applicationId("app1") .jobId("job1") .langType(LangType.SQL) @@ -55,19 +55,16 @@ void toXContentShouldSerializeStatementModel() throws Exception { assertEquals(true, json.contains("\"version\":\"1.0\"")); assertEquals(true, json.contains("\"state\":\"running\"")); assertEquals(true, json.contains("\"statementId\":\"statement1\"")); + assertEquals(true, json.contains("\"accountId\":\"account1\"")); + assertEquals(true, json.contains("\"applicationId\":\"app1\"")); + assertEquals(true, json.contains("\"jobId\":\"job1\"")); } @Test void fromXContentShouldDeserializeStatementModel() throws Exception { StatementModelXContentSerializer serializer = new StatementModelXContentSerializer(); - String json = - "{\"version\":\"1.0\",\"type\":\"statement\",\"state\":\"running\",\"statementId\":\"statement1\",\"sessionId\":\"session1\",\"applicationId\":\"app1\",\"jobId\":\"job1\",\"lang\":\"SQL\",\"dataSourceName\":\"datasource1\",\"query\":\"SELECT" - + " * FROM table\",\"queryId\":\"query1\",\"submitTime\":1623456789,\"error\":\"\"}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + String json = getBaseJson().toString(); + final XContentParser parser = XContentSerializerTestUtil.prepareParser(json); StatementModel statementModel = serializer.fromXContent(parser, 1L, 1L); @@ -75,21 +72,22 @@ void fromXContentShouldDeserializeStatementModel() throws Exception { assertEquals(StatementState.RUNNING, statementModel.getStatementState()); assertEquals("statement1", statementModel.getStatementId().getId()); assertEquals("session1", statementModel.getSessionId().getSessionId()); + assertEquals("account1", statementModel.getAccountId()); } @Test - void fromXContentShouldDeserializeStatementModelThrowException() throws Exception { + void fromXContentShouldDeserializeStatementModelWithoutAccountId() throws Exception { StatementModelXContentSerializer serializer = new StatementModelXContentSerializer(); - String json = - "{\"version\":\"1.0\",\"type\":\"statement_state\",\"state\":\"running\",\"statementId\":\"statement1\",\"sessionId\":\"session1\",\"applicationId\":\"app1\",\"jobId\":\"job1\",\"lang\":\"SQL\",\"dataSourceName\":\"datasource1\",\"query\":\"SELECT" - + " * FROM table\",\"queryId\":\"query1\",\"submitTime\":1623456789,\"error\":null}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); - - assertThrows(IllegalStateException.class, () -> serializer.fromXContent(parser, 1L, 1L)); + String json = getJsonWithout("accountId").toString(); + final XContentParser parser = XContentSerializerTestUtil.prepareParser(json); + + StatementModel statementModel = serializer.fromXContent(parser, 1L, 1L); + + assertEquals("1.0", statementModel.getVersion()); + assertEquals(StatementState.RUNNING, statementModel.getStatementState()); + assertEquals("statement1", statementModel.getStatementId().getId()); + assertEquals("session1", statementModel.getSessionId().getSessionId()); + assertNull(statementModel.getAccountId()); } @Test @@ -102,21 +100,35 @@ void fromXContentThrowsExceptionWhenParsingInvalidContent() { @Test void fromXContentShouldThrowExceptionForUnexpectedField() throws Exception { StatementModelXContentSerializer serializer = new StatementModelXContentSerializer(); - String jsonWithUnexpectedField = - "{\"version\":\"1.0\",\"type\":\"statement\",\"state\":\"running\",\"statementId\":\"statement1\",\"sessionId\":\"session1\",\"applicationId\":\"app1\",\"jobId\":\"job1\",\"lang\":\"SQL\",\"dataSourceName\":\"datasource1\",\"query\":\"SELECT" - + " * FROM" - + " table\",\"queryId\":\"query1\",\"submitTime\":1623456789,\"error\":\"\",\"unexpectedField\":\"someValue\"}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - jsonWithUnexpectedField); - parser.nextToken(); + String json = getBaseJson().put("unexpectedField", "someValue").toString(); + final XContentParser parser = XContentSerializerTestUtil.prepareParser(json); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); assertEquals("Unexpected field: unexpectedField", exception.getMessage()); } + + private JSONObject getJsonWithout(String attr) { + JSONObject result = getBaseJson(); + result.remove(attr); + return result; + } + + private JSONObject getBaseJson() { + return new JSONObject() + .put("version", "1.0") + .put("type", "statement") + .put("state", "running") + .put("statementId", "statement1") + .put("sessionId", "session1") + .put("accountId", "account1") + .put("applicationId", "app1") + .put("jobId", "job1") + .put("lang", "SQL") + .put("dataSourceName", "datasource1") + .put("query", "SELECT * FROM table") + .put("queryId", "query1") + .put("submitTime", 1623456789) + .put("error", ""); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerTestUtil.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerTestUtil.java new file mode 100644 index 0000000000..a9356b6908 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerTestUtil.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.xcontent; + +import java.io.IOException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class XContentSerializerTestUtil { + public static XContentParser prepareParser(String json) throws IOException { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); + parser.nextToken(); + return parser; + } +}