From 6aa43d6429e30b9180f176d174dcc2fbdb3d8c22 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 20 Oct 2023 17:09:53 +0000 Subject: [PATCH] Integration with REPL Spark job (#2327) * add InteractiveSession and SessionManager Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * fix format Signed-off-by: Peng Huo * snapshot Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo * update Signed-off-by: Peng Huo * Update REST and Transport interface Signed-off-by: Peng Huo * Revert on transport layer Signed-off-by: Peng Huo * format code Signed-off-by: Peng Huo * add API doc Signed-off-by: Peng Huo * modify api Signed-off-by: Peng Huo * create query_execution_request index on demand Signed-off-by: Peng Huo * add REPL spark parameters Signed-off-by: Peng Huo * Add IT Signed-off-by: Peng Huo * format code Signed-off-by: Peng Huo * bind request index to datasource Signed-off-by: Peng Huo * fix bug when fetch query result Signed-off-by: Peng Huo * revert entrypoint class Signed-off-by: Peng Huo * update mapping Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo (cherry picked from commit 7b4156e0ad3b9194cc0bf59f43971e67c3941aae) Signed-off-by: github-actions[bot] --- .../org/opensearch/sql/plugin/SQLPlugin.java | 5 +- spark/build.gradle | 1 + .../model/SparkSubmitParameters.java | 14 +- .../spark/data/constants/SparkConstants.java | 4 + .../dispatcher/SparkQueryDispatcher.java | 67 ++-- .../session/CreateSessionRequest.java | 21 +- .../execution/session/InteractiveSession.java | 15 +- .../spark/execution/session/SessionId.java | 21 +- .../execution/session/SessionManager.java | 5 +- .../spark/execution/session/SessionState.java | 7 +- .../spark/execution/session/SessionType.java | 14 +- .../spark/execution/statement/Statement.java | 20 +- .../execution/statement/StatementModel.java | 10 + .../execution/statement/StatementState.java | 7 +- .../execution/statestore/StateStore.java | 203 +++++++--- .../response/JobExecutionResponseReader.java | 4 + .../query_execution_request_mapping.yml | 40 ++ .../query_execution_request_settings.yml | 11 + ...AsyncQueryExecutorServiceImplSpecTest.java | 374 ++++++++++++++++++ .../dispatcher/SparkQueryDispatcherTest.java | 6 +- .../session/InteractiveSessionTest.java | 55 ++- .../execution/statement/StatementTest.java | 63 +-- 22 files changed, 810 insertions(+), 157 deletions(-) create mode 100644 spark/src/main/resources/query_execution_request_mapping.yml create mode 100644 spark/src/main/resources/query_execution_request_settings.yml create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index eb6eabf988..f714a8366b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -321,9 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), - emrServerlessClient, - pluginSettings)); + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/build.gradle b/spark/build.gradle index 15f1e200e0..8f4388495e 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -68,6 +68,7 @@ dependencies { because 'allows tests to run from IDEs that bundle older version of launcher' } testImplementation("org.opensearch.test:framework:${opensearch_version}") + testImplementation project(':opensearch') } test { diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 0609d8903c..db78abb2a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; import static org.opensearch.sql.spark.data.constants.SparkConstants.*; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.net.URI; import java.net.URISyntaxException; @@ -39,7 +40,7 @@ public class SparkSubmitParameters { public static class Builder { - private final String className; + private String className; private final Map config; private String extraParameters; @@ -70,6 +71,11 @@ public static Builder builder() { return new Builder(); } + public Builder className(String className) { + this.className = className; + return this; + } + public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); @@ -141,6 +147,12 @@ public Builder extraParameters(String params) { return this; } + public Builder sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + config.put(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 1b248eb15d..85ce3c4989 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -87,4 +87,8 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; + + public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; + public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 8d5ae10e91..2bd1ae67b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -96,12 +97,19 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); } - // either empty json when the result is not available or data with status - // Fetch from Result Index - JSONObject result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - + JSONObject result; + if (asyncQueryJobMetadata.getSessionId() == null) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + result = + jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } else { + // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. + result = + jobExecutionResponseReader.getResultWithQueryId( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } // if result index document has a status, we are gonna use the status directly; otherwise, we // will use emr-s job status. // That a job is successful does not mean there is no error in execution. For example, even if @@ -230,22 +238,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - false, - dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { Session session; if (dispatchQueryRequest.getSessionId() != null) { @@ -260,7 +253,19 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // create session if not exist session = sessionManager.createSession( - new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + new CreateSessionRequest( + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .className(FLINT_SESSION_CLASS_NAME) + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + tags, + dataSourceMetadata.getResultIndex(), + dataSourceMetadata.getName())); } StatementId statementId = session.submit( @@ -272,6 +277,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); } else { + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 17e3346248..ca2b2b4867 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -5,11 +5,30 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Map; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; @Data public class CreateSessionRequest { - private final StartJobRequest startJobRequest; + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; + private final Map tags; + private final String resultIndex; private final String datasourceName; + + public StartJobRequest getStartJobRequest() { + return new StartJobRequest( + "select 1", + jobName, + applicationId, + executionRoleArn, + sparkSubmitParametersBuilder.build().toString(), + tags, + false, + resultIndex); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index e33ef4245a..4428c3b83d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -42,13 +42,17 @@ public class InteractiveSession implements Session { @Override public void open(CreateSessionRequest createSessionRequest) { try { + // append session id; + createSessionRequest + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore).apply(sessionModel); + createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -59,7 +63,8 @@ public void open(CreateSessionRequest createSessionRequest) { /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -69,7 +74,8 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -84,6 +90,7 @@ public StatementId submit(QueryRequest request) { .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) + .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) .queryId(statementId.getId()) .build(); @@ -103,7 +110,7 @@ public StatementId submit(QueryRequest request) { @Override public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore) + return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) .apply(stID.getId()) .map( model -> diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index 861d906b9b..b3bd716925 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,15 +5,32 @@ package org.opensearch.sql.spark.execution.session; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import lombok.Data; import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { + public static final int PREFIX_LEN = 10; + private final String sessionId; - public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.randomAlphanumeric(16)); + public static SessionId newSessionId(String datasourceName) { + return new SessionId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(sessionId); + } + + private static String decode(String sessionId) { + return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); + } + + private static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c34be7015f..c0f7bbcde8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -28,7 +28,7 @@ public class SessionManager { public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() - .sessionId(newSessionId()) + .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); @@ -37,7 +37,8 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); + Optional model = + StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index a4da957f12..bd5d14c603 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -32,8 +33,10 @@ public enum SessionState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static SessionState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (SessionState ss : SessionState.values()) { + if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid session state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java index dd179a1dc5..10b9ce7bd5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -5,9 +5,7 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Arrays; -import java.util.Map; -import java.util.stream.Collectors; +import java.util.Locale; import lombok.Getter; @Getter @@ -20,13 +18,11 @@ public enum SessionType { this.sessionType = sessionType; } - private static Map TYPES = - Arrays.stream(SessionType.values()) - .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); - public static SessionType fromString(String key) { - if (TYPES.containsKey(key)) { - return TYPES.get(key); + for (SessionType sType : SessionType.values()) { + if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) { + return sType; + } } throw new IllegalArgumentException("Invalid session type: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 8fcedb5fca..d84c91bdb8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -32,6 +32,7 @@ public class Statement { private final String jobId; private final StatementId statementId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final StateStore stateStore; @@ -42,8 +43,16 @@ public class Statement { public void open() { try { statementModel = - submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); - statementModel = createStatement(stateStore).apply(statementModel); + submitStatement( + sessionId, + applicationId, + jobId, + statementId, + langType, + datasourceName, + query, + queryId); + statementModel = createStatement(stateStore, datasourceName).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -61,7 +70,8 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + updateStatementState(stateStore, statementModel.getDatasourceName()) + .apply(this.statementModel, StatementState.CANCELLED); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -69,7 +79,9 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + getStatement(stateStore, statementModel.getDatasourceName()) + .apply(statementModel.getId()) + .orElse(this.statementModel); String errorMsg = String.format( "cancel statement failed. current statementState: %s " + "statement: %s.", diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index c7f681c541..2a1043bf73 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; @@ -45,6 +46,7 @@ public class StatementModel extends StateModel { private final String applicationId; private final String jobId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final long submitTime; @@ -65,6 +67,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(APPLICATION_ID, applicationId) .field(JOB_ID, jobId) .field(LANG, langType.getText()) + .field(DATASOURCE_NAME, datasourceName) .field(QUERY, query) .field(QUERY_ID, queryId) .field(SUBMIT_TIME, submitTime) @@ -82,6 +85,7 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -101,6 +105,7 @@ public static StatementModel copyWithState( .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -143,6 +148,9 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case LANG: builder.langType(LangType.fromString(parser.text())); break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; case QUERY: builder.query(parser.text()); break; @@ -168,6 +176,7 @@ public static StatementModel submitStatement( String jobId, StatementId statementId, LangType langType, + String datasourceName, String query, String queryId) { return builder() @@ -178,6 +187,7 @@ public static StatementModel submitStatement( .applicationId(applicationId) .jobId(jobId) .langType(langType) + .datasourceName(datasourceName) .query(query) .queryId(queryId) .submitTime(System.currentTimeMillis()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 33f7f5e831..48978ff8f9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -30,8 +31,10 @@ public enum StatementState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static StatementState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (StatementState ss : StatementState.values()) { + if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid statement state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index bd72b17353..a36ee3ef45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -5,15 +5,22 @@ package org.opensearch.sql.spark.execution.statestore; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.Locale; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import lombok.RequiredArgsConstructor; +import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; @@ -22,6 +29,9 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -33,15 +43,29 @@ import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; +/** + * State Store maintain the state of Session and Statement. State State create/update/get doc on + * index regardless user FGAC permissions. + */ @RequiredArgsConstructor public class StateStore { + public static String SETTINGS_FILE_NAME = "query_execution_request_settings.yml"; + public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; + public static Function DATASOURCE_TO_REQUEST_INDEX = + datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); + private static final Logger LOG = LogManager.getLogger(); - private final String indexName; private final Client client; + private final ClusterService clusterService; - protected T create(T st, StateModel.CopyBuilder builder) { + protected T create( + T st, StateModel.CopyBuilder builder, String indexName) { try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); + } IndexRequest indexRequest = new IndexRequest(indexName) .id(st.getId()) @@ -50,48 +74,60 @@ protected T create(T st, StateModel.CopyBuilder builde .setIfPrimaryTerm(st.getPrimaryTerm()) .create(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - st.getId(), - indexResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + ; + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } - protected Optional get(String sid, StateModel.FromXContent builder) { + protected Optional get( + String sid, StateModel.FromXContent builder, String indexName) { try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); return Optional.empty(); } + GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } } catch (IOException e) { throw new RuntimeException(e); } } protected T updateState( - T st, S state, StateModel.StateCopyBuilder builder) { + T st, S state, StateModel.StateCopyBuilder builder, String indexName) { try { T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); UpdateRequest updateRequest = @@ -103,47 +139,110 @@ protected T updateState( .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) .fetchSource(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - UpdateResponse updateResponse = client.update(updateRequest).actionGet(); - if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { - LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed update doc. id: %s, error: %s", - st.getId(), - updateResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of( + model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } + private void createIndex(String indexName) { + try { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest + .mapping(loadConfigFromResource(MAPPING_FILE_NAME), XContentType.YAML) + .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", indexName); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + indexName + " index:: " + e.getMessage()); + } + } + + private String loadConfigFromResource(String fileName) throws IOException { + InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); + return IOUtils.toString(fileStream, StandardCharsets.UTF_8); + } + /** Helper Functions */ - public static Function createStatement(StateStore stateStore) { - return (st) -> stateStore.create(st, StatementModel::copy); + public static Function createStatement( + StateStore stateStore, String datasourceName) { + return (st) -> + stateStore.create( + st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getStatement(StateStore stateStore) { - return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + public static Function> getStatement( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } public static BiFunction updateStatementState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function createSession( + StateStore stateStore, String datasourceName) { + return (session) -> + stateStore.create( + session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function createSession(StateStore stateStore) { - return (session) -> stateStore.create(session, SessionModel::of); + public static Function> getSession( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + public static Function> searchSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); } public static BiFunction updateSessionState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { + String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + return () -> stateStore.createIndex(indexName); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index d3cbd68dce..2614992463 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -39,6 +39,10 @@ public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } + public JSONObject getResultWithQueryId(String queryId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex); + } + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml new file mode 100644 index 0000000000..87bd927e6e --- /dev/null +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -0,0 +1,40 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + type: + type: keyword + state: + type: keyword + statementId: + type: keyword + applicationId: + type: keyword + sessionId: + type: keyword + sessionType: + type: keyword + error: + type: text + lang: + type: keyword + query: + type: text + dataSourceName: + type: keyword + submitTime: + type: date + format: strict_date_time||epoch_millis + jobId: + type: keyword + lastUpdateTime: + type: date + format: strict_date_time||epoch_millis + queryId: + type: keyword diff --git a/spark/src/main/resources/query_execution_request_settings.yml b/spark/src/main/resources/query_execution_request_settings.yml new file mode 100644 index 0000000000..da2bf07bf1 --- /dev/null +++ b/spark/src/main/resources/query_execution_request_settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java new file mode 100644 index 0000000000..3eb8958eb2 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -0,0 +1,374 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; +import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import lombok.Getter; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.test.OpenSearchIntegTestCase; + +public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { + public static final String DATASOURCE = "mys3"; + + private ClusterService clusterService; + private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; + private DataSourceServiceImpl dataSourceService; + private StateStore stateStore; + private ClusterSettings clusterSettings; + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TestSettingPlugin.class); + } + + public static class TestSettingPlugin extends Plugin { + @Override + public List> getSettings() { + return OpenSearchSettings.pluginSettings(); + } + } + + @Before + public void setup() { + clusterService = clusterService(); + clusterSettings = clusterService.getClusterSettings(); + pluginSettings = new OpenSearchSettings(clusterSettings); + client = (NodeClient) cluster().client(); + dataSourceService = createDataSourceService(); + dataSourceService.createDataSource( + new DataSourceMetadata( + DATASOURCE, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null)); + stateStore = new StateStore(client, clusterService); + createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); + } + + @After + public void clean() { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) + .get(); + } + + @Test + public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // disable session + enableSession(false); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); + emrsClient.startJobRunCalled(1); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals("RUNNING", asyncQueryResults.getStatus()); + emrsClient.getJobRunResultCalled(1); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void createAsyncQueryCreateJobWithCorrectParameters() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertNull(response.getSessionId()); + assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); + assertFalse( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertFalse( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + + // enable session + enableSession(true); + response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); + assertTrue( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertTrue( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + } + + @Test + public void withSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + } + + @Test + public void reuseSessionWhenCreateAsyncQuery() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + assertNotEquals(first.getQueryId(), second.getQueryId()); + // one session doc. + assertEquals( + 1, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + // two statement docs has same sessionId. + assertEquals( + 2, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + + Optional firstModel = + getStatement(stateStore, DATASOURCE).apply(first.getQueryId()); + assertTrue(firstModel.isPresent()); + assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); + assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); + assertEquals(first.getQueryId(), firstModel.get().getQueryId()); + Optional secondModel = + getStatement(stateStore, DATASOURCE).apply(second.getQueryId()); + assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); + assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); + assertEquals(second.getQueryId(), secondModel.get().getQueryId()); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "1234567890"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> {}); + } + + private AsyncQueryExecutorService createAsyncQueryExecutorService( + EMRServerlessClient emrServerlessClient) { + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader, + new FlintIndexMetadataReaderImpl(client), + client, + new SessionManager( + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + this::sparkExecutionEngineConfig); + } + + public static class LocalEMRSClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + private int getJobResult = 0; + + @Getter private StartJobRequest jobRequest; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + jobRequest = startJobRequest; + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + getJobResult++; + JobRun jobRun = new JobRun(); + jobRun.setState("RUNNING"); + return new GetJobRunResult().withJobRun(jobRun); + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return new CancelJobRunResult().withJobRunId(jobId); + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + + public void getJobRunResultCalled(int expectedTimes) { + assertEquals(expectedTimes, getJobResult); + } + } + + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { + return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); + } + + public void enableSession(boolean enabled) { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder() + .put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled) + .build()) + .get(); + } + + int search(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + + return searchResponse.getHits().getHits().length; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 58fe626dae..15211dec01 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -726,7 +726,7 @@ void testGetQueryResponseWithSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -740,7 +740,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -759,7 +759,7 @@ void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.empty()).when(session).get(any()); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 429c970365..06a8d8c73c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -8,10 +8,12 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -20,15 +22,17 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ -public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { +public class InteractiveSessionTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -38,20 +42,21 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() - .sessionId(SessionId.newSessionId()) + .sessionId(SessionId.newSessionId(DS_NAME)) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -59,7 +64,7 @@ public void openCloseSession() { // open session TestSession testSession = testSession(session, stateStore); testSession - .open(new CreateSessionRequest(startJobRequest, "datasource")) + .open(createSessionRequest()) .assertSessionState(NOT_STARTED) .assertAppId("appId") .assertJobId("jobId"); @@ -72,14 +77,14 @@ public void openCloseSession() { @Test public void openSessionFailedConflict() { - SessionId sessionId = new SessionId("duplicate-session-id"); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); InteractiveSession duplicateSession = InteractiveSession.builder() @@ -89,21 +94,20 @@ public void openSessionFailedConflict() { .build(); IllegalStateException exception = assertThrows( - IllegalStateException.class, - () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); - assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); + assertEquals("session already exist. " + sessionId, exception.getMessage()); } @Test public void closeNotExistSession() { - SessionId sessionId = SessionId.newSessionId(); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -116,7 +120,7 @@ public void closeNotExistSession() { public void sessionManagerCreateSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); TestSession testSession = testSession(session, stateStore); testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); @@ -126,8 +130,7 @@ public void sessionManagerCreateSession() { public void sessionManagerGetSession() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Session session = - sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); assertTrue(managerSession.isPresent()); @@ -139,7 +142,8 @@ public void sessionManagerGetSessionNotExist() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + Optional managerSession = + sessionManager.getSession(SessionId.newSessionId("no-exist")); assertTrue(managerSession.isEmpty()); } @@ -156,7 +160,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore).apply(session.getSessionModel().getId()); + getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -184,6 +188,17 @@ public TestSession close() { } } + public static CreateSessionRequest createSessionRequest() { + return new CreateSessionRequest( + "jobName", + "appId", + "arn", + SparkSubmitParameters.Builder.builder(), + ImmutableMap.of(), + "resultIndex", + DS_NAME); + } + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 214bcb8258..ff3ddd1bef 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,15 +5,16 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; -import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.junit.After; @@ -21,8 +22,6 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -30,27 +29,27 @@ import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; -public class StatementTest extends OpenSearchSingleNodeTestCase { +public class StatementTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); - private StartJobRequest startJobRequest; private StateStore stateStore; private InteractiveSessionTest.TestEMRServerlessClient emrsClient = new InteractiveSessionTest.TestEMRServerlessClient(); @Before public void setup() { - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test @@ -62,6 +61,7 @@ public void openThenCancelStatement() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -87,6 +87,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -101,6 +102,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -119,13 +121,14 @@ public void cancelNotExistStatement() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) .build(); st.open(); - client().delete(new DeleteRequest(indexName, stId.getId())); + client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( @@ -143,6 +146,7 @@ public void cancelFailedBecauseOfConflict() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -150,7 +154,7 @@ public void cancelFailedBecauseOfConflict() { st.open(); StatementModel running = - updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); assertEquals(StatementState.CANCELLED, running.getStatementState()); @@ -172,6 +176,7 @@ public void cancelRunningStatementFailed() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -198,10 +203,10 @@ public void cancelRunningStatementFailed() { public void submitStatementInRunningSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -211,7 +216,7 @@ public void submitStatementInRunningSession() { public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -221,9 +226,9 @@ public void submitStatementInNotStartedState() { public void failToSubmitStatementInDeadState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = assertThrows( @@ -239,9 +244,9 @@ public void failToSubmitStatementInDeadState() { public void failToSubmitStatementInFailState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = assertThrows( @@ -257,7 +262,7 @@ public void failToSubmitStatementInFailState() { public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -275,7 +280,7 @@ public void newStatementFieldAssert() { public void failToSubmitStatementInDeletedSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // other's delete session client() @@ -293,9 +298,9 @@ public void failToSubmitStatementInDeletedSession() { public void getStatementSuccess() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -308,9 +313,9 @@ public void getStatementSuccess() { public void getStatementNotExist() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); Optional statement = session.get(StatementId.newStatementId()); assertFalse(statement.isPresent()); @@ -328,7 +333,8 @@ public static TestStatement testStatement(Statement st, StateStore stateStore) { public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -338,7 +344,8 @@ public TestStatement assertSessionState(StatementState expected) { public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this;